1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #ifdef INTEL_MKL 18 #define EIGEN_USE_THREADS 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/util/mkl_util.h" 23 #include "tensorflow/core/util/padding.h" 24 25 #ifndef INTEL_MKL_ML_ONLY 26 #include <algorithm> 27 #include "mkldnn.hpp" 28 using mkldnn::algorithm; 29 using mkldnn::engine; 30 using mkldnn::error; 31 using mkldnn::memory; 32 using mkldnn::padding_kind; 33 using mkldnn::pooling_backward; 34 using mkldnn::pooling_forward; 35 using mkldnn::prop_kind; 36 #endif 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 42 // MKL-DNN is now default. MKL-ML must be specified explicitly. 43 #ifdef INTEL_MKL_ML_ONLY 44 45 // An implementation of MaxPooling (forward). 46 template <typename Device, typename T> 47 class MklMaxPoolingOp : public OpKernel { 48 public: 49 explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) { 50 string data_format; 51 52 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 53 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 54 errors::InvalidArgument("Invalid data format")); 55 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 56 OP_REQUIRES(context, ksize_.size() == 4, 57 errors::InvalidArgument("Sliding window ksize field must " 58 "specify 4 dimensions")); 59 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 60 OP_REQUIRES(context, stride_.size() == 4, 61 errors::InvalidArgument("Sliding window stride field must " 62 "specify 4 dimensions")); 63 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 64 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 65 errors::Unimplemented("Pooling is not yet supported on the " 66 "batch dimension.")); 67 68 workspace_enabled_ = false; 69 // We may not get this attribute for this node if it does not go through 70 // graph rewrite pass. So we do not check for error while retrieving this 71 // attribute value. 72 OP_REQUIRES_OK(context, 73 context->GetAttr("workspace_enabled", &workspace_enabled_)); 74 } 75 76 void Compute(OpKernelContext* context) override { 77 MklMaxPoolingOpContext mkl_context; 78 // Get the input tensor 79 const Tensor& tensor_in = MklGetInput(context, 0); 80 GetMklShape(context, 0, &mkl_context.input_shape); 81 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 82 83 mkl_context.params.in_dim = 4; 84 MklPoolParameters pool_params; 85 if (input_in_mkl_format == false) { 86 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 87 tensor_in.shape()); 88 OP_REQUIRES( 89 context, (pool_params.depth_window == 1), 90 errors::Unimplemented("Depthwise max pooling not supported by MKL")); 91 92 } else { 93 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 94 &mkl_context.input_shape); 95 } 96 97 // Extract the parameters for the op from the pooling specs 98 99 ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 100 101 mkl_context.MklCreateLayoutsAndPrimitives(context); 102 OP_REQUIRES_OK(context, context->status()); 103 104 // Declare output tensor 105 TensorShape tensor_out_shape; 106 MklShape mkl_out_shape, mkl_workspace_shape; 107 mkl_out_shape.SetMklTensor(true); 108 mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst); 109 mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, 110 mkl_context.params.out_sizes, 111 mkl_context.params.out_strides); 112 mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 113 114 Tensor* output_tensor = nullptr; 115 tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 116 mkl_out_shape.GetMklLayout())) / 117 sizeof(T)); 118 AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape, 119 mkl_out_shape); 120 121 Tensor* workspace_tensor; 122 void* workspace_buf = nullptr; 123 124 TensorShape workspace_shape; 125 mkl_workspace_shape.SetMklTensor(false); 126 workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 127 mkl_context.lt_workspace)) / 128 sizeof(T)); 129 AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape, 130 mkl_workspace_shape); 131 132 mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>( 133 static_cast<const void*>(workspace_tensor->flat<T>().data())); 134 mkl_context.pooling_res[dnnResourceSrc] = 135 const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data())); 136 mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>( 137 static_cast<const void*>(output_tensor->flat<T>().data())); 138 139 CHECK_EQ( 140 dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res), 141 E_SUCCESS); 142 143 mkl_context.MklCleanup(); 144 } 145 146 private: 147 typedef struct { 148 MklPoolingOpParams params; 149 MklShape input_shape; 150 void* pooling_res[dnnResourceNumber]; 151 dnnPrimitive_t prim_pooling_fwd = nullptr; 152 dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr; 153 154 void MklCreateLayoutsAndPrimitives(OpKernelContext* context) { 155 bool input_in_mkl_format = input_shape.IsMklTensor(); 156 // Create or use existing DNN user layout 157 if (input_in_mkl_format == false) { 158 CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, 159 params.in_sizes, params.in_strides), 160 E_SUCCESS); 161 } else { 162 lt_user_input = (dnnLayout_t)input_shape.GetCurLayout(); 163 } 164 165 dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; 166 dnnPrimitiveAttributes_t primAttr = nullptr; 167 168 // Create DNN primitives 169 CHECK_EQ(dnnPoolingCreateForward_F32( 170 &prim_pooling_fwd, primAttr, algorithm, lt_user_input, 171 params.kernel_size, params.kernel_stride, params.in_offset, 172 dnnBorderZerosAsymm), 173 E_SUCCESS); 174 175 // Creates layout for the workspace 176 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_fwd, 177 dnnResourceWorkspace), 178 E_SUCCESS); 179 } 180 181 void MklCleanup() { 182 bool input_in_mkl_format = input_shape.IsMklTensor(); 183 CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 184 if (!input_in_mkl_format) { 185 CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); 186 } 187 CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS); 188 } 189 } MklMaxPoolingOpContext; 190 191 std::vector<int32> ksize_; 192 std::vector<int32> stride_; 193 Padding padding_; 194 TensorFormat data_format_; 195 bool workspace_enabled_; 196 }; 197 198 // The operation to compute MaxPool gradients. 199 // It takes three inputs: 200 // - The original input tensor 201 // - The original output tensor 202 // - Backprop tensor for output 203 // It produces one output: backprop tensor for input. 204 template <class Device, class T> 205 class MklMaxPoolingGradOp : public OpKernel { 206 public: 207 explicit MklMaxPoolingGradOp(OpKernelConstruction* context) 208 : OpKernel(context) { 209 string data_format; 210 211 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 212 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 213 errors::InvalidArgument("Invalid data format")); 214 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 215 OP_REQUIRES(context, ksize_.size() == 4, 216 errors::InvalidArgument("Sliding window ksize field must " 217 "specify 4 dimensions")); 218 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 219 OP_REQUIRES(context, stride_.size() == 4, 220 errors::InvalidArgument("Sliding window strides field must " 221 "specify 4 dimensions")); 222 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 223 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 224 errors::Unimplemented( 225 "Pooling is not yet supported on the batch dimension.")); 226 workspace_enabled_ = false; 227 // We may not get this attribute for this node if it does not go through 228 // graph rewrite pass. So we do not check for error while retrieving this 229 // attribute value. 230 OP_REQUIRES_OK(context, 231 context->GetAttr("workspace_enabled", &workspace_enabled_)); 232 } 233 234 void Compute(OpKernelContext* context) override { 235 MklMaxPoolingGradOpContext mkl_context; 236 // Input - The original input tensor 237 const Tensor& tensor_in = MklGetInput(context, 0); 238 239 // Output - Backprop tensor for input. 240 Tensor* output_tensor = nullptr; 241 242 GetMklShape(context, 0, &mkl_context.input_shape); 243 GetMklShape(context, 2, &mkl_context.output_backprop_shape); 244 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 245 246 if (input_in_mkl_format == false) 247 mkl_context.params.in_dim = tensor_in.dims(); 248 else 249 mkl_context.params.in_dim = mkl_context.input_shape.GetDimension(); 250 251 MklPoolParameters pool_params; 252 if (input_in_mkl_format == false) { 253 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 254 tensor_in.shape()); 255 OP_REQUIRES( 256 context, (pool_params.depth_window == 1), 257 errors::Unimplemented("Depthwise max pooling not supported by MKL")); 258 259 } else { 260 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 261 &mkl_context.input_shape); 262 } 263 264 // Extract the parameters for the op from the pooling specs 265 ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 266 267 mkl_context.MklCreateLayouts(context); 268 OP_REQUIRES_OK(context, context->status()); 269 270 mkl_context.MklCreatePrimitives(context, workspace_enabled_); 271 OP_REQUIRES_OK(context, context->status()); 272 273 mkl_context.MklPrepareInputs(context, workspace_enabled_); 274 OP_REQUIRES_OK(context, context->status()); 275 276 // Create shape for the input back prop output 277 TensorShape mkl_input_backprop; 278 MklShape mkl_output_shape; 279 mkl_output_shape.SetMklTensor(true); 280 mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd, 281 dnnResourceDiffSrc); 282 mkl_output_shape.SetTfLayout(mkl_context.params.in_dim, 283 mkl_context.params.in_sizes, 284 mkl_context.params.in_strides); 285 mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 286 287 mkl_input_backprop.AddDim( 288 dnnLayoutGetMemorySize_F32( 289 static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) / 290 sizeof(T)); 291 AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop, 292 mkl_output_shape); 293 mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>( 294 static_cast<const void*>(output_tensor->flat<T>().data())); 295 296 CHECK_EQ( 297 dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), 298 E_SUCCESS); 299 300 mkl_context.MklCleanup(workspace_enabled_); 301 } 302 303 private: 304 typedef struct { 305 MklPoolingOpParams params; 306 MklShape input_shape, output_backprop_shape; 307 void* pooling_resfwd[dnnResourceNumber]; 308 void* pooling_res[dnnResourceNumber]; 309 dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr, 310 convert_input = nullptr, convert_outbackprop = nullptr; 311 dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr, 312 lt_input_user = nullptr, lt_input_prim = nullptr; 313 void* input_buf; 314 void* outbackprop_buf; 315 Tensor tmp_output_buf_tensor; 316 Tensor workspace_buf_tensor; 317 Tensor input_buf_tensor, outbackprop_buf_tensor; 318 319 void MklCreateLayouts(OpKernelContext* context) { 320 bool input_in_mkl_format = input_shape.IsMklTensor(); 321 bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 322 // Create DNN user layout for input and outbackprop or get existing layout 323 if (input_in_mkl_format == false) { 324 CHECK_EQ(dnnLayoutCreate_F32(<_input_user, params.in_dim, 325 params.in_sizes, params.in_strides), 326 E_SUCCESS); 327 } else { 328 lt_input_user = (dnnLayout_t)input_shape.GetCurLayout(); 329 } 330 331 // We don't care about the output layout for now as we can create it from 332 // primitives for the max pooling fwd prop 333 if (outbackprop_in_mkl_format == false) { 334 CHECK_EQ(dnnLayoutCreate_F32(<_outbackprop_user, params.in_dim, 335 params.out_sizes, params.out_strides), 336 E_SUCCESS); 337 } else { 338 lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout(); 339 } 340 } 341 342 // Create DNN primitives 343 void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) { 344 dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; 345 dnnPrimitiveAttributes_t primAttr = nullptr; 346 347 if (workspace_enabled == false) { 348 CHECK_EQ(dnnPoolingCreateForward_F32( 349 &prim_pooling_fwd, primAttr, algorithm, lt_input_user, 350 params.kernel_size, params.kernel_stride, params.in_offset, 351 dnnBorderZerosAsymm), 352 E_SUCCESS); 353 } 354 355 CHECK_EQ(dnnPoolingCreateBackward_F32( 356 &prim_pooling_bwd, primAttr, algorithm, lt_input_user, 357 params.kernel_size, params.kernel_stride, params.in_offset, 358 dnnBorderZerosAsymm), 359 E_SUCCESS); 360 361 // Creates conversions 362 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 363 <_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst), 364 E_SUCCESS); 365 366 if (workspace_enabled == false) { 367 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 368 <_input_prim, prim_pooling_fwd, dnnResourceSrc), 369 E_SUCCESS); 370 if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) { 371 CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user, 372 lt_input_prim), 373 E_SUCCESS); 374 AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf); 375 } 376 } 377 378 if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) { 379 CHECK_EQ( 380 dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user, 381 lt_outbackprop_prim), 382 E_SUCCESS); 383 AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim, 384 &outbackprop_buf); 385 } 386 } 387 388 // Compare incoming tensor layouts with MKL preferred layouts and convert 389 // data to the preferred layout if necessary 390 void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) { 391 const Tensor& tensor_in = MklGetInput(context, 0); 392 const Tensor& out_backprop = MklGetInput(context, 2); 393 bool input_in_mkl_format = input_shape.IsMklTensor(); 394 bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 395 396 void* tmp_output_buf = nullptr; 397 void* workspace_buf = nullptr; 398 399 if (workspace_enabled == false) { 400 if (convert_input != nullptr) { 401 if (input_in_mkl_format == false) { 402 CHECK_EQ(dnnConversionExecute_F32( 403 convert_input, 404 const_cast<void*>(static_cast<const void*>( 405 tensor_in.flat<T>().data())), 406 input_buf), 407 E_SUCCESS); 408 CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS); 409 convert_input = nullptr; 410 } else { 411 input_shape.GetConvertedFlatData( 412 lt_input_prim, 413 const_cast<void*>( 414 static_cast<const void*>(tensor_in.flat<T>().data())), 415 input_buf); 416 } 417 pooling_resfwd[dnnResourceSrc] = input_buf; 418 } else { 419 pooling_resfwd[dnnResourceSrc] = const_cast<void*>( 420 static_cast<const void*>(tensor_in.flat<T>().data())); 421 } 422 423 dnnLayout_t lt_workspace; 424 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 425 <_workspace, prim_pooling_fwd, dnnResourceWorkspace), 426 E_SUCCESS); 427 AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace, 428 &workspace_buf); 429 pooling_resfwd[dnnResourceWorkspace] = workspace_buf; 430 431 dnnLayoutDelete_F32(lt_workspace); 432 433 // We create the layout for max pooling fwd prop tmp output here 434 AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim, 435 &tmp_output_buf); 436 pooling_resfwd[dnnResourceDst] = tmp_output_buf; 437 438 CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS); 439 pooling_res[dnnResourceWorkspace] = 440 pooling_resfwd[dnnResourceWorkspace]; 441 } else { 442 const Tensor& workspace = MklGetInput(context, 3); 443 pooling_res[dnnResourceWorkspace] = const_cast<void*>( 444 static_cast<const void*>(workspace.flat<T>().data())); 445 } 446 447 // Out backprop conversions if needed 448 if (convert_outbackprop != nullptr) { 449 if (outbackprop_in_mkl_format == false) { 450 CHECK_EQ(dnnConversionExecute_F32( 451 convert_outbackprop, 452 const_cast<void*>(static_cast<const void*>( 453 out_backprop.flat<T>().data())), 454 outbackprop_buf), 455 E_SUCCESS); 456 CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS); 457 } else { 458 output_backprop_shape.GetConvertedFlatData( 459 lt_outbackprop_prim, 460 const_cast<void*>( 461 static_cast<const void*>(out_backprop.flat<T>().data())), 462 outbackprop_buf); 463 } 464 pooling_res[dnnResourceDiffDst] = outbackprop_buf; 465 } else { 466 pooling_res[dnnResourceDiffDst] = const_cast<void*>( 467 static_cast<const void*>(out_backprop.flat<T>().data())); 468 } 469 } 470 471 void MklCleanup(bool workspace_enabled) { 472 bool input_in_mkl_format = input_shape.IsMklTensor(); 473 bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 474 if (workspace_enabled == false) { 475 CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 476 } 477 CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS); 478 if (outbackprop_in_mkl_format == false) { 479 CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS); 480 } 481 CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS); 482 if (input_in_mkl_format == false) { 483 CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS); 484 } 485 if (workspace_enabled == false) { 486 CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS); 487 } 488 } 489 } MklMaxPoolingGradOpContext; 490 491 std::vector<int32> ksize_; 492 std::vector<int32> stride_; 493 Padding padding_; 494 TensorFormat data_format_; 495 496 bool workspace_enabled_; 497 }; // MklMaxPoolingGradOp 498 499 #else 500 501 // An implementation of MaxPooling (forward). 502 template <typename Device, typename T> 503 class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { 504 public: 505 explicit MklMaxPoolingOp(OpKernelConstruction* context) 506 : MklPoolingForwardOpBase<T>(context) { 507 // In Max Pooling, MKLDNN does not allow passing workspace as NULL. 508 // So we set workspace_enabled_ to true. 509 this->workspace_enabled_ = true; 510 } 511 512 void Compute(OpKernelContext* context) override { 513 try { 514 const Tensor& input_tensor = 515 MklGetInput(context, this->kInputTensorIndexInput); 516 MklDnnShape dnn_shape_input; 517 GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); 518 this->SanityCheckInput(context, input_tensor, dnn_shape_input); 519 if (!context->status().ok()) return; 520 521 MklDnnData<T> dnn_data_input(&cpu_engine); 522 MklDnnData<T> dnn_data_output(&cpu_engine); 523 524 // initialize variables for the pooling op 525 MklPoolParameters pool_params; 526 // check whether pooling is 2D or 3D 527 bool is_pool2d = (this->ksize_.size() == 4); 528 // Get the input tensor and initialize the pooling parameters 529 TensorShape input_tensor_shape = input_tensor.shape(); 530 this->InitMklPoolParameters(context, &pool_params, dnn_shape_input, 531 input_tensor_shape); 532 OP_REQUIRES_OK(context, context->status()); 533 534 // Declare output tensor 535 Tensor* output_tensor = nullptr; 536 memory::dims output_dims_mkl_order; 537 this->GetOutputDims(pool_params, &output_dims_mkl_order); 538 539 // If input is an empty tensor, allocate an empty output tensor and return 540 if (input_tensor.NumElements() == 0) { 541 const int kOutputIndex = 0; 542 this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params, 543 output_dims_mkl_order, &output_tensor); 544 return; 545 } 546 547 // Get the input memory descriptor 548 memory::desc input_md = 549 dnn_shape_input.IsMklTensor() 550 ? dnn_shape_input.GetMklLayout() 551 : is_pool2d ? memory::desc( 552 TFShapeToMklDnnDimsInNCHW( 553 input_tensor_shape, this->data_format_tf_), 554 MklDnnType<T>(), this->data_format_mkldnn_) 555 : memory::desc( 556 TFShapeToMklDnnDimsInNCDHW( 557 input_tensor_shape, this->data_format_tf_), 558 MklDnnType<T>(), this->data_format_mkldnn_); 559 560 // Get src/filter/stride/padding information 561 memory::dims src_dims = 562 dnn_shape_input.IsMklTensor() 563 ? dnn_shape_input.GetSizesAsMklDnnDims() 564 : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(), 565 this->data_format_tf_) 566 : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(), 567 this->data_format_tf_); 568 memory::dims filter_dims, strides, padding_left, padding_right; 569 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, 570 &padding_left, &padding_right, is_pool2d); 571 572 // Get a pooling op from the cached pool 573 MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr; 574 prop_kind pooling_prop_kind; 575 bool int8_forward_inference = 576 std::is_same<T, qint8>::value || std::is_same<T, quint8>::value; 577 if (int8_forward_inference) 578 pooling_prop_kind = prop_kind::forward_inference; 579 else 580 pooling_prop_kind = prop_kind::forward_training; 581 MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims, 582 strides, padding_left, padding_right, 583 algorithm::pooling_max, pooling_prop_kind); 584 pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams); 585 586 // allocate output tensor 587 this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), 588 output_dims_mkl_order, 589 this->data_format_mkldnn_, &output_tensor); 590 OP_REQUIRES_OK(context, context->status()); 591 dnn_data_output.SetUsrMem(output_dims_mkl_order, 592 pooling_fwd->GetDstMemoryFormat(), 593 output_tensor); 594 595 // check wehther we need to reorder src 596 const T* src_data = input_tensor.flat<T>().data(); 597 if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) { 598 dnn_data_input.SetUsrMem(input_md, &input_tensor); 599 auto src_target_primitive_desc = memory::primitive_desc( 600 {{src_dims}, MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()}, 601 cpu_engine); 602 dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc); 603 src_data = const_cast<T*>( 604 reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle())); 605 } 606 607 T* dst_data = output_tensor->flat<T>().data(); 608 609 if (int8_forward_inference) { 610 // Execute pooling op 611 pooling_fwd->Execute(src_data, dst_data); 612 613 // pass min, max from input to output 614 const Tensor& min_input_t = MklGetInput(context, 1); 615 const Tensor& max_input_t = MklGetInput(context, 2); 616 const float min_input = min_input_t.flat<float>()(0); 617 const float max_input = max_input_t.flat<float>()(0); 618 619 Tensor* output_min = nullptr; 620 Tensor* output_max = nullptr; 621 MklDnnShape output_min_mkl_shape, output_max_mkl_shape; 622 output_min_mkl_shape.SetMklTensor(false); 623 output_max_mkl_shape.SetMklTensor(false); 624 AllocateOutputSetMklShape(context, 1, &output_min, {}, 625 output_min_mkl_shape); 626 AllocateOutputSetMklShape(context, 2, &output_max, {}, 627 output_max_mkl_shape); 628 output_min->flat<float>()(0) = min_input; 629 output_max->flat<float>()(0) = max_input; 630 } else { 631 MklDnnData<uint8> dnn_data_wksp(&cpu_engine); 632 AllocateWorkspaceTensor(context, *(pooling_fwd->GetPoolingFwdPd()), 633 &dnn_data_wksp); 634 OP_REQUIRES_OK(context, context->status()); 635 T* ws_data = 636 static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle()); 637 638 // execute pooling op 639 pooling_fwd->Execute(src_data, dst_data, ws_data); 640 } 641 } catch (mkldnn::error& e) { 642 string error_msg = "Status: " + std::to_string(e.status) + 643 ", message: " + string(e.message) + ", in file " + 644 string(__FILE__) + ":" + std::to_string(__LINE__); 645 OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 646 error_msg)); 647 } 648 } 649 650 private: 651 const int kOutputTensorIndexWorkspace = 1; 652 engine cpu_engine = engine(engine::cpu, 0); 653 654 void AllocateWorkspaceTensor( 655 OpKernelContext* context, 656 const pooling_forward::primitive_desc& pool_fwd_prim_desc, 657 MklDnnData<uint8>* dnn_data_wksp) { 658 CHECK_NOTNULL(dnn_data_wksp); 659 Tensor* workspace_tensor = nullptr; 660 memory::primitive_desc workspace_pd = 661 pool_fwd_prim_desc.workspace_primitive_desc(); 662 size_t workspace_bytes = workspace_pd.get_size(); 663 MklDnnShape workspace_mkl_shape; 664 workspace_mkl_shape.SetMklTensor(false); 665 TensorShape workspace_tf_shape; 666 workspace_tf_shape.AddDim(workspace_bytes); 667 AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace, 668 &workspace_tensor, workspace_tf_shape, 669 workspace_mkl_shape); 670 CHECK_NOTNULL(workspace_tensor); 671 dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); 672 } 673 }; 674 675 // The operation to compute MaxPool gradients. 676 // It takes three inputs: 677 // - The original input tensor 678 // - The original output tensor 679 // - Backprop tensor for output 680 // It produces one output: backprop tensor for input. 681 template <class Device, class T> 682 class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { 683 public: 684 explicit MklMaxPoolingGradOp(OpKernelConstruction* context) 685 : MklPoolingBackwardOpBase<T>(context) {} 686 void Compute(OpKernelContext* context) override { 687 try { 688 auto cpu_engine = engine(engine::cpu, 0); 689 const Tensor& orig_input_tensor = 690 MklGetInput(context, kInputTensorIndexOrigInput); 691 const Tensor& grad_tensor = 692 MklGetInput(context, kInputTensorIndexGradient); 693 const Tensor& workspace_tensor = 694 MklGetInput(context, kInputTensorIndexWorkspace); 695 MklDnnShape orig_input_mkl_shape, grad_mkl_shape; 696 GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape); 697 GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape); 698 if (!context->status().ok()) return; 699 700 MklDnnData<T> grad_dnn_data(&cpu_engine); 701 MklDnnData<uint8> workspace_dnn_data(&cpu_engine); 702 703 MklPoolParameters pool_params; 704 TensorShape orig_input_shape = orig_input_tensor.shape(); 705 706 bool is_pool2d = (this->ksize_.size() == 4); 707 this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape, 708 orig_input_shape); 709 710 memory::dims filter_dims, strides, padding_left, padding_right; 711 this->PoolParamsToDims(&pool_params, &filter_dims, &strides, 712 &padding_left, &padding_right, is_pool2d); 713 714 memory::dims orig_input_dims_mkl_order = 715 orig_input_mkl_shape.IsMklTensor() 716 ? orig_input_mkl_shape.GetSizesAsMklDnnDims() 717 : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape, 718 this->data_format_tf_) 719 : TFShapeToMklDnnDimsInNCDHW(orig_input_shape, 720 this->data_format_tf_); 721 722 memory::dims diff_dst_dims = 723 grad_mkl_shape.IsMklTensor() 724 ? grad_mkl_shape.GetSizesAsMklDnnDims() 725 : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(), 726 this->data_format_tf_) 727 : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(), 728 this->data_format_tf_); 729 730 memory::dims output_dims_mkl_order; 731 this->GetOutputDims(pool_params, &output_dims_mkl_order); 732 733 MklPoolingParams bwdParams( 734 orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, 735 strides, padding_left, padding_right, algorithm::pooling_max, 736 prop_kind::forward_training); 737 MklPoolingBwdPrimitive<T>* pooling_bwd = 738 MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams); 739 740 // allocate output tensor and memory primitive 741 Tensor* output_tensor = nullptr; 742 this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()), 743 orig_input_dims_mkl_order, 744 this->data_format_mkldnn_, &output_tensor); 745 // get diff_dst mem desc 746 memory::desc diff_dst_md = 747 grad_mkl_shape.IsMklTensor() 748 ? grad_mkl_shape.GetMklLayout() 749 : memory::desc(diff_dst_dims, MklDnnType<T>(), 750 this->data_format_mkldnn_); 751 // check if diff_dst needs to be reordered 752 const T* diff_dst_data = grad_tensor.flat<T>().data(); 753 if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) { 754 auto target_diff_dst = memory::primitive_desc( 755 {{diff_dst_dims}, MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()}, 756 cpu_engine); 757 grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor); 758 grad_dnn_data.CheckReorderToOpMem(target_diff_dst); 759 diff_dst_data = const_cast<T*>( 760 reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle())); 761 } 762 763 void* ws_data = static_cast<void*>( 764 const_cast<uint8*>(workspace_tensor.flat<uint8>().data())); 765 766 auto ws_md = 767 pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc(); 768 if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) { 769 memory::dims ws_dims; 770 ws_dims.assign(ws_md.data.dims, ws_md.data.dims + ws_md.data.ndims); 771 auto target_ws = 772 memory::primitive_desc({{ws_dims}, 773 pooling_bwd->GetWorkspaceDataType(), 774 pooling_bwd->GetWorkspaceFormat()}, 775 cpu_engine); 776 workspace_dnn_data.SetUsrMem(ws_md, &workspace_tensor); 777 workspace_dnn_data.CheckReorderToOpMem(target_ws); 778 ws_data = workspace_dnn_data.GetOpMem().get_data_handle(); 779 } 780 781 T* diff_src_data = output_tensor->flat<T>().data(); 782 783 // execute pooling 784 pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data); 785 } catch (mkldnn::error& e) { 786 string error_msg = "Status:" + std::to_string(e.status) + 787 ", message: " + string(e.message) + ". in file " + 788 string(__FILE__) + ":" + std::to_string(__LINE__); 789 OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 790 error_msg)); 791 } 792 } 793 794 private: 795 // .Input("orig_input: T") 796 // .Input("orig_output: T") 797 // .Input("grad: T") 798 // .Input("workspace: T") 799 const int kInputTensorIndexOrigInput = 0; 800 const int kInputTensorIndexOrigOutput = 1; 801 const int kInputTensorIndexGradient = 2; 802 const int kInputTensorIndexWorkspace = 3; 803 804 void ConfigureWorkspace(const Tensor& workspace_tensor, 805 memory::primitive_desc workspace_pd, 806 MklDnnData<uint8>* workspace_dnn_data) { 807 CHECK_NOTNULL(workspace_dnn_data); 808 809 workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); 810 } 811 812 void SanityCheckInputs(OpKernelContext* context, 813 const Tensor& orig_input_tensor, 814 const Tensor& orig_output_tensor, 815 const Tensor& grad_tensor, 816 const Tensor& workspace_tensor, 817 const MklDnnShape& orig_input_mkl_shape, 818 const MklDnnShape& orig_output_mkl_shape, 819 const MklDnnShape& grad_mkl_shape, 820 const MklDnnShape& workspace_mkl_shape) { 821 if (!orig_input_mkl_shape.IsMklTensor()) { 822 OP_REQUIRES(context, orig_input_tensor.dims() == 4, 823 errors::InvalidArgument( 824 "Original input shape must be 4-dimensional")); 825 } else { 826 OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4, 827 errors::InvalidArgument( 828 "Original input shape must be 4-dimensional")); 829 } 830 if (!orig_output_mkl_shape.IsMklTensor()) { 831 OP_REQUIRES( 832 context, orig_output_tensor.dims() == 4, 833 errors::InvalidArgument("Original output must be 4-dimensional")); 834 } else { 835 OP_REQUIRES( 836 context, orig_output_mkl_shape.GetDimension() == 4, 837 errors::InvalidArgument("Original output must be 4-dimensional")); 838 } 839 if (!grad_mkl_shape.IsMklTensor()) { 840 OP_REQUIRES(context, grad_tensor.dims() == 4, 841 errors::InvalidArgument("Gradient must be 4-dimensional")); 842 } else { 843 OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4, 844 errors::InvalidArgument("Gradient must be 4-dimensional")); 845 } 846 if (this->workspace_enabled_) { 847 // The workspace should not be an MKL tensor 848 OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false, 849 errors::InvalidArgument( 850 "Workspace tensor should not be an MKL Tensor.")); 851 // It should only have one dimension 852 OP_REQUIRES( 853 context, workspace_tensor.dims() == 1, 854 errors::InvalidArgument("Workspace tensor must be 1-dimensional")); 855 } else { 856 OP_REQUIRES( 857 context, this->workspace_enabled_, 858 errors::Unimplemented("MKL-DNN Max Pooling does not " 859 "yet support the use case " 860 "where MaxPoolGrad is called without first" 861 " calling MaxPool.")); 862 } 863 } 864 }; // MklMaxPoolingGradOp 865 866 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D") 867 .Device(DEVICE_CPU) 868 .TypeConstraint<float>("T") 869 .Label(mkl_op_registry::kMklOpLabel), 870 MklMaxPoolingOp<CPUDevice, float>); 871 872 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad") 873 .Device(DEVICE_CPU) 874 .TypeConstraint<float>("T") 875 .Label(mkl_op_registry::kMklOpLabel), 876 MklMaxPoolingGradOp<CPUDevice, float>); 877 878 #endif // INTEL_MKL_ML_ONLY 879 880 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool") 881 .Device(DEVICE_CPU) 882 .TypeConstraint<float>("T") 883 .Label(mkl_op_registry::kMklOpLabel), 884 MklMaxPoolingOp<CPUDevice, float>); 885 886 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool") 887 .Device(DEVICE_CPU) 888 .TypeConstraint<quint8>("T") 889 .Label(mkl_op_registry::kMklQuantizedOpLabel), 890 MklMaxPoolingOp<CPUDevice, quint8>); 891 892 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool") 893 .Device(DEVICE_CPU) 894 .TypeConstraint<qint8>("T") 895 .Label(mkl_op_registry::kMklQuantizedOpLabel), 896 MklMaxPoolingOp<CPUDevice, qint8>); 897 898 REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad") 899 .Device(DEVICE_CPU) 900 .TypeConstraint<float>("T") 901 .Label(mkl_op_registry::kMklOpLabel), 902 MklMaxPoolingGradOp<CPUDevice, float>); 903 904 } // namespace tensorflow 905 #endif // INTEL_MKL 906