1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 #ifdef INTEL_MKL 18 #define EIGEN_USE_THREADS 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" 21 #include "tensorflow/core/lib/core/errors.h" 22 #include "tensorflow/core/util/mkl_util.h" 23 #include "tensorflow/core/util/padding.h" 24 25 #ifndef INTEL_MKL_ML 26 #include <algorithm> 27 #include "mkldnn.hpp" 28 using mkldnn::algorithm; 29 using mkldnn::engine; 30 using mkldnn::error; 31 using mkldnn::memory; 32 using mkldnn::padding_kind; 33 using mkldnn::pooling_backward; 34 using mkldnn::pooling_forward; 35 using mkldnn::prop_kind; 36 #endif 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 42 // MKL-DNN is now default. MKL-ML must be specified explicitly. 43 #ifdef INTEL_MKL_ML 44 45 // An implementation of MaxPooling (forward). 46 template <typename Device, typename T> 47 class MklMaxPoolingOp : public OpKernel { 48 public: 49 explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) { 50 string data_format; 51 52 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 53 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 54 errors::InvalidArgument("Invalid data format")); 55 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 56 OP_REQUIRES(context, ksize_.size() == 4, 57 errors::InvalidArgument("Sliding window ksize field must " 58 "specify 4 dimensions")); 59 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 60 OP_REQUIRES(context, stride_.size() == 4, 61 errors::InvalidArgument("Sliding window stride field must " 62 "specify 4 dimensions")); 63 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 64 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 65 errors::Unimplemented("Pooling is not yet supported on the " 66 "batch dimension.")); 67 68 workspace_enabled_ = false; 69 // We may not get this attribute for this node if it does not go through 70 // graph rewrite pass. So we do not check for error while retrieving this 71 // attribute value. 72 context->GetAttr("workspace_enabled", &workspace_enabled_); 73 } 74 75 void Compute(OpKernelContext* context) override { 76 MklMaxPoolingOpContext mkl_context; 77 // Get the input tensor 78 const Tensor& tensor_in = MklGetInput(context, 0); 79 GetMklShape(context, 0, &mkl_context.input_shape); 80 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 81 82 mkl_context.params.in_dim = 4; 83 MklPoolParameters pool_params; 84 if (input_in_mkl_format == false) { 85 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 86 tensor_in.shape()); 87 OP_REQUIRES( 88 context, (pool_params.depth_window == 1), 89 errors::Unimplemented("Depthwise max pooling not supported by MKL")); 90 91 } else { 92 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 93 &mkl_context.input_shape); 94 } 95 96 // Extract the parameters for the op from the pooling specs 97 98 ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 99 100 mkl_context.MklCreateLayoutsAndPrimitives(context); 101 OP_REQUIRES_OK(context, context->status()); 102 103 // Declare output tensor 104 TensorShape tensor_out_shape; 105 MklShape mkl_out_shape, mkl_workspace_shape; 106 mkl_out_shape.SetMklTensor(true); 107 mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst); 108 mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, 109 mkl_context.params.out_sizes, 110 mkl_context.params.out_strides); 111 mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 112 113 Tensor* output_tensor = nullptr; 114 tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 115 mkl_out_shape.GetMklLayout())) / 116 sizeof(T)); 117 AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape, 118 mkl_out_shape); 119 120 Tensor* workspace_tensor; 121 void* workspace_buf = nullptr; 122 123 TensorShape workspace_shape; 124 mkl_workspace_shape.SetMklTensor(false); 125 workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 126 mkl_context.lt_workspace)) / 127 sizeof(T)); 128 AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape, 129 mkl_workspace_shape); 130 131 mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>( 132 static_cast<const void*>(workspace_tensor->flat<T>().data())); 133 mkl_context.pooling_res[dnnResourceSrc] = 134 const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data())); 135 mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>( 136 static_cast<const void*>(output_tensor->flat<T>().data())); 137 138 CHECK_EQ( 139 dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res), 140 E_SUCCESS); 141 142 mkl_context.MklCleanup(); 143 } 144 145 private: 146 typedef struct { 147 MklPoolingOpParams params; 148 MklShape input_shape; 149 void* pooling_res[dnnResourceNumber]; 150 dnnPrimitive_t prim_pooling_fwd = nullptr; 151 dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr; 152 153 void MklCreateLayoutsAndPrimitives(OpKernelContext* context) { 154 bool input_in_mkl_format = input_shape.IsMklTensor(); 155 // Create or use existing DNN user layout 156 if (input_in_mkl_format == false) { 157 CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, 158 params.in_sizes, params.in_strides), 159 E_SUCCESS); 160 } else { 161 lt_user_input = (dnnLayout_t)input_shape.GetCurLayout(); 162 } 163 164 dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; 165 dnnPrimitiveAttributes_t primAttr = nullptr; 166 167 // Create DNN primitives 168 CHECK_EQ(dnnPoolingCreateForward_F32( 169 &prim_pooling_fwd, primAttr, algorithm, lt_user_input, 170 params.kernel_size, params.kernel_stride, params.in_offset, 171 dnnBorderZerosAsymm), 172 E_SUCCESS); 173 174 // Creates layout for the workspace 175 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_fwd, 176 dnnResourceWorkspace), 177 E_SUCCESS); 178 } 179 180 void MklCleanup() { 181 bool input_in_mkl_format = input_shape.IsMklTensor(); 182 CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 183 if (!input_in_mkl_format) { 184 CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); 185 } 186 CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS); 187 } 188 } MklMaxPoolingOpContext; 189 190 std::vector<int32> ksize_; 191 std::vector<int32> stride_; 192 Padding padding_; 193 TensorFormat data_format_; 194 bool workspace_enabled_; 195 }; 196 197 // The operation to compute MaxPool gradients. 198 // It takes three inputs: 199 // - The original input tensor 200 // - The original output tensor 201 // - Backprop tensor for output 202 // It produces one output: backprop tensor for input. 203 template <class Device, class T> 204 class MklMaxPoolingGradOp : public OpKernel { 205 public: 206 explicit MklMaxPoolingGradOp(OpKernelConstruction* context) 207 : OpKernel(context) { 208 string data_format; 209 210 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 211 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 212 errors::InvalidArgument("Invalid data format")); 213 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 214 OP_REQUIRES(context, ksize_.size() == 4, 215 errors::InvalidArgument("Sliding window ksize field must " 216 "specify 4 dimensions")); 217 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 218 OP_REQUIRES(context, stride_.size() == 4, 219 errors::InvalidArgument("Sliding window strides field must " 220 "specify 4 dimensions")); 221 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 222 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 223 errors::Unimplemented( 224 "Pooling is not yet supported on the batch dimension.")); 225 workspace_enabled_ = false; 226 // We may not get this attribute for this node if it does not go through 227 // graph rewrite pass. So we do not check for error while retrieving this 228 // attribute value. 229 context->GetAttr("workspace_enabled", &workspace_enabled_); 230 } 231 232 void Compute(OpKernelContext* context) override { 233 MklMaxPoolingGradOpContext mkl_context; 234 // Input - The original input tensor 235 const Tensor& tensor_in = MklGetInput(context, 0); 236 237 // Output - Backprop tensor for input. 238 Tensor* output_tensor = nullptr; 239 240 GetMklShape(context, 0, &mkl_context.input_shape); 241 GetMklShape(context, 2, &mkl_context.output_backprop_shape); 242 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 243 244 if (input_in_mkl_format == false) 245 mkl_context.params.in_dim = tensor_in.dims(); 246 else 247 mkl_context.params.in_dim = mkl_context.input_shape.GetDimension(); 248 249 MklPoolParameters pool_params; 250 if (input_in_mkl_format == false) { 251 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 252 tensor_in.shape()); 253 OP_REQUIRES( 254 context, (pool_params.depth_window == 1), 255 errors::Unimplemented("Depthwise max pooling not supported by MKL")); 256 257 } else { 258 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 259 &mkl_context.input_shape); 260 } 261 262 // Extract the parameters for the op from the pooling specs 263 ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 264 265 mkl_context.MklCreateLayouts(context); 266 OP_REQUIRES_OK(context, context->status()); 267 268 mkl_context.MklCreatePrimitives(context, workspace_enabled_); 269 OP_REQUIRES_OK(context, context->status()); 270 271 mkl_context.MklPrepareInputs(context, workspace_enabled_); 272 OP_REQUIRES_OK(context, context->status()); 273 274 // Create shape for the input back prop output 275 TensorShape mkl_input_backprop; 276 MklShape mkl_output_shape; 277 mkl_output_shape.SetMklTensor(true); 278 mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd, 279 dnnResourceDiffSrc); 280 mkl_output_shape.SetTfLayout(mkl_context.params.in_dim, 281 mkl_context.params.in_sizes, 282 mkl_context.params.in_strides); 283 mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 284 285 mkl_input_backprop.AddDim( 286 dnnLayoutGetMemorySize_F32( 287 static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) / 288 sizeof(T)); 289 AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop, 290 mkl_output_shape); 291 mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>( 292 static_cast<const void*>(output_tensor->flat<T>().data())); 293 294 CHECK_EQ( 295 dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), 296 E_SUCCESS); 297 298 mkl_context.MklCleanup(workspace_enabled_); 299 } 300 301 private: 302 typedef struct { 303 MklPoolingOpParams params; 304 MklShape input_shape, output_backprop_shape; 305 void* pooling_resfwd[dnnResourceNumber]; 306 void* pooling_res[dnnResourceNumber]; 307 dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr, 308 convert_input = nullptr, convert_outbackprop = nullptr; 309 dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr, 310 lt_input_user = nullptr, lt_input_prim = nullptr; 311 void* input_buf; 312 void* outbackprop_buf; 313 Tensor tmp_output_buf_tensor; 314 Tensor workspace_buf_tensor; 315 Tensor input_buf_tensor, outbackprop_buf_tensor; 316 317 void MklCreateLayouts(OpKernelContext* context) { 318 bool input_in_mkl_format = input_shape.IsMklTensor(); 319 bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 320 // Create DNN user layout for input and outbackprop or get existing layout 321 if (input_in_mkl_format == false) { 322 CHECK_EQ(dnnLayoutCreate_F32(<_input_user, params.in_dim, 323 params.in_sizes, params.in_strides), 324 E_SUCCESS); 325 } else { 326 lt_input_user = (dnnLayout_t)input_shape.GetCurLayout(); 327 } 328 329 // We don't care about the output layout for now as we can create it from 330 // primitives for the max pooling fwd prop 331 if (outbackprop_in_mkl_format == false) { 332 CHECK_EQ(dnnLayoutCreate_F32(<_outbackprop_user, params.in_dim, 333 params.out_sizes, params.out_strides), 334 E_SUCCESS); 335 } else { 336 lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout(); 337 } 338 } 339 340 // Create DNN primitives 341 void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) { 342 dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax; 343 dnnPrimitiveAttributes_t primAttr = nullptr; 344 345 if (workspace_enabled == false) { 346 CHECK_EQ(dnnPoolingCreateForward_F32( 347 &prim_pooling_fwd, primAttr, algorithm, lt_input_user, 348 params.kernel_size, params.kernel_stride, params.in_offset, 349 dnnBorderZerosAsymm), 350 E_SUCCESS); 351 } 352 353 CHECK_EQ(dnnPoolingCreateBackward_F32( 354 &prim_pooling_bwd, primAttr, algorithm, lt_input_user, 355 params.kernel_size, params.kernel_stride, params.in_offset, 356 dnnBorderZerosAsymm), 357 E_SUCCESS); 358 359 // Creates conversions 360 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 361 <_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst), 362 E_SUCCESS); 363 364 if (workspace_enabled == false) { 365 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 366 <_input_prim, prim_pooling_fwd, dnnResourceSrc), 367 E_SUCCESS); 368 if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) { 369 CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user, 370 lt_input_prim), 371 E_SUCCESS); 372 AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf); 373 } 374 } 375 376 if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) { 377 CHECK_EQ( 378 dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user, 379 lt_outbackprop_prim), 380 E_SUCCESS); 381 AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim, 382 &outbackprop_buf); 383 } 384 } 385 386 // Compare incoming tensor layouts with MKL preferred layouts and convert 387 // data to the preferred layout if necessary 388 void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) { 389 const Tensor& tensor_in = MklGetInput(context, 0); 390 const Tensor& out_backprop = MklGetInput(context, 2); 391 bool input_in_mkl_format = input_shape.IsMklTensor(); 392 bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 393 394 void* tmp_output_buf = nullptr; 395 void* workspace_buf = nullptr; 396 397 if (workspace_enabled == false) { 398 if (convert_input != nullptr) { 399 if (input_in_mkl_format == false) { 400 CHECK_EQ(dnnConversionExecute_F32( 401 convert_input, 402 const_cast<void*>(static_cast<const void*>( 403 tensor_in.flat<T>().data())), 404 input_buf), 405 E_SUCCESS); 406 CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS); 407 convert_input = nullptr; 408 } else { 409 input_shape.GetConvertedFlatData( 410 lt_input_prim, 411 const_cast<void*>( 412 static_cast<const void*>(tensor_in.flat<T>().data())), 413 input_buf); 414 } 415 pooling_resfwd[dnnResourceSrc] = input_buf; 416 } else { 417 pooling_resfwd[dnnResourceSrc] = const_cast<void*>( 418 static_cast<const void*>(tensor_in.flat<T>().data())); 419 } 420 421 dnnLayout_t lt_workspace; 422 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 423 <_workspace, prim_pooling_fwd, dnnResourceWorkspace), 424 E_SUCCESS); 425 AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace, 426 &workspace_buf); 427 pooling_resfwd[dnnResourceWorkspace] = workspace_buf; 428 429 dnnLayoutDelete_F32(lt_workspace); 430 431 // We create the layout for max pooling fwd prop tmp output here 432 AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim, 433 &tmp_output_buf); 434 pooling_resfwd[dnnResourceDst] = tmp_output_buf; 435 436 CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS); 437 pooling_res[dnnResourceWorkspace] = 438 pooling_resfwd[dnnResourceWorkspace]; 439 } else { 440 const Tensor& workspace = MklGetInput(context, 3); 441 pooling_res[dnnResourceWorkspace] = const_cast<void*>( 442 static_cast<const void*>(workspace.flat<T>().data())); 443 } 444 445 // Out backprop conversions if needed 446 if (convert_outbackprop != nullptr) { 447 if (outbackprop_in_mkl_format == false) { 448 CHECK_EQ(dnnConversionExecute_F32( 449 convert_outbackprop, 450 const_cast<void*>(static_cast<const void*>( 451 out_backprop.flat<T>().data())), 452 outbackprop_buf), 453 E_SUCCESS); 454 CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS); 455 } else { 456 output_backprop_shape.GetConvertedFlatData( 457 lt_outbackprop_prim, 458 const_cast<void*>( 459 static_cast<const void*>(out_backprop.flat<T>().data())), 460 outbackprop_buf); 461 } 462 pooling_res[dnnResourceDiffDst] = outbackprop_buf; 463 } else { 464 pooling_res[dnnResourceDiffDst] = const_cast<void*>( 465 static_cast<const void*>(out_backprop.flat<T>().data())); 466 } 467 } 468 469 void MklCleanup(bool workspace_enabled) { 470 bool input_in_mkl_format = input_shape.IsMklTensor(); 471 bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor(); 472 if (workspace_enabled == false) { 473 CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 474 } 475 CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS); 476 if (outbackprop_in_mkl_format == false) { 477 CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS); 478 } 479 CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS); 480 if (input_in_mkl_format == false) { 481 CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS); 482 } 483 if (workspace_enabled == false) { 484 CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS); 485 } 486 } 487 } MklMaxPoolingGradOpContext; 488 489 std::vector<int32> ksize_; 490 std::vector<int32> stride_; 491 Padding padding_; 492 TensorFormat data_format_; 493 494 bool workspace_enabled_; 495 }; // MklMaxPoolingGradOp 496 497 #else 498 499 // An implementation of MaxPooling (forward). 500 template <typename Device, typename T> 501 class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> { 502 public: 503 explicit MklMaxPoolingOp(OpKernelConstruction* context) 504 : MklPoolingForwardOpBase<T>(context) { 505 // In Max Pooling, MKLDNN does not allow passing workspace as NULL. 506 // So we set workspace_enabled_ to true. 507 this->workspace_enabled_ = true; 508 } 509 510 void Compute(OpKernelContext* context) override { 511 try { 512 auto cpu_engine = engine(engine::cpu, 0); 513 const Tensor& input_tensor = 514 MklGetInput(context, this->kInputTensorIndexInput); 515 MklDnnShape dnn_shape_input; 516 GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); 517 this->SanityCheckInput(context, input_tensor, dnn_shape_input); 518 if (!context->status().ok()) return; 519 520 MklDnnData<T> dnn_data_input(&cpu_engine); 521 MklDnnData<T> dnn_data_output(&cpu_engine); 522 MklDnnData<uint8> dnn_data_wksp(&cpu_engine); 523 524 // initialize variables for the pooling op 525 MklPoolParameters pool_params; 526 // Get the input tensor and initialize the pooling parameters 527 this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, 528 &dnn_data_input); 529 OP_REQUIRES_OK(context, context->status()); 530 531 // Declare output tensor 532 Tensor* output_tensor = nullptr; 533 memory::dims output_dims_mkl_order; 534 this->GetOutputDims(pool_params, &output_dims_mkl_order); 535 536 // If input is in Mkl layout, then just get the memory format from it 537 // directly, instead of using input data_format to MaxPool. 538 if (dnn_shape_input.IsMklTensor()) { 539 dnn_data_output.SetUsrMem( 540 output_dims_mkl_order, 541 static_cast<memory::format>( 542 dnn_data_input.GetUsrMemDesc().data.format)); 543 } else { 544 dnn_data_output.SetUsrMem(output_dims_mkl_order, 545 this->data_format_mkldnn_); 546 } 547 548 // describe the memory layout; let mkl-dnn choose the best for the op 549 dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); 550 551 auto pool_desc = pooling_forward::desc( 552 prop_kind::forward, algorithm::pooling_max, 553 dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), 554 memory::dims({pool_params.row_stride, pool_params.col_stride}), 555 memory::dims({pool_params.window_rows, pool_params.window_cols}), 556 memory::dims({static_cast<int>(pool_params.pad_top), 557 static_cast<int>(pool_params.pad_left)}), 558 memory::dims({static_cast<int>(pool_params.pad_bottom), 559 static_cast<int>(pool_params.pad_right)}), 560 TFPaddingToMklDnnPadding(this->padding_)); 561 auto pool_fwd_desc = 562 pooling_forward::primitive_desc(pool_desc, cpu_engine); 563 564 this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order, 565 this->data_format_mkldnn_, &output_tensor); 566 OP_REQUIRES_OK(context, context->status()); 567 dnn_data_output.SetUsrMemDataHandle(output_tensor); 568 569 AllocateWorkspaceTensor(context, pool_fwd_desc, &dnn_data_wksp); 570 OP_REQUIRES_OK(context, context->status()); 571 572 this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input, 573 &dnn_data_output, &dnn_data_wksp); 574 } catch (mkldnn::error& e) { 575 string error_msg = "Status: " + std::to_string(e.status) + 576 ", message: " + string(e.message) + ", in file " + 577 string(__FILE__) + ":" + std::to_string(__LINE__); 578 OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 579 error_msg)); 580 } 581 } // Compute 582 583 private: 584 const int kOutputTensorIndexWorkspace = 1; 585 586 void AllocateWorkspaceTensor( 587 OpKernelContext* context, 588 const pooling_forward::primitive_desc& pool_fwd_prim_desc, 589 MklDnnData<uint8>* dnn_data_wksp) { 590 CHECK_NOTNULL(dnn_data_wksp); 591 Tensor* workspace_tensor = nullptr; 592 memory::primitive_desc workspace_pd = 593 pool_fwd_prim_desc.workspace_primitive_desc(); 594 size_t workspace_bytes = workspace_pd.get_size(); 595 MklDnnShape workspace_mkl_shape; 596 workspace_mkl_shape.SetMklTensor(false); 597 TensorShape workspace_tf_shape; 598 workspace_tf_shape.AddDim(workspace_bytes); 599 AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace, 600 &workspace_tensor, workspace_tf_shape, 601 workspace_mkl_shape); 602 CHECK_NOTNULL(workspace_tensor); 603 dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); 604 } 605 }; 606 607 // The operation to compute MaxPool gradients. 608 // It takes three inputs: 609 // - The original input tensor 610 // - The original output tensor 611 // - Backprop tensor for output 612 // It produces one output: backprop tensor for input. 613 template <class Device, class T> 614 class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> { 615 public: 616 explicit MklMaxPoolingGradOp(OpKernelConstruction* context) 617 : MklPoolingBackwardOpBase<T>(context) {} 618 619 void Compute(OpKernelContext* context) override { 620 try { 621 auto cpu_engine = engine(engine::cpu, 0); 622 const Tensor& orig_input_tensor = 623 MklGetInput(context, kInputTensorIndexOrigInput); 624 const Tensor& orig_output_tensor = 625 MklGetInput(context, kInputTensorIndexOrigOutput); 626 const Tensor& grad_tensor = 627 MklGetInput(context, kInputTensorIndexGradient); 628 const Tensor& workspace_tensor = 629 MklGetInput(context, kInputTensorIndexWorkspace); 630 MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape, 631 workspace_mkl_shape; 632 GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape); 633 GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape); 634 GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape); 635 GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape); 636 637 SanityCheckInputs(context, orig_input_tensor, orig_output_tensor, 638 grad_tensor, workspace_tensor, orig_input_mkl_shape, 639 orig_output_mkl_shape, grad_mkl_shape, 640 workspace_mkl_shape); 641 if (!context->status().ok()) return; 642 643 MklDnnData<T> grad_dnn_data(&cpu_engine); 644 MklDnnData<uint8> workspace_dnn_data(&cpu_engine); 645 MklDnnData<T> output_dnn_data(&cpu_engine); 646 Tensor* output_tensor = nullptr; 647 MklPoolParameters pool_params; 648 TensorShape orig_input_shape; 649 memory::dims output_dims_mkl_order, orig_input_dims_mkl_order; 650 memory::desc original_input_md = ConfigureOriginalInput( 651 context, orig_input_tensor, orig_input_mkl_shape, 652 &orig_input_dims_mkl_order, &pool_params, &orig_input_shape); 653 654 memory::desc original_output_md = this->ConfigureOriginalOutput( 655 pool_params, orig_output_mkl_shape, output_dims_mkl_order); 656 657 memory::desc target_diff_dst_md = this->ConfigureInputGradient( 658 grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md); 659 660 output_dnn_data.SetUsrMem(original_input_md); 661 662 // Create the forward pooling primitive descriptor so we can 663 // pass it as a hint to the backward pooling primitive descriptor 664 auto pool_fwd_desc = pooling_forward::desc( 665 prop_kind::forward, algorithm::pooling_max, original_input_md, 666 original_output_md, 667 memory::dims({pool_params.row_stride, pool_params.col_stride}), 668 memory::dims({pool_params.window_rows, pool_params.window_cols}), 669 memory::dims({static_cast<int>(pool_params.pad_top), 670 static_cast<int>(pool_params.pad_left)}), 671 memory::dims({static_cast<int>(pool_params.pad_bottom), 672 static_cast<int>(pool_params.pad_right)}), 673 TFPaddingToMklDnnPadding(this->padding_)); 674 auto pool_fwd_prim_desc = 675 pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); 676 677 auto pool_bkwd_desc = pooling_backward::desc( 678 algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(), 679 target_diff_dst_md, 680 memory::dims({pool_params.row_stride, pool_params.col_stride}), 681 memory::dims({pool_params.window_rows, pool_params.window_cols}), 682 memory::dims({static_cast<int>(pool_params.pad_top), 683 static_cast<int>(pool_params.pad_left)}), 684 memory::dims({static_cast<int>(pool_params.pad_bottom), 685 static_cast<int>(pool_params.pad_right)}), 686 TFPaddingToMklDnnPadding(this->padding_)); 687 auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( 688 pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); 689 690 this->AllocateOutputTensor(context, pool_bkwd_prim_desc, 691 orig_input_dims_mkl_order, 692 this->data_format_mkldnn_, &output_tensor); 693 output_dnn_data.SetUsrMemDataHandle(output_tensor); 694 695 ConfigureWorkspace(workspace_tensor, 696 pool_fwd_prim_desc.workspace_primitive_desc(), 697 &workspace_dnn_data); 698 this->PrepareAndExecuteNet( 699 pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data, 700 memory::primitive_desc(target_diff_dst_md, cpu_engine), 701 &workspace_dnn_data); 702 } catch (mkldnn::error& e) { 703 string error_msg = "Status: " + std::to_string(e.status) + 704 ", message: " + string(e.message) + ", in file " + 705 string(__FILE__) + ":" + std::to_string(__LINE__); 706 OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 707 error_msg)); 708 } 709 } // Compute 710 711 private: 712 // .Input("orig_input: T") 713 // .Input("orig_output: T") 714 // .Input("grad: T") 715 // .Input("workspace: T") 716 const int kInputTensorIndexOrigInput = 0; 717 const int kInputTensorIndexOrigOutput = 1; 718 const int kInputTensorIndexGradient = 2; 719 const int kInputTensorIndexWorkspace = 3; 720 // Output("output: T") in Base Class 721 722 memory::desc ConfigureOriginalInput( 723 OpKernelContext* context, const Tensor& tensor_original_input, 724 const MklDnnShape& original_input_mkl_shape, 725 memory::dims* original_input_dims_mkl_order, 726 MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { 727 *input_tensor_shape = tensor_original_input.shape(); 728 return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput( 729 context, tensor_original_input, original_input_mkl_shape, 730 original_input_dims_mkl_order, pool_params, *input_tensor_shape); 731 } 732 733 void ConfigureWorkspace(const Tensor& workspace_tensor, 734 memory::primitive_desc workspace_pd, 735 MklDnnData<uint8>* workspace_dnn_data) { 736 CHECK_NOTNULL(workspace_dnn_data); 737 738 workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); 739 } 740 741 void SanityCheckInputs(OpKernelContext* context, 742 const Tensor& orig_input_tensor, 743 const Tensor& orig_output_tensor, 744 const Tensor& grad_tensor, 745 const Tensor& workspace_tensor, 746 const MklDnnShape& orig_input_mkl_shape, 747 const MklDnnShape& orig_output_mkl_shape, 748 const MklDnnShape& grad_mkl_shape, 749 const MklDnnShape& workspace_mkl_shape) { 750 if (!orig_input_mkl_shape.IsMklTensor()) { 751 OP_REQUIRES(context, orig_input_tensor.dims() == 4, 752 errors::InvalidArgument("Original input shape must be " 753 "4-dimensional")); 754 } else { 755 OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4, 756 errors::InvalidArgument("Original input shape must be " 757 "4-dimensional")); 758 } 759 if (!orig_output_mkl_shape.IsMklTensor()) { 760 OP_REQUIRES(context, orig_output_tensor.dims() == 4, 761 errors::InvalidArgument("Original output must be " 762 "4-dimensional")); 763 } else { 764 OP_REQUIRES(context, orig_output_mkl_shape.GetDimension() == 4, 765 errors::InvalidArgument("Original output must be " 766 "4-dimensional")); 767 } 768 if (!grad_mkl_shape.IsMklTensor()) { 769 OP_REQUIRES(context, grad_tensor.dims() == 4, 770 errors::InvalidArgument("Gradient must be 4-dimensional")); 771 } else { 772 OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4, 773 errors::InvalidArgument("Gradient must be " 774 "4-dimensional")); 775 } 776 if (this->workspace_enabled_) { 777 // The workspace should not be an MKL tensor 778 OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false, 779 errors::InvalidArgument("Workspace tensor should not" 780 " be an MKL Tensor.")); 781 // It should only have one dimension 782 OP_REQUIRES(context, workspace_tensor.dims() == 1, 783 errors::InvalidArgument("Workspace tensor must be " 784 "1-dimensional")); 785 } else { 786 OP_REQUIRES( 787 context, this->workspace_enabled_, 788 errors::Unimplemented("MKL-DNN Max Pooling does not " 789 "yet support the use case " 790 "where MaxPoolGrad is called without first" 791 " calling MaxPool.")); 792 } 793 } 794 }; // MklMaxPoolingGradOp 795 796 #endif // INTEL_MKL_ML 797 798 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool") 799 .Device(DEVICE_CPU) 800 .TypeConstraint<float>("T") 801 .Label(mkl_op_registry::kMklOpLabel), 802 MklMaxPoolingOp<CPUDevice, float>); 803 804 REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad") 805 .Device(DEVICE_CPU) 806 .TypeConstraint<float>("T") 807 .Label(mkl_op_registry::kMklOpLabel), 808 MklMaxPoolingGradOp<CPUDevice, float>); 809 810 } // namespace tensorflow 811 #endif // INTEL_MKL 812