1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 #define EIGEN_USE_THREADS 18 19 #include "tensorflow/core/common_runtime/device.h" 20 #include "tensorflow/core/framework/common_shape_fns.h" 21 #include "tensorflow/core/framework/numeric_op.h" 22 #include "tensorflow/core/framework/register_types.h" 23 #include "tensorflow/core/util/mkl_util.h" 24 25 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h" 26 27 #ifndef INTEL_MKL_ML 28 #include "mkldnn.hpp" 29 using mkldnn::algorithm; 30 using mkldnn::engine; 31 using mkldnn::error; 32 using mkldnn::memory; 33 using mkldnn::padding_kind; 34 using mkldnn::pooling_backward; 35 using mkldnn::pooling_forward; 36 using mkldnn::prop_kind; 37 #endif 38 39 namespace tensorflow { 40 41 typedef Eigen::ThreadPoolDevice CPUDevice; 42 43 #ifdef INTEL_MKL_ML 44 45 template <typename Device, typename T> 46 class MklAvgPoolingOp : public OpKernel { 47 public: 48 explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) { 49 string data_format; 50 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 51 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 52 errors::InvalidArgument("Invalid data format")); 53 54 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 55 OP_REQUIRES(context, ksize_.size() == 4, 56 errors::InvalidArgument("Sliding window ksize field must " 57 "specify 4 dimensions")); 58 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 59 OP_REQUIRES(context, stride_.size() == 4, 60 errors::InvalidArgument("Sliding window stride field must " 61 "specify 4 dimensions")); 62 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 63 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 64 errors::Unimplemented("Pooling is not yet supported on the " 65 "batch dimension.")); 66 } 67 68 void Compute(OpKernelContext* context) override { 69 MklAvgPoolingOpContext mkl_context; 70 const Tensor& tensor_in = MklGetInput(context, 0); 71 GetMklShape(context, 0, &mkl_context.input_shape); 72 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 73 74 if (!input_in_mkl_format) 75 mkl_context.params.in_dim = tensor_in.dims(); 76 else 77 mkl_context.params.in_dim = mkl_context.input_shape.GetDimension(); 78 79 MklPoolParameters pool_params; 80 if (!input_in_mkl_format) { 81 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 82 tensor_in.shape()); 83 } else { 84 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 85 &mkl_context.input_shape); 86 } 87 88 // Extract the parameters for the op from the pooling specs 89 ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 90 91 Tensor mkl_tmp_input_buf_tensor_; 92 mkl_context.MklCreateLayoutsAndPrimitives(context, 93 &mkl_tmp_input_buf_tensor_); 94 OP_REQUIRES_OK(context, context->status()); 95 96 Tensor workspace_tensor; 97 void* workspace_buf; 98 AllocTmpBuffer(context, &workspace_tensor, mkl_context.lt_workspace, 99 &workspace_buf); 100 101 if (mkl_context.convert_input != nullptr) { 102 if (input_in_mkl_format == false) { 103 CHECK_EQ( 104 dnnConversionExecute_F32( 105 mkl_context.convert_input, 106 static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())), 107 mkl_context.input_buf), 108 E_SUCCESS); 109 CHECK_EQ(dnnDelete_F32(mkl_context.convert_input), E_SUCCESS); 110 } else { 111 mkl_context.input_shape.GetConvertedFlatData( 112 mkl_context.lt_prim_input, 113 static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())), 114 mkl_context.input_buf); 115 } 116 mkl_context.pooling_res[dnnResourceSrc] = mkl_context.input_buf; 117 } else { 118 mkl_context.pooling_res[dnnResourceSrc] = 119 static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())); 120 } 121 122 // Declare output tensor and allocate memory 123 Tensor* output = nullptr; 124 TensorShape tensor_out_shape; 125 MklShape mkl_out_shape; 126 mkl_out_shape.SetMklTensor(true); 127 mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst); 128 mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, 129 mkl_context.params.out_sizes, 130 mkl_context.params.out_strides); 131 mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 132 133 tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 134 mkl_out_shape.GetMklLayout())) / 135 sizeof(T)); 136 137 AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape, 138 mkl_out_shape); 139 mkl_context.pooling_res[dnnResourceDst] = 140 static_cast<void*>(output->flat<T>().data()); 141 142 mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf; 143 144 CHECK_EQ( 145 dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res), 146 E_SUCCESS); 147 148 mkl_context.MklCleanup(); 149 } // Compute 150 151 private: 152 typedef struct { 153 MklPoolingOpParams params; 154 MklShape input_shape; 155 dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr; 156 dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr, 157 lt_workspace = nullptr; 158 void* input_buf = nullptr; 159 void* pooling_res[dnnResourceNumber]; 160 161 void MklCreateLayoutsAndPrimitives(OpKernelContext* context, 162 Tensor* mkl_tmp_input_buf_tensor) { 163 bool input_in_mkl_format = input_shape.IsMklTensor(); 164 165 if (!input_in_mkl_format) { 166 CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, 167 params.in_sizes, params.in_strides), 168 E_SUCCESS); 169 } else { 170 lt_user_input = (dnnLayout_t)input_shape.GetCurLayout(); 171 } 172 173 dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg; 174 dnnPrimitiveAttributes_t primAttr = nullptr; 175 176 // Create DNN primitives 177 CHECK_EQ(dnnPoolingCreateForward_F32( 178 &prim_pooling_fwd, primAttr, algorithm, lt_user_input, 179 params.kernel_size, params.kernel_stride, params.in_offset, 180 dnnBorderZerosAsymm), 181 E_SUCCESS); 182 183 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 184 <_prim_input, prim_pooling_fwd, dnnResourceSrc), 185 E_SUCCESS); 186 if (!dnnLayoutCompare_F32(lt_user_input, lt_prim_input)) { 187 CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_user_input, 188 lt_prim_input), 189 E_SUCCESS); 190 191 AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_prim_input, 192 &input_buf); 193 } 194 195 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_fwd, 196 dnnResourceWorkspace), 197 E_SUCCESS); 198 } 199 200 void MklCleanup() { 201 bool input_in_mkl_format = input_shape.IsMklTensor(); 202 if (!input_in_mkl_format) { 203 CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); 204 } 205 206 CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS); 207 CHECK_EQ(dnnLayoutDelete_F32(lt_prim_input), E_SUCCESS); 208 } 209 } MklAvgPoolingOpContext; 210 211 std::vector<int32> ksize_; 212 std::vector<int32> stride_; 213 Padding padding_; 214 TensorFormat data_format_; 215 }; 216 217 //----------------------------------------------------------------------------- 218 219 template <class Device, class T> 220 class MklAvgPoolingGradOp : public OpKernel { 221 public: 222 explicit MklAvgPoolingGradOp(OpKernelConstruction* context) 223 : OpKernel(context) { 224 string data_format; 225 226 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 227 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 228 errors::InvalidArgument("Invalid data format")); 229 OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); 230 OP_REQUIRES(context, ksize_.size() == 4, 231 errors::InvalidArgument("Sliding window ksize field must " 232 "specify 4 dimensions")); 233 OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); 234 OP_REQUIRES(context, stride_.size() == 4, 235 errors::InvalidArgument("Sliding window strides field must " 236 "specify 4 dimensions")); 237 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 238 OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1, 239 errors::Unimplemented("Pooling is not yet supported on the " 240 "batch dimension.")); 241 } 242 243 void Compute(OpKernelContext* context) override { 244 MklAvgPoolingGradOpContext mkl_context; 245 const Tensor& tensor_in_shape = MklGetInput(context, 0); 246 const Tensor& out_backprop = MklGetInput(context, 1); 247 GetMklShape(context, 1, &mkl_context.out_backprop_shape); 248 bool outbackprop_in_mkl_format = 249 mkl_context.out_backprop_shape.IsMklTensor(); 250 251 TensorShape output_shape; 252 auto shape_vec = tensor_in_shape.vec<int32>(); 253 for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) { 254 output_shape.AddDim(shape_vec(i)); 255 } 256 257 MklPoolParameters pool_params; 258 pool_params.Init(context, ksize_, stride_, padding_, data_format_, 259 output_shape); 260 261 if (outbackprop_in_mkl_format == false) 262 mkl_context.params.in_dim = out_backprop.dims(); 263 else 264 mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension(); 265 266 // Extract the parameters for the op from the pooling specs 267 ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params); 268 269 // Tensors needed to create temporary buffers 270 Tensor outbackprop_buf_tensor; 271 void* outbackprop_buf; 272 mkl_context.MklCreateLayoutsAndPrimitives(context); 273 OP_REQUIRES_OK(context, context->status()); 274 275 // Check if outbackprop layout requires conversion. 276 if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop, 277 mkl_context.lt_prim_outbackprop)) { 278 CHECK_EQ(dnnConversionCreate_F32(&mkl_context.convert_outbackprop, 279 mkl_context.lt_user_outbackprop, 280 mkl_context.lt_prim_outbackprop), 281 E_SUCCESS); 282 283 AllocTmpBuffer(context, &outbackprop_buf_tensor, 284 mkl_context.lt_prim_outbackprop, &outbackprop_buf); 285 286 if (!outbackprop_in_mkl_format) { 287 CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_outbackprop, 288 static_cast<void*>(const_cast<T*>( 289 out_backprop.flat<T>().data())), 290 outbackprop_buf), 291 E_SUCCESS); 292 CHECK_EQ(dnnDelete_F32(mkl_context.convert_outbackprop), E_SUCCESS); 293 } else { 294 mkl_context.out_backprop_shape.GetConvertedFlatData( 295 mkl_context.lt_prim_outbackprop, 296 static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data())), 297 outbackprop_buf); 298 } 299 mkl_context.pooling_res[dnnResourceDiffDst] = outbackprop_buf; 300 } else { 301 mkl_context.pooling_res[dnnResourceDiffDst] = 302 static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data())); 303 } 304 305 // Handle workspace requirements. 306 Tensor workspace_buf_tensor; 307 void* workspace_buf; 308 AllocTmpBuffer(context, &workspace_buf_tensor, mkl_context.lt_workspace, 309 &workspace_buf); 310 mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf; 311 312 // Handle MKL output tensor setup. 313 Tensor* output = nullptr; 314 TensorShape tensor_out_shape; 315 MklShape mkl_out_shape; 316 mkl_out_shape.SetMklTensor(true); 317 mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_bwd, 318 dnnResourceDiffSrc); 319 mkl_out_shape.SetTfLayout(mkl_context.params.in_dim, 320 mkl_context.params.in_sizes, 321 mkl_context.params.in_strides); 322 mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_); 323 324 tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 325 mkl_out_shape.GetMklLayout())) / 326 sizeof(T)); 327 328 AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape, 329 mkl_out_shape); 330 331 // Set output tensor. 332 mkl_context.pooling_res[dnnResourceDiffSrc] = 333 static_cast<void*>(output->flat<T>().data()); 334 335 // Execute primitive. 336 CHECK_EQ( 337 dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res), 338 E_SUCCESS); 339 340 mkl_context.MklCleanup(); 341 } 342 343 private: 344 typedef struct { 345 MklPoolingOpParams params; 346 MklShape out_backprop_shape; 347 dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr; 348 void* pooling_res[dnnResourceNumber]; 349 dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr, 350 lt_prim_outbackprop = nullptr, lt_workspace = nullptr; 351 352 void MklCreateLayoutsAndPrimitives(OpKernelContext* context) { 353 const Tensor& tensor_in_shape = MklGetInput(context, 0); 354 const Tensor& out_backprop = MklGetInput(context, 1); 355 bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor(); 356 357 if (!outbackprop_in_mkl_format) { 358 // For avgpooling, tensor_in_shape should have 1 dimension, and 4 359 // elements. 360 OP_REQUIRES( 361 context, 362 tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, 363 errors::InvalidArgument("original input shape must be " 364 "1-dimensional and 4 elements")); 365 366 // For avgpooling, out_backprop should have 4 dimensions. 367 OP_REQUIRES(context, out_backprop.dims() == 4, 368 errors::InvalidArgument("out_backprop must be " 369 "4-dimensional")); 370 } else { 371 // Input in MKL format. 372 // For avgpooling, out_backprop should have 4 dimensions. 373 OP_REQUIRES(context, out_backprop_shape.GetDimension() == 4, 374 errors::InvalidArgument("out_backprop must be " 375 "4-dimensional")); 376 } 377 378 // TODO(inteltf): Get outbackprop layout. 379 // Do we need to create layout in every invocation? 380 if (!outbackprop_in_mkl_format) { 381 CHECK_EQ(dnnLayoutCreate_F32(<_user_outbackprop, params.in_dim, 382 params.out_sizes, params.out_strides), 383 E_SUCCESS); 384 } else { 385 lt_user_outbackprop = (dnnLayout_t)out_backprop_shape.GetCurLayout(); 386 } 387 388 // Create the backward primitive 389 // Create DNN user layout 390 CHECK_EQ(dnnLayoutCreate_F32(<_user_input, params.in_dim, 391 params.in_sizes, params.in_strides), 392 E_SUCCESS); 393 394 // Create PoolingBackward primitive 395 dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg; 396 dnnPrimitiveAttributes_t primAttr = nullptr; 397 CHECK_EQ(dnnPoolingCreateBackward_F32( 398 &prim_pooling_bwd, primAttr, algorithm, lt_user_input, 399 params.kernel_size, params.kernel_stride, params.in_offset, 400 dnnBorderZerosAsymm), 401 E_SUCCESS); 402 403 // Create expected outbackprop layout from the primitive. 404 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 405 <_prim_outbackprop, prim_pooling_bwd, dnnResourceDiffDst), 406 E_SUCCESS); 407 408 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, prim_pooling_bwd, 409 dnnResourceWorkspace), 410 E_SUCCESS); 411 } 412 413 void MklCleanup() { 414 bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor(); 415 CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS); 416 CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS); 417 if (!outbackprop_in_mkl_format) { 418 CHECK_EQ(dnnLayoutDelete_F32(lt_user_outbackprop), E_SUCCESS); 419 } 420 CHECK_EQ(dnnLayoutDelete_F32(lt_prim_outbackprop), E_SUCCESS); 421 CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS); 422 } 423 } MklAvgPoolingGradOpContext; 424 425 std::vector<int32> ksize_; 426 std::vector<int32> stride_; 427 Padding padding_; 428 TensorFormat data_format_; 429 }; // MklAvgPoolingGradOp 430 431 #else 432 433 template <typename Device, typename T> 434 class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> { 435 public: 436 explicit MklAvgPoolingOp(OpKernelConstruction* context) 437 : MklPoolingForwardOpBase<T>(context) { 438 // Workspace is an MKLDNN construct that is only used in Max Pooling. 439 // So set workspace_enabled_ to false. 440 this->workspace_enabled_ = false; 441 } 442 443 void Compute(OpKernelContext* context) override { 444 try { 445 auto cpu_engine = engine(engine::cpu, 0); 446 const Tensor& input_tensor = 447 MklGetInput(context, this->kInputTensorIndexInput); 448 MklDnnShape dnn_shape_input; 449 GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input); 450 this->SanityCheckInput(context, input_tensor, dnn_shape_input); 451 if (!context->status().ok()) return; 452 453 MklDnnData<T> dnn_data_input(&cpu_engine); 454 MklDnnData<T> dnn_data_output(&cpu_engine); 455 456 // initialize variables for the pooling op 457 MklPoolParameters pool_params; 458 // Get the input tensor and initialize the pooling parameters 459 this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params, 460 &dnn_data_input); 461 OP_REQUIRES_OK(context, context->status()); 462 463 // Declare output tensor 464 Tensor* output_tensor = nullptr; 465 memory::dims output_dims_mkl_order; 466 this->GetOutputDims(pool_params, &output_dims_mkl_order); 467 468 // If input is an empty tensor, allocate an empty output tensor and return 469 if (input_tensor.NumElements() == 0) { 470 MklDnnShape output_mkl_shape; 471 output_mkl_shape.SetMklTensor(false); 472 TensorShape output_tf_shape; 473 if (pool_params.data_format == TensorFormat::FORMAT_NCHW) { 474 output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); 475 } else { 476 memory::dims output_dims_NHWC_order; 477 output_dims_NHWC_order = {pool_params.tensor_in_batch, 478 static_cast<int>(pool_params.out_height), 479 static_cast<int>(pool_params.out_width), 480 pool_params.out_depth}; 481 output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); 482 } 483 const int kOutputIndex = 0; 484 AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor, 485 output_tf_shape, output_mkl_shape); 486 CHECK_NOTNULL(output_tensor); 487 return; 488 } 489 490 // If input is in Mkl layout, then just get the memory format from it 491 // directly, instead of using input data_format to AvgPool. 492 if (dnn_shape_input.IsMklTensor()) { 493 dnn_data_output.SetUsrMem( 494 output_dims_mkl_order, 495 static_cast<memory::format>( 496 dnn_data_input.GetUsrMemDesc().data.format)); 497 498 } else { 499 dnn_data_output.SetUsrMem(output_dims_mkl_order, 500 this->data_format_mkldnn_); 501 } 502 503 // describe the memory layout 504 dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any); 505 506 // 3. create a pooling primitive descriptor 507 auto pool_desc = pooling_forward::desc( 508 prop_kind::forward, algorithm::pooling_avg_exclude_padding, 509 dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(), 510 memory::dims({pool_params.row_stride, pool_params.col_stride}), 511 memory::dims({pool_params.window_rows, pool_params.window_cols}), 512 memory::dims({static_cast<int>(pool_params.pad_top), 513 static_cast<int>(pool_params.pad_left)}), 514 memory::dims({static_cast<int>(pool_params.pad_bottom), 515 static_cast<int>(pool_params.pad_right)}), 516 TFPaddingToMklDnnPadding(this->padding_)); 517 auto pool_prim_desc = 518 pooling_forward::primitive_desc(pool_desc, cpu_engine); 519 520 this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order, 521 this->data_format_mkldnn_, &output_tensor); 522 CHECK_NOTNULL(output_tensor); 523 524 OP_REQUIRES_OK(context, context->status()); 525 dnn_data_output.SetUsrMemDataHandle(output_tensor); 526 527 this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input, 528 &dnn_data_output); 529 } catch (mkldnn::error& e) { 530 string error_msg = "Status: " + std::to_string(e.status) + 531 ", message: " + string(e.message) + ", in file " + 532 string(__FILE__) + ":" + std::to_string(__LINE__); 533 OP_REQUIRES_OK( 534 context, 535 errors::Aborted("Operation received an exception:", error_msg)); 536 } 537 } // Compute 538 }; // MklAvgPoolingOp 539 540 //----------------------------------------------------------------------------- 541 542 template <class Device, class T> 543 class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> { 544 public: 545 explicit MklAvgPoolingGradOp(OpKernelConstruction* context) 546 : MklPoolingBackwardOpBase<T>(context) {} 547 548 void Compute(OpKernelContext* context) override { 549 try { 550 auto cpu_engine = engine(engine::cpu, 0); 551 MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape; 552 const Tensor& tensor_in_shape = 553 MklGetInput(context, kInputTensorIndexInputShape); 554 const Tensor& input_gradient_tensor = 555 MklGetInput(context, kInputTensorIndexInputGradient); 556 GetMklShape(context, kInputTensorIndexInputShape, 557 &original_input_mkl_shape); 558 GetMklShape(context, kInputTensorIndexInputGradient, 559 &input_gradient_mkl_shape); 560 561 SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor, 562 original_input_mkl_shape, input_gradient_mkl_shape); 563 if (!context->status().ok()) return; 564 565 // Used to allocate output_diff_src/diff_src 566 // and create pool_fwd mdm desc 567 // 0. Input("orig_input_shape: int32") //NOT a T Tensor! 568 // 1. Input("grad: T") 569 570 MklDnnData<T> input_gradient_diff_dst(&cpu_engine); 571 MklDnnData<T> output_diff_src(&cpu_engine); 572 Tensor* output_tensor_diff_src = nullptr; 573 TensorShape original_input_shape; 574 MklPoolParameters pool_params; 575 memory::dims output_dims_mkl_order, original_input_dims_nchw; 576 // Configure the original input memory descriptor 577 memory::desc original_input_md = ConfigureOriginalInput( 578 context, tensor_in_shape, original_input_mkl_shape, 579 &original_input_dims_nchw, &pool_params, &original_input_shape); 580 581 // configure the original output memory descriptor 582 // by definition, the shape of the original output is the same 583 // as the shape of the gradient diff_dst 584 memory::desc original_output_md = this->ConfigureOriginalOutput( 585 pool_params, input_gradient_mkl_shape, output_dims_mkl_order); 586 587 memory::desc target_diff_dst_md = this->ConfigureInputGradient( 588 input_gradient_mkl_shape, input_gradient_tensor, 589 &input_gradient_diff_dst, original_output_md); 590 // The shape of the output diff src needs to be the same shape as the 591 // original input. But we will set its format to be same as the format of 592 // input gradient. We won't use format of original input since it will 593 // always be in Tensorflow layout (given that AvgPoolGrad gets shape of 594 // the input rather than actual input). 595 output_diff_src.SetUsrMem( 596 original_input_dims_nchw, 597 static_cast<memory::format>(target_diff_dst_md.data.format)); 598 599 // Create the forward pooling primitive descriptor so we can reference it 600 // in the backward pooling primitive descriptor 601 auto pool_fwd_desc = pooling_forward::desc( 602 prop_kind::forward, algorithm::pooling_avg_exclude_padding, 603 original_input_md, original_output_md, 604 memory::dims({pool_params.row_stride, pool_params.col_stride}), 605 memory::dims({pool_params.window_rows, pool_params.window_cols}), 606 memory::dims({static_cast<int>(pool_params.pad_top), 607 static_cast<int>(pool_params.pad_left)}), 608 memory::dims({static_cast<int>(pool_params.pad_bottom), 609 static_cast<int>(pool_params.pad_right)}), 610 TFPaddingToMklDnnPadding(this->padding_)); 611 auto pool_fwd_prim_desc = 612 pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine); 613 614 auto pool_bkwd_desc = pooling_backward::desc( 615 algorithm::pooling_avg_exclude_padding, 616 output_diff_src.GetUsrMemDesc(), target_diff_dst_md, 617 memory::dims({pool_params.row_stride, pool_params.col_stride}), 618 memory::dims({pool_params.window_rows, pool_params.window_cols}), 619 memory::dims({static_cast<int>(pool_params.pad_top), 620 static_cast<int>(pool_params.pad_left)}), 621 memory::dims({static_cast<int>(pool_params.pad_bottom), 622 static_cast<int>(pool_params.pad_right)}), 623 TFPaddingToMklDnnPadding(this->padding_)); 624 auto pool_bkwd_prim_desc = pooling_backward::primitive_desc( 625 pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc); 626 this->AllocateOutputTensor( 627 context, pool_bkwd_prim_desc, original_input_dims_nchw, 628 this->data_format_mkldnn_, &output_tensor_diff_src); 629 630 output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src); 631 632 this->PrepareAndExecuteNet( 633 pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src, 634 memory::primitive_desc(target_diff_dst_md, cpu_engine)); 635 } catch (mkldnn::error& e) { 636 string error_msg = "Status: " + std::to_string(e.status) + 637 ", message: " + string(e.message) + ", in file " + 638 string(__FILE__) + ":" + std::to_string(__LINE__); 639 OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:", 640 error_msg)); 641 } 642 } // Compute 643 644 private: 645 // 0. Input("orig_input_shape: int32") 646 // 1. Input("grad: T") 647 const int kInputTensorIndexInputShape = 0; 648 const int kInputTensorIndexInputGradient = 1; 649 650 memory::desc ConfigureOriginalInput( 651 OpKernelContext* context, const Tensor& tensor_original_input_shape, 652 const MklDnnShape& original_input_mkl_shape, 653 memory::dims* original_input_dims_mkl_order, 654 MklPoolParameters* pool_params, TensorShape* input_tensor_shape) { 655 CHECK_NOTNULL(original_input_dims_mkl_order); 656 CHECK_NOTNULL(pool_params); 657 CHECK_NOTNULL(input_tensor_shape); 658 // For AvgPoolGrad, we only get the size of the original input because 659 // The original data is irrelvant. 660 auto shape_vec = tensor_original_input_shape.vec<int32>(); 661 for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) { 662 input_tensor_shape->AddDim(shape_vec(i)); 663 } 664 665 return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput( 666 context, tensor_original_input_shape, original_input_mkl_shape, 667 original_input_dims_mkl_order, pool_params, *input_tensor_shape); 668 } 669 670 void SanityCheckInputs(OpKernelContext* context, 671 const Tensor& tensor_in_shape, 672 const Tensor& input_gradient_tensor, 673 const MklDnnShape& original_input_mkl_shape, 674 const MklDnnShape& input_gradient_mkl_shape) { 675 if (!original_input_mkl_shape.IsMklTensor()) { 676 OP_REQUIRES( 677 context, 678 tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4, 679 errors::InvalidArgument("original input shape must be " 680 "1-dimensional and 4 elements")); 681 } else { 682 OP_REQUIRES(context, 683 original_input_mkl_shape.GetDimension() == 1 && 684 original_input_mkl_shape.DimSize(0) == 4, 685 errors::InvalidArgument("original input shape must be " 686 "1-dimensional and 4 elements")); 687 } 688 689 if (!input_gradient_mkl_shape.IsMklTensor()) { 690 // For avgpooling, input_gradient_diff_dst should have 4 dimensions. 691 OP_REQUIRES(context, input_gradient_tensor.dims() == 4, 692 errors::InvalidArgument("Gradient shape must be " 693 "4-dimensional")); 694 } else { 695 OP_REQUIRES(context, input_gradient_mkl_shape.GetDimension() == 4, 696 errors::InvalidArgument("Gradient shape must be " 697 "4-dimensional")); 698 } 699 } 700 }; // MklAvgPoolingGradOp 701 702 #endif // INTEL_MKL_ML 703 704 REGISTER_KERNEL_BUILDER(Name("_MklAvgPool") 705 .Device(DEVICE_CPU) 706 .TypeConstraint<float>("T") 707 .Label(mkl_op_registry::kMklOpLabel), 708 MklAvgPoolingOp<CPUDevice, float>); 709 710 REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad") 711 .Device(DEVICE_CPU) 712 .TypeConstraint<float>("T") 713 .Label(mkl_op_registry::kMklOpLabel), 714 MklAvgPoolingGradOp<CPUDevice, float>); 715 716 } // namespace tensorflow 717 #endif // INTEL_MKL 718