Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 #ifdef INTEL_MKL
     18 #define EIGEN_USE_THREADS
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/util/mkl_util.h"
     23 #include "tensorflow/core/util/padding.h"
     24 
     25 #ifndef INTEL_MKL_ML_ONLY
     26 #include <algorithm>
     27 #include "mkldnn.hpp"
     28 using mkldnn::algorithm;
     29 using mkldnn::engine;
     30 using mkldnn::error;
     31 using mkldnn::memory;
     32 using mkldnn::padding_kind;
     33 using mkldnn::pooling_backward;
     34 using mkldnn::pooling_forward;
     35 using mkldnn::prop_kind;
     36 #endif
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 
     42 // MKL-DNN is now default. MKL-ML must be specified explicitly.
     43 #ifdef INTEL_MKL_ML_ONLY
     44 
     45 // An implementation of MaxPooling (forward).
     46 template <typename Device, typename T>
     47 class MklMaxPoolingOp : public OpKernel {
     48  public:
     49   explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
     50     string data_format;
     51 
     52     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     53     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     54                 errors::InvalidArgument("Invalid data format"));
     55     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
     56     OP_REQUIRES(context, ksize_.size() == 4,
     57                 errors::InvalidArgument("Sliding window ksize field must "
     58                                         "specify 4 dimensions"));
     59     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
     60     OP_REQUIRES(context, stride_.size() == 4,
     61                 errors::InvalidArgument("Sliding window stride field must "
     62                                         "specify 4 dimensions"));
     63     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     64     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
     65                 errors::Unimplemented("Pooling is not yet supported on the "
     66                                       "batch dimension."));
     67 
     68     workspace_enabled_ = false;
     69     // We may not get this attribute for this node if it does not go through
     70     // graph rewrite pass. So we do not check for error while retrieving this
     71     // attribute value.
     72     OP_REQUIRES_OK(context,
     73                    context->GetAttr("workspace_enabled", &workspace_enabled_));
     74   }
     75 
     76   void Compute(OpKernelContext* context) override {
     77     MklMaxPoolingOpContext mkl_context;
     78     // Get the input tensor
     79     const Tensor& tensor_in = MklGetInput(context, 0);
     80     GetMklShape(context, 0, &mkl_context.input_shape);
     81     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
     82 
     83     mkl_context.params.in_dim = 4;
     84     MklPoolParameters pool_params;
     85     if (input_in_mkl_format == false) {
     86       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     87                        tensor_in.shape());
     88       OP_REQUIRES(
     89           context, (pool_params.depth_window == 1),
     90           errors::Unimplemented("Depthwise max pooling not supported by MKL"));
     91 
     92     } else {
     93       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     94                        &mkl_context.input_shape);
     95     }
     96 
     97     // Extract the parameters for the op from the pooling specs
     98 
     99     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
    100 
    101     mkl_context.MklCreateLayoutsAndPrimitives(context);
    102     OP_REQUIRES_OK(context, context->status());
    103 
    104     // Declare output tensor
    105     TensorShape tensor_out_shape;
    106     MklShape mkl_out_shape, mkl_workspace_shape;
    107     mkl_out_shape.SetMklTensor(true);
    108     mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
    109     mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
    110                               mkl_context.params.out_sizes,
    111                               mkl_context.params.out_strides);
    112     mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    113 
    114     Tensor* output_tensor = nullptr;
    115     tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    116                                 mkl_out_shape.GetMklLayout())) /
    117                             sizeof(T));
    118     AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape,
    119                               mkl_out_shape);
    120 
    121     Tensor* workspace_tensor;
    122     void* workspace_buf = nullptr;
    123 
    124     TensorShape workspace_shape;
    125     mkl_workspace_shape.SetMklTensor(false);
    126     workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    127                                mkl_context.lt_workspace)) /
    128                            sizeof(T));
    129     AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape,
    130                               mkl_workspace_shape);
    131 
    132     mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
    133         static_cast<const void*>(workspace_tensor->flat<T>().data()));
    134     mkl_context.pooling_res[dnnResourceSrc] =
    135         const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data()));
    136     mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>(
    137         static_cast<const void*>(output_tensor->flat<T>().data()));
    138 
    139     CHECK_EQ(
    140         dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res),
    141         E_SUCCESS);
    142 
    143     mkl_context.MklCleanup();
    144   }
    145 
    146  private:
    147   typedef struct {
    148     MklPoolingOpParams params;
    149     MklShape input_shape;
    150     void* pooling_res[dnnResourceNumber];
    151     dnnPrimitive_t prim_pooling_fwd = nullptr;
    152     dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr;
    153 
    154     void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
    155       bool input_in_mkl_format = input_shape.IsMklTensor();
    156       // Create or use existing DNN user layout
    157       if (input_in_mkl_format == false) {
    158         CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
    159                                      params.in_sizes, params.in_strides),
    160                  E_SUCCESS);
    161       } else {
    162         lt_user_input = (dnnLayout_t)input_shape.GetCurLayout();
    163       }
    164 
    165       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
    166       dnnPrimitiveAttributes_t primAttr = nullptr;
    167 
    168       // Create DNN primitives
    169       CHECK_EQ(dnnPoolingCreateForward_F32(
    170                    &prim_pooling_fwd, primAttr, algorithm, lt_user_input,
    171                    params.kernel_size, params.kernel_stride, params.in_offset,
    172                    dnnBorderZerosAsymm),
    173                E_SUCCESS);
    174 
    175       // Creates layout for the workspace
    176       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_fwd,
    177                                                 dnnResourceWorkspace),
    178                E_SUCCESS);
    179     }
    180 
    181     void MklCleanup() {
    182       bool input_in_mkl_format = input_shape.IsMklTensor();
    183       CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
    184       if (!input_in_mkl_format) {
    185         CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
    186       }
    187       CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS);
    188     }
    189   } MklMaxPoolingOpContext;
    190 
    191   std::vector<int32> ksize_;
    192   std::vector<int32> stride_;
    193   Padding padding_;
    194   TensorFormat data_format_;
    195   bool workspace_enabled_;
    196 };
    197 
    198 // The operation to compute MaxPool gradients.
    199 // It takes three inputs:
    200 //   - The original input tensor
    201 //   - The original output tensor
    202 //   - Backprop tensor for output
    203 // It produces one output: backprop tensor for input.
    204 template <class Device, class T>
    205 class MklMaxPoolingGradOp : public OpKernel {
    206  public:
    207   explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
    208       : OpKernel(context) {
    209     string data_format;
    210 
    211     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    212     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    213                 errors::InvalidArgument("Invalid data format"));
    214     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    215     OP_REQUIRES(context, ksize_.size() == 4,
    216                 errors::InvalidArgument("Sliding window ksize field must "
    217                                         "specify 4 dimensions"));
    218     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    219     OP_REQUIRES(context, stride_.size() == 4,
    220                 errors::InvalidArgument("Sliding window strides field must "
    221                                         "specify 4 dimensions"));
    222     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    223     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    224                 errors::Unimplemented(
    225                     "Pooling is not yet supported on the batch dimension."));
    226     workspace_enabled_ = false;
    227     // We may not get this attribute for this node if it does not go through
    228     // graph rewrite pass. So we do not check for error while retrieving this
    229     // attribute value.
    230     OP_REQUIRES_OK(context,
    231                    context->GetAttr("workspace_enabled", &workspace_enabled_));
    232   }
    233 
    234   void Compute(OpKernelContext* context) override {
    235     MklMaxPoolingGradOpContext mkl_context;
    236     // Input - The original input tensor
    237     const Tensor& tensor_in = MklGetInput(context, 0);
    238 
    239     // Output - Backprop tensor for input.
    240     Tensor* output_tensor = nullptr;
    241 
    242     GetMklShape(context, 0, &mkl_context.input_shape);
    243     GetMklShape(context, 2, &mkl_context.output_backprop_shape);
    244     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
    245 
    246     if (input_in_mkl_format == false)
    247       mkl_context.params.in_dim = tensor_in.dims();
    248     else
    249       mkl_context.params.in_dim = mkl_context.input_shape.GetDimension();
    250 
    251     MklPoolParameters pool_params;
    252     if (input_in_mkl_format == false) {
    253       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
    254                        tensor_in.shape());
    255       OP_REQUIRES(
    256           context, (pool_params.depth_window == 1),
    257           errors::Unimplemented("Depthwise max pooling not supported by MKL"));
    258 
    259     } else {
    260       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
    261                        &mkl_context.input_shape);
    262     }
    263 
    264     // Extract the parameters for the op from the pooling specs
    265     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
    266 
    267     mkl_context.MklCreateLayouts(context);
    268     OP_REQUIRES_OK(context, context->status());
    269 
    270     mkl_context.MklCreatePrimitives(context, workspace_enabled_);
    271     OP_REQUIRES_OK(context, context->status());
    272 
    273     mkl_context.MklPrepareInputs(context, workspace_enabled_);
    274     OP_REQUIRES_OK(context, context->status());
    275 
    276     // Create shape for the input back prop output
    277     TensorShape mkl_input_backprop;
    278     MklShape mkl_output_shape;
    279     mkl_output_shape.SetMklTensor(true);
    280     mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd,
    281                                   dnnResourceDiffSrc);
    282     mkl_output_shape.SetTfLayout(mkl_context.params.in_dim,
    283                                  mkl_context.params.in_sizes,
    284                                  mkl_context.params.in_strides);
    285     mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    286 
    287     mkl_input_backprop.AddDim(
    288         dnnLayoutGetMemorySize_F32(
    289             static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) /
    290         sizeof(T));
    291     AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop,
    292                               mkl_output_shape);
    293     mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
    294         static_cast<const void*>(output_tensor->flat<T>().data()));
    295 
    296     CHECK_EQ(
    297         dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
    298         E_SUCCESS);
    299 
    300     mkl_context.MklCleanup(workspace_enabled_);
    301   }
    302 
    303  private:
    304   typedef struct {
    305     MklPoolingOpParams params;
    306     MklShape input_shape, output_backprop_shape;
    307     void* pooling_resfwd[dnnResourceNumber];
    308     void* pooling_res[dnnResourceNumber];
    309     dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr,
    310                    convert_input = nullptr, convert_outbackprop = nullptr;
    311     dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr,
    312                 lt_input_user = nullptr, lt_input_prim = nullptr;
    313     void* input_buf;
    314     void* outbackprop_buf;
    315     Tensor tmp_output_buf_tensor;
    316     Tensor workspace_buf_tensor;
    317     Tensor input_buf_tensor, outbackprop_buf_tensor;
    318 
    319     void MklCreateLayouts(OpKernelContext* context) {
    320       bool input_in_mkl_format = input_shape.IsMklTensor();
    321       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    322       // Create DNN user layout for input and outbackprop or get existing layout
    323       if (input_in_mkl_format == false) {
    324         CHECK_EQ(dnnLayoutCreate_F32(&lt_input_user, params.in_dim,
    325                                      params.in_sizes, params.in_strides),
    326                  E_SUCCESS);
    327       } else {
    328         lt_input_user = (dnnLayout_t)input_shape.GetCurLayout();
    329       }
    330 
    331       // We don't care about the output layout for now as we can create it from
    332       // primitives for the max pooling fwd prop
    333       if (outbackprop_in_mkl_format == false) {
    334         CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop_user, params.in_dim,
    335                                      params.out_sizes, params.out_strides),
    336                  E_SUCCESS);
    337       } else {
    338         lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout();
    339       }
    340     }
    341 
    342     // Create DNN primitives
    343     void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) {
    344       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
    345       dnnPrimitiveAttributes_t primAttr = nullptr;
    346 
    347       if (workspace_enabled == false) {
    348         CHECK_EQ(dnnPoolingCreateForward_F32(
    349                      &prim_pooling_fwd, primAttr, algorithm, lt_input_user,
    350                      params.kernel_size, params.kernel_stride, params.in_offset,
    351                      dnnBorderZerosAsymm),
    352                  E_SUCCESS);
    353       }
    354 
    355       CHECK_EQ(dnnPoolingCreateBackward_F32(
    356                    &prim_pooling_bwd, primAttr, algorithm, lt_input_user,
    357                    params.kernel_size, params.kernel_stride, params.in_offset,
    358                    dnnBorderZerosAsymm),
    359                E_SUCCESS);
    360 
    361       // Creates conversions
    362       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    363                    &lt_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst),
    364                E_SUCCESS);
    365 
    366       if (workspace_enabled == false) {
    367         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    368                      &lt_input_prim, prim_pooling_fwd, dnnResourceSrc),
    369                  E_SUCCESS);
    370         if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) {
    371           CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user,
    372                                            lt_input_prim),
    373                    E_SUCCESS);
    374           AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf);
    375         }
    376       }
    377 
    378       if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) {
    379         CHECK_EQ(
    380             dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user,
    381                                     lt_outbackprop_prim),
    382             E_SUCCESS);
    383         AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim,
    384                        &outbackprop_buf);
    385       }
    386     }
    387 
    388     // Compare incoming tensor layouts with MKL preferred layouts and convert
    389     // data to the preferred layout if necessary
    390     void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) {
    391       const Tensor& tensor_in = MklGetInput(context, 0);
    392       const Tensor& out_backprop = MklGetInput(context, 2);
    393       bool input_in_mkl_format = input_shape.IsMklTensor();
    394       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    395 
    396       void* tmp_output_buf = nullptr;
    397       void* workspace_buf = nullptr;
    398 
    399       if (workspace_enabled == false) {
    400         if (convert_input != nullptr) {
    401           if (input_in_mkl_format == false) {
    402             CHECK_EQ(dnnConversionExecute_F32(
    403                          convert_input,
    404                          const_cast<void*>(static_cast<const void*>(
    405                              tensor_in.flat<T>().data())),
    406                          input_buf),
    407                      E_SUCCESS);
    408             CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS);
    409             convert_input = nullptr;
    410           } else {
    411             input_shape.GetConvertedFlatData(
    412                 lt_input_prim,
    413                 const_cast<void*>(
    414                     static_cast<const void*>(tensor_in.flat<T>().data())),
    415                 input_buf);
    416           }
    417           pooling_resfwd[dnnResourceSrc] = input_buf;
    418         } else {
    419           pooling_resfwd[dnnResourceSrc] = const_cast<void*>(
    420               static_cast<const void*>(tensor_in.flat<T>().data()));
    421         }
    422 
    423         dnnLayout_t lt_workspace;
    424         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    425                      &lt_workspace, prim_pooling_fwd, dnnResourceWorkspace),
    426                  E_SUCCESS);
    427         AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace,
    428                        &workspace_buf);
    429         pooling_resfwd[dnnResourceWorkspace] = workspace_buf;
    430 
    431         dnnLayoutDelete_F32(lt_workspace);
    432 
    433         // We create the layout for max pooling fwd prop tmp output here
    434         AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim,
    435                        &tmp_output_buf);
    436         pooling_resfwd[dnnResourceDst] = tmp_output_buf;
    437 
    438         CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS);
    439         pooling_res[dnnResourceWorkspace] =
    440             pooling_resfwd[dnnResourceWorkspace];
    441       } else {
    442         const Tensor& workspace = MklGetInput(context, 3);
    443         pooling_res[dnnResourceWorkspace] = const_cast<void*>(
    444             static_cast<const void*>(workspace.flat<T>().data()));
    445       }
    446 
    447       // Out backprop conversions if needed
    448       if (convert_outbackprop != nullptr) {
    449         if (outbackprop_in_mkl_format == false) {
    450           CHECK_EQ(dnnConversionExecute_F32(
    451                        convert_outbackprop,
    452                        const_cast<void*>(static_cast<const void*>(
    453                            out_backprop.flat<T>().data())),
    454                        outbackprop_buf),
    455                    E_SUCCESS);
    456           CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS);
    457         } else {
    458           output_backprop_shape.GetConvertedFlatData(
    459               lt_outbackprop_prim,
    460               const_cast<void*>(
    461                   static_cast<const void*>(out_backprop.flat<T>().data())),
    462               outbackprop_buf);
    463         }
    464         pooling_res[dnnResourceDiffDst] = outbackprop_buf;
    465       } else {
    466         pooling_res[dnnResourceDiffDst] = const_cast<void*>(
    467             static_cast<const void*>(out_backprop.flat<T>().data()));
    468       }
    469     }
    470 
    471     void MklCleanup(bool workspace_enabled) {
    472       bool input_in_mkl_format = input_shape.IsMklTensor();
    473       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    474       if (workspace_enabled == false) {
    475         CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
    476       }
    477       CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS);
    478       if (outbackprop_in_mkl_format == false) {
    479         CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS);
    480       }
    481       CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS);
    482       if (input_in_mkl_format == false) {
    483         CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS);
    484       }
    485       if (workspace_enabled == false) {
    486         CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS);
    487       }
    488     }
    489   } MklMaxPoolingGradOpContext;
    490 
    491   std::vector<int32> ksize_;
    492   std::vector<int32> stride_;
    493   Padding padding_;
    494   TensorFormat data_format_;
    495 
    496   bool workspace_enabled_;
    497 };  // MklMaxPoolingGradOp
    498 
    499 #else
    500 
    501 // An implementation of MaxPooling (forward).
    502 template <typename Device, typename T>
    503 class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
    504  public:
    505   explicit MklMaxPoolingOp(OpKernelConstruction* context)
    506       : MklPoolingForwardOpBase<T>(context) {
    507     // In Max Pooling, MKLDNN does not allow passing workspace as NULL.
    508     // So we set workspace_enabled_ to true.
    509     this->workspace_enabled_ = true;
    510   }
    511 
    512   void Compute(OpKernelContext* context) override {
    513     try {
    514       const Tensor& input_tensor =
    515           MklGetInput(context, this->kInputTensorIndexInput);
    516       MklDnnShape dnn_shape_input;
    517       GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
    518       this->SanityCheckInput(context, input_tensor, dnn_shape_input);
    519       if (!context->status().ok()) return;
    520 
    521       MklDnnData<T> dnn_data_input(&cpu_engine);
    522       MklDnnData<T> dnn_data_output(&cpu_engine);
    523 
    524       // initialize variables for the pooling op
    525       MklPoolParameters pool_params;
    526       // check whether pooling is 2D or 3D
    527       bool is_pool2d = (this->ksize_.size() == 4);
    528       // Get the input tensor and initialize the pooling parameters
    529       TensorShape input_tensor_shape = input_tensor.shape();
    530       this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
    531                                   input_tensor_shape);
    532       OP_REQUIRES_OK(context, context->status());
    533 
    534       // Declare output tensor
    535       Tensor* output_tensor = nullptr;
    536       memory::dims output_dims_mkl_order;
    537       this->GetOutputDims(pool_params, &output_dims_mkl_order);
    538 
    539       // If input is an empty tensor, allocate an empty output tensor and return
    540       if (input_tensor.NumElements() == 0) {
    541         const int kOutputIndex = 0;
    542         this->AllocateEmptyOutputTensor(context, kOutputIndex, &pool_params,
    543                                         output_dims_mkl_order, &output_tensor);
    544         return;
    545       }
    546 
    547       // Get the input memory descriptor
    548       memory::desc input_md =
    549           dnn_shape_input.IsMklTensor()
    550               ? dnn_shape_input.GetMklLayout()
    551               : is_pool2d ? memory::desc(
    552                                 TFShapeToMklDnnDimsInNCHW(
    553                                     input_tensor_shape, this->data_format_tf_),
    554                                 MklDnnType<T>(), this->data_format_mkldnn_)
    555                           : memory::desc(
    556                                 TFShapeToMklDnnDimsInNCDHW(
    557                                     input_tensor_shape, this->data_format_tf_),
    558                                 MklDnnType<T>(), this->data_format_mkldnn_);
    559 
    560       // Get src/filter/stride/padding information
    561       memory::dims src_dims =
    562           dnn_shape_input.IsMklTensor()
    563               ? dnn_shape_input.GetSizesAsMklDnnDims()
    564               : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
    565                                                       this->data_format_tf_)
    566                           : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
    567                                                        this->data_format_tf_);
    568       memory::dims filter_dims, strides, padding_left, padding_right;
    569       this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
    570                              &padding_left, &padding_right, is_pool2d);
    571 
    572       // Get a pooling op from the cached pool
    573       MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
    574       prop_kind pooling_prop_kind;
    575       bool int8_forward_inference =
    576           std::is_same<T, qint8>::value || std::is_same<T, quint8>::value;
    577       if (int8_forward_inference)
    578         pooling_prop_kind = prop_kind::forward_inference;
    579       else
    580         pooling_prop_kind = prop_kind::forward_training;
    581       MklPoolingParams fwdParams(src_dims, output_dims_mkl_order, filter_dims,
    582                                  strides, padding_left, padding_right,
    583                                  algorithm::pooling_max, pooling_prop_kind);
    584       pooling_fwd = MklPoolingFwdPrimitiveFactory<T>::Get(fwdParams);
    585 
    586       // allocate output tensor
    587       this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
    588                                  output_dims_mkl_order,
    589                                  this->data_format_mkldnn_, &output_tensor);
    590       OP_REQUIRES_OK(context, context->status());
    591       dnn_data_output.SetUsrMem(output_dims_mkl_order,
    592                                 pooling_fwd->GetDstMemoryFormat(),
    593                                 output_tensor);
    594 
    595       // check wehther we need to reorder src
    596       const T* src_data = input_tensor.flat<T>().data();
    597       if (input_md.data.format != pooling_fwd->GetSrcMemoryFormat()) {
    598         dnn_data_input.SetUsrMem(input_md, &input_tensor);
    599         auto src_target_primitive_desc = memory::primitive_desc(
    600             {{src_dims}, MklDnnType<T>(), pooling_fwd->GetSrcMemoryFormat()},
    601             cpu_engine);
    602         dnn_data_input.CheckReorderToOpMem(src_target_primitive_desc);
    603         src_data = const_cast<T*>(
    604             reinterpret_cast<T*>(dnn_data_input.GetOpMem().get_data_handle()));
    605       }
    606 
    607       T* dst_data = output_tensor->flat<T>().data();
    608 
    609       if (int8_forward_inference) {
    610         // Execute pooling op
    611         pooling_fwd->Execute(src_data, dst_data);
    612 
    613         // pass min, max from input to output
    614         const Tensor& min_input_t = MklGetInput(context, 1);
    615         const Tensor& max_input_t = MklGetInput(context, 2);
    616         const float min_input = min_input_t.flat<float>()(0);
    617         const float max_input = max_input_t.flat<float>()(0);
    618 
    619         Tensor* output_min = nullptr;
    620         Tensor* output_max = nullptr;
    621         MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
    622         output_min_mkl_shape.SetMklTensor(false);
    623         output_max_mkl_shape.SetMklTensor(false);
    624         AllocateOutputSetMklShape(context, 1, &output_min, {},
    625                                   output_min_mkl_shape);
    626         AllocateOutputSetMklShape(context, 2, &output_max, {},
    627                                   output_max_mkl_shape);
    628         output_min->flat<float>()(0) = min_input;
    629         output_max->flat<float>()(0) = max_input;
    630       } else {
    631         MklDnnData<uint8> dnn_data_wksp(&cpu_engine);
    632         AllocateWorkspaceTensor(context, *(pooling_fwd->GetPoolingFwdPd()),
    633                                 &dnn_data_wksp);
    634         OP_REQUIRES_OK(context, context->status());
    635         T* ws_data =
    636             static_cast<T*>(dnn_data_wksp.GetOpMem().get_data_handle());
    637 
    638         // execute pooling op
    639         pooling_fwd->Execute(src_data, dst_data, ws_data);
    640       }
    641     } catch (mkldnn::error& e) {
    642       string error_msg = "Status: " + std::to_string(e.status) +
    643                          ", message: " + string(e.message) + ", in file " +
    644                          string(__FILE__) + ":" + std::to_string(__LINE__);
    645       OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
    646                                               error_msg));
    647     }
    648   }
    649 
    650  private:
    651   const int kOutputTensorIndexWorkspace = 1;
    652   engine cpu_engine = engine(engine::cpu, 0);
    653 
    654   void AllocateWorkspaceTensor(
    655       OpKernelContext* context,
    656       const pooling_forward::primitive_desc& pool_fwd_prim_desc,
    657       MklDnnData<uint8>* dnn_data_wksp) {
    658     CHECK_NOTNULL(dnn_data_wksp);
    659     Tensor* workspace_tensor = nullptr;
    660     memory::primitive_desc workspace_pd =
    661         pool_fwd_prim_desc.workspace_primitive_desc();
    662     size_t workspace_bytes = workspace_pd.get_size();
    663     MklDnnShape workspace_mkl_shape;
    664     workspace_mkl_shape.SetMklTensor(false);
    665     TensorShape workspace_tf_shape;
    666     workspace_tf_shape.AddDim(workspace_bytes);
    667     AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace,
    668                               &workspace_tensor, workspace_tf_shape,
    669                               workspace_mkl_shape);
    670     CHECK_NOTNULL(workspace_tensor);
    671     dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
    672   }
    673 };
    674 
    675 // The operation to compute MaxPool gradients.
    676 // It takes three inputs:
    677 //   - The original input tensor
    678 //   - The original output tensor
    679 //   - Backprop tensor for output
    680 // It produces one output: backprop tensor for input.
    681 template <class Device, class T>
    682 class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
    683  public:
    684   explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
    685       : MklPoolingBackwardOpBase<T>(context) {}
    686   void Compute(OpKernelContext* context) override {
    687     try {
    688       auto cpu_engine = engine(engine::cpu, 0);
    689       const Tensor& orig_input_tensor =
    690           MklGetInput(context, kInputTensorIndexOrigInput);
    691       const Tensor& grad_tensor =
    692           MklGetInput(context, kInputTensorIndexGradient);
    693       const Tensor& workspace_tensor =
    694           MklGetInput(context, kInputTensorIndexWorkspace);
    695       MklDnnShape orig_input_mkl_shape, grad_mkl_shape;
    696       GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape);
    697       GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape);
    698       if (!context->status().ok()) return;
    699 
    700       MklDnnData<T> grad_dnn_data(&cpu_engine);
    701       MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
    702 
    703       MklPoolParameters pool_params;
    704       TensorShape orig_input_shape = orig_input_tensor.shape();
    705 
    706       bool is_pool2d = (this->ksize_.size() == 4);
    707       this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
    708                                   orig_input_shape);
    709 
    710       memory::dims filter_dims, strides, padding_left, padding_right;
    711       this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
    712                              &padding_left, &padding_right, is_pool2d);
    713 
    714       memory::dims orig_input_dims_mkl_order =
    715           orig_input_mkl_shape.IsMklTensor()
    716               ? orig_input_mkl_shape.GetSizesAsMklDnnDims()
    717               : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
    718                                                       this->data_format_tf_)
    719                           : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
    720                                                        this->data_format_tf_);
    721 
    722       memory::dims diff_dst_dims =
    723           grad_mkl_shape.IsMklTensor()
    724               ? grad_mkl_shape.GetSizesAsMklDnnDims()
    725               : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
    726                                                       this->data_format_tf_)
    727                           : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
    728                                                        this->data_format_tf_);
    729 
    730       memory::dims output_dims_mkl_order;
    731       this->GetOutputDims(pool_params, &output_dims_mkl_order);
    732 
    733       MklPoolingParams bwdParams(
    734           orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims,
    735           strides, padding_left, padding_right, algorithm::pooling_max,
    736           prop_kind::forward_training);
    737       MklPoolingBwdPrimitive<T>* pooling_bwd =
    738           MklPoolingBwdPrimitiveFactory<T>::Get(bwdParams);
    739 
    740       // allocate output tensor and memory primitive
    741       Tensor* output_tensor = nullptr;
    742       this->AllocateOutputTensor(context, *(pooling_bwd->GetPoolingBwdPd()),
    743                                  orig_input_dims_mkl_order,
    744                                  this->data_format_mkldnn_, &output_tensor);
    745       // get diff_dst mem desc
    746       memory::desc diff_dst_md =
    747           grad_mkl_shape.IsMklTensor()
    748               ? grad_mkl_shape.GetMklLayout()
    749               : memory::desc(diff_dst_dims, MklDnnType<T>(),
    750                              this->data_format_mkldnn_);
    751       // check if diff_dst needs to be reordered
    752       const T* diff_dst_data = grad_tensor.flat<T>().data();
    753       if (diff_dst_md.data.format != pooling_bwd->GetDiffDstFormat()) {
    754         auto target_diff_dst = memory::primitive_desc(
    755             {{diff_dst_dims}, MklDnnType<T>(), pooling_bwd->GetDiffDstFormat()},
    756             cpu_engine);
    757         grad_dnn_data.SetUsrMem(diff_dst_md, &grad_tensor);
    758         grad_dnn_data.CheckReorderToOpMem(target_diff_dst);
    759         diff_dst_data = const_cast<T*>(
    760             reinterpret_cast<T*>(grad_dnn_data.GetOpMem().get_data_handle()));
    761       }
    762 
    763       void* ws_data = static_cast<void*>(
    764           const_cast<uint8*>(workspace_tensor.flat<uint8>().data()));
    765 
    766       auto ws_md =
    767           pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc();
    768       if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) {
    769         memory::dims ws_dims;
    770         ws_dims.assign(ws_md.data.dims, ws_md.data.dims + ws_md.data.ndims);
    771         auto target_ws =
    772             memory::primitive_desc({{ws_dims},
    773                                     pooling_bwd->GetWorkspaceDataType(),
    774                                     pooling_bwd->GetWorkspaceFormat()},
    775                                    cpu_engine);
    776         workspace_dnn_data.SetUsrMem(ws_md, &workspace_tensor);
    777         workspace_dnn_data.CheckReorderToOpMem(target_ws);
    778         ws_data = workspace_dnn_data.GetOpMem().get_data_handle();
    779       }
    780 
    781       T* diff_src_data = output_tensor->flat<T>().data();
    782 
    783       // execute pooling
    784       pooling_bwd->Execute(diff_dst_data, diff_src_data, ws_data);
    785     } catch (mkldnn::error& e) {
    786       string error_msg = "Status:" + std::to_string(e.status) +
    787                          ", message: " + string(e.message) + ". in file " +
    788                          string(__FILE__) + ":" + std::to_string(__LINE__);
    789       OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
    790                                               error_msg));
    791     }
    792   }
    793 
    794  private:
    795   // .Input("orig_input: T")
    796   // .Input("orig_output: T")
    797   // .Input("grad: T")
    798   // .Input("workspace: T")
    799   const int kInputTensorIndexOrigInput = 0;
    800   const int kInputTensorIndexOrigOutput = 1;
    801   const int kInputTensorIndexGradient = 2;
    802   const int kInputTensorIndexWorkspace = 3;
    803 
    804   void ConfigureWorkspace(const Tensor& workspace_tensor,
    805                           memory::primitive_desc workspace_pd,
    806                           MklDnnData<uint8>* workspace_dnn_data) {
    807     CHECK_NOTNULL(workspace_dnn_data);
    808 
    809     workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
    810   }
    811 
    812   void SanityCheckInputs(OpKernelContext* context,
    813                          const Tensor& orig_input_tensor,
    814                          const Tensor& orig_output_tensor,
    815                          const Tensor& grad_tensor,
    816                          const Tensor& workspace_tensor,
    817                          const MklDnnShape& orig_input_mkl_shape,
    818                          const MklDnnShape& orig_output_mkl_shape,
    819                          const MklDnnShape& grad_mkl_shape,
    820                          const MklDnnShape& workspace_mkl_shape) {
    821     if (!orig_input_mkl_shape.IsMklTensor()) {
    822       OP_REQUIRES(context, orig_input_tensor.dims() == 4,
    823                   errors::InvalidArgument(
    824                       "Original input shape must be 4-dimensional"));
    825     } else {
    826       OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4,
    827                   errors::InvalidArgument(
    828                       "Original input shape must be 4-dimensional"));
    829     }
    830     if (!orig_output_mkl_shape.IsMklTensor()) {
    831       OP_REQUIRES(
    832           context, orig_output_tensor.dims() == 4,
    833           errors::InvalidArgument("Original output must be 4-dimensional"));
    834     } else {
    835       OP_REQUIRES(
    836           context, orig_output_mkl_shape.GetDimension() == 4,
    837           errors::InvalidArgument("Original output must be 4-dimensional"));
    838     }
    839     if (!grad_mkl_shape.IsMklTensor()) {
    840       OP_REQUIRES(context, grad_tensor.dims() == 4,
    841                   errors::InvalidArgument("Gradient must be 4-dimensional"));
    842     } else {
    843       OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4,
    844                   errors::InvalidArgument("Gradient must be 4-dimensional"));
    845     }
    846     if (this->workspace_enabled_) {
    847       // The workspace should not be an MKL tensor
    848       OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false,
    849                   errors::InvalidArgument(
    850                       "Workspace tensor should not be an MKL Tensor."));
    851       // It should only have one dimension
    852       OP_REQUIRES(
    853           context, workspace_tensor.dims() == 1,
    854           errors::InvalidArgument("Workspace tensor must be 1-dimensional"));
    855     } else {
    856       OP_REQUIRES(
    857           context, this->workspace_enabled_,
    858           errors::Unimplemented("MKL-DNN Max Pooling does not "
    859                                 "yet support the use case "
    860                                 "where MaxPoolGrad is called without first"
    861                                 " calling MaxPool."));
    862     }
    863   }
    864 };  // MklMaxPoolingGradOp
    865 
    866 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D")
    867                             .Device(DEVICE_CPU)
    868                             .TypeConstraint<float>("T")
    869                             .Label(mkl_op_registry::kMklOpLabel),
    870                         MklMaxPoolingOp<CPUDevice, float>);
    871 
    872 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad")
    873                             .Device(DEVICE_CPU)
    874                             .TypeConstraint<float>("T")
    875                             .Label(mkl_op_registry::kMklOpLabel),
    876                         MklMaxPoolingGradOp<CPUDevice, float>);
    877 
    878 #endif  // INTEL_MKL_ML_ONLY
    879 
    880 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
    881                             .Device(DEVICE_CPU)
    882                             .TypeConstraint<float>("T")
    883                             .Label(mkl_op_registry::kMklOpLabel),
    884                         MklMaxPoolingOp<CPUDevice, float>);
    885 
    886 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool")
    887                             .Device(DEVICE_CPU)
    888                             .TypeConstraint<quint8>("T")
    889                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
    890                         MklMaxPoolingOp<CPUDevice, quint8>);
    891 
    892 REGISTER_KERNEL_BUILDER(Name("_MklQuantizedMaxPool")
    893                             .Device(DEVICE_CPU)
    894                             .TypeConstraint<qint8>("T")
    895                             .Label(mkl_op_registry::kMklQuantizedOpLabel),
    896                         MklMaxPoolingOp<CPUDevice, qint8>);
    897 
    898 REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad")
    899                             .Device(DEVICE_CPU)
    900                             .TypeConstraint<float>("T")
    901                             .Label(mkl_op_registry::kMklOpLabel),
    902                         MklMaxPoolingGradOp<CPUDevice, float>);
    903 
    904 }  // namespace tensorflow
    905 #endif  // INTEL_MKL
    906