Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // LRN = Local Response Normalization
     17 // See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
     18 // layout and primitives, use MKL dnn primitives to compute local
     19 // response normalization
     20 
     21 #ifdef INTEL_MKL
     22 
     23 #define EIGEN_USE_THREADS
     24 #include <vector>
     25 #include "mkl_dnn.h"
     26 #include "mkl_dnn_types.h"
     27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/kernels/bounds_check.h"
     32 #include "tensorflow/core/kernels/ops_util.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/util/mkl_util.h"
     35 #include "tensorflow/core/util/tensor_format.h"
     36 
     37 #if !defined(IS_MOBILE_PLATFORM)
     38 #include "tensorflow/core/util/work_sharder.h"
     39 #endif
     40 
     41 #ifndef INTEL_MKL_ML
     42 #include "mkldnn.hpp"
     43 using mkldnn::lrn_across_channels;
     44 using mkldnn::lrn_backward;
     45 using mkldnn::lrn_forward;
     46 using mkldnn::prop_kind;
     47 using mkldnn::stream;
     48 #endif
     49 
     50 namespace tensorflow {
     51 
     52 namespace {
     53 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
     54 // depth_radius + 1) around the diagonal.
     55 template <typename T>
     56 void GetBandMatrix(int depth, int depth_radius,
     57                    Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
     58   result->setZero();
     59   for (int row = 0; row < depth; ++row) {
     60     const int begin = std::max<int>(0, row - depth_radius);
     61     const int end = std::min<int>(depth, row + depth_radius + 1);
     62     Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
     63     Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
     64     result->slice(start, sizes).setConstant(T(1));
     65   }
     66 }
     67 
     68 }  // namespace
     69 
     70 #ifdef INTEL_MKL_ML
     71 
     72 template <typename T>
     73 class MklLRNOp : public OpKernel {
     74  public:
     75   ~MklLRNOp() {}
     76 
     77   explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
     78     int64 depth_radius64;
     79     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
     80     OP_REQUIRES(
     81         context,
     82         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
     83         errors::InvalidArgument("depth_radius = ", depth_radius64,
     84                                 " larger than int max"));
     85     depth_radius_ = static_cast<size_t>(depth_radius64);
     86 
     87     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
     88     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
     89     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
     90     workspace_enabled_ = false;
     91     context->GetAttr("workspace_enabled", &workspace_enabled_);
     92   }
     93 
     94   void Compute(OpKernelContext* context) override {
     95     MklLRNOpContext mkl_context;
     96 
     97     const Tensor& input = MklGetInput(context, 0);
     98     GetMklShape(context, 0, &mkl_context.input_shape);
     99     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
    100 
    101     // Sanity checks
    102     mkl_context.in_dims = input_in_mkl_format
    103                               ? mkl_context.input_shape.GetDimension()
    104                               : input.dims();
    105     OP_REQUIRES(context, mkl_context.in_dims == 4,
    106                 errors::InvalidArgument("input must be 4-dimensional"));
    107     OP_REQUIRES(
    108         context,
    109         FastBoundsCheck(input.NumElements(), std::numeric_limits<int>::max()),
    110         errors::InvalidArgument("argument to LRN too large"));
    111 
    112     if (!input_in_mkl_format) {
    113       mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
    114                                     beta_, input);
    115       return;
    116     }
    117 
    118     if (input_in_mkl_format) {
    119       // MKL supports normalization over channel dimension only
    120       if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
    121           MklDims::C) {
    122         mkl_context.lt_input =
    123             static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
    124         workspace_enabled_ = true;
    125       } else {
    126         Tensor converted_tensor =
    127             ConvertMklToTF<T>(context, input, mkl_context.input_shape);
    128         mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
    129                                       beta_, converted_tensor);
    130         return;
    131       }
    132     }
    133 
    134     int kernel_size = 2 * depth_radius_ + 1;
    135 
    136     CHECK_EQ(dnnLRNCreateForward_F32(
    137                  &mkl_context.lrn_fwd, NULL, mkl_context.lt_input, kernel_size,
    138                  static_cast<float>(alpha_ * kernel_size), beta_, bias_),
    139              E_SUCCESS);
    140 
    141     // Allocate output tensor and shape
    142     Tensor* output = nullptr;
    143     Tensor* workspace = nullptr;
    144 
    145     // Convert Inputs if needed
    146     Tensor mkl_tmp_input_buf_tensor;
    147     mkl_context.MklPrepareLRNInputs(context, &mkl_tmp_input_buf_tensor);
    148 
    149     // Allocate Layer Outputs
    150     mkl_context.MklAllocateOutputs(context, &output, &workspace,
    151                                    workspace_enabled_);
    152 
    153     Tensor mkl_tmp_workspace_buf_tensor;
    154     mkl_context.MklPrepareLRNOutputs(context, output, workspace,
    155                                      &mkl_tmp_workspace_buf_tensor,
    156                                      workspace_enabled_);
    157 
    158     // Execute LRN.
    159     CHECK_EQ(dnnExecute_F32(mkl_context.lrn_fwd, mkl_context.lrn_res),
    160              E_SUCCESS);
    161 
    162     // Release MKL resources.
    163     mkl_context.MklCleanup();
    164   }
    165 
    166  private:
    167   typedef struct {
    168     size_t in_dims;
    169     size_t in_sizes[4];
    170     size_t in_strides[4];
    171     size_t out_sizes[4];
    172     size_t out_strides[4];
    173     MklShape input_shape;
    174     dnnPrimitive_t lrn_fwd = nullptr;
    175     dnnPrimitive_t convert_input = nullptr;
    176     dnnLayout_t lt_input = nullptr;
    177     dnnLayout_t lt_internal_input = nullptr;
    178     dnnLayout_t lt_internal_workspace = nullptr;
    179     dnnLayout_t lt_internal_output = nullptr;
    180     void* lrn_res[dnnResourceNumber];
    181 
    182     // Convert Inputs if needed
    183     void MklPrepareLRNInputs(OpKernelContext* context,
    184                              Tensor* mkl_tmp_input_buf_tensor) {
    185       const Tensor& input = MklGetInput(context, 0);
    186       void* mkl_buf_input =
    187           const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
    188 
    189       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_input, lrn_fwd,
    190                                                 dnnResourceSrc),
    191                E_SUCCESS);
    192 
    193       void* mkl_buf_convert_input = nullptr;
    194       bool mkl_convert_input = false;
    195       mkl_convert_input = !dnnLayoutCompare_F32(lt_internal_input, lt_input);
    196 
    197       if (mkl_convert_input) {
    198         CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input,
    199                                          lt_internal_input),
    200                  E_SUCCESS);
    201         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_internal_input,
    202                        &mkl_buf_convert_input);
    203         CHECK_EQ(dnnConversionExecute_F32(convert_input, mkl_buf_input,
    204                                           mkl_buf_convert_input),
    205                  E_SUCCESS);
    206         dnnDelete_F32(convert_input);
    207       }
    208 
    209       lrn_res[dnnResourceSrc] =
    210           (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
    211     }
    212 
    213     // Allocate Layer Outputs
    214     void MklAllocateOutputs(OpKernelContext* context, Tensor** output,
    215                             Tensor** workspace, bool workspace_enabled_) {
    216       TensorShape mkl_output_tf_shape; /* First tensor */
    217       MklShape mkl_output_mkl_shape;   /* Second tensor */
    218 
    219       mkl_output_mkl_shape.SetMklTensor(true);
    220       mkl_output_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceDst);
    221       mkl_output_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
    222                                        input_shape.GetStrides());
    223       mkl_output_mkl_shape.SetTfDimOrder(in_dims,
    224                                          input_shape.GetTfToMklDimMap());
    225       mkl_output_tf_shape.AddDim(
    226           dnnLayoutGetMemorySize_F32(
    227               static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
    228           sizeof(T));
    229       AllocateOutputSetMklShape(context, 0, output,
    230                                 mkl_output_tf_shape /* First tensor */,
    231                                 mkl_output_mkl_shape /* Second Tensor */);
    232 
    233       if (workspace_enabled_) {
    234         TensorShape mkl_workspace_tf_shape; /* First tensor */
    235         MklShape mkl_workspace_mkl_shape;   /* Second tensor */
    236         mkl_workspace_mkl_shape.SetMklTensor(false);
    237         mkl_workspace_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceWorkspace);
    238         // Assumes workspace has same TF layout and TF dim order as input
    239         mkl_workspace_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
    240                                             input_shape.GetStrides());
    241         mkl_workspace_mkl_shape.SetTfDimOrder(in_dims,
    242                                               input_shape.GetTfToMklDimMap());
    243         mkl_workspace_tf_shape.AddDim(
    244             dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    245                 mkl_workspace_mkl_shape.GetMklLayout())) /
    246             sizeof(T));
    247         AllocateOutputSetMklShape(context, 1, workspace,
    248                                   mkl_workspace_tf_shape /* First tensor */,
    249                                   mkl_workspace_mkl_shape /* Second Tensor */);
    250       }
    251     }
    252 
    253     void MklPrepareLRNOutputs(OpKernelContext* context, Tensor* output,
    254                               Tensor* workspace,
    255                               Tensor* mkl_tmp_workspace_buf_tensor,
    256                               bool workspace_enabled_) {
    257       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_workspace, lrn_fwd,
    258                                                 dnnResourceWorkspace),
    259                E_SUCCESS);
    260 
    261       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_output, lrn_fwd,
    262                                                 dnnResourceDst),
    263                E_SUCCESS);
    264 
    265       void* mkl_buf_output =
    266           const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
    267       lrn_res[dnnResourceDst] = mkl_buf_output;
    268 
    269       void* mkl_buf_workspace = nullptr;
    270       if (workspace_enabled_) {
    271         mkl_buf_workspace = const_cast<void*>(
    272             static_cast<const void*>(workspace->flat<T>().data()));
    273       } else {
    274         AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor,
    275                        lt_internal_workspace, &mkl_buf_workspace);
    276       }
    277       lrn_res[dnnResourceWorkspace] = mkl_buf_workspace;
    278     }
    279 
    280     // Fallback implementation - Taken from lrn_op.cc
    281     // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
    282     // copy.
    283     void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
    284                            float bias_, float alpha_, float beta_,
    285                            const Tensor& input) {
    286       const int batch = static_cast<int>(input.dim_size(0));
    287       const int rows = static_cast<int>(input.dim_size(1));
    288       const int cols = static_cast<int>(input.dim_size(2));
    289       const int depth = static_cast<int>(input.dim_size(3));
    290       const int nodes = cols * rows;
    291 
    292       auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
    293       // Multiplying the input with the band matrix has the effect of reducing
    294       // the
    295       // correct patch along the depth.
    296       Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
    297       GetBandMatrix<T>(depth, depth_radius_, &multiplier);
    298 
    299       Tensor *output, *workspace;
    300       MklShape mkl_output_mkl_shape, mkl_workspace_mkl_shape;
    301       mkl_output_mkl_shape.SetMklTensor(false);
    302       mkl_output_mkl_shape.SetDimensions(4);
    303       AllocateOutputSetMklShape(context, 0, &output, input.shape(),
    304                                 mkl_output_mkl_shape);
    305 
    306       mkl_workspace_mkl_shape.SetMklTensor(false);
    307       mkl_workspace_mkl_shape.SetDimensions(4);
    308       AllocateOutputSetMklShape(context, 1, &workspace, input.shape(),
    309                                 mkl_workspace_mkl_shape);
    310 
    311       auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
    312       Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
    313       auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
    314       if (beta_ == T(1)) {
    315         out_shaped.device(context->eigen_cpu_device()) =
    316             in_shaped * tmp.inverse();
    317       } else if (beta_ == T(0.5)) {
    318         out_shaped.device(context->eigen_cpu_device()) =
    319             in_shaped * tmp.rsqrt();
    320       } else {
    321         out_shaped.device(context->eigen_cpu_device()) =
    322             in_shaped * (tmp.log() * -beta_).exp();
    323       }
    324     }
    325 
    326     // Release MKL resources.
    327     void MklCleanup() {
    328       dnnDelete_F32(lrn_fwd);
    329       dnnLayoutDelete_F32(lt_internal_input);
    330       dnnLayoutDelete_F32(lt_internal_workspace);
    331       dnnLayoutDelete_F32(lt_internal_output);
    332     }
    333   } MklLRNOpContext;
    334 
    335   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
    336 
    337   bool workspace_enabled_;
    338   int depth_radius_;
    339   float bias_;
    340   float alpha_;
    341   float beta_;
    342 };
    343 
    344 template <typename T>
    345 class MklLRNGradOp : public OpKernel {
    346  public:
    347   explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
    348     int64 depth_radius64;
    349     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
    350     OP_REQUIRES(
    351         context,
    352         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
    353         errors::InvalidArgument("depth_radius = ", depth_radius64,
    354                                 " larger than int max"));
    355     depth_radius_ = static_cast<int>(depth_radius64);
    356     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
    357     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
    358     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
    359     workspace_enabled_ = false;
    360     context->GetAttr("workspace_enabled", &workspace_enabled_);
    361   }
    362 
    363   void Compute(OpKernelContext* context) override {
    364     MklLRNGradOpContext mkl_context;
    365     mkl_context.depth_radius_ = depth_radius_;
    366     mkl_context.bias_ = bias_;
    367     mkl_context.alpha_ = alpha_;
    368     mkl_context.beta_ = beta_;
    369 
    370     const Tensor& in_grads = MklGetInput(context, 0);
    371     const Tensor& in_image = MklGetInput(context, 1);
    372     const Tensor& out_image = MklGetInput(context, 2);
    373 
    374     GetMklShape(context, 0, &mkl_context.ingrad_shape);
    375     GetMklShape(context, 1, &mkl_context.inimage_shape);
    376     GetMklShape(context, 2, &mkl_context.outimage_shape);
    377 
    378     bool ingrad_in_mkl_format = mkl_context.ingrad_shape.IsMklTensor();
    379     bool inimage_in_mkl_format = mkl_context.inimage_shape.IsMklTensor();
    380     bool outimage_in_mkl_format = mkl_context.outimage_shape.IsMklTensor();
    381 
    382     mkl_context.in_dims = inimage_in_mkl_format
    383                               ? mkl_context.inimage_shape.GetDimension()
    384                               : in_image.dims();
    385     OP_REQUIRES(context, mkl_context.in_dims == 4,
    386                 errors::InvalidArgument("input images must be 4-dimensional"));
    387 
    388     if (!workspace_enabled_) {
    389       mkl_context.MklDefaultToEigen(context);
    390       return;
    391     }
    392 
    393     if (ingrad_in_mkl_format || inimage_in_mkl_format) {
    394       const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
    395                                           ? &mkl_context.ingrad_shape
    396                                           : &mkl_context.inimage_shape;
    397       if (tmp_mkl_shape->tf_dim_idx(mkl_context.in_dims - 1) != MklDims::C) {
    398         // Fallback to eigen
    399         mkl_context.MklDefaultToEigen(context);
    400         return;
    401       } else {  // MKL supports normalization over channel dimension only
    402         for (int i = 0; i < mkl_context.in_dims; i++) {
    403           mkl_context.in_sizes[i] = mkl_context.out_sizes[i] =
    404               tmp_mkl_shape->GetSizes()[i];
    405           mkl_context.in_strides[i] = mkl_context.out_strides[i] =
    406               tmp_mkl_shape->GetStrides()[i];
    407         }
    408       }
    409     } else {
    410       // Fallback to eigen
    411       mkl_context.MklDefaultToEigen(context);
    412       return;
    413     }
    414 
    415     // Dimensions check for sanity purpose
    416     if (ingrad_in_mkl_format) {
    417       OP_REQUIRES(
    418           context, mkl_context.ingrad_shape.GetDimension() == 4,
    419           errors::InvalidArgument("input gradient must be 4-dimensional"));
    420     } else {
    421       OP_REQUIRES(
    422           context, in_grads.dims() == 4,
    423           errors::InvalidArgument("input gradient must be 4-dimensional"));
    424     }
    425 
    426     if (outimage_in_mkl_format) {
    427       OP_REQUIRES(
    428           context, mkl_context.outimage_shape.GetDimension() == 4,
    429           errors::InvalidArgument("Output image must be 4-dimensional"));
    430     } else {
    431       OP_REQUIRES(
    432           context, out_image.dims() == 4,
    433           errors::InvalidArgument("Output image must be 4-dimensional"));
    434     }
    435 
    436     // Prepare mkl input layout
    437     mkl_context.MklPrepareLRNInputsLayouts(context);
    438     int ksize = 2 * depth_radius_ + 1;
    439 
    440     CHECK_EQ(dnnLRNCreateBackward_F32(
    441                  &mkl_context.lrn_bwd, NULL, mkl_context.lt_input,
    442                  mkl_context.lt_output, ksize,
    443                  static_cast<float>(alpha_ * ksize), beta_, bias_),
    444              E_SUCCESS);
    445 
    446     // Allocate output tensor and shape.
    447     TensorShape mkl_output_tf_shape; /* First tensor */
    448     MklShape mkl_output_mkl_shape;   /* Second tensor */
    449     mkl_output_mkl_shape.SetMklTensor(true);
    450     CHECK_NE(mkl_context.lrn_bwd, nullptr);
    451     mkl_output_mkl_shape.SetMklLayout(mkl_context.lrn_bwd, dnnResourceDiffSrc);
    452     mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes,
    453                                      mkl_context.out_strides);
    454     if (ingrad_in_mkl_format) {
    455       mkl_output_mkl_shape.SetTfDimOrder(
    456           mkl_context.in_dims, mkl_context.ingrad_shape.GetTfToMklDimMap());
    457     } else {
    458       mkl_output_mkl_shape.SetTfDimOrder(
    459           mkl_context.in_dims, mkl_context.inimage_shape.GetTfToMklDimMap());
    460     }
    461     mkl_output_tf_shape.AddDim(
    462         dnnLayoutGetMemorySize_F32(
    463             static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
    464         sizeof(T));
    465     Tensor* output = nullptr;
    466     AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
    467                               mkl_output_mkl_shape);
    468 
    469     // Get pointers to output data.
    470     void* user_output =
    471         const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
    472 
    473     Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor,
    474         mkl_tmp_outimage_buf_tensor;
    475     // Convert Inputs if needed
    476     mkl_context.MklPrepareLRNGradInput(context, &mkl_tmp_input_buf_tensor,
    477                                        &mkl_tmp_image_buf_tensor,
    478                                        &mkl_tmp_outimage_buf_tensor);
    479 
    480     // We do not do any conversion for output. But we simply emit it
    481     // in MKL format.
    482     mkl_context.res_lrn_bwd[dnnResourceDiffSrc] = user_output;
    483     // Execute LRN backward using dnnExecute
    484     CHECK_EQ(dnnExecute_F32(mkl_context.lrn_bwd, mkl_context.res_lrn_bwd),
    485              E_SUCCESS);
    486     // Release MKL resources.
    487     mkl_context.Mklcleanup();
    488   }
    489 
    490  private:
    491   typedef struct {
    492     int depth_radius_;
    493     float bias_;
    494     float alpha_;
    495     float beta_;
    496     size_t in_dims;
    497     size_t in_sizes[4];
    498     size_t in_strides[4];
    499     size_t out_sizes[4];
    500     size_t out_strides[4];
    501     MklShape ingrad_shape, inimage_shape, outimage_shape;
    502     dnnPrimitive_t lrn_bwd = nullptr;
    503     dnnPrimitive_t convert_input = nullptr;
    504     dnnLayout_t lt_input = nullptr;
    505     dnnLayout_t lt_output = nullptr;
    506     dnnLayout_t lt_bdw_input = nullptr;
    507     dnnLayout_t lt_workspace = nullptr;
    508     dnnLayout_t lt_internal_input = nullptr;
    509     void* res_lrn_bwd[dnnResourceNumber];
    510 
    511     // prepare mkl input
    512     void MklPrepareLRNInputsLayouts(OpKernelContext* context) {
    513       bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
    514       bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
    515       if (!ingrad_in_mkl_format) {
    516         CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
    517                  E_SUCCESS);
    518       } else {
    519         lt_input = static_cast<dnnLayout_t>(ingrad_shape.GetCurLayout());
    520       }
    521 
    522       if (!inimage_in_mkl_format) {
    523         CHECK_EQ(
    524             dnnLayoutCreate_F32(&lt_output, in_dims, out_sizes, out_strides),
    525             E_SUCCESS);
    526       } else {
    527         lt_output = static_cast<dnnLayout_t>(inimage_shape.GetCurLayout());
    528       }
    529     }
    530 
    531     // convert input if needed
    532     void MklPrepareLRNGradInput(OpKernelContext* context,
    533                                 Tensor* mkl_tmp_input_buf_tensor,
    534                                 Tensor* mkl_tmp_image_buf_tensor,
    535                                 Tensor* mkl_tmp_outimage_buf_tensor) {
    536       const Tensor& in_grads = MklGetInput(context, 0);
    537       const Tensor& in_image = MklGetInput(context, 1);
    538       const Tensor& out_image = MklGetInput(context, 2);
    539       const Tensor& workspace = MklGetInput(
    540           context,
    541           3); /*Worskpsace is enabled, get the buffer to the workspace */
    542 
    543       void* user_input = const_cast<void*>(
    544           static_cast<const void*>(in_grads.flat<T>().data()));
    545       void* user_fwd_input = const_cast<void*>(
    546           static_cast<const void*>(in_image.flat<T>().data()));
    547       void* user_fwd_output = const_cast<void*>(
    548           static_cast<const void*>(out_image.flat<T>().data()));
    549       void* workspace_buffer = const_cast<void*>(
    550           static_cast<const void*>(workspace.flat<T>().data()));
    551 
    552       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_workspace, lrn_bwd,
    553                                                 dnnResourceWorkspace),
    554                E_SUCCESS);
    555       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_bdw_input, lrn_bwd,
    556                                                 dnnResourceDiffDst),
    557                E_SUCCESS);
    558       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&lt_internal_input, lrn_bwd,
    559                                                 dnnResourceSrc),
    560                E_SUCCESS);
    561 
    562       bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
    563       if (ingrad_in_mkl_format) {
    564         if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
    565           AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
    566                          &res_lrn_bwd[dnnResourceDiffDst]);
    567           ingrad_shape.GetConvertedFlatData(lt_bdw_input, user_input,
    568                                             res_lrn_bwd[dnnResourceDiffDst]);
    569         } else {
    570           res_lrn_bwd[dnnResourceDiffDst] = user_input;
    571         }
    572       } else {
    573         if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
    574           CHECK_EQ(
    575               dnnConversionCreate_F32(&convert_input, lt_input, lt_bdw_input),
    576               E_SUCCESS);
    577 
    578           AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
    579                          &res_lrn_bwd[dnnResourceDiffDst]);
    580           CHECK_EQ(dnnConversionExecute_F32(convert_input, user_input,
    581                                             res_lrn_bwd[dnnResourceDiffDst]),
    582                    E_SUCCESS);
    583           dnnDelete_F32(convert_input);
    584         } else {
    585           res_lrn_bwd[dnnResourceDiffDst] = user_input;
    586         }
    587       }
    588 
    589       bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
    590       if (inimage_in_mkl_format) {
    591         if (!dnnLayoutCompare_F32(
    592                 lt_internal_input,
    593                 static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()))) {
    594           AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
    595                          &res_lrn_bwd[dnnResourceSrc]);
    596           ingrad_shape.GetConvertedFlatData(lt_internal_input, user_fwd_input,
    597                                             res_lrn_bwd[dnnResourceSrc]);
    598         } else {
    599           res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
    600         }
    601       } else {
    602         if (!dnnLayoutCompare_F32(
    603                 lt_internal_input,
    604                 static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()))) {
    605           CHECK_EQ(dnnConversionCreate_F32(
    606                        &convert_input,
    607                        static_cast<dnnLayout_t>(inimage_shape.GetCurLayout()),
    608                        lt_internal_input),
    609                    E_SUCCESS);
    610 
    611           AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
    612                          &res_lrn_bwd[dnnResourceSrc]);
    613           CHECK_EQ(dnnConversionExecute_F32(convert_input, user_fwd_input,
    614                                             res_lrn_bwd[dnnResourceSrc]),
    615                    E_SUCCESS);
    616           dnnDelete_F32(convert_input);
    617         } else {
    618           res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
    619         }
    620       }
    621 
    622       res_lrn_bwd[dnnResourceWorkspace] = workspace_buffer;
    623     }
    624 
    625     // Fallback implementation - Taken from lrn_op.cc
    626     // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
    627     // copy.
    628     void MklDefaultToEigen(OpKernelContext* context) {
    629       Tensor in_grads;
    630       Tensor in_image;
    631       Tensor out_image;
    632 
    633       GetMklShape(context, 0, &ingrad_shape);
    634       GetMklShape(context, 1, &inimage_shape);
    635       GetMklShape(context, 2, &outimage_shape);
    636 
    637       if (ingrad_shape.IsMklTensor()) {
    638         in_grads =
    639             ConvertMklToTF<T>(context, MklGetInput(context, 0), ingrad_shape);
    640       } else {
    641         in_grads = MklGetInput(context, 0);
    642       }
    643 
    644       if (inimage_shape.IsMklTensor()) {
    645         in_image =
    646             ConvertMklToTF<T>(context, MklGetInput(context, 1), inimage_shape);
    647       } else {
    648         in_image = MklGetInput(context, 1);
    649       }
    650 
    651       if (outimage_shape.IsMklTensor()) {
    652         out_image =
    653             ConvertMklToTF<T>(context, MklGetInput(context, 2), outimage_shape);
    654       } else {
    655         out_image = MklGetInput(context, 2);
    656       }
    657 
    658       const int64 batch = static_cast<int64>(in_grads.dim_size(0));
    659       const int64 rows = static_cast<int64>(in_grads.dim_size(1));
    660       const int64 cols = static_cast<int64>(in_grads.dim_size(2));
    661       const int64 depth = static_cast<int64>(in_grads.dim_size(3));
    662       const auto nodes = cols * rows;
    663 
    664       auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
    665 
    666       auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
    667       auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
    668 
    669       Tensor* output;
    670       MklShape mkl_output_mkl_shape;
    671       mkl_output_mkl_shape.SetMklTensor(false);
    672       mkl_output_mkl_shape.SetDimensions(4);
    673       AllocateOutputSetMklShape(context, 0, &output, in_grads.shape(),
    674                                 mkl_output_mkl_shape);
    675 
    676       auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
    677       out_shaped.setZero();
    678       auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
    679                     depth](int64 begin, int64 end) {
    680         for (int64 i = begin; i < end; ++i) {
    681           for (int64 j = 0; j < depth; ++j) {
    682             int64 depth_begin = std::max<int64>(0, j - depth_radius_);
    683             int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
    684 
    685             T norm(0);
    686             for (int64 k = depth_begin; k < depth_end; ++k) {
    687               norm += in_shaped(i, k) * in_shaped(i, k);
    688             }
    689             norm = alpha_ * norm + bias_;
    690             DCHECK_GT(norm, T(1e-6));
    691             for (int64 k = depth_begin; k < depth_end; ++k) {
    692               T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
    693                       activations(i, j) / norm;
    694               if (k == j) {
    695                 dyi += Eigen::numext::pow(norm, -beta_);
    696               }
    697               dyi *= grads_shaped(i, j);
    698               const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) +=
    699                   dyi;
    700             }
    701           }
    702         }
    703       };
    704       auto worker_threads =
    705           *(context->device()->tensorflow_cpu_worker_threads());
    706       Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
    707             depth * depth, shard);
    708     }
    709 
    710     // release mkl resources
    711     void Mklcleanup() {
    712       bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
    713       bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
    714       if (!ingrad_in_mkl_format) {
    715         CHECK_EQ(dnnLayoutDelete_F32(lt_input), E_SUCCESS);
    716       }
    717 
    718       if (!inimage_in_mkl_format) {
    719         CHECK_EQ(dnnLayoutDelete_F32(lt_output), E_SUCCESS);
    720       }
    721       dnnDelete_F32(lrn_bwd);
    722       dnnLayoutDelete_F32(lt_bdw_input);
    723       dnnLayoutDelete_F32(lt_workspace);
    724     }
    725   } MklLRNGradOpContext;
    726 
    727   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
    728   bool workspace_enabled_;
    729   int depth_radius_;
    730   float bias_;
    731   float alpha_;
    732   float beta_;
    733 };
    734 
    735 #else
    736 
    737 template <typename T>
    738 class MklLRNOp : public OpKernel {
    739  public:
    740   ~MklLRNOp() {}
    741 
    742   explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
    743     int64 depth_radius64;
    744     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
    745     OP_REQUIRES(
    746         context,
    747         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
    748         errors::InvalidArgument("depth_radius = ", depth_radius64,
    749                                 " larger than int max"));
    750     depth_radius_ = static_cast<size_t>(depth_radius64);
    751 
    752     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
    753     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
    754     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
    755     workspace_enabled_ = false;
    756     context->GetAttr("workspace_enabled", &workspace_enabled_);
    757   }
    758 
    759   void Compute(OpKernelContext* context) override {
    760     try {
    761       SanityCheckInputs(context);
    762       if (!context->status().ok()) return;
    763 
    764       auto cpu_engine = engine(engine::cpu, 0);
    765       const Tensor& src_tensor = MklGetInput(context, kIdxInput);
    766       MklDnnShape src_dnn_shape;
    767       GetMklShape(context, kIdxInput, &src_dnn_shape);
    768 
    769       // MKL-DNN has a notion of kernel_size and not depth_radius.
    770       int kernel_size = 2 * depth_radius_ + 1;
    771       float new_alpha = alpha_ * kernel_size;
    772 
    773       // if the input tensor is not an MKL Tensor, or if the last
    774       // dimension is not channel, then just use Eigen.
    775       // MKL only support normalization over the channel dimension.
    776       if (!src_dnn_shape.IsMklTensor()) {
    777         MklDefaultToEigen(context, src_tensor);
    778         return;
    779       } else if (!src_dnn_shape.IsMklChannelDim(src_dnn_shape.GetDimension() -
    780                                                 1)) {
    781         Tensor converted_tensor =
    782             ConvertMklToTF<T>(context, src_tensor, src_dnn_shape);
    783         MklDefaultToEigen(context, converted_tensor);
    784         return;
    785       }
    786       // At this point, we can assume that the src is an MklTensor
    787       // and we can enable the workspace
    788       workspace_enabled_ = true;
    789 
    790       MklDnnData<T> src_dnn_data(&cpu_engine);
    791       MklDnnData<T> dst_dnn_data(&cpu_engine);
    792       MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
    793 
    794       TensorShape tf_output_shape = src_tensor.shape();
    795 
    796       memory::desc src_md = src_dnn_shape.GetCurLayout();
    797       memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims();
    798 
    799       // Create memory for user input.
    800       // Since Tensorflow always performs normalization over last dimension,
    801       // and MKL-DNN performs normalization over Channel, we tell MKL-DNN
    802       // that input is in NHWC layout with Channel being the last dimension.
    803       src_dnn_data.SetUsrMem(src_md, &src_tensor);
    804       src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
    805 
    806       // output_dnn_data and workspace both have the same shape as input
    807       dst_dnn_data.SetUsrMem(src_md);
    808       dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
    809 
    810       // Create LRN primitive descriptor.
    811       // Tensorflow's normalization semantics is across channels.
    812       // MKL-DNN also supports normalization within channel.
    813       auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels,
    814                                         src_dnn_data.GetUsrMemDesc(),
    815                                         kernel_size, new_alpha, beta_, bias_);
    816       auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
    817 
    818       // Allocate output_dnn_data tensor.
    819       Tensor* output_tensor = nullptr;
    820       memory::format input_format = src_dnn_shape.GetTfDataFormat();
    821       AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format,
    822                            &output_tensor);
    823       OP_REQUIRES_OK(context, context->status());
    824       CHECK_NOTNULL(output_tensor);
    825       dst_dnn_data.SetUsrMemDataHandle(output_tensor);
    826 
    827       // Handle workspace required for MKL-DNN.
    828       AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
    829       OP_REQUIRES_OK(context, context->status());
    830 
    831       PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, &dst_dnn_data,
    832                            &workspace_dnn_data);
    833     } catch (mkldnn::error& e) {
    834       string error_msg = "Status: " + std::to_string(e.status) +
    835                          ", message: " + string(e.message) + ", in file " +
    836                          string(__FILE__) + ":" + std::to_string(__LINE__);
    837       OP_REQUIRES_OK(
    838           context,
    839           errors::Aborted("Operation received an exception:", error_msg));
    840     }
    841   }
    842 
    843  private:
    844   void PrepareAndExecuteNet(const lrn_forward::primitive_desc& lrn_fwd_desc,
    845                             MklDnnData<T>* src_dnn_data,
    846                             MklDnnData<T>* dst_dnn_data,
    847                             MklDnnData<uint8>* wksp_dnn_data = nullptr) {
    848     std::vector<primitive> net;
    849 
    850     // Check for input reorder
    851     src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net);
    852 
    853     // Create pooling primitive and add it to net
    854     if (wksp_dnn_data != nullptr) {
    855       net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
    856                                 wksp_dnn_data->GetOpMem(),
    857                                 dst_dnn_data->GetOpMem()));
    858     } else {
    859       net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
    860                                 dst_dnn_data->GetOpMem()));
    861     }
    862     stream(stream::kind::eager).submit(net).wait();
    863   }
    864 
    865   void AllocateOutputTensor(
    866       OpKernelContext* context,
    867       const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
    868       const memory::dims output_dims_mkl_order,
    869       const memory::format& output_tf_format, Tensor** output_tensor) {
    870     CHECK_NOTNULL(output_tensor);
    871     memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc();
    872 
    873     MklDnnShape output_mkl_shape;
    874     // We only handle the case when the inputs and output are in Mkl format
    875     // Any other case is handled by Eigen
    876     output_mkl_shape.SetMklTensor(true);
    877     output_mkl_shape.SetMklLayout(&dst_pd);
    878     output_mkl_shape.SetElemType(MklDnnType<T>());
    879     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    880                                  output_dims_mkl_order, output_tf_format);
    881     TensorShape output_tf_shape;
    882     // only allocate enough space for the elements we need.
    883     size_t num_bytes = dst_pd.get_size();
    884     CHECK_EQ(num_bytes % sizeof(T), 0);
    885     output_tf_shape.AddDim(num_bytes / sizeof(T));
    886     AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
    887                               output_tf_shape, output_mkl_shape);
    888   }
    889 
    890   // Fallback implementation - Taken from lrn_op.cc
    891   // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
    892   // copy.
    893   void MklDefaultToEigen(OpKernelContext* context, const Tensor& input) {
    894     const int batch = static_cast<int>(input.dim_size(0));
    895     const int rows = static_cast<int>(input.dim_size(1));
    896     const int cols = static_cast<int>(input.dim_size(2));
    897     const int depth = static_cast<int>(input.dim_size(3));
    898     const int nodes = cols * rows;
    899 
    900     auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
    901     // Multiplying the input with the band matrix has the effect of reducing
    902     // the
    903     // correct patch along the depth.
    904     Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
    905     GetBandMatrix<T>(depth, depth_radius_, &multiplier);
    906 
    907     Tensor* output_dnn_data = nullptr;
    908     MklDnnShape mkl_output_mkl_shape;
    909     mkl_output_mkl_shape.SetMklTensor(false);
    910     mkl_output_mkl_shape.SetDimensions(4);
    911     AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
    912                               input.shape(), mkl_output_mkl_shape);
    913     CHECK_NOTNULL(output_dnn_data);
    914 
    915     Tensor* workspace_tensor = nullptr;
    916     MklDnnShape workspace_mkl_shape;
    917     workspace_mkl_shape.SetMklTensor(false);
    918     TensorShape workspace_tf_shape;
    919     workspace_tf_shape.AddDim(0);
    920     AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
    921                               workspace_tf_shape, workspace_mkl_shape);
    922     CHECK_NOTNULL(workspace_tensor);
    923 
    924     auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
    925     Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
    926     auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
    927     if (beta_ == T(1)) {
    928       out_shaped.device(context->eigen_cpu_device()) =
    929           in_shaped * tmp.inverse();
    930     } else if (beta_ == T(0.5)) {
    931       out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
    932     } else {
    933       out_shaped.device(context->eigen_cpu_device()) =
    934           in_shaped * (tmp.log() * -beta_).exp();
    935     }
    936   }
    937 
    938   void AllocateWorkspaceTensor(
    939       OpKernelContext* context,
    940       const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
    941       MklDnnData<uint8>* dnn_data_wksp) {
    942     CHECK_NOTNULL(dnn_data_wksp);
    943     Tensor* workspace_tensor = nullptr;
    944     memory::primitive_desc workspace_pd =
    945         lrn_fwd_prim_desc.workspace_primitive_desc();
    946     size_t workspace_bytes = workspace_pd.get_size();
    947     MklDnnShape workspace_mkl_shape;
    948     // the workspace tensor is a uint8 tensor that has
    949     // exactly the number of bytes necessary
    950     workspace_mkl_shape.SetMklTensor(false);
    951     TensorShape workspace_tf_shape;
    952     workspace_tf_shape.AddDim(workspace_bytes);
    953     AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
    954                               workspace_tf_shape, workspace_mkl_shape);
    955     CHECK_NOTNULL(workspace_tensor);
    956     dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
    957   }
    958 
    959   void SanityCheckInputs(OpKernelContext* context) {
    960     const Tensor& src_tensor = MklGetInput(context, kIdxInput);
    961     MklDnnShape src_dnn_shape;
    962     GetMklShape(context, kIdxInput, &src_dnn_shape);
    963     if (src_dnn_shape.IsMklTensor()) {
    964       OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
    965                   errors::InvalidArgument("input must be 4-dimensional"));
    966       OP_REQUIRES(context,
    967                   FastBoundsCheck(src_tensor.NumElements(),
    968                                   std::numeric_limits<int>::max()),
    969                   errors::InvalidArgument("argument to LRN too large"));
    970     } else {
    971       OP_REQUIRES(context, src_tensor.dims() == 4,
    972                   errors::InvalidArgument("input must be 4-dimensional"));
    973       OP_REQUIRES(context,
    974                   FastBoundsCheck(src_tensor.NumElements(),
    975                                   std::numeric_limits<int>::max()),
    976                   errors::InvalidArgument("argument to LRN too large"));
    977     }
    978   }
    979   const int kIdxInput = 0, kIdxOutput = 0, kIdxWorkspace = 1;
    980 
    981   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
    982   bool workspace_enabled_;
    983   int depth_radius_;
    984   float bias_;
    985   float alpha_;
    986   float beta_;
    987 };
    988 
    989 template <typename T>
    990 class MklLRNGradOp : public OpKernel {
    991  public:
    992   explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
    993     int64 depth_radius64;
    994     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
    995     OP_REQUIRES(
    996         context,
    997         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
    998         errors::InvalidArgument("depth_radius = ", depth_radius64,
    999                                 " larger than int max"));
   1000     depth_radius_ = static_cast<int>(depth_radius64);
   1001     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
   1002     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
   1003     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
   1004     workspace_enabled_ = false;
   1005     context->GetAttr("workspace_enabled", &workspace_enabled_);
   1006   }
   1007 
   1008   void Compute(OpKernelContext* context) override {
   1009     try {
   1010       SanityCheckInputs(context);
   1011       if (!context->status().ok()) return;
   1012 
   1013       auto cpu_engine = engine(engine::cpu, 0);
   1014       MklDnnData<T> input_grad_dnn_data(&cpu_engine);
   1015       MklDnnData<T> orig_input_dnn_data(&cpu_engine);
   1016       MklDnnData<T> orig_output_dnn_data(&cpu_engine);
   1017       MklDnnData<T> output_dnn_data(&cpu_engine);
   1018 
   1019       MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
   1020           orig_output_dnn_shape;
   1021       GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
   1022       GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
   1023       GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
   1024 
   1025       // We only use MKLDNN if all of the necessary inputs are present
   1026       // in mkldnn format, and Channel is the last dimension
   1027       bool can_use_mkldnn = workspace_enabled_ &&
   1028                             input_grad_dnn_shape.IsMklTensor() &&
   1029                             orig_input_dnn_shape.IsMklTensor() &&
   1030                             orig_output_dnn_shape.IsMklTensor() &&
   1031                             input_grad_dnn_shape.IsMklChannelDim(
   1032                                 input_grad_dnn_shape.GetDimension() - 1) &&
   1033                             orig_input_dnn_shape.IsMklChannelDim(
   1034                                 orig_input_dnn_shape.GetDimension() - 1) &&
   1035                             orig_output_dnn_shape.IsMklChannelDim(
   1036                                 orig_output_dnn_shape.GetDimension() - 1);
   1037 
   1038       if (!can_use_mkldnn) {
   1039         // Fallback to eigen
   1040         MklDefaultToEigen(context);
   1041         return;
   1042       }
   1043       // At this point, we have the all clear to use MklDnn constructs
   1044       // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor.
   1045       const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient);
   1046       const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
   1047       const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
   1048 
   1049       // Get input sizes in MKL-DNN required NCHW format.
   1050       // LRN does not have data_format attribute. But by default it has
   1051       // NHWC format.
   1052       memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout();
   1053       memory::desc target_diff_dst_md = ConfigureInputGradient(
   1054           input_grad_tensor, input_grad_dnn_shape, &input_grad_dnn_data);
   1055 
   1056       memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout();
   1057       memory::dims orig_input_dims =
   1058           orig_input_dnn_shape.GetSizesAsMklDnnDims();
   1059       orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
   1060       orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
   1061 
   1062       // output_dnn_data has the same shape as original input
   1063       output_dnn_data.SetUsrMem(orig_input_md);
   1064       output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
   1065 
   1066       // MKL-DNN has a notion of kernel_size and not depth_radius.
   1067       int kernel_size = 2 * depth_radius_ + 1;
   1068       float new_alpha = alpha_ * kernel_size;
   1069 
   1070       // Create LRN backward primitive descriptor. It requires LRN forward
   1071       // primitive descriptor also.
   1072       auto lrn_fwd_desc = lrn_forward::desc(
   1073           prop_kind::forward, lrn_across_channels, orig_input_md, kernel_size,
   1074           new_alpha, beta_, bias_);
   1075       auto lrn_fwd_prim_desc =
   1076           lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine);
   1077       auto lrn_bwd_desc = lrn_backward::desc(
   1078           lrn_across_channels, original_output_md, target_diff_dst_md,
   1079           kernel_size, new_alpha, beta_, bias_);
   1080       auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(
   1081           lrn_bwd_desc, cpu_engine, lrn_fwd_prim_desc);
   1082 
   1083       Tensor* output_tensor = nullptr;
   1084       memory::format orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
   1085       AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims,
   1086                            orig_input_format, &output_tensor);
   1087       OP_REQUIRES_OK(context, context->status());
   1088       CHECK_NOTNULL(output_tensor);
   1089       output_dnn_data.SetUsrMemDataHandle(output_tensor);
   1090 
   1091       // Create LRN primitive and add it to the net
   1092       // At this point, workspace is enabled, so we don't need
   1093       // to check. Pass input workspace to LRN backward primitive.
   1094       const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
   1095       MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
   1096       ConfigureWorkspace(workspace_tensor,
   1097                          lrn_fwd_prim_desc.workspace_primitive_desc(),
   1098                          &workspace_dnn_data);
   1099 
   1100       PrepareAndExecuteNet(
   1101           lrn_bwd_prim_desc, lrn_fwd_prim_desc, &orig_input_dnn_data,
   1102           &input_grad_dnn_data, &output_dnn_data,
   1103           memory::primitive_desc(target_diff_dst_md, cpu_engine),
   1104           &workspace_dnn_data);
   1105     } catch (mkldnn::error& e) {
   1106       string error_msg = "Status: " + std::to_string(e.status) +
   1107                          ", message: " + string(e.message) + ", in file " +
   1108                          string(__FILE__) + ":" + std::to_string(__LINE__);
   1109       OP_REQUIRES_OK(
   1110           context,
   1111           errors::Aborted("Operation received an exception:", error_msg));
   1112     }
   1113   }
   1114 
   1115   void AllocateOutputTensor(
   1116       OpKernelContext* context,
   1117       const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
   1118       const memory::dims output_dims_mkl_order,
   1119       const memory::format& output_tf_format, Tensor** output_tensor) {
   1120     CHECK_NOTNULL(output_tensor);
   1121     memory::primitive_desc dst_pd =
   1122         lrn_bkwd_prim_desc.diff_src_primitive_desc();
   1123     MklDnnShape output_mkl_shape;
   1124 
   1125     // We assume that all outputs at this point are MKL Tensors
   1126     output_mkl_shape.SetMklTensor(true);
   1127     output_mkl_shape.SetMklLayout(&dst_pd);
   1128     output_mkl_shape.SetElemType(MklDnnType<T>());
   1129     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
   1130                                  output_dims_mkl_order, output_tf_format);
   1131 
   1132     TensorShape output_tf_shape;
   1133     size_t num_bytes = dst_pd.get_size();
   1134     CHECK_EQ(num_bytes % sizeof(T), 0);
   1135     output_tf_shape.AddDim(num_bytes / sizeof(T));
   1136     AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
   1137                               output_tf_shape, output_mkl_shape);
   1138   }
   1139 
   1140   memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
   1141                                       const MklDnnShape& input_grad_dnn_shape,
   1142                                       MklDnnData<T>* input_grad_dnn_data) {
   1143     CHECK_NOTNULL(input_grad_dnn_data);
   1144     // This shouldn't be necessary at this point, but just in case
   1145     CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true);
   1146 
   1147     memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
   1148     memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims();
   1149     input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
   1150     input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc);
   1151     return input_grad_md;
   1152   }
   1153 
   1154   void PrepareAndExecuteNet(
   1155       const lrn_backward::primitive_desc& lrn_bkwd_desc,
   1156       const lrn_forward::primitive_desc& lrn_fwd_desc,
   1157       MklDnnData<T>* src_dnn_data, MklDnnData<T>* input_gradient_diff_dst,
   1158       MklDnnData<T>* output_diff_src,
   1159       const memory::primitive_desc& target_diff_dst_pd,
   1160       const MklDnnData<uint8>* workspace_dnn_data = nullptr) {
   1161     std::vector<primitive> net;
   1162 
   1163     // Check for input reordering on the diff dst input
   1164     input_gradient_diff_dst->CheckReorderToOpMem(
   1165         lrn_bkwd_desc.diff_dst_primitive_desc(), &net);
   1166 
   1167     // Check for input reordering on the original input
   1168     src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc(), &net);
   1169     // Create pooling primitive and add it to net
   1170     if (nullptr == workspace_dnn_data) {
   1171       net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
   1172                                  input_gradient_diff_dst->GetOpMem(),
   1173                                  output_diff_src->GetOpMem()));
   1174     } else {
   1175       net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
   1176                                  input_gradient_diff_dst->GetOpMem(),
   1177                                  workspace_dnn_data->GetOpMem(),
   1178                                  output_diff_src->GetOpMem()));
   1179     }
   1180     stream(stream::kind::eager).submit(net).wait();
   1181   }
   1182 
   1183   void ConfigureWorkspace(const Tensor& workspace_tensor,
   1184                           memory::primitive_desc workspace_pd,
   1185                           MklDnnData<uint8>* workspace_dnn_data) {
   1186     CHECK_NOTNULL(workspace_dnn_data);
   1187 
   1188     workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
   1189   }
   1190 
   1191   // Fallback implementation - Taken from lrn_op.cc
   1192   // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
   1193   // copy.
   1194   void MklDefaultToEigen(OpKernelContext* context) {
   1195     Tensor input_gradient_tensor;
   1196     Tensor orig_input_tensor;
   1197     Tensor orig_output_tensor;
   1198 
   1199     MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
   1200         orig_output_dnn_shape;
   1201     GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
   1202     GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
   1203     GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
   1204 
   1205     if (input_grad_dnn_shape.IsMklTensor()) {
   1206       input_gradient_tensor = ConvertMklToTF<T>(
   1207           context, MklGetInput(context, kIdxGradient), input_grad_dnn_shape);
   1208     } else {
   1209       input_gradient_tensor = MklGetInput(context, kIdxGradient);
   1210     }
   1211 
   1212     if (orig_input_dnn_shape.IsMklTensor()) {
   1213       orig_input_tensor = ConvertMklToTF<T>(
   1214           context, MklGetInput(context, kIdxOrigInput), orig_input_dnn_shape);
   1215     } else {
   1216       orig_input_tensor = MklGetInput(context, kIdxOrigInput);
   1217     }
   1218 
   1219     if (orig_output_dnn_shape.IsMklTensor()) {
   1220       orig_output_tensor = ConvertMklToTF<T>(
   1221           context, MklGetInput(context, kIdxOrigOutput), orig_output_dnn_shape);
   1222     } else {
   1223       orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
   1224     }
   1225 
   1226     const int64 batch = static_cast<int64>(input_gradient_tensor.dim_size(0));
   1227     const int64 rows = static_cast<int64>(input_gradient_tensor.dim_size(1));
   1228     const int64 cols = static_cast<int64>(input_gradient_tensor.dim_size(2));
   1229     const int64 depth = static_cast<int64>(input_gradient_tensor.dim_size(3));
   1230     const auto nodes = cols * rows;
   1231 
   1232     auto grads_shaped =
   1233         input_gradient_tensor.shaped<T, 2>({nodes * batch, depth});
   1234 
   1235     auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth});
   1236     auto activations = orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
   1237 
   1238     Tensor* output_dnn_data;
   1239     MklShape mkl_output_mkl_shape;
   1240     mkl_output_mkl_shape.SetMklTensor(false);
   1241     mkl_output_mkl_shape.SetDimensions(4);
   1242     AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
   1243                               input_gradient_tensor.shape(),
   1244                               mkl_output_mkl_shape);
   1245 
   1246     auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
   1247     out_shaped.setZero();
   1248     auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
   1249                   depth](int64 begin, int64 end) {
   1250       for (int64 i = begin; i < end; ++i) {
   1251         for (int64 j = 0; j < depth; ++j) {
   1252           int64 depth_begin = std::max<int64>(0, j - depth_radius_);
   1253           int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
   1254 
   1255           T norm(0);
   1256           for (int64 k = depth_begin; k < depth_end; ++k) {
   1257             norm += in_shaped(i, k) * in_shaped(i, k);
   1258           }
   1259           norm = alpha_ * norm + bias_;
   1260           DCHECK_GT(norm, T(1e-6));
   1261           for (int64 k = depth_begin; k < depth_end; ++k) {
   1262             T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
   1263                     activations(i, j) / norm;
   1264             if (k == j) {
   1265               dyi += Eigen::numext::pow(norm, -beta_);
   1266             }
   1267             dyi *= grads_shaped(i, j);
   1268             const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
   1269           }
   1270         }
   1271       }
   1272     };
   1273     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
   1274     Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
   1275           depth * depth, shard);
   1276   }
   1277 
   1278   void SanityCheckInputs(OpKernelContext* context) {
   1279     const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient);
   1280     const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
   1281     const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
   1282     const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
   1283     MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape,
   1284         workspace_dnn_shape;
   1285     GetMklShape(context, kIdxGradient, &in_grads_dnn_shape);
   1286     GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape);
   1287     GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape);
   1288     GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape);
   1289     if (in_grads_dnn_shape.IsMklTensor()) {
   1290       OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4,
   1291                   errors::InvalidArgument("Input gradient must be "
   1292                                           "4-dimensional"));
   1293     } else {
   1294       OP_REQUIRES(
   1295           context, input_gradient_tensor.dims() == 4,
   1296           errors::InvalidArgument("input gradient must be 4-dimensional"));
   1297     }
   1298 
   1299     if (in_image_dnn_shape.IsMklTensor()) {
   1300       OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4,
   1301                   errors::InvalidArgument("input images must be "
   1302                                           "4-dimensional"));
   1303     } else {
   1304       OP_REQUIRES(context, orig_input_tensor.dims() == 4,
   1305                   errors::InvalidArgument("input images must be "
   1306                                           "4-dimensional"));
   1307     }
   1308 
   1309     if (out_image_dnn_shape.IsMklTensor()) {
   1310       OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4,
   1311                   errors::InvalidArgument("Output image must be "
   1312                                           "4-dimensional"));
   1313     } else {
   1314       OP_REQUIRES(
   1315           context, orig_output_tensor.dims() == 4,
   1316           errors::InvalidArgument("Output image must be 4-dimensional"));
   1317     }
   1318 
   1319     if (workspace_enabled_) {
   1320       if (workspace_dnn_shape.IsMklTensor()) {
   1321         OP_REQUIRES(
   1322             context, workspace_dnn_shape.IsMklTensor() == false,
   1323             errors::InvalidArgument("Workspace should not be MKL Tensor."));
   1324       } else {
   1325         OP_REQUIRES(context, workspace_tensor.dims() == 1,
   1326                     errors::InvalidArgument("Workspace must be 1-dimensional"));
   1327       }
   1328     }
   1329   }
   1330 
   1331   // Input("input_grads: T")
   1332   // Input("input_image: T")
   1333   // Input("output_image: T")
   1334   // Input("workspace: uint8")
   1335   const int kIdxGradient = 0, kIdxOrigInput = 1, kIdxOrigOutput = 2,
   1336             kIdxWorkspace = 3, kIdxOutput = 0;
   1337 
   1338   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
   1339   bool workspace_enabled_;
   1340   int depth_radius_;
   1341   float bias_;
   1342   float alpha_;
   1343   float beta_;
   1344 };
   1345 
   1346 #endif  // INTEL_MKL_ML
   1347 
   1348 #define REGISTER_MKL_LRN_CPU(T)                                     \
   1349   REGISTER_KERNEL_BUILDER(Name("_MklLRN")                           \
   1350                               .Device(DEVICE_CPU)                   \
   1351                               .TypeConstraint<T>("T")               \
   1352                               .Label(mkl_op_registry::kMklOpLabel), \
   1353                           MklLRNOp<T>);                             \
   1354   REGISTER_KERNEL_BUILDER(Name("_MklLRNGrad")                       \
   1355                               .Device(DEVICE_CPU)                   \
   1356                               .TypeConstraint<T>("T")               \
   1357                               .Label(mkl_op_registry::kMklOpLabel), \
   1358                           MklLRNGradOp<T>);
   1359 
   1360 TF_CALL_float(REGISTER_MKL_LRN_CPU);
   1361 
   1362 }  // namespace tensorflow
   1363 
   1364 #endif  // INTEL_MKL
   1365