Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // LRN = Local Response Normalization
     17 // See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
     18 // layout and primitives, use MKL dnn primitives to compute local
     19 // response normalization
     20 
     21 #ifdef INTEL_MKL
     22 
     23 #define EIGEN_USE_THREADS
     24 #include <vector>
     25 #include "mkldnn.hpp"
     26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     27 #include "tensorflow/core/framework/bounds_check.h"
     28 #include "tensorflow/core/framework/op_kernel.h"
     29 #include "tensorflow/core/framework/register_types.h"
     30 #include "tensorflow/core/framework/tensor.h"
     31 #include "tensorflow/core/kernels/ops_util.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/util/mkl_util.h"
     34 #include "tensorflow/core/util/tensor_format.h"
     35 
     36 #if !defined(IS_MOBILE_PLATFORM)
     37 #include "tensorflow/core/util/work_sharder.h"
     38 #endif
     39 
     40 using mkldnn::lrn_across_channels;
     41 using mkldnn::lrn_backward;
     42 using mkldnn::lrn_forward;
     43 using mkldnn::prop_kind;
     44 using mkldnn::stream;
     45 
     46 namespace tensorflow {
     47 
     48 namespace {
     49 // Create a depth-by-depth band matrix with 1s along a swath of size (2 *
     50 // depth_radius + 1) around the diagonal.
     51 template <typename T>
     52 void GetBandMatrix(int depth, int depth_radius,
     53                    Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
     54   result->setZero();
     55   for (int row = 0; row < depth; ++row) {
     56     const int begin = std::max<int>(0, row - depth_radius);
     57     const int end = std::min<int>(depth, row + depth_radius + 1);
     58     Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
     59     Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
     60     result->slice(start, sizes).setConstant(T(1));
     61   }
     62 }
     63 
     64 }  // namespace
     65 
     66 template <typename T>
     67 class MklLRNOp : public OpKernel {
     68  public:
     69   ~MklLRNOp() {}
     70 
     71   explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
     72     int64 depth_radius64;
     73     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
     74     OP_REQUIRES(
     75         context,
     76         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
     77         errors::InvalidArgument("depth_radius = ", depth_radius64,
     78                                 " larger than int max"));
     79     depth_radius_ = static_cast<size_t>(depth_radius64);
     80 
     81     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
     82     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
     83     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
     84     workspace_enabled_ = false;
     85     OP_REQUIRES_OK(context,
     86                    context->GetAttr("workspace_enabled", &workspace_enabled_));
     87   }
     88 
     89   void Compute(OpKernelContext* context) override {
     90     try {
     91       SanityCheckInputs(context);
     92       if (!context->status().ok()) return;
     93 
     94       auto cpu_engine = engine(engine::cpu, 0);
     95       const Tensor& src_tensor = MklGetInput(context, kIdxInput);
     96       MklDnnShape src_dnn_shape;
     97       GetMklShape(context, kIdxInput, &src_dnn_shape);
     98 
     99       // MKL-DNN has a notion of kernel_size and not depth_radius.
    100       int kernel_size = 2 * depth_radius_ + 1;
    101       float new_alpha = alpha_ * kernel_size;
    102 
    103       // if the input tensor is not an MKL Tensor, or if the last
    104       // dimension is not channel, then just use Eigen.
    105       // MKL only support normalization over the channel dimension.
    106       if (!src_dnn_shape.IsMklTensor()) {
    107         MklDefaultToEigen(context, src_tensor);
    108         return;
    109       } else if (!src_dnn_shape.IsMklChannelDim(src_dnn_shape.GetDimension() -
    110                                                 1)) {
    111         Tensor converted_tensor =
    112             ConvertMklToTF<T>(context, src_tensor, src_dnn_shape);
    113         MklDefaultToEigen(context, converted_tensor);
    114         return;
    115       }
    116       // At this point, we can assume that the src is an MklTensor
    117       // and we can enable the workspace
    118       workspace_enabled_ = true;
    119 
    120       MklDnnData<T> src_dnn_data(&cpu_engine);
    121       MklDnnData<T> dst_dnn_data(&cpu_engine);
    122       MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
    123 
    124       TensorShape tf_output_shape = src_tensor.shape();
    125 
    126       memory::desc src_md = src_dnn_shape.GetCurLayout();
    127       memory::dims input_dims = src_dnn_shape.GetSizesAsMklDnnDims();
    128 
    129       // Create memory for user input.
    130       // Since Tensorflow always performs normalization over last dimension,
    131       // and MKL-DNN performs normalization over Channel, we tell MKL-DNN
    132       // that input is in NHWC layout with Channel being the last dimension.
    133       src_dnn_data.SetUsrMem(src_md, &src_tensor);
    134       src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
    135 
    136       // output_dnn_data and workspace both have the same shape as input
    137       dst_dnn_data.SetUsrMem(src_md);
    138       dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
    139 
    140       // Create LRN primitive descriptor.
    141       // Tensorflow's normalization semantics is across channels.
    142       // MKL-DNN also supports normalization within channel.
    143       auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels,
    144                                         src_dnn_data.GetUsrMemDesc(),
    145                                         kernel_size, new_alpha, beta_, bias_);
    146       auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
    147 
    148       // Allocate output_dnn_data tensor.
    149       Tensor* output_tensor = nullptr;
    150       memory::format input_format = src_dnn_shape.GetTfDataFormat();
    151       AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format,
    152                            &output_tensor);
    153       OP_REQUIRES_OK(context, context->status());
    154       CHECK_NOTNULL(output_tensor);
    155       dst_dnn_data.SetUsrMemDataHandle(output_tensor);
    156 
    157       // Handle workspace required for MKL-DNN.
    158       AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
    159       OP_REQUIRES_OK(context, context->status());
    160 
    161       PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, &dst_dnn_data,
    162                            &workspace_dnn_data);
    163     } catch (mkldnn::error& e) {
    164       string error_msg = "Status: " + std::to_string(e.status) +
    165                          ", message: " + string(e.message) + ", in file " +
    166                          string(__FILE__) + ":" + std::to_string(__LINE__);
    167       OP_REQUIRES_OK(
    168           context,
    169           errors::Aborted("Operation received an exception:", error_msg));
    170     }
    171   }
    172 
    173  private:
    174   void PrepareAndExecuteNet(const lrn_forward::primitive_desc& lrn_fwd_desc,
    175                             MklDnnData<T>* src_dnn_data,
    176                             MklDnnData<T>* dst_dnn_data,
    177                             MklDnnData<uint8>* wksp_dnn_data = nullptr) {
    178     // Check for input reorder
    179     src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
    180 
    181     // Create pooling primitive and add it to net
    182     std::vector<primitive> net;
    183     if (wksp_dnn_data != nullptr) {
    184       net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
    185                                 wksp_dnn_data->GetOpMem(),
    186                                 dst_dnn_data->GetOpMem()));
    187     } else {
    188       net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
    189                                 dst_dnn_data->GetOpMem()));
    190     }
    191     stream(stream::kind::eager).submit(net).wait();
    192   }
    193 
    194   void AllocateOutputTensor(
    195       OpKernelContext* context,
    196       const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
    197       const memory::dims output_dims_mkl_order,
    198       const memory::format& output_tf_format, Tensor** output_tensor) {
    199     CHECK_NOTNULL(output_tensor);
    200     memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc();
    201 
    202     MklDnnShape output_mkl_shape;
    203     // We only handle the case when the inputs and output are in Mkl format
    204     // Any other case is handled by Eigen
    205     output_mkl_shape.SetMklTensor(true);
    206     output_mkl_shape.SetMklLayout(&dst_pd);
    207     output_mkl_shape.SetElemType(MklDnnType<T>());
    208     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    209                                  output_dims_mkl_order, output_tf_format);
    210     TensorShape output_tf_shape;
    211     // only allocate enough space for the elements we need.
    212     size_t num_bytes = dst_pd.get_size();
    213     CHECK_EQ(num_bytes % sizeof(T), 0);
    214     output_tf_shape.AddDim(num_bytes / sizeof(T));
    215     AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
    216                               output_tf_shape, output_mkl_shape);
    217   }
    218 
    219   // Fallback implementation - Taken from lrn_op.cc
    220   // TODO(inteltf) Check if we can use EigenLRNOp directly instead of making a
    221   // copy.
    222   void MklDefaultToEigen(OpKernelContext* context, const Tensor& input) {
    223     const int batch = static_cast<int>(input.dim_size(0));
    224     const int rows = static_cast<int>(input.dim_size(1));
    225     const int cols = static_cast<int>(input.dim_size(2));
    226     const int depth = static_cast<int>(input.dim_size(3));
    227     const int nodes = cols * rows;
    228 
    229     auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
    230     // Multiplying the input with the band matrix has the effect of reducing
    231     // the
    232     // correct patch along the depth.
    233     Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
    234     GetBandMatrix<T>(depth, depth_radius_, &multiplier);
    235 
    236     Tensor* output_dnn_data = nullptr;
    237     MklDnnShape mkl_output_mkl_shape;
    238     mkl_output_mkl_shape.SetMklTensor(false);
    239     mkl_output_mkl_shape.SetDimensions(4);
    240     AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
    241                               input.shape(), mkl_output_mkl_shape);
    242     CHECK_NOTNULL(output_dnn_data);
    243 
    244     Tensor* workspace_tensor = nullptr;
    245     MklDnnShape workspace_mkl_shape;
    246     workspace_mkl_shape.SetMklTensor(false);
    247     TensorShape workspace_tf_shape;
    248     workspace_tf_shape.AddDim(0);
    249     AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
    250                               workspace_tf_shape, workspace_mkl_shape);
    251     CHECK_NOTNULL(workspace_tensor);
    252 
    253     auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
    254     Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
    255     auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
    256     if (beta_ == T(1)) {
    257       out_shaped.device(context->eigen_cpu_device()) =
    258           in_shaped * tmp.inverse();
    259     } else if (beta_ == T(0.5)) {
    260       out_shaped.device(context->eigen_cpu_device()) = in_shaped * tmp.rsqrt();
    261     } else {
    262       out_shaped.device(context->eigen_cpu_device()) =
    263           in_shaped * (tmp.log() * -beta_).exp();
    264     }
    265   }
    266 
    267   void AllocateWorkspaceTensor(
    268       OpKernelContext* context,
    269       const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
    270       MklDnnData<uint8>* dnn_data_wksp) {
    271     CHECK_NOTNULL(dnn_data_wksp);
    272     Tensor* workspace_tensor = nullptr;
    273     memory::primitive_desc workspace_pd =
    274         lrn_fwd_prim_desc.workspace_primitive_desc();
    275     size_t workspace_bytes = workspace_pd.get_size();
    276     MklDnnShape workspace_mkl_shape;
    277     // the workspace tensor is a uint8 tensor that has
    278     // exactly the number of bytes necessary
    279     workspace_mkl_shape.SetMklTensor(false);
    280     TensorShape workspace_tf_shape;
    281     workspace_tf_shape.AddDim(workspace_bytes);
    282     AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
    283                               workspace_tf_shape, workspace_mkl_shape);
    284     CHECK_NOTNULL(workspace_tensor);
    285     dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
    286   }
    287 
    288   void SanityCheckInputs(OpKernelContext* context) {
    289     const Tensor& src_tensor = MklGetInput(context, kIdxInput);
    290     MklDnnShape src_dnn_shape;
    291     GetMklShape(context, kIdxInput, &src_dnn_shape);
    292     if (src_dnn_shape.IsMklTensor()) {
    293       OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
    294                   errors::InvalidArgument("input must be 4-dimensional"));
    295       OP_REQUIRES(context,
    296                   FastBoundsCheck(src_tensor.NumElements(),
    297                                   std::numeric_limits<int>::max()),
    298                   errors::InvalidArgument("argument to LRN too large"));
    299     } else {
    300       OP_REQUIRES(context, src_tensor.dims() == 4,
    301                   errors::InvalidArgument("input must be 4-dimensional"));
    302       OP_REQUIRES(context,
    303                   FastBoundsCheck(src_tensor.NumElements(),
    304                                   std::numeric_limits<int>::max()),
    305                   errors::InvalidArgument("argument to LRN too large"));
    306     }
    307   }
    308   const int kIdxInput = 0, kIdxOutput = 0, kIdxWorkspace = 1;
    309 
    310   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
    311   bool workspace_enabled_;
    312   int depth_radius_;
    313   float bias_;
    314   float alpha_;
    315   float beta_;
    316 };
    317 
    318 template <typename T>
    319 class MklLRNGradOp : public OpKernel {
    320  public:
    321   explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
    322     int64 depth_radius64;
    323     OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
    324     OP_REQUIRES(
    325         context,
    326         FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
    327         errors::InvalidArgument("depth_radius = ", depth_radius64,
    328                                 " larger than int max"));
    329     depth_radius_ = static_cast<int>(depth_radius64);
    330     OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
    331     OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
    332     OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
    333     workspace_enabled_ = false;
    334     OP_REQUIRES_OK(context,
    335                    context->GetAttr("workspace_enabled", &workspace_enabled_));
    336   }
    337 
    338   void Compute(OpKernelContext* context) override {
    339     try {
    340       SanityCheckInputs(context);
    341       if (!context->status().ok()) return;
    342 
    343       auto cpu_engine = engine(engine::cpu, 0);
    344       MklDnnData<T> input_grad_dnn_data(&cpu_engine);
    345       MklDnnData<T> orig_input_dnn_data(&cpu_engine);
    346       MklDnnData<T> orig_output_dnn_data(&cpu_engine);
    347       MklDnnData<T> output_dnn_data(&cpu_engine);
    348 
    349       MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
    350           orig_output_dnn_shape;
    351       GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
    352       GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
    353       GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
    354 
    355       // We only use MKLDNN if all of the necessary inputs are present
    356       // in mkldnn format, and Channel is the last dimension
    357       bool can_use_mkldnn = workspace_enabled_ &&
    358                             input_grad_dnn_shape.IsMklTensor() &&
    359                             orig_input_dnn_shape.IsMklTensor() &&
    360                             orig_output_dnn_shape.IsMklTensor() &&
    361                             input_grad_dnn_shape.IsMklChannelDim(
    362                                 input_grad_dnn_shape.GetDimension() - 1) &&
    363                             orig_input_dnn_shape.IsMklChannelDim(
    364                                 orig_input_dnn_shape.GetDimension() - 1) &&
    365                             orig_output_dnn_shape.IsMklChannelDim(
    366                                 orig_output_dnn_shape.GetDimension() - 1);
    367 
    368       if (!can_use_mkldnn) {
    369         // Fallback to eigen
    370         MklDefaultToEigen(context);
    371         return;
    372       }
    373       // At this point, we have the all clear to use MklDnn constructs
    374       // Naming: diff_dst is input_gradient_tensor; src is orig_input_tensor.
    375       const Tensor& input_grad_tensor = MklGetInput(context, kIdxGradient);
    376       const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
    377 
    378       // Get input sizes in MKL-DNN required NCHW format.
    379       // LRN does not have data_format attribute. But by default it has
    380       // NHWC format.
    381       memory::desc original_output_md = orig_output_dnn_shape.GetCurLayout();
    382       memory::desc target_diff_dst_md = ConfigureInputGradient(
    383           input_grad_tensor, input_grad_dnn_shape, &input_grad_dnn_data);
    384 
    385       memory::desc orig_input_md = orig_input_dnn_shape.GetCurLayout();
    386       memory::dims orig_input_dims =
    387           orig_input_dnn_shape.GetSizesAsMklDnnDims();
    388       orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
    389       orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
    390 
    391       // output_dnn_data has the same shape as original input
    392       output_dnn_data.SetUsrMem(orig_input_md);
    393       output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
    394 
    395       // MKL-DNN has a notion of kernel_size and not depth_radius.
    396       int kernel_size = 2 * depth_radius_ + 1;
    397       float new_alpha = alpha_ * kernel_size;
    398 
    399       // Create LRN backward primitive descriptor. It requires LRN forward
    400       // primitive descriptor also.
    401       auto lrn_fwd_desc = lrn_forward::desc(
    402           prop_kind::forward, lrn_across_channels, orig_input_md, kernel_size,
    403           new_alpha, beta_, bias_);
    404       auto lrn_fwd_prim_desc =
    405           lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine);
    406       auto lrn_bwd_desc = lrn_backward::desc(
    407           lrn_across_channels, original_output_md, target_diff_dst_md,
    408           kernel_size, new_alpha, beta_, bias_);
    409       auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(
    410           lrn_bwd_desc, cpu_engine, lrn_fwd_prim_desc);
    411 
    412       Tensor* output_tensor = nullptr;
    413       memory::format orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
    414       AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims,
    415                            orig_input_format, &output_tensor);
    416       OP_REQUIRES_OK(context, context->status());
    417       CHECK_NOTNULL(output_tensor);
    418       output_dnn_data.SetUsrMemDataHandle(output_tensor);
    419 
    420       // Create LRN primitive and add it to the net
    421       // At this point, workspace is enabled, so we don't need
    422       // to check. Pass input workspace to LRN backward primitive.
    423       const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
    424       MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
    425       ConfigureWorkspace(workspace_tensor,
    426                          lrn_fwd_prim_desc.workspace_primitive_desc(),
    427                          &workspace_dnn_data);
    428 
    429       PrepareAndExecuteNet(
    430           lrn_bwd_prim_desc, lrn_fwd_prim_desc, &orig_input_dnn_data,
    431           &input_grad_dnn_data, &output_dnn_data,
    432           memory::primitive_desc(target_diff_dst_md, cpu_engine),
    433           &workspace_dnn_data);
    434     } catch (mkldnn::error& e) {
    435       string error_msg = "Status: " + std::to_string(e.status) +
    436                          ", message: " + string(e.message) + ", in file " +
    437                          string(__FILE__) + ":" + std::to_string(__LINE__);
    438       OP_REQUIRES_OK(
    439           context,
    440           errors::Aborted("Operation received an exception:", error_msg));
    441     }
    442   }
    443 
    444   void AllocateOutputTensor(
    445       OpKernelContext* context,
    446       const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
    447       const memory::dims output_dims_mkl_order,
    448       const memory::format& output_tf_format, Tensor** output_tensor) {
    449     CHECK_NOTNULL(output_tensor);
    450     memory::primitive_desc dst_pd =
    451         lrn_bkwd_prim_desc.diff_src_primitive_desc();
    452     MklDnnShape output_mkl_shape;
    453 
    454     // We assume that all outputs at this point are MKL Tensors
    455     output_mkl_shape.SetMklTensor(true);
    456     output_mkl_shape.SetMklLayout(&dst_pd);
    457     output_mkl_shape.SetElemType(MklDnnType<T>());
    458     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    459                                  output_dims_mkl_order, output_tf_format);
    460 
    461     TensorShape output_tf_shape;
    462     size_t num_bytes = dst_pd.get_size();
    463     CHECK_EQ(num_bytes % sizeof(T), 0);
    464     output_tf_shape.AddDim(num_bytes / sizeof(T));
    465     AllocateOutputSetMklShape(context, kIdxOutput, output_tensor,
    466                               output_tf_shape, output_mkl_shape);
    467   }
    468 
    469   memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
    470                                       const MklDnnShape& input_grad_dnn_shape,
    471                                       MklDnnData<T>* input_grad_dnn_data) {
    472     CHECK_NOTNULL(input_grad_dnn_data);
    473     // This shouldn't be necessary at this point, but just in case
    474     CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true);
    475 
    476     memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
    477     memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims();
    478     input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
    479     input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc);
    480     return input_grad_md;
    481   }
    482 
    483   void PrepareAndExecuteNet(
    484       const lrn_backward::primitive_desc& lrn_bkwd_desc,
    485       const lrn_forward::primitive_desc& lrn_fwd_desc,
    486       MklDnnData<T>* src_dnn_data, MklDnnData<T>* input_gradient_diff_dst,
    487       MklDnnData<T>* output_diff_src,
    488       const memory::primitive_desc& target_diff_dst_pd,
    489       const MklDnnData<uint8>* workspace_dnn_data = nullptr) {
    490     // Check for input reordering on the diff dst input
    491     input_gradient_diff_dst->CheckReorderToOpMem(
    492         lrn_bkwd_desc.diff_dst_primitive_desc());
    493 
    494     // Check for input reordering on the original input
    495     src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
    496     // Create pooling primitive and add it to net
    497     std::vector<primitive> net;
    498     if (nullptr == workspace_dnn_data) {
    499       net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
    500                                  input_gradient_diff_dst->GetOpMem(),
    501                                  output_diff_src->GetOpMem()));
    502     } else {
    503       net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
    504                                  input_gradient_diff_dst->GetOpMem(),
    505                                  workspace_dnn_data->GetOpMem(),
    506                                  output_diff_src->GetOpMem()));
    507     }
    508     stream(stream::kind::eager).submit(net).wait();
    509   }
    510 
    511   void ConfigureWorkspace(const Tensor& workspace_tensor,
    512                           memory::primitive_desc workspace_pd,
    513                           MklDnnData<uint8>* workspace_dnn_data) {
    514     CHECK_NOTNULL(workspace_dnn_data);
    515 
    516     workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
    517   }
    518 
    519   // Fallback implementation - Taken from lrn_op.cc
    520   // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
    521   // copy.
    522   void MklDefaultToEigen(OpKernelContext* context) {
    523     Tensor input_gradient_tensor;
    524     Tensor orig_input_tensor;
    525     Tensor orig_output_tensor;
    526 
    527     MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
    528         orig_output_dnn_shape;
    529     GetMklShape(context, kIdxGradient, &input_grad_dnn_shape);
    530     GetMklShape(context, kIdxOrigInput, &orig_input_dnn_shape);
    531     GetMklShape(context, kIdxOrigOutput, &orig_output_dnn_shape);
    532 
    533     if (input_grad_dnn_shape.IsMklTensor()) {
    534       input_gradient_tensor = ConvertMklToTF<T>(
    535           context, MklGetInput(context, kIdxGradient), input_grad_dnn_shape);
    536     } else {
    537       input_gradient_tensor = MklGetInput(context, kIdxGradient);
    538     }
    539 
    540     if (orig_input_dnn_shape.IsMklTensor()) {
    541       orig_input_tensor = ConvertMklToTF<T>(
    542           context, MklGetInput(context, kIdxOrigInput), orig_input_dnn_shape);
    543     } else {
    544       orig_input_tensor = MklGetInput(context, kIdxOrigInput);
    545     }
    546 
    547     if (orig_output_dnn_shape.IsMklTensor()) {
    548       orig_output_tensor = ConvertMklToTF<T>(
    549           context, MklGetInput(context, kIdxOrigOutput), orig_output_dnn_shape);
    550     } else {
    551       orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
    552     }
    553 
    554     const int64 batch = static_cast<int64>(input_gradient_tensor.dim_size(0));
    555     const int64 rows = static_cast<int64>(input_gradient_tensor.dim_size(1));
    556     const int64 cols = static_cast<int64>(input_gradient_tensor.dim_size(2));
    557     const int64 depth = static_cast<int64>(input_gradient_tensor.dim_size(3));
    558     const auto nodes = cols * rows;
    559 
    560     auto grads_shaped =
    561         input_gradient_tensor.shaped<T, 2>({nodes * batch, depth});
    562 
    563     auto in_shaped = orig_input_tensor.shaped<T, 2>({nodes * batch, depth});
    564     auto activations = orig_output_tensor.shaped<T, 2>({nodes * batch, depth});
    565 
    566     Tensor* output_dnn_data;
    567     MklDnnShape mkl_output_mkl_shape;
    568     mkl_output_mkl_shape.SetMklTensor(false);
    569     mkl_output_mkl_shape.SetDimensions(4);
    570     AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
    571                               input_gradient_tensor.shape(),
    572                               mkl_output_mkl_shape);
    573 
    574     auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
    575     out_shaped.setZero();
    576     auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
    577                   depth](int64 begin, int64 end) {
    578       for (int64 i = begin; i < end; ++i) {
    579         for (int64 j = 0; j < depth; ++j) {
    580           int64 depth_begin = std::max<int64>(0, j - depth_radius_);
    581           int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
    582 
    583           T norm(0);
    584           for (int64 k = depth_begin; k < depth_end; ++k) {
    585             norm += in_shaped(i, k) * in_shaped(i, k);
    586           }
    587           norm = alpha_ * norm + bias_;
    588           DCHECK_GT(norm, T(1e-6));
    589           for (int64 k = depth_begin; k < depth_end; ++k) {
    590             T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
    591                     activations(i, j) / norm;
    592             if (k == j) {
    593               dyi += Eigen::numext::pow(norm, -beta_);
    594             }
    595             dyi *= grads_shaped(i, j);
    596             const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) += dyi;
    597           }
    598         }
    599       }
    600     };
    601     auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
    602     Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
    603           depth * depth, shard);
    604   }
    605 
    606   void SanityCheckInputs(OpKernelContext* context) {
    607     const Tensor& input_gradient_tensor = MklGetInput(context, kIdxGradient);
    608     const Tensor& orig_input_tensor = MklGetInput(context, kIdxOrigInput);
    609     const Tensor& orig_output_tensor = MklGetInput(context, kIdxOrigOutput);
    610     const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
    611     MklDnnShape in_grads_dnn_shape, in_image_dnn_shape, out_image_dnn_shape,
    612         workspace_dnn_shape;
    613     GetMklShape(context, kIdxGradient, &in_grads_dnn_shape);
    614     GetMklShape(context, kIdxOrigInput, &in_image_dnn_shape);
    615     GetMklShape(context, kIdxOrigOutput, &out_image_dnn_shape);
    616     GetMklShape(context, kIdxWorkspace, &workspace_dnn_shape);
    617     if (in_grads_dnn_shape.IsMklTensor()) {
    618       OP_REQUIRES(context, in_grads_dnn_shape.GetDimension() == 4,
    619                   errors::InvalidArgument("Input gradient must be "
    620                                           "4-dimensional"));
    621     } else {
    622       OP_REQUIRES(
    623           context, input_gradient_tensor.dims() == 4,
    624           errors::InvalidArgument("input gradient must be 4-dimensional"));
    625     }
    626 
    627     if (in_image_dnn_shape.IsMklTensor()) {
    628       OP_REQUIRES(context, in_image_dnn_shape.GetDimension() == 4,
    629                   errors::InvalidArgument("input images must be "
    630                                           "4-dimensional"));
    631     } else {
    632       OP_REQUIRES(context, orig_input_tensor.dims() == 4,
    633                   errors::InvalidArgument("input images must be "
    634                                           "4-dimensional"));
    635     }
    636 
    637     if (out_image_dnn_shape.IsMklTensor()) {
    638       OP_REQUIRES(context, out_image_dnn_shape.GetDimension() == 4,
    639                   errors::InvalidArgument("Output image must be "
    640                                           "4-dimensional"));
    641     } else {
    642       OP_REQUIRES(
    643           context, orig_output_tensor.dims() == 4,
    644           errors::InvalidArgument("Output image must be 4-dimensional"));
    645     }
    646 
    647     if (workspace_enabled_) {
    648       if (workspace_dnn_shape.IsMklTensor()) {
    649         OP_REQUIRES(
    650             context, workspace_dnn_shape.IsMklTensor() == false,
    651             errors::InvalidArgument("Workspace should not be MKL Tensor."));
    652       } else {
    653         OP_REQUIRES(context, workspace_tensor.dims() == 1,
    654                     errors::InvalidArgument("Workspace must be 1-dimensional"));
    655       }
    656     }
    657   }
    658 
    659   // Input("input_grads: T")
    660   // Input("input_image: T")
    661   // Input("output_image: T")
    662   // Input("workspace: uint8")
    663   const int kIdxGradient = 0, kIdxOrigInput = 1, kIdxOrigOutput = 2,
    664             kIdxWorkspace = 3, kIdxOutput = 0;
    665 
    666   typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
    667   bool workspace_enabled_;
    668   int depth_radius_;
    669   float bias_;
    670   float alpha_;
    671   float beta_;
    672 };
    673 
    674 #define REGISTER_MKL_LRN_CPU(T)                                     \
    675   REGISTER_KERNEL_BUILDER(Name("_MklLRN")                           \
    676                               .Device(DEVICE_CPU)                   \
    677                               .TypeConstraint<T>("T")               \
    678                               .Label(mkl_op_registry::kMklOpLabel), \
    679                           MklLRNOp<T>);                             \
    680   REGISTER_KERNEL_BUILDER(Name("_MklLRNGrad")                       \
    681                               .Device(DEVICE_CPU)                   \
    682                               .TypeConstraint<T>("T")               \
    683                               .Label(mkl_op_registry::kMklOpLabel), \
    684                           MklLRNGradOp<T>);
    685 
    686 TF_CALL_float(REGISTER_MKL_LRN_CPU);
    687 
    688 }  // namespace tensorflow
    689 
    690 #endif  // INTEL_MKL
    691