Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3    Licensed under the Apache License, Version 2.0 (the "License");
      4    you may not use this file except in compliance with the License.
      5    You may obtain a copy of the License at
      6 
      7    http://www.apache.org/licenses/LICENSE-2.0
      8 
      9    Unless required by applicable law or agreed to in writing, software
     10    distributed under the License is distributed on an "AS IS" BASIS,
     11    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12    See the License for the specific language governing permissions and
     13    limitations under the License.
     14    ==============================================================================*/
     15 
     16 #ifdef INTEL_MKL
     17 #define EIGEN_USE_THREADS
     18 
     19 #include "tensorflow/core/common_runtime/device.h"
     20 #include "tensorflow/core/framework/common_shape_fns.h"
     21 #include "tensorflow/core/framework/numeric_op.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/util/mkl_util.h"
     24 
     25 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
     26 
     27 #ifndef INTEL_MKL_ML
     28 #include "mkldnn.hpp"
     29 using mkldnn::algorithm;
     30 using mkldnn::engine;
     31 using mkldnn::error;
     32 using mkldnn::memory;
     33 using mkldnn::padding_kind;
     34 using mkldnn::pooling_backward;
     35 using mkldnn::pooling_forward;
     36 using mkldnn::prop_kind;
     37 #endif
     38 
     39 namespace tensorflow {
     40 
     41 typedef Eigen::ThreadPoolDevice CPUDevice;
     42 
     43 #ifdef INTEL_MKL_ML
     44 
     45 template <typename Device, typename T>
     46 class MklAvgPoolingOp : public OpKernel {
     47  public:
     48   explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
     49     string data_format;
     50     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     51     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     52                 errors::InvalidArgument("Invalid data format"));
     53 
     54     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
     55     OP_REQUIRES(context, ksize_.size() == 4,
     56                 errors::InvalidArgument("Sliding window ksize field must "
     57                                         "specify 4 dimensions"));
     58     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
     59     OP_REQUIRES(context, stride_.size() == 4,
     60                 errors::InvalidArgument("Sliding window stride field must "
     61                                         "specify 4 dimensions"));
     62     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     63     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
     64                 errors::Unimplemented("Pooling is not yet supported on the "
     65                                       "batch dimension."));
     66   }
     67 
     68   void Compute(OpKernelContext* context) override {
     69     MklAvgPoolingOpContext mkl_context;
     70     const Tensor& tensor_in = MklGetInput(context, 0);
     71     GetMklShape(context, 0, &mkl_context.input_shape);
     72     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
     73 
     74     if (!input_in_mkl_format)
     75       mkl_context.params.in_dim = tensor_in.dims();
     76     else
     77       mkl_context.params.in_dim = mkl_context.input_shape.GetDimension();
     78 
     79     MklPoolParameters pool_params;
     80     if (!input_in_mkl_format) {
     81       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     82                        tensor_in.shape());
     83     } else {
     84       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     85                        &mkl_context.input_shape);
     86     }
     87 
     88     // Extract the parameters for the op from the pooling specs
     89     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
     90 
     91     Tensor mkl_tmp_input_buf_tensor_;
     92     mkl_context.MklCreateLayoutsAndPrimitives(context,
     93                                               &mkl_tmp_input_buf_tensor_);
     94     OP_REQUIRES_OK(context, context->status());
     95 
     96     Tensor workspace_tensor;
     97     void* workspace_buf;
     98     AllocTmpBuffer(context, &workspace_tensor, mkl_context.lt_workspace,
     99                    &workspace_buf);
    100 
    101     if (mkl_context.convert_input != nullptr) {
    102       if (input_in_mkl_format == false) {
    103         CHECK_EQ(
    104             dnnConversionExecute_F32(
    105                 mkl_context.convert_input,
    106                 static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())),
    107                 mkl_context.input_buf),
    108             E_SUCCESS);
    109         CHECK_EQ(dnnDelete_F32(mkl_context.convert_input), E_SUCCESS);
    110       } else {
    111         mkl_context.input_shape.GetConvertedFlatData(
    112             mkl_context.lt_prim_input,
    113             static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data())),
    114             mkl_context.input_buf);
    115       }
    116       mkl_context.pooling_res[dnnResourceSrc] = mkl_context.input_buf;
    117     } else {
    118       mkl_context.pooling_res[dnnResourceSrc] =
    119           static_cast<void*>(const_cast<T*>(tensor_in.flat<T>().data()));
    120     }
    121 
    122     // Declare output tensor and allocate memory
    123     Tensor* output = nullptr;
    124     TensorShape tensor_out_shape;
    125     MklShape mkl_out_shape;
    126     mkl_out_shape.SetMklTensor(true);
    127     mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
    128     mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
    129                               mkl_context.params.out_sizes,
    130                               mkl_context.params.out_strides);
    131     mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    132 
    133     tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    134                                 mkl_out_shape.GetMklLayout())) /
    135                             sizeof(T));
    136 
    137     AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
    138                               mkl_out_shape);
    139     mkl_context.pooling_res[dnnResourceDst] =
    140         static_cast<void*>(output->flat<T>().data());
    141 
    142     mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf;
    143 
    144     CHECK_EQ(
    145         dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res),
    146         E_SUCCESS);
    147 
    148     mkl_context.MklCleanup();
    149   }  // Compute
    150 
    151  private:
    152   typedef struct {
    153     MklPoolingOpParams params;
    154     MklShape input_shape;
    155     dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr;
    156     dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr,
    157                 lt_workspace = nullptr;
    158     void* input_buf = nullptr;
    159     void* pooling_res[dnnResourceNumber];
    160 
    161     void MklCreateLayoutsAndPrimitives(OpKernelContext* context,
    162                                        Tensor* mkl_tmp_input_buf_tensor) {
    163       bool input_in_mkl_format = input_shape.IsMklTensor();
    164 
    165       if (!input_in_mkl_format) {
    166         CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
    167                                      params.in_sizes, params.in_strides),
    168                  E_SUCCESS);
    169       } else {
    170         lt_user_input = (dnnLayout_t)input_shape.GetCurLayout();
    171       }
    172 
    173       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg;
    174       dnnPrimitiveAttributes_t primAttr = nullptr;
    175 
    176       // Create DNN primitives
    177       CHECK_EQ(dnnPoolingCreateForward_F32(
    178                    &prim_pooling_fwd, primAttr, algorithm, lt_user_input,
    179                    params.kernel_size, params.kernel_stride, params.in_offset,
    180                    dnnBorderZerosAsymm),
    181                E_SUCCESS);
    182 
    183       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    184                    &lt_prim_input, prim_pooling_fwd, dnnResourceSrc),
    185                E_SUCCESS);
    186       if (!dnnLayoutCompare_F32(lt_user_input, lt_prim_input)) {
    187         CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_user_input,
    188                                          lt_prim_input),
    189                  E_SUCCESS);
    190 
    191         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_prim_input,
    192                        &input_buf);
    193       }
    194 
    195       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_fwd,
    196                                                 dnnResourceWorkspace),
    197                E_SUCCESS);
    198     }
    199 
    200     void MklCleanup() {
    201       bool input_in_mkl_format = input_shape.IsMklTensor();
    202       if (!input_in_mkl_format) {
    203         CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
    204       }
    205 
    206       CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
    207       CHECK_EQ(dnnLayoutDelete_F32(lt_prim_input), E_SUCCESS);
    208     }
    209   } MklAvgPoolingOpContext;
    210 
    211   std::vector<int32> ksize_;
    212   std::vector<int32> stride_;
    213   Padding padding_;
    214   TensorFormat data_format_;
    215 };
    216 
    217 //-----------------------------------------------------------------------------
    218 
    219 template <class Device, class T>
    220 class MklAvgPoolingGradOp : public OpKernel {
    221  public:
    222   explicit MklAvgPoolingGradOp(OpKernelConstruction* context)
    223       : OpKernel(context) {
    224     string data_format;
    225 
    226     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    227     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    228                 errors::InvalidArgument("Invalid data format"));
    229     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    230     OP_REQUIRES(context, ksize_.size() == 4,
    231                 errors::InvalidArgument("Sliding window ksize field must "
    232                                         "specify 4 dimensions"));
    233     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    234     OP_REQUIRES(context, stride_.size() == 4,
    235                 errors::InvalidArgument("Sliding window strides field must "
    236                                         "specify 4 dimensions"));
    237     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    238     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    239                 errors::Unimplemented("Pooling is not yet supported on the "
    240                                       "batch dimension."));
    241   }
    242 
    243   void Compute(OpKernelContext* context) override {
    244     MklAvgPoolingGradOpContext mkl_context;
    245     const Tensor& tensor_in_shape = MklGetInput(context, 0);
    246     const Tensor& out_backprop = MklGetInput(context, 1);
    247     GetMklShape(context, 1, &mkl_context.out_backprop_shape);
    248     bool outbackprop_in_mkl_format =
    249         mkl_context.out_backprop_shape.IsMklTensor();
    250 
    251     TensorShape output_shape;
    252     auto shape_vec = tensor_in_shape.vec<int32>();
    253     for (int64 i = 0; i < tensor_in_shape.NumElements(); ++i) {
    254       output_shape.AddDim(shape_vec(i));
    255     }
    256 
    257     MklPoolParameters pool_params;
    258     pool_params.Init(context, ksize_, stride_, padding_, data_format_,
    259                      output_shape);
    260 
    261     if (outbackprop_in_mkl_format == false)
    262       mkl_context.params.in_dim = out_backprop.dims();
    263     else
    264       mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension();
    265 
    266     // Extract the parameters for the op from the pooling specs
    267     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
    268 
    269     // Tensors needed to create temporary buffers
    270     Tensor outbackprop_buf_tensor;
    271     void* outbackprop_buf;
    272     mkl_context.MklCreateLayoutsAndPrimitives(context);
    273     OP_REQUIRES_OK(context, context->status());
    274 
    275     // Check if outbackprop layout requires conversion.
    276     if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop,
    277                               mkl_context.lt_prim_outbackprop)) {
    278       CHECK_EQ(dnnConversionCreate_F32(&mkl_context.convert_outbackprop,
    279                                        mkl_context.lt_user_outbackprop,
    280                                        mkl_context.lt_prim_outbackprop),
    281                E_SUCCESS);
    282 
    283       AllocTmpBuffer(context, &outbackprop_buf_tensor,
    284                      mkl_context.lt_prim_outbackprop, &outbackprop_buf);
    285 
    286       if (!outbackprop_in_mkl_format) {
    287         CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_outbackprop,
    288                                           static_cast<void*>(const_cast<T*>(
    289                                               out_backprop.flat<T>().data())),
    290                                           outbackprop_buf),
    291                  E_SUCCESS);
    292         CHECK_EQ(dnnDelete_F32(mkl_context.convert_outbackprop), E_SUCCESS);
    293       } else {
    294         mkl_context.out_backprop_shape.GetConvertedFlatData(
    295             mkl_context.lt_prim_outbackprop,
    296             static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data())),
    297             outbackprop_buf);
    298       }
    299       mkl_context.pooling_res[dnnResourceDiffDst] = outbackprop_buf;
    300     } else {
    301       mkl_context.pooling_res[dnnResourceDiffDst] =
    302           static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data()));
    303     }
    304 
    305     // Handle workspace requirements.
    306     Tensor workspace_buf_tensor;
    307     void* workspace_buf;
    308     AllocTmpBuffer(context, &workspace_buf_tensor, mkl_context.lt_workspace,
    309                    &workspace_buf);
    310     mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf;
    311 
    312     // Handle MKL output tensor setup.
    313     Tensor* output = nullptr;
    314     TensorShape tensor_out_shape;
    315     MklShape mkl_out_shape;
    316     mkl_out_shape.SetMklTensor(true);
    317     mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_bwd,
    318                                dnnResourceDiffSrc);
    319     mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
    320                               mkl_context.params.in_sizes,
    321                               mkl_context.params.in_strides);
    322     mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    323 
    324     tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    325                                 mkl_out_shape.GetMklLayout())) /
    326                             sizeof(T));
    327 
    328     AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
    329                               mkl_out_shape);
    330 
    331     // Set output tensor.
    332     mkl_context.pooling_res[dnnResourceDiffSrc] =
    333         static_cast<void*>(output->flat<T>().data());
    334 
    335     // Execute primitive.
    336     CHECK_EQ(
    337         dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
    338         E_SUCCESS);
    339 
    340     mkl_context.MklCleanup();
    341   }
    342 
    343  private:
    344   typedef struct {
    345     MklPoolingOpParams params;
    346     MklShape out_backprop_shape;
    347     dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr;
    348     void* pooling_res[dnnResourceNumber];
    349     dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr,
    350                 lt_prim_outbackprop = nullptr, lt_workspace = nullptr;
    351 
    352     void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
    353       const Tensor& tensor_in_shape = MklGetInput(context, 0);
    354       const Tensor& out_backprop = MklGetInput(context, 1);
    355       bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor();
    356 
    357       if (!outbackprop_in_mkl_format) {
    358         // For avgpooling, tensor_in_shape should have 1 dimension, and 4
    359         // elements.
    360         OP_REQUIRES(
    361             context,
    362             tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
    363             errors::InvalidArgument("original input shape must be "
    364                                     "1-dimensional and 4 elements"));
    365 
    366         // For avgpooling, out_backprop should have 4 dimensions.
    367         OP_REQUIRES(context, out_backprop.dims() == 4,
    368                     errors::InvalidArgument("out_backprop must be "
    369                                             "4-dimensional"));
    370       } else {
    371         // Input in MKL format.
    372         // For avgpooling, out_backprop should have 4 dimensions.
    373         OP_REQUIRES(context, out_backprop_shape.GetDimension() == 4,
    374                     errors::InvalidArgument("out_backprop must be "
    375                                             "4-dimensional"));
    376       }
    377 
    378       // TODO(inteltf): Get outbackprop layout.
    379       // Do we need to create layout in every invocation?
    380       if (!outbackprop_in_mkl_format) {
    381         CHECK_EQ(dnnLayoutCreate_F32(&lt_user_outbackprop, params.in_dim,
    382                                      params.out_sizes, params.out_strides),
    383                  E_SUCCESS);
    384       } else {
    385         lt_user_outbackprop = (dnnLayout_t)out_backprop_shape.GetCurLayout();
    386       }
    387 
    388       // Create the backward primitive
    389       // Create DNN user layout
    390       CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
    391                                    params.in_sizes, params.in_strides),
    392                E_SUCCESS);
    393 
    394       // Create PoolingBackward primitive
    395       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingAvg;
    396       dnnPrimitiveAttributes_t primAttr = nullptr;
    397       CHECK_EQ(dnnPoolingCreateBackward_F32(
    398                    &prim_pooling_bwd, primAttr, algorithm, lt_user_input,
    399                    params.kernel_size, params.kernel_stride, params.in_offset,
    400                    dnnBorderZerosAsymm),
    401                E_SUCCESS);
    402 
    403       // Create expected outbackprop layout from the primitive.
    404       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    405                    &lt_prim_outbackprop, prim_pooling_bwd, dnnResourceDiffDst),
    406                E_SUCCESS);
    407 
    408       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_bwd,
    409                                                 dnnResourceWorkspace),
    410                E_SUCCESS);
    411     }
    412 
    413     void MklCleanup() {
    414       bool outbackprop_in_mkl_format = out_backprop_shape.IsMklTensor();
    415       CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS);
    416       CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
    417       if (!outbackprop_in_mkl_format) {
    418         CHECK_EQ(dnnLayoutDelete_F32(lt_user_outbackprop), E_SUCCESS);
    419       }
    420       CHECK_EQ(dnnLayoutDelete_F32(lt_prim_outbackprop), E_SUCCESS);
    421       CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS);
    422     }
    423   } MklAvgPoolingGradOpContext;
    424 
    425   std::vector<int32> ksize_;
    426   std::vector<int32> stride_;
    427   Padding padding_;
    428   TensorFormat data_format_;
    429 };  // MklAvgPoolingGradOp
    430 
    431 #else
    432 
    433 template <typename Device, typename T>
    434 class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
    435  public:
    436   explicit MklAvgPoolingOp(OpKernelConstruction* context)
    437       : MklPoolingForwardOpBase<T>(context) {
    438     // Workspace is an MKLDNN construct that is only used in Max Pooling.
    439     // So set workspace_enabled_ to false.
    440     this->workspace_enabled_ = false;
    441   }
    442 
    443   void Compute(OpKernelContext* context) override {
    444     try {
    445       auto cpu_engine = engine(engine::cpu, 0);
    446       const Tensor& input_tensor =
    447           MklGetInput(context, this->kInputTensorIndexInput);
    448       MklDnnShape dnn_shape_input;
    449       GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
    450       this->SanityCheckInput(context, input_tensor, dnn_shape_input);
    451       if (!context->status().ok()) return;
    452 
    453       MklDnnData<T> dnn_data_input(&cpu_engine);
    454       MklDnnData<T> dnn_data_output(&cpu_engine);
    455 
    456       // initialize variables for the pooling op
    457       MklPoolParameters pool_params;
    458       // Get the input tensor and initialize the pooling parameters
    459       this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
    460                            &dnn_data_input);
    461       OP_REQUIRES_OK(context, context->status());
    462 
    463       // Declare output tensor
    464       Tensor* output_tensor = nullptr;
    465       memory::dims output_dims_mkl_order;
    466       this->GetOutputDims(pool_params, &output_dims_mkl_order);
    467 
    468       // If input is an empty tensor, allocate an empty output tensor and return
    469       if (input_tensor.NumElements() == 0) {
    470         MklDnnShape output_mkl_shape;
    471         output_mkl_shape.SetMklTensor(false);
    472         TensorShape output_tf_shape;
    473         if (pool_params.data_format == TensorFormat::FORMAT_NCHW) {
    474           output_tf_shape = MklDnnDimsToTFShape(output_dims_mkl_order);
    475         } else {
    476           memory::dims output_dims_NHWC_order;
    477           output_dims_NHWC_order = {pool_params.tensor_in_batch,
    478                                     static_cast<int>(pool_params.out_height),
    479                                     static_cast<int>(pool_params.out_width),
    480                                     pool_params.out_depth};
    481           output_tf_shape = MklDnnDimsToTFShape(output_dims_NHWC_order);
    482         }
    483         const int kOutputIndex = 0;
    484         AllocateOutputSetMklShape(context, kOutputIndex, &output_tensor,
    485                                   output_tf_shape, output_mkl_shape);
    486         CHECK_NOTNULL(output_tensor);
    487         return;
    488       }
    489 
    490       // If input is in Mkl layout, then just get the memory format from it
    491       // directly, instead of using input data_format to AvgPool.
    492       if (dnn_shape_input.IsMklTensor()) {
    493         dnn_data_output.SetUsrMem(
    494             output_dims_mkl_order,
    495             static_cast<memory::format>(
    496                 dnn_data_input.GetUsrMemDesc().data.format));
    497 
    498       } else {
    499         dnn_data_output.SetUsrMem(output_dims_mkl_order,
    500                                   this->data_format_mkldnn_);
    501       }
    502 
    503       // describe the memory layout
    504       dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
    505 
    506       // 3. create a pooling primitive descriptor
    507       auto pool_desc = pooling_forward::desc(
    508           prop_kind::forward, algorithm::pooling_avg_exclude_padding,
    509           dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
    510           memory::dims({pool_params.row_stride, pool_params.col_stride}),
    511           memory::dims({pool_params.window_rows, pool_params.window_cols}),
    512           memory::dims({static_cast<int>(pool_params.pad_top),
    513                         static_cast<int>(pool_params.pad_left)}),
    514           memory::dims({static_cast<int>(pool_params.pad_bottom),
    515                         static_cast<int>(pool_params.pad_right)}),
    516           TFPaddingToMklDnnPadding(this->padding_));
    517       auto pool_prim_desc =
    518           pooling_forward::primitive_desc(pool_desc, cpu_engine);
    519 
    520       this->AllocateOutputTensor(context, pool_prim_desc, output_dims_mkl_order,
    521                                  this->data_format_mkldnn_, &output_tensor);
    522       CHECK_NOTNULL(output_tensor);
    523 
    524       OP_REQUIRES_OK(context, context->status());
    525       dnn_data_output.SetUsrMemDataHandle(output_tensor);
    526 
    527       this->PrepareAndExecuteNet(pool_prim_desc, &dnn_data_input,
    528                                  &dnn_data_output);
    529     } catch (mkldnn::error& e) {
    530       string error_msg = "Status: " + std::to_string(e.status) +
    531                          ", message: " + string(e.message) + ", in file " +
    532                          string(__FILE__) + ":" + std::to_string(__LINE__);
    533       OP_REQUIRES_OK(
    534           context,
    535           errors::Aborted("Operation received an exception:", error_msg));
    536     }
    537   }  // Compute
    538 };   // MklAvgPoolingOp
    539 
    540 //-----------------------------------------------------------------------------
    541 
    542 template <class Device, class T>
    543 class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
    544  public:
    545   explicit MklAvgPoolingGradOp(OpKernelConstruction* context)
    546       : MklPoolingBackwardOpBase<T>(context) {}
    547 
    548   void Compute(OpKernelContext* context) override {
    549     try {
    550       auto cpu_engine = engine(engine::cpu, 0);
    551       MklDnnShape original_input_mkl_shape, input_gradient_mkl_shape;
    552       const Tensor& tensor_in_shape =
    553           MklGetInput(context, kInputTensorIndexInputShape);
    554       const Tensor& input_gradient_tensor =
    555           MklGetInput(context, kInputTensorIndexInputGradient);
    556       GetMklShape(context, kInputTensorIndexInputShape,
    557                   &original_input_mkl_shape);
    558       GetMklShape(context, kInputTensorIndexInputGradient,
    559                   &input_gradient_mkl_shape);
    560 
    561       SanityCheckInputs(context, tensor_in_shape, input_gradient_tensor,
    562                         original_input_mkl_shape, input_gradient_mkl_shape);
    563       if (!context->status().ok()) return;
    564 
    565       // Used to allocate output_diff_src/diff_src
    566       // and create pool_fwd mdm desc
    567       // 0. Input("orig_input_shape: int32") //NOT a T Tensor!
    568       // 1. Input("grad: T")
    569 
    570       MklDnnData<T> input_gradient_diff_dst(&cpu_engine);
    571       MklDnnData<T> output_diff_src(&cpu_engine);
    572       Tensor* output_tensor_diff_src = nullptr;
    573       TensorShape original_input_shape;
    574       MklPoolParameters pool_params;
    575       memory::dims output_dims_mkl_order, original_input_dims_nchw;
    576       // Configure the original input memory descriptor
    577       memory::desc original_input_md = ConfigureOriginalInput(
    578           context, tensor_in_shape, original_input_mkl_shape,
    579           &original_input_dims_nchw, &pool_params, &original_input_shape);
    580 
    581       // configure the original output memory descriptor
    582       // by definition, the shape of the original output is the same
    583       // as the shape of the gradient diff_dst
    584       memory::desc original_output_md = this->ConfigureOriginalOutput(
    585           pool_params, input_gradient_mkl_shape, output_dims_mkl_order);
    586 
    587       memory::desc target_diff_dst_md = this->ConfigureInputGradient(
    588           input_gradient_mkl_shape, input_gradient_tensor,
    589           &input_gradient_diff_dst, original_output_md);
    590       // The shape of the output diff src needs to be the same shape as the
    591       // original input. But we will set its format to be same as the format of
    592       // input gradient. We won't use format of original input since it will
    593       // always be in Tensorflow layout (given that AvgPoolGrad gets shape of
    594       // the input rather than actual input).
    595       output_diff_src.SetUsrMem(
    596           original_input_dims_nchw,
    597           static_cast<memory::format>(target_diff_dst_md.data.format));
    598 
    599       // Create the forward pooling primitive descriptor so we can reference it
    600       // in the backward pooling primitive descriptor
    601       auto pool_fwd_desc = pooling_forward::desc(
    602           prop_kind::forward, algorithm::pooling_avg_exclude_padding,
    603           original_input_md, original_output_md,
    604           memory::dims({pool_params.row_stride, pool_params.col_stride}),
    605           memory::dims({pool_params.window_rows, pool_params.window_cols}),
    606           memory::dims({static_cast<int>(pool_params.pad_top),
    607                         static_cast<int>(pool_params.pad_left)}),
    608           memory::dims({static_cast<int>(pool_params.pad_bottom),
    609                         static_cast<int>(pool_params.pad_right)}),
    610           TFPaddingToMklDnnPadding(this->padding_));
    611       auto pool_fwd_prim_desc =
    612           pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);
    613 
    614       auto pool_bkwd_desc = pooling_backward::desc(
    615           algorithm::pooling_avg_exclude_padding,
    616           output_diff_src.GetUsrMemDesc(), target_diff_dst_md,
    617           memory::dims({pool_params.row_stride, pool_params.col_stride}),
    618           memory::dims({pool_params.window_rows, pool_params.window_cols}),
    619           memory::dims({static_cast<int>(pool_params.pad_top),
    620                         static_cast<int>(pool_params.pad_left)}),
    621           memory::dims({static_cast<int>(pool_params.pad_bottom),
    622                         static_cast<int>(pool_params.pad_right)}),
    623           TFPaddingToMklDnnPadding(this->padding_));
    624       auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
    625           pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
    626       this->AllocateOutputTensor(
    627           context, pool_bkwd_prim_desc, original_input_dims_nchw,
    628           this->data_format_mkldnn_, &output_tensor_diff_src);
    629 
    630       output_diff_src.SetUsrMemDataHandle(output_tensor_diff_src);
    631 
    632       this->PrepareAndExecuteNet(
    633           pool_bkwd_prim_desc, &input_gradient_diff_dst, &output_diff_src,
    634           memory::primitive_desc(target_diff_dst_md, cpu_engine));
    635     } catch (mkldnn::error& e) {
    636       string error_msg = "Status: " + std::to_string(e.status) +
    637                          ", message: " + string(e.message) + ", in file " +
    638                          string(__FILE__) + ":" + std::to_string(__LINE__);
    639       OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
    640                                               error_msg));
    641     }
    642   }  // Compute
    643 
    644  private:
    645   // 0. Input("orig_input_shape: int32")
    646   // 1. Input("grad: T")
    647   const int kInputTensorIndexInputShape = 0;
    648   const int kInputTensorIndexInputGradient = 1;
    649 
    650   memory::desc ConfigureOriginalInput(
    651       OpKernelContext* context, const Tensor& tensor_original_input_shape,
    652       const MklDnnShape& original_input_mkl_shape,
    653       memory::dims* original_input_dims_mkl_order,
    654       MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
    655     CHECK_NOTNULL(original_input_dims_mkl_order);
    656     CHECK_NOTNULL(pool_params);
    657     CHECK_NOTNULL(input_tensor_shape);
    658     // For AvgPoolGrad, we only get the size of the original input because
    659     // The original data is irrelvant.
    660     auto shape_vec = tensor_original_input_shape.vec<int32>();
    661     for (int64 i = 0; i < tensor_original_input_shape.NumElements(); ++i) {
    662       input_tensor_shape->AddDim(shape_vec(i));
    663     }
    664 
    665     return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
    666         context, tensor_original_input_shape, original_input_mkl_shape,
    667         original_input_dims_mkl_order, pool_params, *input_tensor_shape);
    668   }
    669 
    670   void SanityCheckInputs(OpKernelContext* context,
    671                          const Tensor& tensor_in_shape,
    672                          const Tensor& input_gradient_tensor,
    673                          const MklDnnShape& original_input_mkl_shape,
    674                          const MklDnnShape& input_gradient_mkl_shape) {
    675     if (!original_input_mkl_shape.IsMklTensor()) {
    676       OP_REQUIRES(
    677           context,
    678           tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
    679           errors::InvalidArgument("original input shape must be "
    680                                   "1-dimensional and 4 elements"));
    681     } else {
    682       OP_REQUIRES(context,
    683                   original_input_mkl_shape.GetDimension() == 1 &&
    684                       original_input_mkl_shape.DimSize(0) == 4,
    685                   errors::InvalidArgument("original input shape must be "
    686                                           "1-dimensional and 4 elements"));
    687     }
    688 
    689     if (!input_gradient_mkl_shape.IsMklTensor()) {
    690       // For avgpooling, input_gradient_diff_dst should have 4 dimensions.
    691       OP_REQUIRES(context, input_gradient_tensor.dims() == 4,
    692                   errors::InvalidArgument("Gradient shape must be "
    693                                           "4-dimensional"));
    694     } else {
    695       OP_REQUIRES(context, input_gradient_mkl_shape.GetDimension() == 4,
    696                   errors::InvalidArgument("Gradient shape must be "
    697                                           "4-dimensional"));
    698     }
    699   }
    700 };  // MklAvgPoolingGradOp
    701 
    702 #endif  // INTEL_MKL_ML
    703 
    704 REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
    705                             .Device(DEVICE_CPU)
    706                             .TypeConstraint<float>("T")
    707                             .Label(mkl_op_registry::kMklOpLabel),
    708                         MklAvgPoolingOp<CPUDevice, float>);
    709 
    710 REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad")
    711                             .Device(DEVICE_CPU)
    712                             .TypeConstraint<float>("T")
    713                             .Label(mkl_op_registry::kMklOpLabel),
    714                         MklAvgPoolingGradOp<CPUDevice, float>);
    715 
    716 }  // namespace tensorflow
    717 #endif  // INTEL_MKL
    718