Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 #ifdef INTEL_MKL
     18 #define EIGEN_USE_THREADS
     19 #include "tensorflow/core/framework/op_kernel.h"
     20 #include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
     21 #include "tensorflow/core/lib/core/errors.h"
     22 #include "tensorflow/core/util/mkl_util.h"
     23 #include "tensorflow/core/util/padding.h"
     24 
     25 #ifndef INTEL_MKL_ML
     26 #include <algorithm>
     27 #include "mkldnn.hpp"
     28 using mkldnn::algorithm;
     29 using mkldnn::engine;
     30 using mkldnn::error;
     31 using mkldnn::memory;
     32 using mkldnn::padding_kind;
     33 using mkldnn::pooling_backward;
     34 using mkldnn::pooling_forward;
     35 using mkldnn::prop_kind;
     36 #endif
     37 
     38 namespace tensorflow {
     39 
     40 typedef Eigen::ThreadPoolDevice CPUDevice;
     41 
     42 // MKL-DNN is now default. MKL-ML must be specified explicitly.
     43 #ifdef INTEL_MKL_ML
     44 
     45 // An implementation of MaxPooling (forward).
     46 template <typename Device, typename T>
     47 class MklMaxPoolingOp : public OpKernel {
     48  public:
     49   explicit MklMaxPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
     50     string data_format;
     51 
     52     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     53     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     54                 errors::InvalidArgument("Invalid data format"));
     55     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
     56     OP_REQUIRES(context, ksize_.size() == 4,
     57                 errors::InvalidArgument("Sliding window ksize field must "
     58                                         "specify 4 dimensions"));
     59     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
     60     OP_REQUIRES(context, stride_.size() == 4,
     61                 errors::InvalidArgument("Sliding window stride field must "
     62                                         "specify 4 dimensions"));
     63     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     64     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
     65                 errors::Unimplemented("Pooling is not yet supported on the "
     66                                       "batch dimension."));
     67 
     68     workspace_enabled_ = false;
     69     // We may not get this attribute for this node if it does not go through
     70     // graph rewrite pass. So we do not check for error while retrieving this
     71     // attribute value.
     72     context->GetAttr("workspace_enabled", &workspace_enabled_);
     73   }
     74 
     75   void Compute(OpKernelContext* context) override {
     76     MklMaxPoolingOpContext mkl_context;
     77     // Get the input tensor
     78     const Tensor& tensor_in = MklGetInput(context, 0);
     79     GetMklShape(context, 0, &mkl_context.input_shape);
     80     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
     81 
     82     mkl_context.params.in_dim = 4;
     83     MklPoolParameters pool_params;
     84     if (input_in_mkl_format == false) {
     85       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     86                        tensor_in.shape());
     87       OP_REQUIRES(
     88           context, (pool_params.depth_window == 1),
     89           errors::Unimplemented("Depthwise max pooling not supported by MKL"));
     90 
     91     } else {
     92       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
     93                        &mkl_context.input_shape);
     94     }
     95 
     96     // Extract the parameters for the op from the pooling specs
     97 
     98     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
     99 
    100     mkl_context.MklCreateLayoutsAndPrimitives(context);
    101     OP_REQUIRES_OK(context, context->status());
    102 
    103     // Declare output tensor
    104     TensorShape tensor_out_shape;
    105     MklShape mkl_out_shape, mkl_workspace_shape;
    106     mkl_out_shape.SetMklTensor(true);
    107     mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
    108     mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
    109                               mkl_context.params.out_sizes,
    110                               mkl_context.params.out_strides);
    111     mkl_out_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    112 
    113     Tensor* output_tensor = nullptr;
    114     tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    115                                 mkl_out_shape.GetMklLayout())) /
    116                             sizeof(T));
    117     AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape,
    118                               mkl_out_shape);
    119 
    120     Tensor* workspace_tensor;
    121     void* workspace_buf = nullptr;
    122 
    123     TensorShape workspace_shape;
    124     mkl_workspace_shape.SetMklTensor(false);
    125     workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    126                                mkl_context.lt_workspace)) /
    127                            sizeof(T));
    128     AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape,
    129                               mkl_workspace_shape);
    130 
    131     mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
    132         static_cast<const void*>(workspace_tensor->flat<T>().data()));
    133     mkl_context.pooling_res[dnnResourceSrc] =
    134         const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data()));
    135     mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>(
    136         static_cast<const void*>(output_tensor->flat<T>().data()));
    137 
    138     CHECK_EQ(
    139         dnnExecute_F32(mkl_context.prim_pooling_fwd, mkl_context.pooling_res),
    140         E_SUCCESS);
    141 
    142     mkl_context.MklCleanup();
    143   }
    144 
    145  private:
    146   typedef struct {
    147     MklPoolingOpParams params;
    148     MklShape input_shape;
    149     void* pooling_res[dnnResourceNumber];
    150     dnnPrimitive_t prim_pooling_fwd = nullptr;
    151     dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr;
    152 
    153     void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
    154       bool input_in_mkl_format = input_shape.IsMklTensor();
    155       // Create or use existing DNN user layout
    156       if (input_in_mkl_format == false) {
    157         CHECK_EQ(dnnLayoutCreate_F32(&lt_user_input, params.in_dim,
    158                                      params.in_sizes, params.in_strides),
    159                  E_SUCCESS);
    160       } else {
    161         lt_user_input = (dnnLayout_t)input_shape.GetCurLayout();
    162       }
    163 
    164       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
    165       dnnPrimitiveAttributes_t primAttr = nullptr;
    166 
    167       // Create DNN primitives
    168       CHECK_EQ(dnnPoolingCreateForward_F32(
    169                    &prim_pooling_fwd, primAttr, algorithm, lt_user_input,
    170                    params.kernel_size, params.kernel_stride, params.in_offset,
    171                    dnnBorderZerosAsymm),
    172                E_SUCCESS);
    173 
    174       // Creates layout for the workspace
    175       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, prim_pooling_fwd,
    176                                                 dnnResourceWorkspace),
    177                E_SUCCESS);
    178     }
    179 
    180     void MklCleanup() {
    181       bool input_in_mkl_format = input_shape.IsMklTensor();
    182       CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
    183       if (!input_in_mkl_format) {
    184         CHECK_EQ(dnnLayoutDelete_F32(lt_user_input), E_SUCCESS);
    185       }
    186       CHECK_EQ(dnnLayoutDelete_F32(lt_workspace), E_SUCCESS);
    187     }
    188   } MklMaxPoolingOpContext;
    189 
    190   std::vector<int32> ksize_;
    191   std::vector<int32> stride_;
    192   Padding padding_;
    193   TensorFormat data_format_;
    194   bool workspace_enabled_;
    195 };
    196 
    197 // The operation to compute MaxPool gradients.
    198 // It takes three inputs:
    199 //   - The original input tensor
    200 //   - The original output tensor
    201 //   - Backprop tensor for output
    202 // It produces one output: backprop tensor for input.
    203 template <class Device, class T>
    204 class MklMaxPoolingGradOp : public OpKernel {
    205  public:
    206   explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
    207       : OpKernel(context) {
    208     string data_format;
    209 
    210     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    211     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    212                 errors::InvalidArgument("Invalid data format"));
    213     OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
    214     OP_REQUIRES(context, ksize_.size() == 4,
    215                 errors::InvalidArgument("Sliding window ksize field must "
    216                                         "specify 4 dimensions"));
    217     OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
    218     OP_REQUIRES(context, stride_.size() == 4,
    219                 errors::InvalidArgument("Sliding window strides field must "
    220                                         "specify 4 dimensions"));
    221     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    222     OP_REQUIRES(context, ksize_[0] == 1 && stride_[0] == 1,
    223                 errors::Unimplemented(
    224                     "Pooling is not yet supported on the batch dimension."));
    225     workspace_enabled_ = false;
    226     // We may not get this attribute for this node if it does not go through
    227     // graph rewrite pass. So we do not check for error while retrieving this
    228     // attribute value.
    229     context->GetAttr("workspace_enabled", &workspace_enabled_);
    230   }
    231 
    232   void Compute(OpKernelContext* context) override {
    233     MklMaxPoolingGradOpContext mkl_context;
    234     // Input - The original input tensor
    235     const Tensor& tensor_in = MklGetInput(context, 0);
    236 
    237     // Output - Backprop tensor for input.
    238     Tensor* output_tensor = nullptr;
    239 
    240     GetMklShape(context, 0, &mkl_context.input_shape);
    241     GetMklShape(context, 2, &mkl_context.output_backprop_shape);
    242     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
    243 
    244     if (input_in_mkl_format == false)
    245       mkl_context.params.in_dim = tensor_in.dims();
    246     else
    247       mkl_context.params.in_dim = mkl_context.input_shape.GetDimension();
    248 
    249     MklPoolParameters pool_params;
    250     if (input_in_mkl_format == false) {
    251       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
    252                        tensor_in.shape());
    253       OP_REQUIRES(
    254           context, (pool_params.depth_window == 1),
    255           errors::Unimplemented("Depthwise max pooling not supported by MKL"));
    256 
    257     } else {
    258       pool_params.Init(context, ksize_, stride_, padding_, data_format_,
    259                        &mkl_context.input_shape);
    260     }
    261 
    262     // Extract the parameters for the op from the pooling specs
    263     ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
    264 
    265     mkl_context.MklCreateLayouts(context);
    266     OP_REQUIRES_OK(context, context->status());
    267 
    268     mkl_context.MklCreatePrimitives(context, workspace_enabled_);
    269     OP_REQUIRES_OK(context, context->status());
    270 
    271     mkl_context.MklPrepareInputs(context, workspace_enabled_);
    272     OP_REQUIRES_OK(context, context->status());
    273 
    274     // Create shape for the input back prop output
    275     TensorShape mkl_input_backprop;
    276     MklShape mkl_output_shape;
    277     mkl_output_shape.SetMklTensor(true);
    278     mkl_output_shape.SetMklLayout(mkl_context.prim_pooling_bwd,
    279                                   dnnResourceDiffSrc);
    280     mkl_output_shape.SetTfLayout(mkl_context.params.in_dim,
    281                                  mkl_context.params.in_sizes,
    282                                  mkl_context.params.in_strides);
    283     mkl_output_shape.SetTfDimOrder(mkl_context.params.in_dim, data_format_);
    284 
    285     mkl_input_backprop.AddDim(
    286         dnnLayoutGetMemorySize_F32(
    287             static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) /
    288         sizeof(T));
    289     AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop,
    290                               mkl_output_shape);
    291     mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
    292         static_cast<const void*>(output_tensor->flat<T>().data()));
    293 
    294     CHECK_EQ(
    295         dnnExecute_F32(mkl_context.prim_pooling_bwd, mkl_context.pooling_res),
    296         E_SUCCESS);
    297 
    298     mkl_context.MklCleanup(workspace_enabled_);
    299   }
    300 
    301  private:
    302   typedef struct {
    303     MklPoolingOpParams params;
    304     MklShape input_shape, output_backprop_shape;
    305     void* pooling_resfwd[dnnResourceNumber];
    306     void* pooling_res[dnnResourceNumber];
    307     dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr,
    308                    convert_input = nullptr, convert_outbackprop = nullptr;
    309     dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr,
    310                 lt_input_user = nullptr, lt_input_prim = nullptr;
    311     void* input_buf;
    312     void* outbackprop_buf;
    313     Tensor tmp_output_buf_tensor;
    314     Tensor workspace_buf_tensor;
    315     Tensor input_buf_tensor, outbackprop_buf_tensor;
    316 
    317     void MklCreateLayouts(OpKernelContext* context) {
    318       bool input_in_mkl_format = input_shape.IsMklTensor();
    319       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    320       // Create DNN user layout for input and outbackprop or get existing layout
    321       if (input_in_mkl_format == false) {
    322         CHECK_EQ(dnnLayoutCreate_F32(&lt_input_user, params.in_dim,
    323                                      params.in_sizes, params.in_strides),
    324                  E_SUCCESS);
    325       } else {
    326         lt_input_user = (dnnLayout_t)input_shape.GetCurLayout();
    327       }
    328 
    329       // We don't care about the output layout for now as we can create it from
    330       // primitives for the max pooling fwd prop
    331       if (outbackprop_in_mkl_format == false) {
    332         CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop_user, params.in_dim,
    333                                      params.out_sizes, params.out_strides),
    334                  E_SUCCESS);
    335       } else {
    336         lt_outbackprop_user = (dnnLayout_t)output_backprop_shape.GetCurLayout();
    337       }
    338     }
    339 
    340     // Create DNN primitives
    341     void MklCreatePrimitives(OpKernelContext* context, bool workspace_enabled) {
    342       dnnAlgorithm_t algorithm = dnnAlgorithmPoolingMax;
    343       dnnPrimitiveAttributes_t primAttr = nullptr;
    344 
    345       if (workspace_enabled == false) {
    346         CHECK_EQ(dnnPoolingCreateForward_F32(
    347                      &prim_pooling_fwd, primAttr, algorithm, lt_input_user,
    348                      params.kernel_size, params.kernel_stride, params.in_offset,
    349                      dnnBorderZerosAsymm),
    350                  E_SUCCESS);
    351       }
    352 
    353       CHECK_EQ(dnnPoolingCreateBackward_F32(
    354                    &prim_pooling_bwd, primAttr, algorithm, lt_input_user,
    355                    params.kernel_size, params.kernel_stride, params.in_offset,
    356                    dnnBorderZerosAsymm),
    357                E_SUCCESS);
    358 
    359       // Creates conversions
    360       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    361                    &lt_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst),
    362                E_SUCCESS);
    363 
    364       if (workspace_enabled == false) {
    365         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    366                      &lt_input_prim, prim_pooling_fwd, dnnResourceSrc),
    367                  E_SUCCESS);
    368         if (!dnnLayoutCompare_F32(lt_input_user, lt_input_prim)) {
    369           CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input_user,
    370                                            lt_input_prim),
    371                    E_SUCCESS);
    372           AllocTmpBuffer(context, &input_buf_tensor, lt_input_prim, &input_buf);
    373         }
    374       }
    375 
    376       if (!dnnLayoutCompare_F32(lt_outbackprop_user, lt_outbackprop_prim)) {
    377         CHECK_EQ(
    378             dnnConversionCreate_F32(&convert_outbackprop, lt_outbackprop_user,
    379                                     lt_outbackprop_prim),
    380             E_SUCCESS);
    381         AllocTmpBuffer(context, &outbackprop_buf_tensor, lt_outbackprop_prim,
    382                        &outbackprop_buf);
    383       }
    384     }
    385 
    386     // Compare incoming tensor layouts with MKL preferred layouts and convert
    387     // data to the preferred layout if necessary
    388     void MklPrepareInputs(OpKernelContext* context, bool workspace_enabled) {
    389       const Tensor& tensor_in = MklGetInput(context, 0);
    390       const Tensor& out_backprop = MklGetInput(context, 2);
    391       bool input_in_mkl_format = input_shape.IsMklTensor();
    392       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    393 
    394       void* tmp_output_buf = nullptr;
    395       void* workspace_buf = nullptr;
    396 
    397       if (workspace_enabled == false) {
    398         if (convert_input != nullptr) {
    399           if (input_in_mkl_format == false) {
    400             CHECK_EQ(dnnConversionExecute_F32(
    401                          convert_input,
    402                          const_cast<void*>(static_cast<const void*>(
    403                              tensor_in.flat<T>().data())),
    404                          input_buf),
    405                      E_SUCCESS);
    406             CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS);
    407             convert_input = nullptr;
    408           } else {
    409             input_shape.GetConvertedFlatData(
    410                 lt_input_prim,
    411                 const_cast<void*>(
    412                     static_cast<const void*>(tensor_in.flat<T>().data())),
    413                 input_buf);
    414           }
    415           pooling_resfwd[dnnResourceSrc] = input_buf;
    416         } else {
    417           pooling_resfwd[dnnResourceSrc] = const_cast<void*>(
    418               static_cast<const void*>(tensor_in.flat<T>().data()));
    419         }
    420 
    421         dnnLayout_t lt_workspace;
    422         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    423                      &lt_workspace, prim_pooling_fwd, dnnResourceWorkspace),
    424                  E_SUCCESS);
    425         AllocTmpBuffer(context, &workspace_buf_tensor, lt_workspace,
    426                        &workspace_buf);
    427         pooling_resfwd[dnnResourceWorkspace] = workspace_buf;
    428 
    429         dnnLayoutDelete_F32(lt_workspace);
    430 
    431         // We create the layout for max pooling fwd prop tmp output here
    432         AllocTmpBuffer(context, &tmp_output_buf_tensor, lt_outbackprop_prim,
    433                        &tmp_output_buf);
    434         pooling_resfwd[dnnResourceDst] = tmp_output_buf;
    435 
    436         CHECK_EQ(dnnExecute_F32(prim_pooling_fwd, pooling_resfwd), E_SUCCESS);
    437         pooling_res[dnnResourceWorkspace] =
    438             pooling_resfwd[dnnResourceWorkspace];
    439       } else {
    440         const Tensor& workspace = MklGetInput(context, 3);
    441         pooling_res[dnnResourceWorkspace] = const_cast<void*>(
    442             static_cast<const void*>(workspace.flat<T>().data()));
    443       }
    444 
    445       // Out backprop conversions if needed
    446       if (convert_outbackprop != nullptr) {
    447         if (outbackprop_in_mkl_format == false) {
    448           CHECK_EQ(dnnConversionExecute_F32(
    449                        convert_outbackprop,
    450                        const_cast<void*>(static_cast<const void*>(
    451                            out_backprop.flat<T>().data())),
    452                        outbackprop_buf),
    453                    E_SUCCESS);
    454           CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS);
    455         } else {
    456           output_backprop_shape.GetConvertedFlatData(
    457               lt_outbackprop_prim,
    458               const_cast<void*>(
    459                   static_cast<const void*>(out_backprop.flat<T>().data())),
    460               outbackprop_buf);
    461         }
    462         pooling_res[dnnResourceDiffDst] = outbackprop_buf;
    463       } else {
    464         pooling_res[dnnResourceDiffDst] = const_cast<void*>(
    465             static_cast<const void*>(out_backprop.flat<T>().data()));
    466       }
    467     }
    468 
    469     void MklCleanup(bool workspace_enabled) {
    470       bool input_in_mkl_format = input_shape.IsMklTensor();
    471       bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
    472       if (workspace_enabled == false) {
    473         CHECK_EQ(dnnDelete_F32(prim_pooling_fwd), E_SUCCESS);
    474       }
    475       CHECK_EQ(dnnDelete_F32(prim_pooling_bwd), E_SUCCESS);
    476       if (outbackprop_in_mkl_format == false) {
    477         CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_user), E_SUCCESS);
    478       }
    479       CHECK_EQ(dnnLayoutDelete_F32(lt_outbackprop_prim), E_SUCCESS);
    480       if (input_in_mkl_format == false) {
    481         CHECK_EQ(dnnLayoutDelete_F32(lt_input_user), E_SUCCESS);
    482       }
    483       if (workspace_enabled == false) {
    484         CHECK_EQ(dnnLayoutDelete_F32(lt_input_prim), E_SUCCESS);
    485       }
    486     }
    487   } MklMaxPoolingGradOpContext;
    488 
    489   std::vector<int32> ksize_;
    490   std::vector<int32> stride_;
    491   Padding padding_;
    492   TensorFormat data_format_;
    493 
    494   bool workspace_enabled_;
    495 };  // MklMaxPoolingGradOp
    496 
    497 #else
    498 
    499 // An implementation of MaxPooling (forward).
    500 template <typename Device, typename T>
    501 class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
    502  public:
    503   explicit MklMaxPoolingOp(OpKernelConstruction* context)
    504       : MklPoolingForwardOpBase<T>(context) {
    505     // In Max Pooling, MKLDNN does not allow passing workspace as NULL.
    506     // So we set workspace_enabled_ to true.
    507     this->workspace_enabled_ = true;
    508   }
    509 
    510   void Compute(OpKernelContext* context) override {
    511     try {
    512       auto cpu_engine = engine(engine::cpu, 0);
    513       const Tensor& input_tensor =
    514           MklGetInput(context, this->kInputTensorIndexInput);
    515       MklDnnShape dnn_shape_input;
    516       GetMklShape(context, this->kInputTensorIndexInput, &dnn_shape_input);
    517       this->SanityCheckInput(context, input_tensor, dnn_shape_input);
    518       if (!context->status().ok()) return;
    519 
    520       MklDnnData<T> dnn_data_input(&cpu_engine);
    521       MklDnnData<T> dnn_data_output(&cpu_engine);
    522       MklDnnData<uint8> dnn_data_wksp(&cpu_engine);
    523 
    524       // initialize variables for the pooling op
    525       MklPoolParameters pool_params;
    526       // Get the input tensor and initialize the pooling parameters
    527       this->ConfigureInput(context, dnn_shape_input, input_tensor, &pool_params,
    528                            &dnn_data_input);
    529       OP_REQUIRES_OK(context, context->status());
    530 
    531       // Declare output tensor
    532       Tensor* output_tensor = nullptr;
    533       memory::dims output_dims_mkl_order;
    534       this->GetOutputDims(pool_params, &output_dims_mkl_order);
    535 
    536       // If input is in Mkl layout, then just get the memory format from it
    537       // directly, instead of using input data_format to MaxPool.
    538       if (dnn_shape_input.IsMklTensor()) {
    539         dnn_data_output.SetUsrMem(
    540             output_dims_mkl_order,
    541             static_cast<memory::format>(
    542                 dnn_data_input.GetUsrMemDesc().data.format));
    543       } else {
    544         dnn_data_output.SetUsrMem(output_dims_mkl_order,
    545                                   this->data_format_mkldnn_);
    546       }
    547 
    548       // describe the memory layout; let mkl-dnn choose the best for the op
    549       dnn_data_output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
    550 
    551       auto pool_desc = pooling_forward::desc(
    552           prop_kind::forward, algorithm::pooling_max,
    553           dnn_data_input.GetUsrMemDesc(), dnn_data_output.GetUsrMemDesc(),
    554           memory::dims({pool_params.row_stride, pool_params.col_stride}),
    555           memory::dims({pool_params.window_rows, pool_params.window_cols}),
    556           memory::dims({static_cast<int>(pool_params.pad_top),
    557                         static_cast<int>(pool_params.pad_left)}),
    558           memory::dims({static_cast<int>(pool_params.pad_bottom),
    559                         static_cast<int>(pool_params.pad_right)}),
    560           TFPaddingToMklDnnPadding(this->padding_));
    561       auto pool_fwd_desc =
    562           pooling_forward::primitive_desc(pool_desc, cpu_engine);
    563 
    564       this->AllocateOutputTensor(context, pool_fwd_desc, output_dims_mkl_order,
    565                                  this->data_format_mkldnn_, &output_tensor);
    566       OP_REQUIRES_OK(context, context->status());
    567       dnn_data_output.SetUsrMemDataHandle(output_tensor);
    568 
    569       AllocateWorkspaceTensor(context, pool_fwd_desc, &dnn_data_wksp);
    570       OP_REQUIRES_OK(context, context->status());
    571 
    572       this->PrepareAndExecuteNet(pool_fwd_desc, &dnn_data_input,
    573                                  &dnn_data_output, &dnn_data_wksp);
    574     } catch (mkldnn::error& e) {
    575       string error_msg = "Status: " + std::to_string(e.status) +
    576                          ", message: " + string(e.message) + ", in file " +
    577                          string(__FILE__) + ":" + std::to_string(__LINE__);
    578       OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
    579                                               error_msg));
    580     }
    581   }  // Compute
    582 
    583  private:
    584   const int kOutputTensorIndexWorkspace = 1;
    585 
    586   void AllocateWorkspaceTensor(
    587       OpKernelContext* context,
    588       const pooling_forward::primitive_desc& pool_fwd_prim_desc,
    589       MklDnnData<uint8>* dnn_data_wksp) {
    590     CHECK_NOTNULL(dnn_data_wksp);
    591     Tensor* workspace_tensor = nullptr;
    592     memory::primitive_desc workspace_pd =
    593         pool_fwd_prim_desc.workspace_primitive_desc();
    594     size_t workspace_bytes = workspace_pd.get_size();
    595     MklDnnShape workspace_mkl_shape;
    596     workspace_mkl_shape.SetMklTensor(false);
    597     TensorShape workspace_tf_shape;
    598     workspace_tf_shape.AddDim(workspace_bytes);
    599     AllocateOutputSetMklShape(context, kOutputTensorIndexWorkspace,
    600                               &workspace_tensor, workspace_tf_shape,
    601                               workspace_mkl_shape);
    602     CHECK_NOTNULL(workspace_tensor);
    603     dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
    604   }
    605 };
    606 
    607 // The operation to compute MaxPool gradients.
    608 // It takes three inputs:
    609 //   - The original input tensor
    610 //   - The original output tensor
    611 //   - Backprop tensor for output
    612 // It produces one output: backprop tensor for input.
    613 template <class Device, class T>
    614 class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
    615  public:
    616   explicit MklMaxPoolingGradOp(OpKernelConstruction* context)
    617       : MklPoolingBackwardOpBase<T>(context) {}
    618 
    619   void Compute(OpKernelContext* context) override {
    620     try {
    621       auto cpu_engine = engine(engine::cpu, 0);
    622       const Tensor& orig_input_tensor =
    623           MklGetInput(context, kInputTensorIndexOrigInput);
    624       const Tensor& orig_output_tensor =
    625           MklGetInput(context, kInputTensorIndexOrigOutput);
    626       const Tensor& grad_tensor =
    627           MklGetInput(context, kInputTensorIndexGradient);
    628       const Tensor& workspace_tensor =
    629           MklGetInput(context, kInputTensorIndexWorkspace);
    630       MklDnnShape orig_input_mkl_shape, orig_output_mkl_shape, grad_mkl_shape,
    631           workspace_mkl_shape;
    632       GetMklShape(context, kInputTensorIndexOrigInput, &orig_input_mkl_shape);
    633       GetMklShape(context, kInputTensorIndexOrigOutput, &orig_output_mkl_shape);
    634       GetMklShape(context, kInputTensorIndexGradient, &grad_mkl_shape);
    635       GetMklShape(context, kInputTensorIndexWorkspace, &workspace_mkl_shape);
    636 
    637       SanityCheckInputs(context, orig_input_tensor, orig_output_tensor,
    638                         grad_tensor, workspace_tensor, orig_input_mkl_shape,
    639                         orig_output_mkl_shape, grad_mkl_shape,
    640                         workspace_mkl_shape);
    641       if (!context->status().ok()) return;
    642 
    643       MklDnnData<T> grad_dnn_data(&cpu_engine);
    644       MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
    645       MklDnnData<T> output_dnn_data(&cpu_engine);
    646       Tensor* output_tensor = nullptr;
    647       MklPoolParameters pool_params;
    648       TensorShape orig_input_shape;
    649       memory::dims output_dims_mkl_order, orig_input_dims_mkl_order;
    650       memory::desc original_input_md = ConfigureOriginalInput(
    651           context, orig_input_tensor, orig_input_mkl_shape,
    652           &orig_input_dims_mkl_order, &pool_params, &orig_input_shape);
    653 
    654       memory::desc original_output_md = this->ConfigureOriginalOutput(
    655           pool_params, orig_output_mkl_shape, output_dims_mkl_order);
    656 
    657       memory::desc target_diff_dst_md = this->ConfigureInputGradient(
    658           grad_mkl_shape, grad_tensor, &grad_dnn_data, original_output_md);
    659 
    660       output_dnn_data.SetUsrMem(original_input_md);
    661 
    662       // Create the forward pooling primitive descriptor so we can
    663       // pass it as a hint to the backward pooling primitive descriptor
    664       auto pool_fwd_desc = pooling_forward::desc(
    665           prop_kind::forward, algorithm::pooling_max, original_input_md,
    666           original_output_md,
    667           memory::dims({pool_params.row_stride, pool_params.col_stride}),
    668           memory::dims({pool_params.window_rows, pool_params.window_cols}),
    669           memory::dims({static_cast<int>(pool_params.pad_top),
    670                         static_cast<int>(pool_params.pad_left)}),
    671           memory::dims({static_cast<int>(pool_params.pad_bottom),
    672                         static_cast<int>(pool_params.pad_right)}),
    673           TFPaddingToMklDnnPadding(this->padding_));
    674       auto pool_fwd_prim_desc =
    675           pooling_forward::primitive_desc(pool_fwd_desc, cpu_engine);
    676 
    677       auto pool_bkwd_desc = pooling_backward::desc(
    678           algorithm::pooling_max, output_dnn_data.GetUsrMemDesc(),
    679           target_diff_dst_md,
    680           memory::dims({pool_params.row_stride, pool_params.col_stride}),
    681           memory::dims({pool_params.window_rows, pool_params.window_cols}),
    682           memory::dims({static_cast<int>(pool_params.pad_top),
    683                         static_cast<int>(pool_params.pad_left)}),
    684           memory::dims({static_cast<int>(pool_params.pad_bottom),
    685                         static_cast<int>(pool_params.pad_right)}),
    686           TFPaddingToMklDnnPadding(this->padding_));
    687       auto pool_bkwd_prim_desc = pooling_backward::primitive_desc(
    688           pool_bkwd_desc, cpu_engine, pool_fwd_prim_desc);
    689 
    690       this->AllocateOutputTensor(context, pool_bkwd_prim_desc,
    691                                  orig_input_dims_mkl_order,
    692                                  this->data_format_mkldnn_, &output_tensor);
    693       output_dnn_data.SetUsrMemDataHandle(output_tensor);
    694 
    695       ConfigureWorkspace(workspace_tensor,
    696                          pool_fwd_prim_desc.workspace_primitive_desc(),
    697                          &workspace_dnn_data);
    698       this->PrepareAndExecuteNet(
    699           pool_bkwd_prim_desc, &grad_dnn_data, &output_dnn_data,
    700           memory::primitive_desc(target_diff_dst_md, cpu_engine),
    701           &workspace_dnn_data);
    702     } catch (mkldnn::error& e) {
    703       string error_msg = "Status: " + std::to_string(e.status) +
    704                          ", message: " + string(e.message) + ", in file " +
    705                          string(__FILE__) + ":" + std::to_string(__LINE__);
    706       OP_REQUIRES_OK(context, errors::Aborted("Compute received an exception:",
    707                                               error_msg));
    708     }
    709   }  // Compute
    710 
    711  private:
    712   // .Input("orig_input: T")
    713   // .Input("orig_output: T")
    714   // .Input("grad: T")
    715   // .Input("workspace: T")
    716   const int kInputTensorIndexOrigInput = 0;
    717   const int kInputTensorIndexOrigOutput = 1;
    718   const int kInputTensorIndexGradient = 2;
    719   const int kInputTensorIndexWorkspace = 3;
    720   //  Output("output: T") in Base Class
    721 
    722   memory::desc ConfigureOriginalInput(
    723       OpKernelContext* context, const Tensor& tensor_original_input,
    724       const MklDnnShape& original_input_mkl_shape,
    725       memory::dims* original_input_dims_mkl_order,
    726       MklPoolParameters* pool_params, TensorShape* input_tensor_shape) {
    727     *input_tensor_shape = tensor_original_input.shape();
    728     return MklPoolingBackwardOpBase<T>::ConfigureOriginalInput(
    729         context, tensor_original_input, original_input_mkl_shape,
    730         original_input_dims_mkl_order, pool_params, *input_tensor_shape);
    731   }
    732 
    733   void ConfigureWorkspace(const Tensor& workspace_tensor,
    734                           memory::primitive_desc workspace_pd,
    735                           MklDnnData<uint8>* workspace_dnn_data) {
    736     CHECK_NOTNULL(workspace_dnn_data);
    737 
    738     workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
    739   }
    740 
    741   void SanityCheckInputs(OpKernelContext* context,
    742                          const Tensor& orig_input_tensor,
    743                          const Tensor& orig_output_tensor,
    744                          const Tensor& grad_tensor,
    745                          const Tensor& workspace_tensor,
    746                          const MklDnnShape& orig_input_mkl_shape,
    747                          const MklDnnShape& orig_output_mkl_shape,
    748                          const MklDnnShape& grad_mkl_shape,
    749                          const MklDnnShape& workspace_mkl_shape) {
    750     if (!orig_input_mkl_shape.IsMklTensor()) {
    751       OP_REQUIRES(context, orig_input_tensor.dims() == 4,
    752                   errors::InvalidArgument("Original input shape must be "
    753                                           "4-dimensional"));
    754     } else {
    755       OP_REQUIRES(context, orig_input_mkl_shape.GetDimension() == 4,
    756                   errors::InvalidArgument("Original input shape must be "
    757                                           "4-dimensional"));
    758     }
    759     if (!orig_output_mkl_shape.IsMklTensor()) {
    760       OP_REQUIRES(context, orig_output_tensor.dims() == 4,
    761                   errors::InvalidArgument("Original output must be "
    762                                           "4-dimensional"));
    763     } else {
    764       OP_REQUIRES(context, orig_output_mkl_shape.GetDimension() == 4,
    765                   errors::InvalidArgument("Original output must be "
    766                                           "4-dimensional"));
    767     }
    768     if (!grad_mkl_shape.IsMklTensor()) {
    769       OP_REQUIRES(context, grad_tensor.dims() == 4,
    770                   errors::InvalidArgument("Gradient must be 4-dimensional"));
    771     } else {
    772       OP_REQUIRES(context, grad_mkl_shape.GetDimension() == 4,
    773                   errors::InvalidArgument("Gradient must be "
    774                                           "4-dimensional"));
    775     }
    776     if (this->workspace_enabled_) {
    777       // The workspace should not be an MKL tensor
    778       OP_REQUIRES(context, workspace_mkl_shape.IsMklTensor() == false,
    779                   errors::InvalidArgument("Workspace tensor should not"
    780                                           " be an MKL Tensor."));
    781       // It should only have one dimension
    782       OP_REQUIRES(context, workspace_tensor.dims() == 1,
    783                   errors::InvalidArgument("Workspace tensor must be "
    784                                           "1-dimensional"));
    785     } else {
    786       OP_REQUIRES(
    787           context, this->workspace_enabled_,
    788           errors::Unimplemented("MKL-DNN Max Pooling does not "
    789                                 "yet support the use case "
    790                                 "where MaxPoolGrad is called without first"
    791                                 " calling MaxPool."));
    792     }
    793   }
    794 };  // MklMaxPoolingGradOp
    795 
    796 #endif  // INTEL_MKL_ML
    797 
    798 REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
    799                             .Device(DEVICE_CPU)
    800                             .TypeConstraint<float>("T")
    801                             .Label(mkl_op_registry::kMklOpLabel),
    802                         MklMaxPoolingOp<CPUDevice, float>);
    803 
    804 REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad")
    805                             .Device(DEVICE_CPU)
    806                             .TypeConstraint<float>("T")
    807                             .Label(mkl_op_registry::kMklOpLabel),
    808                         MklMaxPoolingGradOp<CPUDevice, float>);
    809 
    810 }  // namespace tensorflow
    811 #endif  // INTEL_MKL
    812