Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
     17 #define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
     18 
     19 #ifdef INTEL_MKL
     20 #include <memory>
     21 #include <string>
     22 #include <vector>
     23 #include "tensorflow/core/util/mkl_util.h"
     24 #include "tensorflow/core/util/padding.h"
     25 
     26 #ifndef INTEL_MKL_ML_ONLY
     27 #include "mkldnn.hpp"
     28 using mkldnn::memory;
     29 using mkldnn::pooling_backward;
     30 using mkldnn::pooling_forward;
     31 using mkldnn::stream;
     32 #endif
     33 
     34 namespace tensorflow {
     35 
     36 #ifndef INTEL_MKL_ML_ONLY
     37 
     38 using mkldnn::memory;
     39 using mkldnn::pooling_avg;
     40 using mkldnn::pooling_avg_exclude_padding;
     41 using mkldnn::pooling_avg_include_padding;
     42 using mkldnn::pooling_max;
     43 using mkldnn::prop_kind;
     44 
     45 struct MklPoolingParams {
     46   memory::dims src_dims;
     47   memory::dims dst_dims;
     48   memory::dims filter_dims;
     49   memory::dims strides;
     50   memory::dims padding_left;
     51   memory::dims padding_right;
     52   mkldnn::algorithm alg_kind;
     53   mkldnn::prop_kind prop_kind;
     54 
     55   MklPoolingParams(memory::dims src_dims, memory::dims dst_dims,
     56                    memory::dims filter_dims, memory::dims strides,
     57                    memory::dims padding_left, memory::dims padding_right,
     58                    mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind)
     59       : src_dims(src_dims),
     60         dst_dims(dst_dims),
     61         filter_dims(filter_dims),
     62         strides(strides),
     63         padding_left(padding_left),
     64         padding_right(padding_right),
     65         alg_kind(alg_kind),
     66         prop_kind(prop_kind) {}
     67 };
     68 
     69 template <typename T>
     70 class MklPoolingFwdPrimitive : public MklPrimitive {
     71  public:
     72   explicit MklPoolingFwdPrimitive(const MklPoolingParams& fwdParams)
     73       : cpu_engine_(engine::cpu, 0) {
     74     context_.fwd_stream.reset(new stream(stream::kind::eager));
     75     if (context_.fwd == nullptr) Setup(fwdParams);
     76   }
     77 
     78   ~MklPoolingFwdPrimitive() {}
     79 
     80   // Pooling forward execute
     81   //   src_data:  input data buffer of src
     82   //   ws_data:   output data buffer of workspace
     83   //   dst_data:  output data buffer of dst
     84   void Execute(const T* src_data, T* dst_data, void* ws_data = nullptr);
     85 
     86   std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd()
     87       const {
     88     return context_.fwd_pd;
     89   }
     90 
     91   memory::format GetSrcMemoryFormat() const { return context_.src_fmt; }
     92 
     93   memory::format GetDstMemoryFormat() const { return context_.dst_fmt; }
     94 
     95  private:
     96   void Setup(const MklPoolingParams& fwdParams);
     97 
     98   struct PoolingFwdContext {
     99     // algorithm
    100     mkldnn::algorithm alg_kind;
    101 
    102     // Kind of propagation, forward or backward
    103     mkldnn::prop_kind prop_kind;
    104 
    105     // expected memory format
    106     memory::format src_fmt;
    107     memory::format dst_fmt;
    108     memory::format ws_fmt;
    109 
    110     // workspace shape
    111     memory::dims ws_dims;
    112     memory::data_type ws_dt;
    113     size_t ws_size;
    114 
    115     // MKL-DNN memory, just dummy data
    116     std::shared_ptr<mkldnn::memory> ws_mem;
    117     std::shared_ptr<mkldnn::memory> src_mem;
    118     std::shared_ptr<mkldnn::memory> dst_mem;
    119 
    120     // desc & primitive desc
    121     std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
    122     std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
    123 
    124     // memory desc
    125     std::shared_ptr<mkldnn::memory::desc> src_md;
    126     std::shared_ptr<mkldnn::memory::desc> dst_md;
    127 
    128     // Pooling primitive
    129     std::shared_ptr<mkldnn::pooling_forward> fwd;
    130     std::shared_ptr<mkldnn::stream> fwd_stream;
    131     std::vector<mkldnn::primitive> fwd_primitives;
    132 
    133     PoolingFwdContext()
    134         : src_fmt(memory::format::any),
    135           dst_fmt(memory::format::any),
    136           ws_fmt(memory::format::any),
    137           ws_mem(nullptr),
    138           src_mem(nullptr),
    139           dst_mem(nullptr),
    140           fwd_desc(nullptr),
    141           fwd_pd(nullptr),
    142           src_md(nullptr),
    143           dst_md(nullptr),
    144           fwd(nullptr),
    145           fwd_stream(nullptr) {}
    146   };
    147 
    148   struct PoolingFwdContext context_;
    149   engine cpu_engine_;
    150 };
    151 
    152 template <typename T>
    153 class MklPoolingFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
    154  public:
    155   static MklPoolingFwdPrimitive<T>* Get(const MklPoolingParams& fwdParams) {
    156     MklPoolingFwdPrimitive<T>* pooling_forward = nullptr;
    157 
    158     // Get pooling primitive from the pool
    159     pooling_forward = static_cast<MklPoolingFwdPrimitive<T>*>(
    160         MklPoolingFwdPrimitiveFactory<T>::GetInstance().GetPoolingFwd(
    161             fwdParams));
    162 
    163     if (pooling_forward == nullptr) {
    164       pooling_forward = new MklPoolingFwdPrimitive<T>(fwdParams);
    165       MklPoolingFwdPrimitiveFactory<T>::GetInstance().SetPoolingFwd(
    166           fwdParams, pooling_forward);
    167     }
    168     return pooling_forward;
    169   }
    170 
    171   static MklPoolingFwdPrimitiveFactory& GetInstance() {
    172     static MklPoolingFwdPrimitiveFactory instance_;
    173     return instance_;
    174   }
    175 
    176  private:
    177   MklPoolingFwdPrimitiveFactory() {}
    178   ~MklPoolingFwdPrimitiveFactory() {}
    179 
    180   // The key to be created will be used to get/set pooling
    181   // primitive op from reuse perspective.
    182   // A pooling key is a string which concates key parameters
    183   // as well as algorithm kind (max versus avg).
    184   static string CreateKey(const MklPoolingParams& fwdParams) {
    185     string prefix = "pooling_fwd";
    186     FactoryKeyCreator key_creator;
    187     key_creator.AddAsKey(prefix);
    188     key_creator.AddAsKey(fwdParams.src_dims);
    189     key_creator.AddAsKey(fwdParams.dst_dims);
    190     key_creator.AddAsKey(fwdParams.filter_dims);
    191     key_creator.AddAsKey(fwdParams.strides);
    192     key_creator.AddAsKey(fwdParams.padding_left);
    193     key_creator.AddAsKey(fwdParams.padding_right);
    194     key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
    195     key_creator.AddAsKey<int>(static_cast<int>(fwdParams.prop_kind));
    196     return key_creator.GetKey();
    197   }
    198 
    199   MklPrimitive* GetPoolingFwd(const MklPoolingParams& fwdParams) {
    200     string key = CreateKey(fwdParams);
    201     return this->GetOp(key);
    202   }
    203 
    204   void SetPoolingFwd(const MklPoolingParams& fwdParams, MklPrimitive* op) {
    205     string key = CreateKey(fwdParams);
    206     this->SetOp(key, op);
    207   }
    208 };
    209 
    210 template <typename T>
    211 class MklPoolingBwdPrimitive : public MklPrimitive {
    212  public:
    213   explicit MklPoolingBwdPrimitive(const MklPoolingParams& bwdParams)
    214       : cpu_engine(engine::cpu, 0) {
    215     context_.bwd_stream.reset(new stream(stream::kind::eager));
    216     if (context_.bwd == nullptr) Setup(bwdParams);
    217   }
    218 
    219   ~MklPoolingBwdPrimitive() {}
    220 
    221   // Pooling backward execute
    222   //   diff_dst_data:  input data buffer of diff_dst
    223   //   diff_src_data:  output data buffer of diff_src
    224   //   ws_data:        input data buffer of workspace
    225   void Execute(const T* diff_dst_data, T* diff_src_data,
    226                const void* ws_data = nullptr);
    227 
    228  public:
    229   std::shared_ptr<mkldnn::pooling_forward::primitive_desc> GetPoolingFwdPd()
    230       const {
    231     return context_.fwd_pd;
    232   }
    233   std::shared_ptr<mkldnn::pooling_backward::primitive_desc> GetPoolingBwdPd()
    234       const {
    235     return context_.bwd_pd;
    236   }
    237 
    238   memory::format GetDiffDstFormat() const { return context_.diff_dst_fmt; }
    239 
    240   mkldnn::memory::data_type GetWorkspaceDataType() const {
    241     return context_.ws_dt;
    242   }
    243   memory::format GetWorkspaceFormat() const { return context_.ws_fmt; }
    244 
    245  private:
    246   void Setup(const MklPoolingParams& bwdParams);
    247 
    248   // Primitive reuse context for pooling bwd ops
    249   struct PoolingBwdContext {
    250     // algorithm
    251     mkldnn::algorithm alg_kind;
    252 
    253     // expected memory format
    254     mkldnn::memory::format diff_src_fmt;
    255     mkldnn::memory::format diff_dst_fmt;
    256     mkldnn::memory::format ws_fmt;
    257 
    258     // workspace attribute
    259     mkldnn::memory::dims ws_dims;
    260     mkldnn::memory::data_type ws_dt;
    261 
    262     // MKL-DNN memory
    263     std::shared_ptr<mkldnn::memory> ws_mem;
    264     std::shared_ptr<mkldnn::memory> diff_src_mem;
    265     std::shared_ptr<mkldnn::memory> diff_dst_mem;
    266 
    267     // memory desc
    268     std::shared_ptr<mkldnn::memory::desc> diff_src_md;
    269     std::shared_ptr<mkldnn::memory::desc> diff_dst_md;
    270 
    271     // desc & primitive desc
    272     std::shared_ptr<mkldnn::pooling_forward::desc> fwd_desc;
    273     std::shared_ptr<mkldnn::pooling_backward::desc> bwd_desc;
    274     std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd;
    275     std::shared_ptr<mkldnn::pooling_backward::primitive_desc> bwd_pd;
    276 
    277     // pooling primitive
    278     std::shared_ptr<mkldnn::pooling_backward> bwd;
    279     std::shared_ptr<mkldnn::stream> bwd_stream;
    280 
    281     std::vector<mkldnn::primitive> bwd_primitives;
    282 
    283     PoolingBwdContext()
    284         : diff_src_fmt(memory::format::any),
    285           diff_dst_fmt(memory::format::any),
    286           ws_fmt(memory::format::any),
    287           ws_mem(nullptr),
    288           diff_src_mem(nullptr),
    289           diff_dst_mem(nullptr),
    290           diff_src_md(nullptr),
    291           diff_dst_md(nullptr),
    292           fwd_desc(nullptr),
    293           bwd_desc(nullptr),
    294           fwd_pd(nullptr),
    295           bwd_pd(nullptr),
    296           bwd(nullptr),
    297           bwd_stream(nullptr) {}
    298   };
    299 
    300   struct PoolingBwdContext context_;
    301   engine cpu_engine;
    302 };
    303 
    304 template <typename T>
    305 class MklPoolingBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
    306  public:
    307   static MklPoolingBwdPrimitive<T>* Get(const MklPoolingParams& bwdParams) {
    308     MklPoolingBwdPrimitive<T>* pooling_backward = nullptr;
    309 
    310     // Find a pooling backward primitive from the pool
    311     // If it does not exist, create a new one
    312     pooling_backward = static_cast<MklPoolingBwdPrimitive<T>*>(
    313         MklPoolingBwdPrimitiveFactory<T>::GetInstance().GetPoolingBwd(
    314             bwdParams));
    315     if (pooling_backward == nullptr) {
    316       pooling_backward = new MklPoolingBwdPrimitive<T>(bwdParams);
    317       MklPoolingBwdPrimitiveFactory<T>::GetInstance().SetPoolingBwd(
    318           bwdParams, pooling_backward);
    319     }
    320     return pooling_backward;
    321   }
    322 
    323   static MklPoolingBwdPrimitiveFactory& GetInstance() {
    324     static MklPoolingBwdPrimitiveFactory instance_;
    325     return instance_;
    326   }
    327 
    328  private:
    329   MklPoolingBwdPrimitiveFactory() {}
    330   ~MklPoolingBwdPrimitiveFactory() {}
    331 
    332   // The key to be created will be used to get/set pooling
    333   // primitive op from reuse perspective.
    334   // A pooling key is a string which concates key parameters
    335   // as well as algorithm kind (max versus avg).
    336   static string CreateKey(const MklPoolingParams& bwdParams) {
    337     string prefix = "pooling_bwd";
    338     FactoryKeyCreator key_creator;
    339     key_creator.AddAsKey(prefix);
    340     key_creator.AddAsKey(bwdParams.src_dims);
    341     key_creator.AddAsKey(bwdParams.dst_dims);
    342     key_creator.AddAsKey(bwdParams.filter_dims);
    343     key_creator.AddAsKey(bwdParams.strides);
    344     key_creator.AddAsKey(bwdParams.padding_left);
    345     key_creator.AddAsKey(bwdParams.padding_right);
    346     key_creator.AddAsKey<int>(static_cast<int>(bwdParams.alg_kind));
    347     return key_creator.GetKey();
    348   }
    349 
    350   MklPrimitive* GetPoolingBwd(const MklPoolingParams& bwdParams) {
    351     string key = CreateKey(bwdParams);
    352     return this->GetOp(key);
    353   }
    354 
    355   void SetPoolingBwd(const MklPoolingParams& bwdParams, MklPrimitive* op) {
    356     string key = CreateKey(bwdParams);
    357     this->SetOp(key, op);
    358   }
    359 };
    360 #endif
    361 
    362 typedef Eigen::ThreadPoolDevice CPUDevice;
    363 
    364 struct MklPoolParameters {
    365   int depth;
    366 
    367   int tensor_in_planes;  // Pool3D
    368   int tensor_in_cols;
    369   int tensor_in_rows;
    370   int tensor_in_batch;
    371 
    372   int window_planes;  // Pool3D
    373   int window_rows;
    374   int window_cols;
    375   int depth_window;
    376 
    377   int planes_stride;  // Pool3D
    378   int row_stride;
    379   int col_stride;
    380   int depth_stride;
    381 
    382   int64 out_planes;  // Pool3D
    383   int64 out_height;
    384   int64 out_width;
    385   int out_depth;
    386 
    387   int64 pad_P1;  // Pool3D
    388   int64 pad_P2;  // Pool3D
    389   int64 pad_left;
    390   int64 pad_right;
    391   int64 pad_top;
    392   int64 pad_bottom;
    393   int pad_depth;
    394 
    395   TensorFormat data_format;
    396   MklPoolParameters()
    397       : depth(0),
    398         tensor_in_planes(0),
    399         tensor_in_cols(0),
    400         tensor_in_rows(0),
    401         tensor_in_batch(0),
    402         window_planes(0),
    403         window_rows(0),
    404         window_cols(0),
    405         depth_window(0),
    406         planes_stride(0),
    407         row_stride(0),
    408         col_stride(0),
    409         depth_stride(0),
    410         out_planes(0),
    411         out_height(0),
    412         out_width(0),
    413         out_depth(0),
    414         pad_P1(0),
    415         pad_P2(0),
    416         pad_left(0),
    417         pad_right(0),
    418         pad_top(0),
    419         pad_bottom(0),
    420         pad_depth(0),
    421         data_format(TensorFormat::FORMAT_NCHW) {}
    422 
    423   // Updates context->status if there is an invalid input.
    424   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
    425             const std::vector<int32>& stride, Padding padding,
    426             TensorFormat data_format, const TensorShape& tensor_in_shape);
    427 #ifdef INTEL_MKL_ML_ONLY
    428   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
    429             const std::vector<int32>& stride, Padding padding,
    430             TensorFormat data_format, const MklShape* mkl_in_shape);
    431 #else
    432   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
    433             const std::vector<int32>& stride, Padding padding,
    434             TensorFormat data_format, const MklDnnShape* mkl_in_shape);
    435 #endif
    436 
    437  private:
    438   // Common initialization for TensorFlow and MKL formats
    439   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
    440             const std::vector<int32>& stride, Padding padding,
    441             TensorFormat data_format);
    442 };
    443 
    444 #ifndef INTEL_MKL_ML_ONLY
    445 
    446 template <class T>
    447 class MklPoolingOpBase : public OpKernel {
    448  public:
    449   explicit MklPoolingOpBase(OpKernelConstruction* context)
    450       : OpKernel(context), workspace_enabled_(false) {
    451     string data_format;
    452     if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
    453       // current quantized convolution doesn't have data_format attribute.
    454       data_format = "NHWC";
    455     } else {
    456       OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    457     }
    458     OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_),
    459                 errors::InvalidArgument("Invalid data format"));
    460     OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_));
    461     OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5,
    462                 errors::InvalidArgument("Sliding window ksize field must "
    463                                         "specify 4 or 5 dimensions"));
    464     OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_));
    465     OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5,
    466                 errors::InvalidArgument("Sliding window strides field must "
    467                                         "specify 4 or 5 dimensions"));
    468     OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_));
    469     OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1,
    470                 errors::Unimplemented("Pooling is not yet supported on the "
    471                                       "batch dimension."));
    472     bool is_pool2d = (this->ksize_.size() == 4);
    473     this->data_format_mkldnn_ =
    474         is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_)
    475                   : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_);
    476 
    477     // We may not get this attribute for this node if it does not go through
    478     // graph rewrite pass. So we do not check for error while retrieving this
    479     // attribute value.
    480     context->GetAttr("workspace_enabled", &this->workspace_enabled_);
    481   }
    482   void Compute(OpKernelContext* context) override = 0;
    483 
    484  protected:
    485   // Calculate output shape of pooling op in MKL-DNN and TensorFlow order.
    486   // MKL-DNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order.
    487   // But TensorFlow output will be in NHWC/NCHW(Pool2D) or
    488   // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects
    489   // output height and width to have already been int32 bounds-checked.
    490   void GetOutputDims(const MklPoolParameters& mkl_pool_params,
    491                      memory::dims* output_dims_mkl_order) {
    492     if (this->ksize_.size() == 4) {
    493       // Pooling2D: MKL-DNN always needs output in NCHW format.
    494       *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
    495                                 mkl_pool_params.out_depth,
    496                                 static_cast<int>(mkl_pool_params.out_height),
    497                                 static_cast<int>(mkl_pool_params.out_width)};
    498     } else {
    499       // Pooling3D: MKL-DNN always needs output in NCDHW format.
    500       *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
    501                                 mkl_pool_params.out_depth,
    502                                 static_cast<int>(mkl_pool_params.out_planes),
    503                                 static_cast<int>(mkl_pool_params.out_height),
    504                                 static_cast<int>(mkl_pool_params.out_width)};
    505     }
    506   }
    507 
    508   void InitMklPoolParameters(OpKernelContext* context,
    509                              MklPoolParameters* pool_params,
    510                              const MklDnnShape& original_input_mkl_shape,
    511                              const TensorShape& input_tensor_shape) {
    512     if (!original_input_mkl_shape.IsMklTensor()) {
    513       pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
    514                         this->data_format_tf_, input_tensor_shape);
    515     } else {
    516       pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
    517                         this->data_format_tf_, &original_input_mkl_shape);
    518     }
    519   }
    520 
    521   void PoolParamsToDims(const MklPoolParameters* pool_params,
    522                         memory::dims* filter_dims, memory::dims* strides,
    523                         memory::dims* padding_left, memory::dims* padding_right,
    524                         bool is_pool2d) {
    525     if (is_pool2d) {
    526       // Pool2D
    527       *filter_dims =
    528           memory::dims({pool_params->window_rows, pool_params->window_cols});
    529       *strides =
    530           memory::dims({pool_params->row_stride, pool_params->col_stride});
    531       *padding_left = memory::dims({static_cast<int>(pool_params->pad_top),
    532                                     static_cast<int>(pool_params->pad_left)});
    533       *padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom),
    534                                      static_cast<int>(pool_params->pad_right)});
    535     } else {
    536       // Pool3D
    537       *filter_dims =
    538           memory::dims({pool_params->window_planes, pool_params->window_rows,
    539                         pool_params->window_cols});
    540       *strides =
    541           memory::dims({pool_params->planes_stride, pool_params->row_stride,
    542                         pool_params->col_stride});
    543 
    544       *padding_left = memory::dims({static_cast<int>(pool_params->pad_P1),
    545                                     static_cast<int>(pool_params->pad_top),
    546                                     static_cast<int>(pool_params->pad_left)});
    547       *padding_right = memory::dims({static_cast<int>(pool_params->pad_P2),
    548                                      static_cast<int>(pool_params->pad_bottom),
    549                                      static_cast<int>(pool_params->pad_right)});
    550     }
    551   }
    552 
    553   void AllocateEmptyOutputTensor(OpKernelContext* context,
    554                                  const int kOutputIndex,
    555                                  MklPoolParameters* pool_params,
    556                                  const memory::dims output_dims_mkl_order,
    557                                  Tensor** output_tensor) {
    558     MklDnnShape output_mkl_shape;
    559     output_mkl_shape.SetMklTensor(false);
    560     TensorShape output_tf_shape;
    561     if (pool_params->data_format == TensorFormat::FORMAT_NCHW) {
    562       output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
    563     } else {
    564       memory::dims output_dims_NHWC_order;
    565       output_dims_NHWC_order = {pool_params->tensor_in_batch,
    566                                 static_cast<int>(pool_params->out_height),
    567                                 static_cast<int>(pool_params->out_width),
    568                                 pool_params->out_depth};
    569       output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
    570     }
    571     AllocateOutputSetMklShape(context, kOutputIndex, output_tensor,
    572                               output_tf_shape, output_mkl_shape);
    573     CHECK_NOTNULL(output_tensor);
    574   }
    575 
    576   // Checks to make sure that the memory we need to allocate
    577   // is a multiple of sizeof(T)
    578   // returns the number of elements
    579   size_t GetNumTElements(const memory::primitive_desc& pd) {
    580     size_t num_bytes = pd.get_size();
    581     size_t ret_val = num_bytes / sizeof(T);
    582     if (num_bytes % sizeof(T) != 0) {
    583       ret_val++;
    584     }
    585     return ret_val;
    586   }
    587 
    588   std::vector<int32> ksize_;
    589   std::vector<int32> stride_;
    590   Padding padding_;
    591   TensorFormat data_format_tf_;
    592   memory::format data_format_mkldnn_;
    593   bool workspace_enabled_;
    594 };
    595 
    596 template <class T>
    597 class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
    598  public:
    599   explicit MklPoolingForwardOpBase<T>(OpKernelConstruction* context)
    600       : MklPoolingOpBase<T>(context) {}
    601   void Compute(OpKernelContext* context) override = 0;
    602 
    603  protected:
    604   void ConfigureInput(OpKernelContext* context,
    605                       const MklDnnShape& input_mkl_shape,
    606                       const Tensor& input_tensor,
    607                       MklPoolParameters* pool_params,
    608                       MklDnnData<T>* dnn_data_input) {
    609     CHECK_NOTNULL(pool_params);
    610     CHECK_NOTNULL(dnn_data_input);
    611     TensorShape input_tensor_shape = input_tensor.shape();
    612     if (input_tensor.NumElements() != 0) {
    613       memory::desc input_md =
    614           input_mkl_shape.IsMklTensor()
    615               ? input_mkl_shape.GetMklLayout()
    616               : memory::desc(
    617                     (this->ksize_.size() == 4)
    618                         ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
    619                                                     this->data_format_tf_)
    620                         : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape,
    621                                                      this->data_format_tf_),
    622                     MklDnnType<T>(), this->data_format_mkldnn_);
    623       dnn_data_input->SetUsrMem(input_md, &input_tensor);
    624 
    625       if (this->ksize_.size() == 5) {
    626         // Pool3D
    627         std::vector<int> mkldnn_sizes(5, -1);
    628         mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0];
    629         mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1];
    630         mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2];
    631         mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3];
    632         mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4];
    633         dnn_data_input->SetOpMemDesc(mkldnn_sizes, this->data_format_mkldnn_);
    634       }
    635     }
    636     this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
    637                                 input_tensor_shape);
    638   }
    639 
    640   void AllocateOutputTensor(
    641       OpKernelContext* context,
    642       const pooling_forward::primitive_desc& pool_fwd_prim_desc,
    643       const memory::dims output_dims_mkl_order,
    644       const memory::format& output_tf_format, Tensor** output_tensor) {
    645     CHECK_NOTNULL(output_tensor);
    646     memory::primitive_desc dst_pd = pool_fwd_prim_desc.dst_primitive_desc();
    647 
    648     MklDnnShape output_mkl_shape;
    649     output_mkl_shape.SetMklTensor(true);
    650     output_mkl_shape.SetMklLayout(&dst_pd);
    651     output_mkl_shape.SetElemType(MklDnnType<T>());
    652     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    653                                  output_dims_mkl_order, output_tf_format);
    654     TensorShape output_tf_shape;
    655 
    656     // only allocate enough space for the elements we need.
    657     output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
    658     AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
    659                               output_tf_shape, output_mkl_shape);
    660     CHECK_NOTNULL(*output_tensor);
    661   }
    662 
    663   void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
    664                         const MklDnnShape& input_mkl_shape) {
    665     if (!input_mkl_shape.IsMklTensor()) {
    666       OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5,
    667                   errors::InvalidArgument("Input must be 4 or 5-dimensional"));
    668     } else {
    669       OP_REQUIRES(
    670           context,
    671           input_mkl_shape.GetDimension() == 4 ||
    672               input_mkl_shape.GetDimension() == 5,
    673           errors::InvalidArgument("Input shape must be 4 or 5-dimensional"));
    674     }
    675   }
    676   // .Input("value: T")
    677   // .Output("output: T")
    678   const int kInputTensorIndexInput = 0;
    679   const int kOutputTensorIndexOutput = 0;
    680 };  // MklPoolingForwardBaseOp
    681 
    682 template <class T>
    683 class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
    684  public:
    685   explicit MklPoolingBackwardOpBase<T>(OpKernelConstruction* context)
    686       : MklPoolingOpBase<T>(context) {}
    687   void Compute(OpKernelContext* context) override = 0;
    688 
    689  protected:
    690   const int kOutputTensorIndexOutput = 0;
    691 
    692   void AllocateOutputTensor(
    693       OpKernelContext* context,
    694       const pooling_backward::primitive_desc& pool_bkwd_prim_desc,
    695       const memory::dims output_dims_mkl_order,
    696       const memory::format& output_tf_format, Tensor** output_tensor) {
    697     CHECK_NOTNULL(output_tensor);
    698     memory::primitive_desc dst_pd =
    699         pool_bkwd_prim_desc.diff_src_primitive_desc();
    700     MklDnnShape output_mkl_shape;
    701     output_mkl_shape.SetMklTensor(true);
    702     output_mkl_shape.SetMklLayout(&dst_pd);
    703     output_mkl_shape.SetElemType(MklDnnType<T>());
    704     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    705                                  output_dims_mkl_order, output_tf_format);
    706 
    707     TensorShape output_tf_shape;
    708     output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
    709     AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
    710                               output_tf_shape, output_mkl_shape);
    711     CHECK_NOTNULL(*output_tensor);
    712   }
    713 
    714   memory::desc ConfigureInputGradient(
    715       const MklDnnShape& input_gradient_mkl_shape,
    716       const Tensor& input_gradient_tensor,
    717       MklDnnData<T>* input_gradient_dnn_data,
    718       const memory::desc& original_output_md) {
    719     // Configure the gradient as is
    720     memory::desc original_input_grad_md =
    721         input_gradient_mkl_shape.IsMklTensor()
    722             ? input_gradient_mkl_shape.GetMklLayout()
    723             : memory::desc(
    724                   (this->ksize_.size() == 4)
    725                       ? TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
    726                                                   this->data_format_tf_)
    727                       : TFShapeToMklDnnDimsInNCDHW(
    728                             input_gradient_tensor.shape(),
    729                             this->data_format_tf_),
    730                   MklDnnType<T>(), this->data_format_mkldnn_);
    731 
    732     input_gradient_dnn_data->SetUsrMem(original_input_grad_md,
    733                                        &input_gradient_tensor);
    734 
    735     // Check to see if input grad diff dst is in the right format
    736     // Create a new memory descriptor with the same shape as the
    737     // original, but the format of the other tensors.
    738     memory::format original_output_format =
    739         static_cast<memory::format>(original_output_md.data.format);
    740     bool grad_reorder_needed =
    741         input_gradient_dnn_data->IsReorderNeeded(original_output_format);
    742     memory::dims diff_dst_dims =
    743         input_gradient_mkl_shape.IsMklTensor()
    744             ? input_gradient_mkl_shape.GetSizesAsMklDnnDims()
    745             : TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
    746                                         this->data_format_tf_);
    747     memory::desc target_diff_dst_md =
    748         memory::desc(diff_dst_dims, MklDnnType<T>(), original_output_format);
    749 
    750     return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md;
    751   }
    752 };
    753 #endif  // INTEL_MKL_ML_ONLY
    754 
    755 //-------------------------------------------------------------------
    756 // Utility functions
    757 
    758 typedef struct {
    759   size_t in_dim;
    760   size_t in_sizes[4];
    761   size_t in_strides[4];
    762   size_t out_sizes[4];
    763   size_t out_strides[4];
    764   int in_offset[4];
    765   size_t kernel_stride[2];
    766   size_t kernel_size[2];
    767 } MklPoolingOpParams;
    768 
    769 // Transfers the right parameters for pooling to the op parameters
    770 // Updates context->status if there is an invalid input.
    771 void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format,
    772                         const MklPoolParameters& params,
    773                         MklPoolingOpParams* mkl_params);
    774 }  // namespace tensorflow
    775 
    776 #endif  // INTEL_MKL
    777 #endif  // TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
    778