1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ 18 19 #ifdef INTEL_MKL 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include "tensorflow/core/util/mkl_util.h" 24 #include "tensorflow/core/util/padding.h" 25 26 #ifndef INTEL_MKL_ML_ONLY 27 #include "mkldnn.hpp" 28 using mkldnn::memory; 29 using mkldnn::pooling_backward; 30 using mkldnn::pooling_forward; 31 using mkldnn::stream; 32 #endif 33 34 namespace tensorflow { 35 36 #ifndef INTEL_MKL_ML_ONLY 37 38 using mkldnn::memory; 39 using mkldnn::pooling_avg; 40 using mkldnn::pooling_avg_exclude_padding; 41 using mkldnn::pooling_avg_include_padding; 42 using mkldnn::pooling_max; 43 using mkldnn::prop_kind; 44 45 struct MklPoolingParams { 46 memory::dims src_dims; 47 memory::dims dst_dims; 48 memory::dims filter_dims; 49 memory::dims strides; 50 memory::dims padding_left; 51 memory::dims padding_right; 52 mkldnn::algorithm alg_kind; 53 mkldnn::prop_kind prop_kind; 54 55 MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, 56 memory::dims filter_dims, memory::dims strides, 57 memory::dims padding_left, memory::dims padding_right, 58 mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind) 59 : src_dims(src_dims), 60 dst_dims(dst_dims), 61 filter_dims(filter_dims), 62 strides(strides), 63 padding_left(padding_left), 64 padding_right(padding_right), 65 alg_kind(alg_kind), 66 prop_kind(prop_kind) {} 67 }; 68 69 template <typename T> 70 class MklPoolingFwdPrimitive : public MklPrimitive { 71 public: 72 explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams) 73 : cpu_engine_(engine::cpu, 0) { 74 context_.fwd_stream.reset(new stream(stream::kind::eager)); 75 if (context_.fwd == nullptr) Setup(fwdParams); 76 } 77 78 ~MklPoolingFwdPrimitive() {} 79 80 // Pooling forward execute 81 // src_data: input data buffer of src 82 // ws_data: output data buffer of workspace 83 // dst_data: output data buffer of dst 84 void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr); 85 86 std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd() 87 const { 88 return context_.fwd_pd; 89 } 90 91 memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } 92 93 memory::format GetDstMemoryFormat() const { return context_.dst_fmt; } 94 95 private: 96 void Setup(const MklPoolingParams& fwdParams); 97 98 struct PoolingFwdContext { 99 // algorithm 100 mkldnn::algorithm alg_kind; 101 102 // Kind of propagation, forward or backward 103 mkldnn::prop_kind prop_kind; 104 105 // expected memory format 106 memory::format src_fmt; 107 memory::format dst_fmt; 108 memory::format ws_fmt; 109 110 // workspace shape 111 memory::dims ws_dims; 112 memory::data_type ws_dt; 113 size_t ws_size; 114 115 // MKL-DNN memory, just dummy data 116 std::shared_ptr<mkldnn::memory> ws_mem; 117 std::shared_ptr<mkldnn::memory> src_mem; 118 std::shared_ptr<mkldnn::memory> dst_mem; 119 120 // desc & primitive desc 121 std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc; 122 std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd; 123 124 // memory desc 125 std::shared_ptr<mkldnn::memory::desc> src_md; 126 std::shared_ptr<mkldnn::memory::desc> dst_md; 127 128 // Pooling primitive 129 std::shared_ptr<mkldnn::pooling_forward> fwd; 130 std::shared_ptr<mkldnn::stream> fwd_stream; 131 std::vector<mkldnn::primitive> fwd_primitives; 132 133 PoolingFwdContext() 134 : src_fmt(memory::format::any), 135 dst_fmt(memory::format::any), 136 ws_fmt(memory::format::any), 137 ws_mem(nullptr), 138 src_mem(nullptr), 139 dst_mem(nullptr), 140 fwd_desc(nullptr), 141 fwd_pd(nullptr), 142 src_md(nullptr), 143 dst_md(nullptr), 144 fwd(nullptr), 145 fwd_stream(nullptr) {} 146 }; 147 148 struct PoolingFwdContext context_; 149 engine cpu_engine_; 150 }; 151 152 template <typename T> 153 class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> { 154 public: 155 static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) { 156 MklPoolingFwdPrimitive<T>* pooling_forward = nullptr; 157 158 // Get pooling primitive from the pool 159 pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>( 160 MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd( 161 fwdParams)); 162 163 if (pooling_forward == nullptr) { 164 pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams); 165 MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd( 166 fwdParams, pooling_forward); 167 } 168 return pooling_forward; 169 } 170 171 static MklPoolingFwdPrimitiveFactory& GetInstance() { 172 static MklPoolingFwdPrimitiveFactory instance_; 173 return instance_; 174 } 175 176 private: 177 MklPoolingFwdPrimitiveFactory() {} 178 ~MklPoolingFwdPrimitiveFactory() {} 179 180 // The key to be created will be used to get/set pooling 181 // primitive op from reuse perspective. 182 // A pooling key is a string which concates key parameters 183 // as well as algorithm kind (max versus avg). 184 static string CreateKey(const MklPoolingParams& fwdParams) { 185 string prefix = "pooling_fwd"; 186 FactoryKeyCreator key_creator; 187 key_creator.AddAsKey(prefix); 188 key_creator.AddAsKey(fwdParams.src_dims); 189 key_creator.AddAsKey(fwdParams.dst_dims); 190 key_creator.AddAsKey(fwdParams.filter_dims); 191 key_creator.AddAsKey(fwdParams.strides); 192 key_creator.AddAsKey(fwdParams.padding_left); 193 key_creator.AddAsKey(fwdParams.padding_right); 194 key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind)); 195 key_creator.AddAsKey<int>(static_cast<int>(fwdParams.prop_kind)); 196 return key_creator.GetKey(); 197 } 198 199 MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) { 200 string key = CreateKey(fwdParams); 201 return this->GetOp(key); 202 } 203 204 void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) { 205 string key = CreateKey(fwdParams); 206 this->SetOp(key, op); 207 } 208 }; 209 210 template <typename T> 211 class MklPoolingBwdPrimitive : public MklPrimitive { 212 public: 213 explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams) 214 : cpu_engine(engine::cpu, 0) { 215 context_.bwd_stream.reset(new stream(stream::kind::eager)); 216 if (context_.bwd == nullptr) Setup(bwdParams); 217 } 218 219 ~MklPoolingBwdPrimitive() {} 220 221 // Pooling backward execute 222 // diff_dst_data: input data buffer of diff_dst 223 // diff_src_data: output data buffer of diff_src 224 // ws_data: input data buffer of workspace 225 void Execute(const T* diff_dst_data, T* diff_src_data, 226 const void* ws_data = nullptr); 227 228 public: 229 std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd() 230 const { 231 return context_.fwd_pd; 232 } 233 std::shared_ptr<mkldnn::pooling_backward::primitive_desc> GetPoolingBwdPd() 234 const { 235 return context_.bwd_pd; 236 } 237 238 memory::format GetDiffDstFormat() const { return context_.diff_dst_fmt; } 239 240 mkldnn::memory::data_type GetWorkspaceDataType() const { 241 return context_.ws_dt; 242 } 243 memory::format GetWorkspaceFormat() const { return context_.ws_fmt; } 244 245 private: 246 void Setup(const MklPoolingParams& bwdParams); 247 248 // Primitive reuse context for pooling bwd ops 249 struct PoolingBwdContext { 250 // algorithm 251 mkldnn::algorithm alg_kind; 252 253 // expected memory format 254 mkldnn::memory::format diff_src_fmt; 255 mkldnn::memory::format diff_dst_fmt; 256 mkldnn::memory::format ws_fmt; 257 258 // workspace attribute 259 mkldnn::memory::dims ws_dims; 260 mkldnn::memory::data_type ws_dt; 261 262 // MKL-DNN memory 263 std::shared_ptr<mkldnn::memory> ws_mem; 264 std::shared_ptr<mkldnn::memory> diff_src_mem; 265 std::shared_ptr<mkldnn::memory> diff_dst_mem; 266 267 // memory desc 268 std::shared_ptr<mkldnn::memory::desc> diff_src_md; 269 std::shared_ptr<mkldnn::memory::desc> diff_dst_md; 270 271 // desc & primitive desc 272 std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc; 273 std::shared_ptr<mkldnn::pooling_backward::desc> bwd_desc; 274 std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd; 275 std::shared_ptr<mkldnn::pooling_backward::primitive_desc> bwd_pd; 276 277 // pooling primitive 278 std::shared_ptr<mkldnn::pooling_backward> bwd; 279 std::shared_ptr<mkldnn::stream> bwd_stream; 280 281 std::vector<mkldnn::primitive> bwd_primitives; 282 283 PoolingBwdContext() 284 : diff_src_fmt(memory::format::any), 285 diff_dst_fmt(memory::format::any), 286 ws_fmt(memory::format::any), 287 ws_mem(nullptr), 288 diff_src_mem(nullptr), 289 diff_dst_mem(nullptr), 290 diff_src_md(nullptr), 291 diff_dst_md(nullptr), 292 fwd_desc(nullptr), 293 bwd_desc(nullptr), 294 fwd_pd(nullptr), 295 bwd_pd(nullptr), 296 bwd(nullptr), 297 bwd_stream(nullptr) {} 298 }; 299 300 struct PoolingBwdContext context_; 301 engine cpu_engine; 302 }; 303 304 template <typename T> 305 class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> { 306 public: 307 static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) { 308 MklPoolingBwdPrimitive<T>* pooling_backward = nullptr; 309 310 // Find a pooling backward primitive from the pool 311 // If it does not exist, create a new one 312 pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>( 313 MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd( 314 bwdParams)); 315 if (pooling_backward == nullptr) { 316 pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams); 317 MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd( 318 bwdParams, pooling_backward); 319 } 320 return pooling_backward; 321 } 322 323 static MklPoolingBwdPrimitiveFactory& GetInstance() { 324 static MklPoolingBwdPrimitiveFactory instance_; 325 return instance_; 326 } 327 328 private: 329 MklPoolingBwdPrimitiveFactory() {} 330 ~MklPoolingBwdPrimitiveFactory() {} 331 332 // The key to be created will be used to get/set pooling 333 // primitive op from reuse perspective. 334 // A pooling key is a string which concates key parameters 335 // as well as algorithm kind (max versus avg). 336 static string CreateKey(const MklPoolingParams& bwdParams) { 337 string prefix = "pooling_bwd"; 338 FactoryKeyCreator key_creator; 339 key_creator.AddAsKey(prefix); 340 key_creator.AddAsKey(bwdParams.src_dims); 341 key_creator.AddAsKey(bwdParams.dst_dims); 342 key_creator.AddAsKey(bwdParams.filter_dims); 343 key_creator.AddAsKey(bwdParams.strides); 344 key_creator.AddAsKey(bwdParams.padding_left); 345 key_creator.AddAsKey(bwdParams.padding_right); 346 key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind)); 347 return key_creator.GetKey(); 348 } 349 350 MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) { 351 string key = CreateKey(bwdParams); 352 return this->GetOp(key); 353 } 354 355 void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) { 356 string key = CreateKey(bwdParams); 357 this->SetOp(key, op); 358 } 359 }; 360 #endif 361 362 typedef Eigen::ThreadPoolDevice CPUDevice; 363 364 struct MklPoolParameters { 365 int depth; 366 367 int tensor_in_planes; // Pool3D 368 int tensor_in_cols; 369 int tensor_in_rows; 370 int tensor_in_batch; 371 372 int window_planes; // Pool3D 373 int window_rows; 374 int window_cols; 375 int depth_window; 376 377 int planes_stride; // Pool3D 378 int row_stride; 379 int col_stride; 380 int depth_stride; 381 382 int64 out_planes; // Pool3D 383 int64 out_height; 384 int64 out_width; 385 int out_depth; 386 387 int64 pad_P1; // Pool3D 388 int64 pad_P2; // Pool3D 389 int64 pad_left; 390 int64 pad_right; 391 int64 pad_top; 392 int64 pad_bottom; 393 int pad_depth; 394 395 TensorFormat data_format; 396 MklPoolParameters() 397 : depth(0), 398 tensor_in_planes(0), 399 tensor_in_cols(0), 400 tensor_in_rows(0), 401 tensor_in_batch(0), 402 window_planes(0), 403 window_rows(0), 404 window_cols(0), 405 depth_window(0), 406 planes_stride(0), 407 row_stride(0), 408 col_stride(0), 409 depth_stride(0), 410 out_planes(0), 411 out_height(0), 412 out_width(0), 413 out_depth(0), 414 pad_P1(0), 415 pad_P2(0), 416 pad_left(0), 417 pad_right(0), 418 pad_top(0), 419 pad_bottom(0), 420 pad_depth(0), 421 data_format(TensorFormat::FORMAT_NCHW) {} 422 423 // Updates context->status if there is an invalid input. 424 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 425 const std::vector<int32>& stride, Padding padding, 426 TensorFormat data_format, const TensorShape& tensor_in_shape); 427 #ifdef INTEL_MKL_ML_ONLY 428 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 429 const std::vector<int32>& stride, Padding padding, 430 TensorFormat data_format, const MklShape* mkl_in_shape); 431 #else 432 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 433 const std::vector<int32>& stride, Padding padding, 434 TensorFormat data_format, const MklDnnShape* mkl_in_shape); 435 #endif 436 437 private: 438 // Common initialization for TensorFlow and MKL formats 439 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 440 const std::vector<int32>& stride, Padding padding, 441 TensorFormat data_format); 442 }; 443 444 #ifndef INTEL_MKL_ML_ONLY 445 446 template <class T> 447 class MklPoolingOpBase : public OpKernel { 448 public: 449 explicit MklPoolingOpBase(OpKernelConstruction* context) 450 : OpKernel(context), workspace_enabled_(false) { 451 string data_format; 452 if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) { 453 // current quantized convolution doesn't have data_format attribute. 454 data_format = "NHWC"; 455 } else { 456 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 457 } 458 OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_), 459 errors::InvalidArgument("Invalid data format")); 460 OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_)); 461 OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5, 462 errors::InvalidArgument("Sliding window ksize field must " 463 "specify 4 or 5 dimensions")); 464 OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); 465 OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5, 466 errors::InvalidArgument("Sliding window strides field must " 467 "specify 4 or 5 dimensions")); 468 OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_)); 469 OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1, 470 errors::Unimplemented("Pooling is not yet supported on the " 471 "batch dimension.")); 472 bool is_pool2d = (this->ksize_.size() == 4); 473 this->data_format_mkldnn_ = 474 is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_) 475 : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_); 476 477 // We may not get this attribute for this node if it does not go through 478 // graph rewrite pass. So we do not check for error while retrieving this 479 // attribute value. 480 context->GetAttr("workspace_enabled", &this->workspace_enabled_); 481 } 482 void Compute(OpKernelContext* context) override = 0; 483 484 protected: 485 // Calculate output shape of pooling op in MKL-DNN and TensorFlow order. 486 // MKL-DNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order. 487 // But TensorFlow output will be in NHWC/NCHW(Pool2D) or 488 // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects 489 // output height and width to have already been int32 bounds-checked. 490 void GetOutputDims(const MklPoolParameters& mkl_pool_params, 491 memory::dims* output_dims_mkl_order) { 492 if (this->ksize_.size() == 4) { 493 // Pooling2D: MKL-DNN always needs output in NCHW format. 494 *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, 495 mkl_pool_params.out_depth, 496 static_cast<int>(mkl_pool_params.out_height), 497 static_cast<int>(mkl_pool_params.out_width)}; 498 } else { 499 // Pooling3D: MKL-DNN always needs output in NCDHW format. 500 *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, 501 mkl_pool_params.out_depth, 502 static_cast<int>(mkl_pool_params.out_planes), 503 static_cast<int>(mkl_pool_params.out_height), 504 static_cast<int>(mkl_pool_params.out_width)}; 505 } 506 } 507 508 void InitMklPoolParameters(OpKernelContext* context, 509 MklPoolParameters* pool_params, 510 const MklDnnShape& original_input_mkl_shape, 511 const TensorShape& input_tensor_shape) { 512 if (!original_input_mkl_shape.IsMklTensor()) { 513 pool_params->Init(context, this->ksize_, this->stride_, this->padding_, 514 this->data_format_tf_, input_tensor_shape); 515 } else { 516 pool_params->Init(context, this->ksize_, this->stride_, this->padding_, 517 this->data_format_tf_, &original_input_mkl_shape); 518 } 519 } 520 521 void PoolParamsToDims(const MklPoolParameters* pool_params, 522 memory::dims* filter_dims, memory::dims* strides, 523 memory::dims* padding_left, memory::dims* padding_right, 524 bool is_pool2d) { 525 if (is_pool2d) { 526 // Pool2D 527 *filter_dims = 528 memory::dims({pool_params->window_rows, pool_params->window_cols}); 529 *strides = 530 memory::dims({pool_params->row_stride, pool_params->col_stride}); 531 *padding_left = memory::dims({static_cast<int>(pool_params->pad_top), 532 static_cast<int>(pool_params->pad_left)}); 533 *padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom), 534 static_cast<int>(pool_params->pad_right)}); 535 } else { 536 // Pool3D 537 *filter_dims = 538 memory::dims({pool_params->window_planes, pool_params->window_rows, 539 pool_params->window_cols}); 540 *strides = 541 memory::dims({pool_params->planes_stride, pool_params->row_stride, 542 pool_params->col_stride}); 543 544 *padding_left = memory::dims({static_cast<int>(pool_params->pad_P1), 545 static_cast<int>(pool_params->pad_top), 546 static_cast<int>(pool_params->pad_left)}); 547 *padding_right = memory::dims({static_cast<int>(pool_params->pad_P2), 548 static_cast<int>(pool_params->pad_bottom), 549 static_cast<int>(pool_params->pad_right)}); 550 } 551 } 552 553 void AllocateEmptyOutputTensor(OpKernelContext* context, 554 const int kOutputIndex, 555 MklPoolParameters* pool_params, 556 const memory::dims output_dims_mkl_order, 557 Tensor** output_tensor) { 558 MklDnnShape output_mkl_shape; 559 output_mkl_shape.SetMklTensor(false); 560 TensorShape output_tf_shape; 561 if (pool_params->data_format == TensorFormat::FORMAT_NCHW) { 562 output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order); 563 } else { 564 memory::dims output_dims_NHWC_order; 565 output_dims_NHWC_order = {pool_params->tensor_in_batch, 566 static_cast<int>(pool_params->out_height), 567 static_cast<int>(pool_params->out_width), 568 pool_params->out_depth}; 569 output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order); 570 } 571 AllocateOutputSetMklShape(context, kOutputIndex, output_tensor, 572 output_tf_shape, output_mkl_shape); 573 CHECK_NOTNULL(output_tensor); 574 } 575 576 // Checks to make sure that the memory we need to allocate 577 // is a multiple of sizeof(T) 578 // returns the number of elements 579 size_t GetNumTElements(const memory::primitive_desc& pd) { 580 size_t num_bytes = pd.get_size(); 581 size_t ret_val = num_bytes / sizeof(T); 582 if (num_bytes % sizeof(T) != 0) { 583 ret_val++; 584 } 585 return ret_val; 586 } 587 588 std::vector<int32> ksize_; 589 std::vector<int32> stride_; 590 Padding padding_; 591 TensorFormat data_format_tf_; 592 memory::format data_format_mkldnn_; 593 bool workspace_enabled_; 594 }; 595 596 template <class T> 597 class MklPoolingForwardOpBase : public MklPoolingOpBase<T> { 598 public: 599 explicit MklPoolingForwardOpBase<T>(OpKernelConstruction* context) 600 : MklPoolingOpBase<T>(context) {} 601 void Compute(OpKernelContext* context) override = 0; 602 603 protected: 604 void ConfigureInput(OpKernelContext* context, 605 const MklDnnShape& input_mkl_shape, 606 const Tensor& input_tensor, 607 MklPoolParameters* pool_params, 608 MklDnnData<T>* dnn_data_input) { 609 CHECK_NOTNULL(pool_params); 610 CHECK_NOTNULL(dnn_data_input); 611 TensorShape input_tensor_shape = input_tensor.shape(); 612 if (input_tensor.NumElements() != 0) { 613 memory::desc input_md = 614 input_mkl_shape.IsMklTensor() 615 ? input_mkl_shape.GetMklLayout() 616 : memory::desc( 617 (this->ksize_.size() == 4) 618 ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape, 619 this->data_format_tf_) 620 : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape, 621 this->data_format_tf_), 622 MklDnnType<T>(), this->data_format_mkldnn_); 623 dnn_data_input->SetUsrMem(input_md, &input_tensor); 624 625 if (this->ksize_.size() == 5) { 626 // Pool3D 627 std::vector<int> mkldnn_sizes(5, -1); 628 mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0]; 629 mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1]; 630 mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2]; 631 mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3]; 632 mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4]; 633 dnn_data_input->SetOpMemDesc(mkldnn_sizes, this->data_format_mkldnn_); 634 } 635 } 636 this->InitMklPoolParameters(context, pool_params, input_mkl_shape, 637 input_tensor_shape); 638 } 639 640 void AllocateOutputTensor( 641 OpKernelContext* context, 642 const pooling_forward::primitive_desc& pool_fwd_prim_desc, 643 const memory::dims output_dims_mkl_order, 644 const memory::format& output_tf_format, Tensor** output_tensor) { 645 CHECK_NOTNULL(output_tensor); 646 memory::primitive_desc dst_pd = pool_fwd_prim_desc.dst_primitive_desc(); 647 648 MklDnnShape output_mkl_shape; 649 output_mkl_shape.SetMklTensor(true); 650 output_mkl_shape.SetMklLayout(&dst_pd); 651 output_mkl_shape.SetElemType(MklDnnType<T>()); 652 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 653 output_dims_mkl_order, output_tf_format); 654 TensorShape output_tf_shape; 655 656 // only allocate enough space for the elements we need. 657 output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); 658 AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, 659 output_tf_shape, output_mkl_shape); 660 CHECK_NOTNULL(*output_tensor); 661 } 662 663 void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, 664 const MklDnnShape& input_mkl_shape) { 665 if (!input_mkl_shape.IsMklTensor()) { 666 OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5, 667 errors::InvalidArgument("Input must be 4 or 5-dimensional")); 668 } else { 669 OP_REQUIRES( 670 context, 671 input_mkl_shape.GetDimension() == 4 || 672 input_mkl_shape.GetDimension() == 5, 673 errors::InvalidArgument("Input shape must be 4 or 5-dimensional")); 674 } 675 } 676 // .Input("value: T") 677 // .Output("output: T") 678 const int kInputTensorIndexInput = 0; 679 const int kOutputTensorIndexOutput = 0; 680 }; // MklPoolingForwardBaseOp 681 682 template <class T> 683 class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> { 684 public: 685 explicit MklPoolingBackwardOpBase<T>(OpKernelConstruction* context) 686 : MklPoolingOpBase<T>(context) {} 687 void Compute(OpKernelContext* context) override = 0; 688 689 protected: 690 const int kOutputTensorIndexOutput = 0; 691 692 void AllocateOutputTensor( 693 OpKernelContext* context, 694 const pooling_backward::primitive_desc& pool_bkwd_prim_desc, 695 const memory::dims output_dims_mkl_order, 696 const memory::format& output_tf_format, Tensor** output_tensor) { 697 CHECK_NOTNULL(output_tensor); 698 memory::primitive_desc dst_pd = 699 pool_bkwd_prim_desc.diff_src_primitive_desc(); 700 MklDnnShape output_mkl_shape; 701 output_mkl_shape.SetMklTensor(true); 702 output_mkl_shape.SetMklLayout(&dst_pd); 703 output_mkl_shape.SetElemType(MklDnnType<T>()); 704 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 705 output_dims_mkl_order, output_tf_format); 706 707 TensorShape output_tf_shape; 708 output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); 709 AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, 710 output_tf_shape, output_mkl_shape); 711 CHECK_NOTNULL(*output_tensor); 712 } 713 714 memory::desc ConfigureInputGradient( 715 const MklDnnShape& input_gradient_mkl_shape, 716 const Tensor& input_gradient_tensor, 717 MklDnnData<T>* input_gradient_dnn_data, 718 const memory::desc& original_output_md) { 719 // Configure the gradient as is 720 memory::desc original_input_grad_md = 721 input_gradient_mkl_shape.IsMklTensor() 722 ? input_gradient_mkl_shape.GetMklLayout() 723 : memory::desc( 724 (this->ksize_.size() == 4) 725 ? TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), 726 this->data_format_tf_) 727 : TFShapeToMklDnnDimsInNCDHW( 728 input_gradient_tensor.shape(), 729 this->data_format_tf_), 730 MklDnnType<T>(), this->data_format_mkldnn_); 731 732 input_gradient_dnn_data->SetUsrMem(original_input_grad_md, 733 &input_gradient_tensor); 734 735 // Check to see if input grad diff dst is in the right format 736 // Create a new memory descriptor with the same shape as the 737 // original, but the format of the other tensors. 738 memory::format original_output_format = 739 static_cast<memory::format>(original_output_md.data.format); 740 bool grad_reorder_needed = 741 input_gradient_dnn_data->IsReorderNeeded(original_output_format); 742 memory::dims diff_dst_dims = 743 input_gradient_mkl_shape.IsMklTensor() 744 ? input_gradient_mkl_shape.GetSizesAsMklDnnDims() 745 : TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), 746 this->data_format_tf_); 747 memory::desc target_diff_dst_md = 748 memory::desc(diff_dst_dims, MklDnnType<T>(), original_output_format); 749 750 return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md; 751 } 752 }; 753 #endif // INTEL_MKL_ML_ONLY 754 755 //------------------------------------------------------------------- 756 // Utility functions 757 758 typedef struct { 759 size_t in_dim; 760 size_t in_sizes[4]; 761 size_t in_strides[4]; 762 size_t out_sizes[4]; 763 size_t out_strides[4]; 764 int in_offset[4]; 765 size_t kernel_stride[2]; 766 size_t kernel_size[2]; 767 } MklPoolingOpParams; 768 769 // Transfers the right parameters for pooling to the op parameters 770 // Updates context->status if there is an invalid input. 771 void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format, 772 const MklPoolParameters& params, 773 MklPoolingOpParams* mkl_params); 774 } // namespace tensorflow 775 776 #endif // INTEL_MKL 777 #endif // TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ 778