Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      7     http://www.apache.org/licenses/LICENSE-2.0
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     16 // See docs in ../ops/nn_ops.cc.
     17 #ifdef INTEL_MKL
     18 #define EIGEN_USE_THREADS
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/util/mkl_util.h"
     23 #include "tensorflow/core/util/padding.h"
     25 #ifndef INTEL_MKL_ML_ONLY
     26 #include <algorithm>
     27 #include "mkldnn.hpp"
     28 using mkldnn::algorithm;
     29 using mkldnn::engine;
     30 using mkldnn::error;
     31 using mkldnn::memory;
     32 using mkldnn::padding_kind;
     33 using mkldnn::pooling_backward;
     34 using mkldnn::pooling_forward;
     35 using mkldnn::prop_kind;
     36 #endif
     38 namespace tensorflow {
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     42 // MKL-DNN is now default. MKL-ML must be specified explicitly.
     43 #ifdef INTEL_MKL_ML_ONLY
     45 // An implementation of MaxPooling (forward).
     46 template <typename Device, typename T>
     47 class MklMaxPoolingOp : public OpKernel {
     48  public:
     49   explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
     50     string data_format;
     52     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     53     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     54                 errors::InvalidArgument("Invalid data format"));
     55     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
     56     OP_REQUIRES(context, ksize_.size() == 4,
     57                 errors::InvalidArgument("Sliding window ksize field must "
     58                                         "specify 4 dimensions"));
     59     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
     60     OP_REQUIRES(context, stride_.size() == 4,
     61                 errors::InvalidArgument("Sliding window stride field must "
     62                                         "specify 4 dimensions"));
     63     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     64     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
     65                 errors::Unimplemented("Pooling is not yet supported on the "
     66                                       "batch dimension."));
     68     workspace_enabled_ = false;
     69     // We may not get this attribute for this node if it does not go through
     70     // graph rewrite pass. So we do not check for error while retrieving this
     71     // attribute value.
     72     OP_REQUIRES_OK(context,
     73                    context->GetAttr("workspace_enabled", &workspace_enabled_));
     74   }
     76   void Compute(OpKernelContext* context) override {
     77     MklMaxPoolingOpContext mkl_context;
     78     // Get the input tensor
     79     const Tensor& tensor_in = MklGetInput(context, 0);
     80     GetMklShape(context, 0, &mkl_context.input_shape);
     81     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
     83     mkl_context.params.in_dim = 4;
     84     MklPoolParameters pool_params;
     85     if (input_in_mkl_format == false) {
     86       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     87                        tensor_in.shape());
     88       OP_REQUIRES(
     89           context, (pool_params.depth_window == 1),
     90           errors::Unimplemented("Depthwise max pooling not supported by MKL"));
     92     } else {
     93       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     94                        &mkl_context.input_shape);
     95     }
     97     // Extract the parameters for the op from the pooling specs
     99     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
    101     mkl_context.MklCreateLayoutsAndPrimitives(context);
    102     OP_REQUIRES_OK(context, context->status());
    104     // Declare output tensor
    105     TensorShape tensor_out_shape;
    106     MklShape mkl_out_shape, mkl_workspace_shape;
    107     mkl_out_shape.SetMklTensor(true);
    108     mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
    109     mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
    110                               mkl_context.params.out_sizes,
    111                               mkl_context.params.out_strides);
    112     mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    114     Tensor* output_tensor = nullptr;
    115     tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    116                                 mkl_out_shape.GetMklLayout())) /
    117                             sizeof(T));
    118     AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape,
    119                               mkl_out_shape);
    121     Tensor* workspace_tensor;
    122     void* workspace_buf = nullptr;
    124     TensorShape workspace_shape;
    125     mkl_workspace_shape.SetMklTensor(false);
    126     workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    127                                mkl_context.lt_workspace)) /
    128                            sizeof(T));
    129     AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape,
    130                               mkl_workspace_shape);
    132     mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
    133         static_cast<const void*>(workspace_tensor->flat<T>().data()));
    134     mkl_context.pooling_res[dnnResourceSrc] =
    135         const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data()));
    136     mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>(
    137         static_cast<const void*>(output_tensor->flat<T>().data()));
    139     CHECK_EQ(
    140         dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res),
    141         E_SUCCESS);
    143     mkl_context.MklCleanup();
    144   }
    146  private:
    147   typedef struct {
    148     MklPoolingOpParams params;
    149     MklShape input_shape;
    150     void* pooling_res[dnnResourceNumber];
    151     dnnPrimitive_t prim_pooling_fwd = nullptr;
    152     dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr;
    154     void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
    155       bool input_in_mkl_format = input_shape.IsMklTensor();
    156       // Create or use existing DNN user layout
    157       if (input_in_mkl_format == false) {
    158         CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
    159                                      params.in_sizes, params.in_strides),
    160                  E_SUCCESS);
    161       } else {
    162         lt_user_input = (dnnLayout_t)input_shape.GetCurLayout();
    163       }
    165       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
    166       dnnPrimitiveAttributes_t primAttr = nullptr;
    168       // Create DNN primitives
    169       CHECK_EQ(dnnPoolingCreateForward_F32(
    170                    &prim_pooling_fwd, primAttr, algorithm, lt_user_input,
    171                    params.kernel_size, params.kernel_stride, params.in_offset,
    172                    dnnBorderZerosAsymm),
    173                E_SUCCESS);
    175       // Creates layout for the workspace
    176       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_fwd,
    177                                                 dnnResourceWorkspace),
    178                E_SUCCESS);
    179     }
    181     void MklCleanup() {
    182       bool input_in_mkl_format = input_shape.IsMklTensor();
    183       CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
    184       if (!input_in_mkl_format) {
    185         CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
    186       }
    187       CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS);
    188     }
    189   } MklMaxPoolingOpContext;
    191   std::vector<int32> ksize_;
    192   std::vector<int32> stride_;
    193   Padding padding_;
    194   TensorFormat data_format_;
    195   bool workspace_enabled_;
    196 };
    198 // The operation to compute MaxPool gradients.
    199 // It takes three inputs:
    200 //   - The original input tensor
    201 //   - The original output tensor
    202 //   - Backprop tensor for output
    203 // It produces one output: backprop tensor for input.
    204 template <class Device, class T>
    205 class MklMaxPoolingGradOp : public OpKernel {
    206  public:
    207   explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
    208       : OpKernel(context) {
    209     string data_format;
    211     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    212     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    213                 errors::InvalidArgument("Invalid data format"));
    214     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    215     OP_REQUIRES(context, ksize_.size() == 4,
    216                 errors::InvalidArgument("Sliding window ksize field must "
    217                                         "specify 4 dimensions"));
    218     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    219     OP_REQUIRES(context, stride_.size() == 4,
    220                 errors::InvalidArgument("Sliding window strides field must "
    221                                         "specify 4 dimensions"));
    222     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    223     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    224                 errors::Unimplemented(
    225                     "Pooling is not yet supported on the batch dimension."));
    226     workspace_enabled_ = false;
    227     // We may not get this attribute for this node if it does not go through
    228     // graph rewrite pass. So we do not check for error while retrieving this
    229     // attribute value.
    230     OP_REQUIRES_OK(context,
    231                    context->GetAttr("workspace_enabled", &workspace_enabled_));
    232   }
    234   void Compute(OpKernelContext* context) override {
    235     MklMaxPoolingGradOpContext mkl_context;
    236     // Input - The original input tensor
    237     const Tensor& tensor_in = MklGetInput(context, 0);
    239     // Output - Backprop tensor for input.
    240     Tensor* output_tensor = nullptr;
    242     GetMklShape(context, 0, &mkl_context.input_shape);
    243     GetMklShape(context, 2, &mkl_context.output_backprop_shape);
    244     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
    246     if (input_in_mkl_format == false)
    247       mkl_context.params.in_dim = tensor_in.dims();
    248     else
    249       mkl_context.params.in_dim = mkl_context.input_shape.GetDimension();
    251     MklPoolParameters pool_params;
    252     if (input_in_mkl_format == false) {
    253       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
    254                        tensor_in.shape());
    255       OP_REQUIRES(
    256           context, (pool_params.depth_window == 1),
    257           errors::Unimplemented("Depthwise max pooling not supported by MKL"));
    259     } else {
    260       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
    261                        &mkl_context.input_shape);
    262     }
    264     // Extract the parameters for the op from the pooling specs
    265     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
    267     mkl_context.MklCreateLayouts(context);
    268     OP_REQUIRES_OK(context, context->status());
    270     mkl_context.MklCreatePrimitives(context, workspace_enabled_);
    271     OP_REQUIRES_OK(context, context->status());
    273     mkl_context.MklPrepareInputs(context, workspace_enabled_);
    274     OP_REQUIRES_OK(context, context->status());
    276     // Create shape for the input back prop output
    277     TensorShape mkl_input_backprop;
    278     MklShape mkl_output_shape;
    279     mkl_output_shape.SetMklTensor(true);
    280     mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd,
    281                                   dnnResourceDiffSrc);
    282     mkl_output_shape.SetTfLayout(mkl_context.params.in_dim,
    283                                  mkl_context.params.in_sizes,
    284                                  mkl_context.params.in_strides);
    285     mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    287     mkl_input_backprop.AddDim(
    288         dnnLayoutGetMemorySize_F32(
    289             static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) /
    290         sizeof(T));
    291     AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop,
    292                               mkl_output_shape);
    293     mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
    294         static_cast<const void*>(output_tensor->flat<T>().data()));
    296     CHECK_EQ(
    297         dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
    298         E_SUCCESS);
    300     mkl_context.MklCleanup(workspace_enabled_);
    301   }
    303  private:
    304   typedef struct {
    305     MklPoolingOpParams params;
    306     MklShape input_shape, output_backprop_shape;
    307     void* pooling_resfwd[dnnResourceNumber];
    308     void* pooling_res[dnnResourceNumber];
    309     dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr,
    310                    convert_input = nullptr, convert_outbackprop = nullptr;
    311     dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr,
    312                 lt_input_user = nullptr, lt_input_prim = nullptr;
    313     void* input_buf;
    314     void* outbackprop_buf;
    315     Tensor tmp_output_buf_tensor;
    316     Tensor workspace_buf_tensor;
    317     Tensor input_buf_tensor, outbackprop_buf_tensor;
    319     void MklCreateLayouts(OpKernelContext* context) {
    320       bool input_in_mkl_format = input_shape.IsMklTensor();
    321       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    322       // Create DNN user layout for input and outbackprop or get existing layout
    323       if (input_in_mkl_format == false) {
    324         CHECK_EQ(dnnLayoutCreate_F32(&lt_input_user, params.in_dim,
    325                                      params.in_sizes, params.in_strides),
    326                  E_SUCCESS);
    327       } else {
    328         lt_input_user = (dnnLayout_t)input_shape.GetCurLayout();
    329       }
    331       // We don't care about the output layout for now as we can create it from
    332       // primitives for the max pooling fwd prop
    333       if (outbackprop_in_mkl_format == false) {
    334         CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop_user, params.in_dim,
    335                                      params.out_sizes, params.out_strides),
    336                  E_SUCCESS);
    337       } else {
    338         lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout();
    339       }
    340     }
    342     // Create DNN primitives
    343     void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) {
    344       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
    345       dnnPrimitiveAttributes_t primAttr = nullptr;
    347       if (workspace_enabled == false) {
    348         CHECK_EQ(dnnPoolingCreateForward_F32(
    349                      &prim_pooling_fwd, primAttr, algorithm, lt_input_user,
    350                      params.kernel_size, params.kernel_stride, params.in_offset,
    351                      dnnBorderZerosAsymm),
    352                  E_SUCCESS);
    353       }
    355       CHECK_EQ(dnnPoolingCreateBackward_F32(
    356                    &prim_pooling_bwd, primAttr, algorithm, lt_input_user,
    357                    params.kernel_size, params.kernel_stride, params.in_offset,
    358                    dnnBorderZerosAsymm),
    359                E_SUCCESS);
    361       // Creates conversions
    362       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    363                    &lt_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst),
    364                E_SUCCESS);
    366       if (workspace_enabled == false) {
    367         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    368                      &lt_input_prim, prim_pooling_fwd, dnnResourceSrc),
    369                  E_SUCCESS);
    370         if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) {
    371           CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user,
    372                                            lt_input_prim),
    373                    E_SUCCESS);
    374           AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf);
    375         }
    376       }
    378       if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) {
    379         CHECK_EQ(
    380             dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user,
    381                                     lt_outbackprop_prim),
    382             E_SUCCESS);
    383         AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim,
    384                        &outbackprop_buf);
    385       }
    386     }
    388     // Compare incoming tensor layouts with MKL preferred layouts and convert
    389     // data to the preferred layout if necessary
    390     void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) {
    391       const Tensor& tensor_in = MklGetInput(context, 0);
    392       const Tensor& out_backprop = MklGetInput(context, 2);
    393       bool input_in_mkl_format = input_shape.IsMklTensor();
    394       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    396       void* tmp_output_buf = nullptr;
    397       void* workspace_buf = nullptr;
    399       if (workspace_enabled == false) {
    400         if (convert_input != nullptr) {
    401           if (input_in_mkl_format == false) {
    402             CHECK_EQ(dnnConversionExecute_F32(
    403                          convert_input,
    404                          const_cast<void*>(static_cast<const void*>(
    405                              tensor_in.flat<T>().data())),
    406                          input_buf),
    407                      E_SUCCESS);
    408             CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS);
    409             convert_input = nullptr;
    410           } else {
    411             input_shape.GetConvertedFlatData(
    412                 lt_input_prim,
    413                 const_cast<void*>(
    414                     static_cast<const void*>(tensor_in.flat<T>().data())),
    415                 input_buf);
    416           }
    417           pooling_resfwd[dnnResourceSrc] = input_buf;
    418         } else {
    419           pooling_resfwd[dnnResourceSrc] = const_cast<void*>(
    420               static_cast<const void*>(tensor_in.flat<T>().data()));
    421         }
    423         dnnLayout_t lt_workspace;
    424         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    425                      &lt_workspace, prim_pooling_fwd, dnnResourceWorkspace),
    426                  E_SUCCESS);
    427         AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace,
    428                        &workspace_buf);
    429         pooling_resfwd[dnnResourceWorkspace] = workspace_buf;
    431         dnnLayoutDelete_F32(lt_workspace);
    433         // We create the layout for max pooling fwd prop tmp output here
    434         AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim,
    435                        &tmp_output_buf);
    436         pooling_resfwd[dnnResourceDst] = tmp_output_buf;
    438         CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS);
    439         pooling_res[dnnResourceWorkspace] =
    440             pooling_resfwd[dnnResourceWorkspace];
    441       } else {
    442         const Tensor& workspace = MklGetInput(context, 3);
    443         pooling_res[dnnResourceWorkspace] = const_cast<void*>(
    444             static_cast<const void*>(workspace.flat<T>().data()));
    445       }
    447       // Out backprop conversions if needed
    448       if (convert_outbackprop != nullptr) {
    449         if (outbackprop_in_mkl_format == false) {
    450           CHECK_EQ(dnnConversionExecute_F32(
    451                        convert_outbackprop,
    452                        const_cast<void*>(static_cast<const void*>(
    453                            out_backprop.flat<T>().data())),
    454                        outbackprop_buf),
    455                    E_SUCCESS);
    456           CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS);
    457         } else {
    458           output_backprop_shape.GetConvertedFlatData(
    459               lt_outbackprop_prim,
    460               const_cast<void*>(
    461                   static_cast<const void*>(out_backprop.flat<T>().data())),
    462               outbackprop_buf);
    463         }
    464         pooling_res[dnnResourceDiffDst] = outbackprop_buf;
    465       } else {
    466         pooling_res[dnnResourceDiffDst] = const_cast<void*>(
    467             static_cast<const void*>(out_backprop.flat<T>().data()));
    468       }
    469     }
    471     void MklCleanup(bool workspace_enabled) {
    472       bool input_in_mkl_format = input_shape.IsMklTensor();
    473       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    474       if (workspace_enabled == false) {
    475         CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
    476       }
    477       CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS);
    478       if (outbackprop_in_mkl_format == false) {
    479         CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS);
    480       }
    481       CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS);
    482       if (input_in_mkl_format == false) {
    483         CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS);
    484       }
    485       if (workspace_enabled == false) {
    486         CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS);
    487       }
    488     }
    489   } MklMaxPoolingGradOpContext;
    491   std::vector<int32> ksize_;
    492   std::vector<int32> stride_;
    493   Padding padding_;
    494   TensorFormat data_format_;
    496   bool workspace_enabled_;
    497 };  // MklMaxPoolingGradOp
    499 #else
    501 // An implementation of MaxPooling (forward).
    502 template <typename Device, typename T>
    503 class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
    504  public:
    505   explicit MklMaxPoolingOp(OpKernelConstruction* context)
    506       : MklPoolingForwardOpBase<T>(context) {
    507     // In Max Pooling, MKLDNN does not allow passing workspace as NULL.
    508     // So we set workspace_enabled_ to true.
    509     this->workspace_enabled_ = true;
    510   }
    512   void Compute(OpKernelContext* context) override {
    513     try {
    514       const Tensor& input_tensor =
    515           MklGetInput(context, this->kInputTensorIndexInput);
    516       MklDnnShape dnn_shape_input;
    517       GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
    518       this->SanityCheckInput(context, input_tensor, dnn_shape_input);
    519       if (!context->status().ok()) return;
    521       MklDnnData<T> dnn_data_input(&cpu_engine);
    522       MklDnnData<T> dnn_data_output(&cpu_engine);
    524       // initialize variables for the pooling op
    525       MklPoolParameters pool_params;
    526       // check whether pooling is 2D or 3D
    527       bool is_pool2d = (this->ksize_.size() == 4);
    528       // Get the input tensor and initialize the pooling parameters
    529       TensorShape input_tensor_shape = input_tensor.shape();
    530       this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
    531                                   input_tensor_shape);
    532       OP_REQUIRES_OK(context, context->status());
    534       // Declare output tensor
    535       Tensor* output_tensor = nullptr;
    536       memory::dims output_dims_mkl_order;
    537       this->GetOutputDims(pool_params, &output_dims_mkl_order);
    539       // If input is an empty tensor, allocate an empty output tensor and return
    540       if (input_tensor.NumElements() == 0) {
    541         const int kOutputIndex = 0;
    542         this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params,
    543                                         output_dims_mkl_order, &output_tensor);
    544         return;
    545       }
    547       // Get the input memory descriptor
    548       memory::desc input_md =
    549           dnn_shape_input.IsMklTensor()
    550               ? dnn_shape_input.GetMklLayout()
    551               : is_pool2d ? memory::desc(
    552                                 TFShapeToMklDnnDimsInNCHW(
    553                                     input_tensor_shape, this->data_format_tf_),
    554                                 MklDnnType<T>(), this->data_format_mkldnn_)
    555                           : memory::desc(
    556                                 TFShapeToMklDnnDimsInNCDHW(
    557                                     input_tensor_shape, this->data_format_tf_),
    558                                 MklDnnType<T>(), this->data_format_mkldnn_);
    560       // Get src/filter/stride/padding information
    561       memory::dims src_dims =
    562           dnn_shape_input.IsMklTensor()
    563               ? dnn_shape_input.GetSizesAsMklDnnDims()
    564               : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
    565                                                       this->data_format_tf_)
    566                           : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
    567                                                        this->data_format_tf_);
    568       memory::dims filter_dims, strides, padding_left, padding_right;
    569       this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
    570                              &padding_left, &padding_right, is_pool2d);
    572       // Get a pooling op from the cached pool
    573       MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
    574       prop_kind pooling_prop_kind;
    575       bool int8_forward_inference =
    576           std::is_same<T, qint8>::value || std::is_same<T, quint8>::value;
    577       if (int8_forward_inference)
    578         pooling_prop_kind = prop_kind::forward_inference;
    579       else
    580         pooling_prop_kind = prop_kind::forward_training;
    581       MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims,
    582                                  strides, padding_left, padding_right,
    583                                  algorithm::pooling_max, pooling_prop_kind);
    584       pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
    586       // allocate output tensor
    587       this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
    588                                  output_dims_mkl_order,
    589                                  this->data_format_mkldnn_, &output_tensor);
    590       OP_REQUIRES_OK(context, context->status());
    591       dnn_data_output.SetUsrMem(output_dims_mkl_order,
    592                                 pooling_fwd->GetDstMemoryFormat(),
    593                                 output_tensor);
    595       // check wehther we need to reorder src
    596       const T* src_data = input_tensor.flat<T>().data();
    597       if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) {
    598         dnn_data_input.SetUsrMem(input_md, &input_tensor);
    599         auto src_target_primitive_desc = memory::primitive_desc(
    600             {{src_dims}, MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()},
    601             cpu_engine);
    602         dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc);
    603         src_data = const_cast<T*>(
    604             reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle()));
    605       }
    607       T* dst_data = output_tensor->flat<T>().data();
    609       if (int8_forward_inference) {
    610         // Execute pooling op
    611         pooling_fwd->Execute(src_data, dst_data);
    613         // pass min, max from input to output
    614         const Tensor& min_input_t = MklGetInput(context, 1);
    615         const Tensor& max_input_t = MklGetInput(context, 2);
    616         const float min_input = min_input_t.flat<float>()(0);
    617         const float max_input = max_input_t.flat<float>()(0);
    619         Tensor* output_min = nullptr;
    620         Tensor* output_max = nullptr;
    621         MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
    622         output_min_mkl_shape.SetMklTensor(false);
    623         output_max_mkl_shape.SetMklTensor(false);
    624         AllocateOutputSetMklShape(context, 1, &output_min, {},
    625                                   output_min_mkl_shape);
    626         AllocateOutputSetMklShape(context, 2, &output_max, {},
    627                                   output_max_mkl_shape);
    628         output_min->flat<float>()(0) = min_input;
    629         output_max->flat<float>()(0) = max_input;
    630       } else {
    631         MklDnnData<uint8> dnn_data_wksp(&cpu_engine);
    632         AllocateWorkspaceTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
    633                                 &dnn_data_wksp);
    634         OP_REQUIRES_OK(context, context->status());
    635         T* ws_data =
    636             static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
    638         // execute pooling op
    639         pooling_fwd->Execute(src_data, dst_data, ws_data);
    640       }
    641     } catch (mkldnn::error& e) {
    642       string error_msg = "Status: " + std::to_string(e.status) +
    643                          ", message: " + string(e.message) + ", in file " +
    644                          string(__FILE__) + ":" + std::to_string(__LINE__);
    645       OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
    646                                               error_msg));
    647     }
    648   }
    650  private:
    651   const int kOutputTensorIndexWorkspace = 1;
    652   engine cpu_engine = engine(engine::cpu, 0);
    654   void AllocateWorkspaceTensor(
    655       OpKernelContext* context,
    656       const pooling_forward::primitive_desc& pool_fwd_prim_desc,
    657       MklDnnData<uint8>* dnn_data_wksp) {
    658     CHECK_NOTNULL(dnn_data_wksp);
    659     Tensor* workspace_tensor = nullptr;
    660     memory::primitive_desc workspace_pd =
    661         pool_fwd_prim_desc.workspace_primitive_desc();
    662     size_t workspace_bytes = workspace_pd.get_size();
    663     MklDnnShape workspace_mkl_shape;
    664     workspace_mkl_shape.SetMklTensor(false);
    665     TensorShape workspace_tf_shape;
    666     workspace_tf_shape.AddDim(workspace_bytes);
    667     AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace,
    668                               &workspace_tensor, workspace_tf_shape,
    669                               workspace_mkl_shape);
    670     CHECK_NOTNULL(workspace_tensor);
    671     dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
    672   }
    673 };
    675 // The operation to compute MaxPool gradients.
    676 // It takes three inputs:
    677 //   - The original input tensor
    678 //   - The original output tensor
    679 //   - Backprop tensor for output
    680 // It produces one output: backprop tensor for input.
    681 template <class Device, class T>
    682 class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
    683  public:
    684   explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
    685       : MklPoolingBackwardOpBase<T>(context) {}
    686   void Compute(OpKernelContext* context) override {
    687     try {
    688       auto cpu_engine = engine(engine::cpu, 0);
    689       const Tensor& orig_input_tensor =
    690           MklGetInput(context, kInputTensorIndexOrigInput);
    691       const Tensor& grad_tensor =
    692           MklGetInput(context, kInputTensorIndexGradient);
    693       const Tensor& workspace_tensor =
    694           MklGetInput(context, kInputTensorIndexWorkspace);
    695       MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
    696       GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape);
    697       GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape);
    698       if (!context->status().ok()) return;
    700       MklDnnData<T> grad_dnn_data(&cpu_engine);
    701       MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
    703       MklPoolParameters pool_params;
    704       TensorShape orig_input_shape = orig_input_tensor.shape();
    706       bool is_pool2d = (this->ksize_.size() == 4);
    707       this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
    708                                   orig_input_shape);
    710       memory::dims filter_dims, strides, padding_left, padding_right;
    711       this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
    712                              &padding_left, &padding_right, is_pool2d);
    714       memory::dims orig_input_dims_mkl_order =
    715           orig_input_mkl_shape.IsMklTensor()
    716               ? orig_input_mkl_shape.GetSizesAsMklDnnDims()
    717               : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
    718                                                       this->data_format_tf_)
    719                           : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
    720                                                        this->data_format_tf_);
    722       memory::dims diff_dst_dims =
    723           grad_mkl_shape.IsMklTensor()
    724               ? grad_mkl_shape.GetSizesAsMklDnnDims()
    725               : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
    726                                                       this->data_format_tf_)
    727                           : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
    728                                                        this->data_format_tf_);
    730       memory::dims output_dims_mkl_order;
    731       this->GetOutputDims(pool_params, &output_dims_mkl_order);
    733       MklPoolingParams bwdParams(
    734           orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
    735           strides, padding_left, padding_right, algorithm::pooling_max,
    736           prop_kind::forward_training);
    737       MklPoolingBwdPrimitive<T>* pooling_bwd =
    738           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
    740       // allocate output tensor and memory primitive
    741       Tensor* output_tensor = nullptr;
    742       this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
    743                                  orig_input_dims_mkl_order,
    744                                  this->data_format_mkldnn_, &output_tensor);
    745       // get diff_dst mem desc
    746       memory::desc diff_dst_md =
    747           grad_mkl_shape.IsMklTensor()
    748               ? grad_mkl_shape.GetMklLayout()
    749               : memory::desc(diff_dst_dims, MklDnnType<T>(),
    750                              this->data_format_mkldnn_);
    751       // check if diff_dst needs to be reordered
    752       const T* diff_dst_data = grad_tensor.flat<T>().data();
    753       if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) {
    754         auto target_diff_dst = memory::primitive_desc(
    755             {{diff_dst_dims}, MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()},
    756             cpu_engine);
    757         grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
    758         grad_dnn_data.CheckReorderToOpMem(target_diff_dst);
    759         diff_dst_data = const_cast<T*>(
    760             reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle()));
    761       }
    763       void* ws_data = static_cast<void*>(
    764           const_cast<uint8*>(workspace_tensor.flat<uint8>().data()));
    766       auto ws_md =
    767           pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc();
    768       if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) {
    769         memory::dims ws_dims;
    770         ws_dims.assign(ws_md.data.dims, ws_md.data.dims + ws_md.data.ndims);
    771         auto target_ws =
    772             memory::primitive_desc({{ws_dims},
    773                                     pooling_bwd->GetWorkspaceDataType(),
    774                                     pooling_bwd->GetWorkspaceFormat()},
    775                                    cpu_engine);
    776         workspace_dnn_data.SetUsrMem(ws_md, &workspace_tensor);
    777         workspace_dnn_data.CheckReorderToOpMem(target_ws);
    778         ws_data = workspace_dnn_data.GetOpMem().get_data_handle();
    779       }
    781       T* diff_src_data = output_tensor->flat<T>().data();
    783       // execute pooling
    784       pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
    785     } catch (mkldnn::error& e) {
    786       string error_msg = "Status:" + std::to_string(e.status) +
    787                          ", message: " + string(e.message) + ". in file " +
    788                          string(__FILE__) + ":" + std::to_string(__LINE__);
    789       OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
    790                                               error_msg));
    791     }
    792   }
    794  private:
    795   // .Input("orig_input: T")
    796   // .Input("orig_output: T")
    797   // .Input("grad: T")
    798   // .Input("workspace: T")
    799   const int kInputTensorIndexOrigInput = 0;
    800   const int kInputTensorIndexOrigOutput = 1;
    801   const int kInputTensorIndexGradient = 2;
    802   const int kInputTensorIndexWorkspace = 3;
    804   void ConfigureWorkspace(const Tensor& workspace_tensor,
    805                           memory::primitive_desc workspace_pd,
    806                           MklDnnData<uint8>* workspace_dnn_data) {
    807     CHECK_NOTNULL(workspace_dnn_data);
    809     workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
    810   }
    812   void SanityCheckInputs(OpKernelContext* context,
    813                          const Tensor& orig_input_tensor,
    814                          const Tensor& orig_output_tensor,
    815                          const Tensor& grad_tensor,
    816                          const Tensor& workspace_tensor,
    817                          const MklDnnShape& orig_input_mkl_shape,
    818                          const MklDnnShape& orig_output_mkl_shape,
    819                          const MklDnnShape& grad_mkl_shape,
    820                          const MklDnnShape& workspace_mkl_shape) {
    821     if (!orig_input_mkl_shape.IsMklTensor()) {
    822       OP_REQUIRES(context, orig_input_tensor.dims() == 4,
    823                   errors::InvalidArgument(
    824                       "Original input shape must be 4-dimensional"));
    825     } else {
    826       OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4,
    827                   errors::InvalidArgument(
    828                       "Original input shape must be 4-dimensional"));
    829     }
    830     if (!orig_output_mkl_shape.IsMklTensor()) {
    831       OP_REQUIRES(
    832           context, orig_output_tensor.dims() == 4,
    833           errors::InvalidArgument("Original output must be 4-dimensional"));
    834     } else {
    835       OP_REQUIRES(
    836           context, orig_output_mkl_shape.GetDimension() == 4,
    837           errors::InvalidArgument("Original output must be 4-dimensional"));
    838     }
    839     if (!grad_mkl_shape.IsMklTensor()) {
    840       OP_REQUIRES(context, grad_tensor.dims() == 4,
    841                   errors::InvalidArgument("Gradient must be 4-dimensional"));
    842     } else {
    843       OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4,
    844                   errors::InvalidArgument("Gradient must be 4-dimensional"));
    845     }
    846     if (this->workspace_enabled_) {
    847       // The workspace should not be an MKL tensor
    848       OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false,
    849                   errors::InvalidArgument(
    850                       "Workspace tensor should not be an MKL Tensor."));
    851       // It should only have one dimension
    852       OP_REQUIRES(
    853           context, workspace_tensor.dims() == 1,
    854           errors::InvalidArgument("Workspace tensor must be 1-dimensional"));
    855     } else {
    856       OP_REQUIRES(
    857           context, this->workspace_enabled_,
    858           errors::Unimplemented("MKL-DNN Max Pooling does not "
    859                                 "yet support the use case "
    860                                 "where MaxPoolGrad is called without first"
    861                                 " calling MaxPool."));
    862     }
    863   }
    864 };  // MklMaxPoolingGradOp
    866 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D")
    867                             .Device(DEVICE_CPU)
    868                             .TypeConstraint<float>("T")
    869                             .Label(mkl_op_registry::kMklOpLabel),
    870                         MklMaxPoolingOp<CPUDevice, float>);
    872 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad")
    873                             .Device(DEVICE_CPU)
    874                             .TypeConstraint<float>("T")
    875                             .Label(mkl_op_registry::kMklOpLabel),
    876                         MklMaxPoolingGradOp<CPUDevice, float>);
    878 #endif  // INTEL_MKL_ML_ONLY
    880 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
    881                             .Device(DEVICE_CPU)
    882                             .TypeConstraint<float>("T")
    883                             .Label(mkl_op_registry::kMklOpLabel),
    884                         MklMaxPoolingOp<CPUDevice, float>);
    886 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool")
    887                             .Device(DEVICE_CPU)
    888                             .TypeConstraint<quint8>("T")
    889                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
    890                         MklMaxPoolingOp<CPUDevice, quint8>);
    892 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool")
    893                             .Device(DEVICE_CPU)
    894                             .TypeConstraint<qint8>("T")
    895                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
    896                         MklMaxPoolingOp<CPUDevice, qint8>);
    898 REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad")
    899                             .Device(DEVICE_CPU)
    900                             .TypeConstraint<float>("T")
    901                             .Label(mkl_op_registry::kMklOpLabel),
    902                         MklMaxPoolingGradOp<CPUDevice, float>);
    904 }  // namespace tensorflow
    905 #endif  // INTEL_MKL