Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #ifdef INTEL_MKL
     16 
     17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     18 #include "tensorflow/core/framework/op_kernel.h"
     19 #include "tensorflow/core/framework/register_types.h"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 #include "tensorflow/core/util/tensor_format.h"
     23 
     24 #include "mkl_dnn.h"
     25 #include "mkl_dnn_types.h"
     26 #include "tensorflow/core/util/mkl_util.h"
     27 
     28 #ifndef INTEL_MKL_ML
     29 #include "mkldnn.hpp"
     30 
     31 using mkldnn::batch_normalization_backward;
     32 using mkldnn::batch_normalization_forward;
     33 using mkldnn::prop_kind;
     34 using mkldnn::stream;
     35 using mkldnn::use_global_stats;
     36 using mkldnn::use_scale_shift;
     37 #endif
     38 
     39 // TODO(inteltf) Address comments from PR 8968.
     40 
     41 namespace tensorflow {
     42 using CPUDevice = Eigen::ThreadPoolDevice;
     43 
     44 #ifdef INTEL_MKL_ML
     45 
     46 template <typename Device, typename T>
     47 class MklFusedBatchNormOp : public OpKernel {
     48  public:
     49   explicit MklFusedBatchNormOp(OpKernelConstruction* context)
     50       : OpKernel(context) {
     51     float epsilon;
     52     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
     53     epsilon_ = T(epsilon);
     54     string tensor_format;
     55     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
     56     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
     57                 errors::InvalidArgument("Invalid data format"));
     58     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
     59   }
     60 
     61   void Compute(OpKernelContext* context) override {
     62     MklFusedBatchNormOpContext mkl_context;
     63     const Tensor& input = MklGetInput(context, 0);
     64     const Tensor& scale = MklGetInput(context, 1);
     65     const Tensor& shift = MklGetInput(context, 2);
     66     const Tensor& est_mean = MklGetInput(context, 3);
     67     const Tensor& est_variance = MklGetInput(context, 4);
     68 
     69     GetMklShape(context, 0, &(mkl_context.mkl_shape_input_shape));
     70     bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
     71 
     72     if (!input_in_mkl_format) {
     73       OP_REQUIRES(context, input.dims() == 4,
     74                   errors::InvalidArgument("input must be 4-dimensional",
     75                                           input.shape().DebugString()));
     76     }
     77     OP_REQUIRES(context, scale.dims() == 1,
     78                 errors::InvalidArgument("scale must be 1-dimensional",
     79                                         scale.shape().DebugString()));
     80     OP_REQUIRES(context, shift.dims() == 1,
     81                 errors::InvalidArgument("offset must be 1-dimensional",
     82                                         shift.shape().DebugString()));
     83     OP_REQUIRES(context, est_mean.dims() == 1,
     84                 errors::InvalidArgument("estimated_mean must be 1-dimensional",
     85                                         est_mean.shape().DebugString()));
     86 
     87     OP_REQUIRES(
     88         context, est_variance.dims() == 1,
     89         errors::InvalidArgument("estimated_variance must be 1-dimensional",
     90                                 est_variance.shape().DebugString()));
     91 
     92     if (is_training_) {
     93       OP_REQUIRES(context, est_mean.dim_size(0) == 0,
     94                   errors::InvalidArgument("estimated_mean empty for training",
     95                                           est_mean.shape().DebugString()));
     96       OP_REQUIRES(context, est_variance.dim_size(0) == 0,
     97                   errors::InvalidArgument(
     98                       "estimated_variance must be empty for training",
     99                       est_variance.shape().DebugString()));
    100     }
    101 
    102     unsigned int flag_batch_norm =
    103         is_training_ ? dnnUseScaleShift
    104                      : (dnnUseInputMeanVariance | dnnUseScaleShift);
    105 
    106     mkl_context.MklExtractParams(context, tensor_format_);
    107 
    108     // Create layout only for input data as it is used in Op primitive.
    109     mkl_context.MklCreateInputLayout(context);
    110 
    111     // Create Op primitive.
    112     CHECK_EQ(dnnBatchNormalizationCreateForward_v2_F32(
    113                  &(mkl_context.mkl_prim_batchnorm), nullptr,
    114                  mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
    115                  flag_batch_norm),
    116              E_SUCCESS);
    117 
    118     // Temporary tensors with buffers for the context inputs, if
    119     // conversion to MKL-Op specific layouts are required. It is assumed here
    120     // that TF's 1D tensors (scale, shift, est_mean, and est_variance) won't
    121     // require any conversion.
    122     // Since scale-shift is combined in MKL, a buffer is required.
    123     Tensor mkl_tmp_input_buf_tensor, mkl_tmp_scale_shift_buf_tensor;
    124     mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
    125                                         &mkl_tmp_scale_shift_buf_tensor);
    126 
    127     // Output data in MKL layout
    128     Tensor* output = nullptr;
    129     TensorShape tf_shape_output;
    130     MklShape mkl_shape_output;
    131     mkl_shape_output.SetMklTensor(true);
    132     mkl_shape_output.SetMklLayout(mkl_context.mkl_prim_batchnorm,
    133                                   dnnResourceDst);
    134     mkl_shape_output.SetTfLayout(mkl_context.mkl_params.in_dim,
    135                                  mkl_context.mkl_params.in_sizes,
    136                                  mkl_context.mkl_params.in_strides);
    137     mkl_shape_output.SetTfDimOrder(mkl_context.mkl_params.in_dim,
    138                                    tensor_format_);
    139     tf_shape_output.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    140                                mkl_shape_output.GetMklLayout())) /
    141                            sizeof(T));
    142     AllocateOutputSetMklShape(context, 0, &output, tf_shape_output,
    143                               mkl_shape_output);
    144     mkl_context.mkl_res_batchnorm[dnnResourceDst] =
    145         static_cast<void*>(output->flat<T>().data());
    146 
    147     // Batch mean in TF layout
    148     Tensor* batch_mean = nullptr;
    149     MklShape mkl_shape_batch_mean;
    150     mkl_shape_batch_mean.SetMklTensor(false);
    151     AllocateOutputSetMklShape(context, 1, &batch_mean, scale.shape(),
    152                               mkl_shape_batch_mean);
    153     // Batch variance in TF layout
    154     Tensor* batch_variance = nullptr;
    155     MklShape mkl_shape_batch_variance;
    156     mkl_shape_batch_variance.SetMklTensor(false);
    157     AllocateOutputSetMklShape(context, 2, &batch_variance, scale.shape(),
    158                               mkl_shape_batch_variance);
    159     // If training mode, set dnnResourceMean and dnnResourceVariance to
    160     // output tensors for batch mean and variance.
    161     // Otherwise, set dnnResourceMean and dnnResourceVariance to
    162     // estimated mean and variance.
    163     if (is_training_)
    164       mkl_context.MklSetMeanVariance(*batch_mean, *batch_variance);
    165     else
    166       mkl_context.MklSetMeanVariance(est_mean, est_variance);
    167 
    168     // Now that all resources are set, it is ready for dnnExecute
    169     CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm,
    170                             mkl_context.mkl_res_batchnorm),
    171              E_SUCCESS);
    172 
    173     // Mean and variance (without Bessel's correction) saved for backward
    174     // computation to serve as pre-computed mean and variance.
    175     Tensor* saved_mean = nullptr;
    176     MklShape mkl_shape_saved_mean;
    177     mkl_shape_saved_mean.SetMklTensor(false);
    178     AllocateOutputSetMklShape(context, 3, &saved_mean, scale.shape(),
    179                               mkl_shape_saved_mean);
    180     std::memcpy(
    181         reinterpret_cast<char*>(saved_mean->flat<float>().data()),
    182         reinterpret_cast<char*>(mkl_context.mkl_res_batchnorm[dnnResourceMean]),
    183         scale.NumElements() * sizeof(float));
    184     Tensor* saved_variance = nullptr;
    185     MklShape mkl_shape_saved_variance;
    186     mkl_shape_saved_variance.SetMklTensor(false);
    187     AllocateOutputSetMklShape(context, 4, &saved_variance, scale.shape(),
    188                               mkl_shape_saved_variance);
    189     std::memcpy(reinterpret_cast<char*>(saved_variance->flat<float>().data()),
    190                 reinterpret_cast<char*>(
    191                     mkl_context.mkl_res_batchnorm[dnnResourceVariance]),
    192                 scale.NumElements() * sizeof(float));
    193 
    194     // Bessel's correction on variance, if training mode is on
    195     if (is_training_) {
    196       float* p_var = static_cast<float*>(batch_variance->flat<T>().data());
    197       auto depth = mkl_context.mkl_params.depth;
    198       size_t orig_size = mkl_context.mkl_params.in_sizes[0] *
    199                          mkl_context.mkl_params.in_sizes[1] *
    200                          mkl_context.mkl_params.in_sizes[3];
    201       size_t adjust_size = orig_size - 1;
    202       float adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
    203       for (int i = 0; i < depth; i++) p_var[i] = adjust_factor * p_var[i];
    204     }
    205 
    206     mkl_context.MklCleanup();
    207   }
    208 
    209  private:
    210   T epsilon_;
    211   TensorFormat tensor_format_;
    212   bool is_training_;
    213 
    214   // Structure containing all info for MklOp
    215   typedef struct {
    216     // Parameters used for input and output layouts
    217     struct MklBatchNormParams {
    218       // BatchNormOp src and
    219       size_t in_dim;
    220       size_t in_sizes[4];
    221       size_t in_strides[4];
    222       size_t depth;  // Batch normalization is done for per channel.
    223     } mkl_params;
    224 
    225     MklShape mkl_shape_input_shape;
    226 
    227     // MKL primitive and resources for BatchNormOp
    228     dnnPrimitive_t mkl_prim_batchnorm = nullptr;
    229     void* mkl_res_batchnorm[dnnResourceNumber];
    230 
    231     // MKL layouts for inputs in the context
    232     dnnLayout_t mkl_lt_input = nullptr;
    233 
    234     void MklCleanup() {
    235       bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
    236       if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
    237       if (mkl_prim_batchnorm != nullptr) dnnDelete_F32(mkl_prim_batchnorm);
    238     }
    239 
    240     void MklExtractParams(OpKernelContext* context,
    241                           const TensorFormat& tensor_format) {
    242       const Tensor& input = MklGetInput(context, 0);
    243       bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
    244       mkl_params.in_dim = input_in_mkl_format
    245                               ? mkl_shape_input_shape.GetDimension()
    246                               : input.dims();
    247       mkl_params.in_sizes[0] = static_cast<size_t>(
    248           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
    249                               : GetTensorDim(input, tensor_format, 'W'));
    250       mkl_params.in_sizes[1] = static_cast<size_t>(
    251           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
    252                               : GetTensorDim(input, tensor_format, 'H'));
    253       mkl_params.in_sizes[2] = static_cast<size_t>(
    254           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
    255                               : GetTensorDim(input, tensor_format, 'C'));
    256       mkl_params.in_sizes[3] = static_cast<size_t>(
    257           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
    258                               : GetTensorDim(input, tensor_format, 'N'));
    259       mkl_params.depth = mkl_params.in_sizes[2];
    260       GetStridesFromSizes(tensor_format, mkl_params.in_strides,
    261                           mkl_params.in_sizes);
    262     }
    263 
    264     void MklCreateInputLayout(OpKernelContext* context) {
    265       const Tensor& input = MklGetInput(context, 0);
    266       bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
    267       if (input_in_mkl_format) {
    268         mkl_lt_input =
    269             static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
    270       } else {
    271         CHECK_EQ(
    272             dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dim,
    273                                 mkl_params.in_sizes, mkl_params.in_strides),
    274             E_SUCCESS);
    275       }
    276     }
    277     void MklPrepareContextInputs(OpKernelContext* context,
    278                                  Tensor* mkl_tmp_input_buf_tensor,
    279                                  Tensor* mkl_tmp_scale_shift_buf_tensor) {
    280       bool mkl_convert_input;
    281       dnnPrimitive_t mkl_prim_convert_input = nullptr;
    282       dnnLayout_t mkl_lt_internal_input = nullptr;
    283       void* mkl_buf_converted_input = nullptr;
    284       // Compare with internal layouts and convert if needed
    285       const Tensor& input = MklGetInput(context, 0);
    286       void* mkl_buf_input =
    287           const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
    288       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    289                    &mkl_lt_internal_input, mkl_prim_batchnorm, dnnResourceSrc),
    290                E_SUCCESS);
    291       mkl_convert_input =
    292           !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
    293       if (mkl_convert_input) {
    294         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
    295                                          mkl_lt_internal_input),
    296                  E_SUCCESS);
    297         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
    298                        &mkl_buf_converted_input);
    299         CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
    300                                           mkl_buf_converted_input),
    301                  E_SUCCESS);
    302         dnnDelete_F32(mkl_prim_convert_input);
    303       }
    304       dnnLayoutDelete_F32(mkl_lt_internal_input);
    305       mkl_res_batchnorm[dnnResourceSrc] =
    306           (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
    307 
    308       // scale-shift layout is created from primitive. So no conversion
    309       // is needed, however, a buffer has to be allocated.
    310       dnnLayout_t mkl_lt_scale_shift = nullptr;
    311       void* mkl_buf_scale_shift = nullptr;
    312       CHECK_EQ(
    313           dnnLayoutCreateFromPrimitive_F32(
    314               &mkl_lt_scale_shift, mkl_prim_batchnorm, dnnResourceScaleShift),
    315           E_SUCCESS);
    316       AllocTmpBuffer(context, mkl_tmp_scale_shift_buf_tensor,
    317                      mkl_lt_scale_shift, &mkl_buf_scale_shift);
    318       // Fill the scale-shift buffer with data, presumably buffer is 2D array
    319       const Tensor& scale = MklGetInput(context, 1);
    320       const Tensor& shift = MklGetInput(context, 2);
    321       float* buf_scale_shift = static_cast<float*>(mkl_buf_scale_shift);
    322       float* buf_scale = const_cast<float*>(
    323           static_cast<const float*>(scale.flat<float>().data()));
    324       float* buf_shift = const_cast<float*>(
    325           static_cast<const float*>(shift.flat<float>().data()));
    326       auto depth = mkl_params.depth;
    327       for (int i = 0; i < depth; i++) {
    328         buf_scale_shift[i] = buf_scale[i];
    329         buf_scale_shift[i + depth] = buf_shift[i];
    330       }
    331       mkl_res_batchnorm[dnnResourceScaleShift] = mkl_buf_scale_shift;
    332     }
    333 
    334     inline void MklSetMeanVariance(const Tensor& mean, const Tensor& variance) {
    335       mkl_res_batchnorm[dnnResourceMean] = const_cast<void*>(
    336           static_cast<const void*>(mean.flat<float>().data()));
    337       mkl_res_batchnorm[dnnResourceVariance] = const_cast<void*>(
    338           static_cast<const void*>(variance.flat<float>().data()));
    339     }
    340   } MklFusedBatchNormOpContext;
    341 };
    342 
    343 template <typename Device, typename T>
    344 class MklFusedBatchNormGradOp : public OpKernel {
    345  public:
    346   explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
    347       : OpKernel(context) {
    348     float epsilon;
    349     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    350     epsilon_ = T(epsilon);
    351     string tensor_format;
    352     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    353     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
    354                 errors::InvalidArgument("Invalid data format"));
    355   }
    356 
    357   void Compute(OpKernelContext* context) override {
    358     MklFusedBatchNormGradOpContext mkl_context;
    359 
    360     const Tensor& out_backprop = MklGetInput(context, 0);
    361     const Tensor& input = MklGetInput(context, 1);
    362     const Tensor& scale = MklGetInput(context, 2);
    363     const Tensor& saved_mean = MklGetInput(context, 3);
    364     const Tensor& saved_var = MklGetInput(context, 4);
    365 
    366     // Here scale, mean, and variance are 1D and considered
    367     // those having same layout in MKL and TF
    368     GetMklShape(context, 0, &(mkl_context.mkl_shape_out_backprop));
    369     GetMklShape(context, 1, &(mkl_context.mkl_shape_input_shape));
    370 
    371     bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
    372     bool out_backprop_in_mkl_format =
    373         mkl_context.mkl_shape_out_backprop.IsMklTensor();
    374     if (!out_backprop_in_mkl_format) {
    375       OP_REQUIRES(context, out_backprop.dims() == 4,
    376                   errors::InvalidArgument("input must be 4-dimensional",
    377                                           out_backprop.shape().DebugString()));
    378     }
    379     if (!input_in_mkl_format) {
    380       OP_REQUIRES(context, input.dims() == 4,
    381                   errors::InvalidArgument("input must be 4-dimensional",
    382                                           input.shape().DebugString()));
    383     }
    384     OP_REQUIRES(context, scale.dims() == 1,
    385                 errors::InvalidArgument("scale must be 1-dimensional",
    386                                         scale.shape().DebugString()));
    387     OP_REQUIRES(context, saved_mean.dims() == 1,
    388                 errors::InvalidArgument("saved mean must be 1-dimensional",
    389                                         saved_mean.shape().DebugString()));
    390     OP_REQUIRES(context, saved_var.dims() == 1,
    391                 errors::InvalidArgument("saved variance must be 1-dimensional",
    392                                         saved_var.shape().DebugString()));
    393 
    394     mkl_context.MklExtractParams(context, tensor_format_);
    395 
    396     mkl_context.MklCreateInputLayout(context);
    397 
    398     unsigned int flag_batch_norm_grad = dnnUseScaleShift;
    399 
    400     // Create Backward Op primitive.
    401     CHECK_EQ(dnnBatchNormalizationCreateBackward_v2_F32(
    402                  &(mkl_context.mkl_prim_batchnorm_bwd), nullptr,
    403                  mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
    404                  flag_batch_norm_grad),
    405              E_SUCCESS);
    406 
    407     // Temporary tensors and their buffers if conversion is required
    408     Tensor mkl_tmp_input_buf_tensor, mkl_tmp_outbackprop_buf_tensor,
    409         mkl_tmp_scaleshift_buf_tensor;
    410     mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
    411                                         &mkl_tmp_outbackprop_buf_tensor,
    412                                         &mkl_tmp_scaleshift_buf_tensor);
    413 
    414     // Allocate tensor for grad w.r.t. input(x)
    415     Tensor* in_backprop = nullptr;
    416     TensorShape tf_shape_in_backprop;
    417     MklShape mkl_shape_in_backprop;
    418     mkl_shape_in_backprop.SetMklTensor(true);
    419     mkl_shape_in_backprop.SetMklLayout(mkl_context.mkl_prim_batchnorm_bwd,
    420                                        dnnResourceDiffSrc);
    421     mkl_shape_in_backprop.SetTfLayout(mkl_context.mkl_params.in_dims,
    422                                       mkl_context.mkl_params.in_sizes,
    423                                       mkl_context.mkl_params.in_strides);
    424     mkl_shape_in_backprop.SetTfDimOrder(mkl_context.mkl_params.in_dims,
    425                                         tensor_format_);
    426     tf_shape_in_backprop.AddDim(
    427         dnnLayoutGetMemorySize_F32(
    428             static_cast<dnnLayout_t>(mkl_shape_in_backprop.GetMklLayout())) /
    429         sizeof(T));
    430     AllocateOutputSetMklShape(context, 0, &in_backprop, tf_shape_in_backprop,
    431                               mkl_shape_in_backprop);
    432     mkl_context.mkl_res_batchnorm_bwd[dnnResourceDiffSrc] =
    433         static_cast<void*>(in_backprop->flat<T>().data());
    434 
    435     // grad_scale and grad_shift are combined together in MKL
    436     // So create a single temporary buffer for those.
    437     // Also set dnnResourceDiffScaleShift to the temporary buffer
    438     Tensor mkl_tmp_grad_scale_shift_buf_tensor;
    439     mkl_context.MklPrepareGradScaleShift(context,
    440                                          &mkl_tmp_grad_scale_shift_buf_tensor);
    441 
    442     // All dnn resources are set now, ready to execute
    443     CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm_bwd,
    444                             mkl_context.mkl_res_batchnorm_bwd),
    445              E_SUCCESS);
    446 
    447     // Now separate out scale and shift grad and copy to individual tensors
    448     const TensorShape& tf_shape_scale_shift = scale.shape();
    449     // Allocate tensor for grad w.r.t. scale (beta)
    450     Tensor* scale_backprop = nullptr;
    451     MklShape mkl_shape_scale_backprop;
    452     AllocateOutputSetMklShape(context, 1, &scale_backprop, tf_shape_scale_shift,
    453                               mkl_shape_scale_backprop);
    454 
    455     // Allocate tensor for grad w.r.t. shift(gamma)
    456     Tensor* shift_backprop = nullptr;
    457     MklShape mkl_shape_shift_backprop;
    458     AllocateOutputSetMklShape(context, 2, &shift_backprop, tf_shape_scale_shift,
    459                               mkl_shape_shift_backprop);
    460 
    461     // copy scale and shift grads to tensors
    462     float* mkl_buf_scale_shift = const_cast<float*>(static_cast<const float*>(
    463         mkl_tmp_grad_scale_shift_buf_tensor.flat<T>().data()));
    464     float* tf_buf_scale = const_cast<float*>(
    465         static_cast<const float*>(scale_backprop->flat<T>().data()));
    466     float* tf_buf_shift = const_cast<float*>(
    467         static_cast<const float*>(shift_backprop->flat<T>().data()));
    468     auto depth = mkl_context.mkl_params.depth;
    469     for (int i = 0; i < depth; i++) {
    470       tf_buf_scale[i] = mkl_buf_scale_shift[i];
    471       tf_buf_shift[i] = mkl_buf_scale_shift[i + depth];
    472     }
    473 
    474     // Two placeholders for estimated_mean and estimated_variance, which are
    475     // used for inference and thus not needed here for gradient computation.
    476     Tensor* placeholder_1 = nullptr;
    477     MklShape mkl_shape_placeholder_1;
    478     AllocateOutputSetMklShape(context, 3, &placeholder_1, TensorShape({}),
    479                               mkl_shape_placeholder_1);
    480     Tensor* placeholder_2 = nullptr;
    481     MklShape mkl_shape_placeholder_2;
    482     AllocateOutputSetMklShape(context, 4, &placeholder_2, TensorShape({}),
    483                               mkl_shape_placeholder_2);
    484 
    485     mkl_context.MklCleanup();
    486   }
    487 
    488  private:
    489   T epsilon_;
    490   TensorFormat tensor_format_;
    491 
    492   // Structure containing all info for MklOp
    493   typedef struct {
    494     // Parameters used for input and output layouts
    495     struct MklBatchNormParams {
    496       // BatchNormOp src and
    497       size_t in_dims;
    498       size_t in_sizes[4];
    499       size_t in_strides[4];
    500       size_t depth;  // Batch normalization is done for per channel.
    501     } mkl_params;
    502 
    503     MklShape mkl_shape_out_backprop;
    504     MklShape mkl_shape_input_shape;
    505 
    506     // MKL primitive and resources for BatchNormOp
    507     dnnPrimitive_t mkl_prim_batchnorm_bwd = nullptr;
    508     void* mkl_res_batchnorm_bwd[dnnResourceNumber];
    509 
    510     // MKL layouts for inputs in the context
    511     dnnLayout_t mkl_lt_out_backprop = nullptr;
    512     dnnLayout_t mkl_lt_input = nullptr;
    513 
    514     void MklCleanup() {
    515       bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
    516       bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
    517       if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
    518       if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_out_backprop);
    519 
    520       dnnDelete_F32(mkl_prim_batchnorm_bwd);
    521     }
    522 
    523     void MklExtractParams(OpKernelContext* context,
    524                           const TensorFormat& tensor_format) {
    525       const Tensor& input = MklGetInput(context, 1);
    526       bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
    527       mkl_params.in_dims = input_in_mkl_format
    528                                ? mkl_shape_input_shape.GetDimension()
    529                                : input.dims();
    530       mkl_params.in_sizes[0] = static_cast<size_t>(
    531           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
    532                               : GetTensorDim(input, tensor_format, 'W'));
    533       mkl_params.in_sizes[1] = static_cast<size_t>(
    534           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
    535                               : GetTensorDim(input, tensor_format, 'H'));
    536       mkl_params.in_sizes[2] = static_cast<size_t>(
    537           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
    538                               : GetTensorDim(input, tensor_format, 'C'));
    539       mkl_params.in_sizes[3] = static_cast<size_t>(
    540           input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
    541                               : GetTensorDim(input, tensor_format, 'N'));
    542       mkl_params.depth = mkl_params.in_sizes[2];
    543       GetStridesFromSizes(tensor_format, mkl_params.in_strides,
    544                           mkl_params.in_sizes);
    545     }
    546 
    547     void MklCreateInputLayout(OpKernelContext* context) {
    548       bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
    549       if (input_in_mkl_format) {
    550         mkl_lt_input =
    551             static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
    552       } else {
    553         CHECK_EQ(
    554             dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dims,
    555                                 mkl_params.in_sizes, mkl_params.in_strides),
    556             E_SUCCESS);
    557       }
    558 
    559       bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
    560       if (out_backprop_in_mkl_format) {
    561         mkl_lt_out_backprop =
    562             static_cast<dnnLayout_t>(mkl_shape_out_backprop.GetCurLayout());
    563       } else {
    564         CHECK_EQ(
    565             dnnLayoutCreate_F32(&mkl_lt_out_backprop, mkl_params.in_dims,
    566                                 mkl_params.in_sizes, mkl_params.in_strides),
    567             E_SUCCESS);
    568       }
    569     }
    570 
    571     void MklPrepareContextInputs(OpKernelContext* context,
    572                                  Tensor* mkl_tmp_input_buf_tensor,
    573                                  Tensor* mkl_tmp_outbackprop_buf_tensor,
    574                                  Tensor* mkl_tmp_scaleshift_buf_tensor) {
    575       bool mkl_convert_input;
    576       dnnPrimitive_t mkl_prim_convert_input = nullptr;
    577       dnnLayout_t mkl_lt_internal_input = nullptr;
    578       void* mkl_buf_converted_input = nullptr;
    579       // Compare with internal layouts and convert if needed
    580       const Tensor& input = MklGetInput(context, 1);
    581       void* mkl_buf_input =
    582           const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
    583       CHECK_EQ(
    584           dnnLayoutCreateFromPrimitive_F32(
    585               &mkl_lt_internal_input, mkl_prim_batchnorm_bwd, dnnResourceSrc),
    586           E_SUCCESS);
    587       mkl_convert_input =
    588           !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
    589       if (mkl_convert_input) {
    590         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
    591                                          mkl_lt_internal_input),
    592                  E_SUCCESS);
    593         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
    594                        &mkl_buf_converted_input);
    595         CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
    596                                           mkl_buf_converted_input),
    597                  E_SUCCESS);
    598         dnnDelete_F32(mkl_prim_convert_input);
    599       }
    600       dnnLayoutDelete_F32(mkl_lt_internal_input);
    601       mkl_res_batchnorm_bwd[dnnResourceSrc] =
    602           (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
    603 
    604       bool mkl_convert_out_backprop;
    605       dnnPrimitive_t mkl_prim_convert_out_backprop = nullptr;
    606       dnnLayout_t mkl_lt_internal_out_backprop = nullptr;
    607       void* mkl_buf_converted_out_backprop = nullptr;
    608       // Compare with internal layouts and convert if needed
    609       const Tensor& out_backprop = MklGetInput(context, 0);
    610       void* mkl_buf_out_backprop = const_cast<void*>(
    611           static_cast<const void*>(out_backprop.flat<T>().data()));
    612       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
    613                                                 mkl_prim_batchnorm_bwd,
    614                                                 dnnResourceDiffDst),
    615                E_SUCCESS);
    616       mkl_convert_out_backprop = !dnnLayoutCompare_F32(
    617           mkl_lt_internal_out_backprop, mkl_lt_out_backprop);
    618       if (mkl_convert_out_backprop) {
    619         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
    620                                          mkl_lt_out_backprop,
    621                                          mkl_lt_internal_out_backprop),
    622                  E_SUCCESS);
    623         AllocTmpBuffer(context, mkl_tmp_outbackprop_buf_tensor,
    624                        mkl_lt_internal_out_backprop,
    625                        &mkl_buf_converted_out_backprop);
    626         CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
    627                                           mkl_buf_out_backprop,
    628                                           mkl_buf_converted_out_backprop),
    629                  E_SUCCESS);
    630         dnnDelete_F32(mkl_prim_convert_out_backprop);
    631       }
    632       dnnLayoutDelete_F32(mkl_lt_internal_out_backprop);
    633       mkl_res_batchnorm_bwd[dnnResourceDiffDst] =
    634           (mkl_convert_out_backprop) ? mkl_buf_converted_out_backprop
    635                                      : mkl_buf_out_backprop;
    636 
    637       // Set dnnResourceMean and dnnResourceVariance
    638       const Tensor& saved_mean = MklGetInput(context, 3);
    639       const Tensor& saved_var = MklGetInput(context, 4);
    640       void* mkl_buf_saved_mean = const_cast<void*>(
    641           static_cast<const void*>(saved_mean.flat<T>().data()));
    642       void* mkl_buf_saved_var = const_cast<void*>(
    643           static_cast<const void*>(saved_var.flat<T>().data()));
    644       mkl_res_batchnorm_bwd[dnnResourceMean] = mkl_buf_saved_mean;
    645       mkl_res_batchnorm_bwd[dnnResourceVariance] = mkl_buf_saved_var;
    646 
    647       // Set dnnResourceScaleShift
    648       // Note backward Op needs only current values of scale parameters,
    649       // shift parameters could be garbage and won't be used
    650       const Tensor& scale = MklGetInput(context, 2);
    651       dnnLayout_t mkl_lt_scale_shift = nullptr;
    652       void* mkl_buf_scale_shift = nullptr;
    653       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_scale_shift,
    654                                                 mkl_prim_batchnorm_bwd,
    655                                                 dnnResourceScaleShift),
    656                E_SUCCESS);
    657       AllocTmpBuffer(context, mkl_tmp_scaleshift_buf_tensor, mkl_lt_scale_shift,
    658                      &mkl_buf_scale_shift);
    659       float* pscale =
    660           const_cast<float*>(static_cast<const float*>(scale.flat<T>().data()));
    661       float* pscale_shift = static_cast<float*>(mkl_buf_scale_shift);
    662       auto depth = mkl_params.depth;
    663       for (int i = 0; i < depth; i++) pscale_shift[i] = pscale[i];
    664       mkl_res_batchnorm_bwd[dnnResourceScaleShift] = mkl_buf_scale_shift;
    665       dnnLayoutDelete_F32(mkl_lt_scale_shift);
    666     }
    667 
    668     void MklPrepareGradScaleShift(OpKernelContext* context,
    669                                   Tensor* mkl_tmp_grad_scale_shift_buf_tensor) {
    670       dnnLayout_t mkl_lt_grad_scaleshift = nullptr;
    671       void* mkl_buf_grad_scaleshift = nullptr;
    672       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_grad_scaleshift,
    673                                                 mkl_prim_batchnorm_bwd,
    674                                                 dnnResourceDiffScaleShift),
    675                E_SUCCESS);
    676       AllocTmpBuffer(context, mkl_tmp_grad_scale_shift_buf_tensor,
    677                      mkl_lt_grad_scaleshift, &mkl_buf_grad_scaleshift);
    678       mkl_res_batchnorm_bwd[dnnResourceDiffScaleShift] =
    679           mkl_buf_grad_scaleshift;
    680       dnnLayoutDelete_F32(mkl_lt_grad_scaleshift);
    681     }
    682   } MklFusedBatchNormGradOpContext;
    683 };
    684 #endif
    685 
    686 #ifndef INTEL_MKL_ML
    687 
    688 template <typename Device, typename T>
    689 class MklFusedBatchNormOp : public OpKernel {
    690  public:
    691   explicit MklFusedBatchNormOp(OpKernelConstruction* context)
    692       : OpKernel(context) {
    693     float epsilon;
    694     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    695     epsilon_ = T(epsilon);
    696     string tensor_format;
    697     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    698     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
    699                 errors::InvalidArgument("Invalid data format"));
    700     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
    701   }
    702 
    703   void Compute(OpKernelContext* context) override {
    704     try {
    705       auto cpu_engine = engine(engine::cpu, 0);
    706       const size_t kSrcIndex = 0;       // index of src input tensor
    707       const size_t kScaleIndex = 1;     // index of scale tensor
    708       const size_t kShiftIndex = 2;     // index of shift tensor
    709       const size_t kMeanIndex = 3;      // index of est_mean tensor
    710       const size_t kVarianceIndex = 4;  // index of est_variance tensor
    711 
    712       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
    713       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
    714       const Tensor& shift_tensor = MklGetInput(context, kShiftIndex);
    715       const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex);
    716       const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex);
    717 
    718       TensorShape tf_shape_src;
    719       MklDnnShape dnn_shape_src;
    720       GetMklShape(context, kSrcIndex, &dnn_shape_src);
    721 
    722       if (dnn_shape_src.IsMklTensor()) {
    723         tf_shape_src = dnn_shape_src.GetTfShape();
    724         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
    725                     errors::InvalidArgument("input must be 4-dimensional",
    726                                             src_tensor.shape().DebugString()));
    727       } else {
    728         tf_shape_src = src_tensor.shape();
    729         OP_REQUIRES(context, src_tensor.dims() == 4,
    730                     errors::InvalidArgument("input must be 4-dimensional",
    731                                             src_tensor.shape().DebugString()));
    732       }
    733       OP_REQUIRES(context, scale_tensor.dims() == 1,
    734                   errors::InvalidArgument("scale must be 1-dimensional",
    735                                           scale_tensor.shape().DebugString()));
    736       OP_REQUIRES(context, shift_tensor.dims() == 1,
    737                   errors::InvalidArgument("offset must be 1-dimensional",
    738                                           shift_tensor.shape().DebugString()));
    739       OP_REQUIRES(
    740           context, est_mean_tensor.dims() == 1,
    741           errors::InvalidArgument("estimated_mean must be 1-dimensional",
    742                                   est_mean_tensor.shape().DebugString()));
    743       OP_REQUIRES(
    744           context, est_variance_tensor.dims() == 1,
    745           errors::InvalidArgument("estimated_variance must be 1-dimensional",
    746                                   est_variance_tensor.shape().DebugString()));
    747 
    748       if (is_training_) {
    749         OP_REQUIRES(
    750             context, est_mean_tensor.dim_size(0) == 0,
    751             errors::InvalidArgument("estimated_mean must be empty for training",
    752                                     est_mean_tensor.shape().DebugString()));
    753         OP_REQUIRES(context, est_variance_tensor.dim_size(0) == 0,
    754                     errors::InvalidArgument(
    755                         "estimated_variance must be empty for training",
    756                         est_variance_tensor.shape().DebugString()));
    757       }
    758 
    759       // special case: input with 0 element and 0 batch size
    760       Tensor* dst_tensor = nullptr;
    761       if (tf_shape_src.num_elements() == 0) {
    762         HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
    763                          &dst_tensor);
    764         return;
    765       }
    766 
    767       if (dnn_shape_src.IsMklTensor())
    768         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
    769       else
    770         ExtractParams(context);
    771 
    772       // Indices of output tensors
    773       const size_t kDstIndex = 0;
    774 
    775       // allocate 4 output TF tensors
    776       Tensor* batch_mean_tensor = nullptr;
    777       Tensor* batch_variance_tensor = nullptr;
    778       Tensor* saved_mean_tensor = nullptr;
    779       Tensor* saved_variance_tensor = nullptr;
    780       AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor,
    781                         &batch_variance_tensor, &saved_mean_tensor,
    782                         &saved_variance_tensor);
    783 
    784       if (is_training_)
    785         SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor);
    786       else
    787         SetMeanVariance(est_mean_tensor, est_variance_tensor);
    788 
    789       MklDnnData<T> src(&cpu_engine);
    790       MklDnnData<T> dst(&cpu_engine);
    791 
    792       memory::format format_m;
    793       if (dnn_shape_src.IsMklTensor()) {
    794         if (dnn_shape_src.IsTensorInNCHWFormat()) {
    795           format_m = memory::format::nchw;
    796         } else {
    797           format_m = memory::format::nhwc;
    798         }
    799       } else {
    800         format_m = TFDataFormatToMklDnnDataFormat(tensor_format_);
    801       }
    802 
    803       // set src primitive
    804       memory::dims src_dims;
    805       if (dnn_shape_src.IsMklTensor()) {
    806         src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(),
    807                                              tensor_format_);
    808       } else {
    809         src_dims =
    810             TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
    811       }
    812 
    813       auto src_md = dnn_shape_src.IsMklTensor()
    814                         ? dnn_shape_src.GetMklLayout()
    815                         : memory::desc(src_dims, MklDnnType<T>(), format_m);
    816       src.SetUsrMem(src_md, &src_tensor);
    817 
    818       // set weights primitive
    819       // MKL-DNN packs scale & shift as "weights":
    820       // <scale>...<scale><shift>...<shift>
    821       auto weights_desc =
    822           memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc);
    823       auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine);
    824       auto weights_m = memory(weights_pd);
    825       T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle());
    826       T* scale_tf =
    827           reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data()));
    828       T* shift_tf =
    829           reinterpret_cast<T*>(const_cast<T*>(shift_tensor.flat<T>().data()));
    830 
    831       for (int k = 0; k < depth_; k++) {
    832         weights_data[k] = scale_tf[k];
    833         weights_data[k + depth_] = shift_tf[k];
    834       }
    835 
    836       // set mean primitive
    837       auto mean_desc =
    838           memory::desc({1, depth_}, MklDnnType<T>(), memory::format::nc);
    839       auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine);
    840       char* saved_mean_data_tf =
    841           reinterpret_cast<char*>(saved_mean_tensor->flat<T>().data());
    842       std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_),
    843                   depth_ * sizeof(T));
    844       auto mean_m =
    845           memory(mean_pd, reinterpret_cast<void*>(saved_mean_data_tf));
    846 
    847       // set variance primitive
    848       auto variance_desc =
    849           memory::desc({1, depth_}, MklDnnType<T>(), memory::format::nc);
    850       auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine);
    851       char* saved_variance_data_tf =
    852           reinterpret_cast<char*>(saved_variance_tensor->flat<T>().data());
    853       std::memcpy(saved_variance_data_tf,
    854                   reinterpret_cast<char*>(variance_values_),
    855                   depth_ * sizeof(T));
    856       auto variance_m = memory(variance_pd, saved_variance_data_tf);
    857 
    858       prop_kind pk = (is_training_) ? prop_kind::forward_training
    859                                     : prop_kind::forward_scoring;
    860       auto bnrm_fwd_desc = batch_normalization_forward::desc(
    861           pk, src.GetUsrMemDesc(), epsilon_,
    862           is_training_ ? use_scale_shift
    863                        : (use_scale_shift | use_global_stats));
    864       auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(
    865           bnrm_fwd_desc, cpu_engine);
    866 
    867       // allocate dst tensor
    868       MklDnnShape dnn_shape_dst;
    869       TensorShape tf_shape_dst;
    870       if (dnn_shape_src.IsMklTensor()) {
    871         dnn_shape_dst.SetMklTensor(true);
    872         auto dst_pd = bnrm_fwd_pd.dst_primitive_desc();
    873         dnn_shape_dst.SetMklLayout(&dst_pd);
    874         dnn_shape_dst.SetElemType(MklDnnType<T>());
    875         dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims,
    876                                   format_m);
    877         tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
    878       } else {
    879         dnn_shape_dst.SetMklTensor(false);
    880         tf_shape_dst = src_tensor.shape();
    881       }
    882       AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst,
    883                                 dnn_shape_dst);
    884 
    885       // Output of batchnorm has same shape as input.
    886       dst.SetUsrMem(src_md, dst_tensor);
    887 
    888       primitive bnrm_fwd_op;
    889       if (is_training_) {
    890         bnrm_fwd_op =
    891             batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m,
    892                                         dst.GetOpMem(), mean_m, variance_m);
    893       } else {
    894         bnrm_fwd_op = batch_normalization_forward(
    895             bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m,
    896             (const primitive::at)weights_m, dst.GetOpMem());
    897       }
    898       std::vector<primitive> net;
    899       net.push_back(bnrm_fwd_op);
    900       stream(stream::kind::eager).submit(net).wait();
    901 
    902       // copy batch_mean data
    903       T* batch_mean_data_tf =
    904           reinterpret_cast<T*>(batch_mean_tensor->flat<T>().data());
    905       std::memcpy(reinterpret_cast<char*>(batch_mean_data_tf),
    906                   reinterpret_cast<char*>(mean_m.get_data_handle()),
    907                   depth_ * sizeof(T));
    908 
    909       // copy batch_variance data with Bessel's correction
    910       // if training mode is on
    911       float adjust_factor = 1.0;
    912       if (is_training_) {
    913         size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3];
    914         size_t adjust_size = orig_size - 1;
    915         adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
    916       }
    917       for (int k = 0; k < depth_; k++)
    918         batch_variance_tensor->flat<T>().data()[k] =
    919             (reinterpret_cast<T*>(variance_m.get_data_handle()))[k] *
    920             adjust_factor;
    921     } catch (mkldnn::error& e) {
    922       string error_msg = "Status: " + std::to_string(e.status) +
    923                          ", message: " + string(e.message) + ", in file " +
    924                          string(__FILE__) + ":" + std::to_string(__LINE__);
    925       OP_REQUIRES_OK(
    926           context,
    927           errors::Aborted("Operation received an exception:", error_msg));
    928     }
    929   }
    930 
    931  private:
    932   T epsilon_;
    933   TensorFormat tensor_format_;
    934   bool is_training_;
    935   T* mean_values_;
    936   T* variance_values_;
    937   size_t depth_;  // batch normalization is done for per channel.
    938 
    939   void ExtractParams(OpKernelContext* context) {
    940     const Tensor& input = MklGetInput(context, 0);
    941     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
    942   }
    943 
    944   void SetMeanVariance(const Tensor& mean, const Tensor& variance) {
    945     mean_values_ = reinterpret_cast<T*>(const_cast<T*>(mean.flat<T>().data()));
    946     variance_values_ =
    947         reinterpret_cast<T*>(const_cast<T*>(variance.flat<T>().data()));
    948   }
    949 
    950   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
    951                         TensorShape tf_shape_scale, Tensor** dst_tensor) {
    952     CHECK_NOTNULL(dst_tensor);
    953 
    954     const size_t kDstIndex = 0;
    955     MklDnnShape dnn_shape_dst;
    956     dnn_shape_dst.SetMklTensor(false);
    957     AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src,
    958                               dnn_shape_dst);
    959     CHECK_NOTNULL(*dst_tensor);
    960     memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0,
    961            (*dst_tensor)->tensor_data().size());
    962 
    963     Tensor* batch_mean_tensor = nullptr;
    964     Tensor* batch_variance_tensor = nullptr;
    965     Tensor* saved_mean_tensor = nullptr;
    966     Tensor* saved_variance_tensor = nullptr;
    967     AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor,
    968                       &batch_variance_tensor, &saved_mean_tensor,
    969                       &saved_variance_tensor);
    970   }
    971 
    972   void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale,
    973                          Tensor** batch_mean_tensor,
    974                          Tensor** batch_variance_tensor,
    975                          Tensor** saved_mean_tensor,
    976                          Tensor** saved_variance_tensor) {
    977     CHECK_NOTNULL(batch_mean_tensor);
    978     CHECK_NOTNULL(batch_variance_tensor);
    979     CHECK_NOTNULL(saved_mean_tensor);
    980     CHECK_NOTNULL(saved_variance_tensor);
    981 
    982     const size_t kBatchMeanIndex = 1;
    983     const size_t kBatchVarianceIndex = 2;
    984     const size_t kSavedMeanIndex = 3;
    985     const size_t kSavedVarianceIndex = 4;
    986 
    987     // allocate batch mean output tensor
    988     MklDnnShape mkl_shape_batch_mean;
    989     mkl_shape_batch_mean.SetMklTensor(false);
    990     AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor,
    991                               tf_shape_scale, mkl_shape_batch_mean);
    992     CHECK_NOTNULL(*batch_mean_tensor);
    993     // set NAN mean value in case of empty input tensor
    994     for (int k = 0; k < tf_shape_scale.num_elements(); k++)
    995       (*batch_mean_tensor)->flat<T>().data()[k] = NAN;
    996 
    997     // allocate batch variance output tensor
    998     MklDnnShape mkl_shape_batch_variance;
    999     mkl_shape_batch_variance.SetMklTensor(false);
   1000     AllocateOutputSetMklShape(context, kBatchVarianceIndex,
   1001                               batch_variance_tensor, tf_shape_scale,
   1002                               mkl_shape_batch_variance);
   1003     CHECK_NOTNULL(*batch_variance_tensor);
   1004     // set NAN variance value in case of empty input tensor
   1005     for (int k = 0; k < tf_shape_scale.num_elements(); k++)
   1006       (*batch_variance_tensor)->flat<T>().data()[k] = NAN;
   1007 
   1008     // Mean and variance (without Bessel's correction) saved for backward
   1009     // computation to serve as pre-computed mean and variance.
   1010     MklDnnShape mkl_shape_saved_mean;
   1011     mkl_shape_saved_mean.SetMklTensor(false);
   1012     AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor,
   1013                               tf_shape_scale, mkl_shape_saved_mean);
   1014     CHECK_NOTNULL(*saved_mean_tensor);
   1015     // set NAN mean value in case of empty input tensor
   1016     for (int k = 0; k < tf_shape_scale.num_elements(); k++)
   1017       (*saved_mean_tensor)->flat<T>().data()[k] = NAN;
   1018 
   1019     MklDnnShape mkl_shape_saved_variance;
   1020     mkl_shape_saved_variance.SetMklTensor(false);
   1021     AllocateOutputSetMklShape(context, kSavedVarianceIndex,
   1022                               saved_variance_tensor, tf_shape_scale,
   1023                               mkl_shape_saved_variance);
   1024     CHECK_NOTNULL(*saved_variance_tensor);
   1025     // set NAN variance value in case of empty input tensor
   1026     for (int k = 0; k < tf_shape_scale.num_elements(); k++)
   1027       (*saved_variance_tensor)->flat<T>().data()[k] = NAN;
   1028   }
   1029 };
   1030 
   1031 template <typename Device, typename T>
   1032 class MklFusedBatchNormGradOp : public OpKernel {
   1033  public:
   1034   explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
   1035       : OpKernel(context) {
   1036     float epsilon;
   1037     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
   1038     epsilon_ = T(epsilon);
   1039     string tensor_format;
   1040     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
   1041     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
   1042                 errors::InvalidArgument("Invalid data format"));
   1043     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
   1044   }
   1045 
   1046   void Compute(OpKernelContext* context) override {
   1047     try {
   1048       auto cpu_engine = engine(engine::cpu, 0);
   1049       const size_t kDiffDstIndex = 0;   // index of diff_dst tensor
   1050       const size_t kSrcIndex = 1;       // index of src input tensor
   1051       const size_t kScaleIndex = 2;     // index of scale tensor
   1052       const size_t kMeanIndex = 3;      // index of saved_mean tensor
   1053       const size_t kVarianceIndex = 4;  // index of saved_variance tensor
   1054       const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex);
   1055       const Tensor& src_tensor = MklGetInput(context, kSrcIndex);
   1056       const Tensor& scale_tensor = MklGetInput(context, kScaleIndex);
   1057       const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex);
   1058       const Tensor& saved_variance_tensor =
   1059           MklGetInput(context, kVarianceIndex);
   1060 
   1061       MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
   1062       GetMklShape(context, kSrcIndex, &dnn_shape_src);
   1063       GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst);
   1064       TensorShape tf_shape_src, tf_shape_diff_dst;
   1065 
   1066       if (dnn_shape_diff_dst.IsMklTensor()) {
   1067         tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape();
   1068         OP_REQUIRES(
   1069             context, dnn_shape_diff_dst.GetDimension() == 4,
   1070             errors::InvalidArgument("input must be 4-dimensional",
   1071                                     diff_dst_tensor.shape().DebugString()));
   1072       } else {
   1073         tf_shape_diff_dst = diff_dst_tensor.shape();
   1074         OP_REQUIRES(
   1075             context, diff_dst_tensor.dims() == 4,
   1076             errors::InvalidArgument("input must be 4-dimensional",
   1077                                     diff_dst_tensor.shape().DebugString()));
   1078       }
   1079 
   1080       if (dnn_shape_src.IsMklTensor()) {
   1081         tf_shape_src = dnn_shape_src.GetTfShape();
   1082         OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4,
   1083                     errors::InvalidArgument("input must be 4-dimensional",
   1084                                             src_tensor.shape().DebugString()));
   1085       } else {
   1086         tf_shape_src = src_tensor.shape();
   1087         OP_REQUIRES(context, src_tensor.dims() == 4,
   1088                     errors::InvalidArgument("input must be 4-dimensional",
   1089                                             src_tensor.shape().DebugString()));
   1090       }
   1091 
   1092       OP_REQUIRES(context, scale_tensor.dims() == 1,
   1093                   errors::InvalidArgument("scale must be 1-dimensional",
   1094                                           scale_tensor.shape().DebugString()));
   1095       OP_REQUIRES(
   1096           context, saved_mean_tensor.dims() == 1,
   1097           errors::InvalidArgument("saved mean must be 1-dimensional",
   1098                                   saved_mean_tensor.shape().DebugString()));
   1099 
   1100       OP_REQUIRES(
   1101           context, saved_variance_tensor.dims() == 1,
   1102           errors::InvalidArgument("saved variance must be 1-dimensional",
   1103                                   saved_variance_tensor.shape().DebugString()));
   1104 
   1105       Tensor* diff_src_tensor = nullptr;
   1106       if (tf_shape_src.num_elements() == 0 ||
   1107           tf_shape_diff_dst.num_elements() == 0) {
   1108         HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(),
   1109                          &diff_src_tensor);
   1110         return;
   1111       }
   1112 
   1113       if (dnn_shape_src.IsMklTensor())
   1114         depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C);
   1115       else
   1116         ExtractParams(context);
   1117 
   1118       memory::format format_m;
   1119       if (dnn_shape_src.IsMklTensor()) {
   1120         if (dnn_shape_src.IsTensorInNCHWFormat())
   1121           format_m = memory::format::nchw;
   1122         else
   1123           format_m = memory::format::nhwc;
   1124       } else {
   1125         format_m = TFDataFormatToMklDnnDataFormat(tensor_format_);
   1126       }
   1127 
   1128       MklDnnData<T> src(&cpu_engine);
   1129       MklDnnData<T> mean(&cpu_engine);
   1130       MklDnnData<T> variance(&cpu_engine);
   1131       MklDnnData<T> diff_dst(&cpu_engine);
   1132       MklDnnData<T> diff_src(&cpu_engine);
   1133 
   1134       memory::dims src_dims, diff_dst_dims;
   1135       if (dnn_shape_src.IsMklTensor())
   1136         src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(),
   1137                                              tensor_format_);
   1138       else
   1139         src_dims =
   1140             TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_);
   1141 
   1142       if (dnn_shape_diff_dst.IsMklTensor())
   1143         diff_dst_dims = TFShapeToMklDnnDimsInNCHW(
   1144             dnn_shape_diff_dst.GetTfShape(), tensor_format_);
   1145       else
   1146         diff_dst_dims =
   1147             TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_);
   1148 
   1149       // set src and diff_dst primitives
   1150       memory::desc src_md({}, memory::data_undef, memory::format_undef);
   1151       memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
   1152       if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
   1153         if (dnn_shape_src.IsMklTensor()) {
   1154           src_md = dnn_shape_src.GetMklLayout();
   1155           diff_dst_md = src_md;
   1156         } else {
   1157           diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
   1158           src_md = diff_dst_md;
   1159         }
   1160       } else {
   1161         src_md = memory::desc(src_dims, MklDnnType<T>(), format_m);
   1162         diff_dst_md = src_md;
   1163       }
   1164       src.SetUsrMem(src_md, &src_tensor);
   1165       diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
   1166 
   1167       // weights -- DNN packs scales/shifts as weights in order of
   1168       // scale, ..., scale, shift, ..., shift
   1169       auto weights_desc =
   1170           memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc);
   1171       auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine);
   1172       auto weights_m = memory(weights_pd);
   1173       T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle());
   1174       T* scale_tf =
   1175           reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data()));
   1176       for (int k = 0; k < depth_; k++) {
   1177         weights_data[k] = scale_tf[k];
   1178         weights_data[k + depth_] = 0;
   1179       }
   1180 
   1181       // set mean primitive
   1182       memory::dims mv_dims = GetMeanVarianceDims();
   1183       mean.SetUsrMem(mv_dims, memory::format::nc,
   1184                      const_cast<void*>(static_cast<const void*>(
   1185                          saved_mean_tensor.flat<T>().data())));
   1186       mean.SetOpMemDesc(mv_dims, memory::format::nc);
   1187 
   1188       // set variance primitive
   1189       variance.SetUsrMem(mv_dims, memory::format::nc,
   1190                          const_cast<void*>(static_cast<const void*>(
   1191                              saved_variance_tensor.flat<T>().data())));
   1192       variance.SetOpMemDesc(mv_dims, memory::format::nc);
   1193 
   1194       // set diff_weight primitive
   1195       auto diff_weights_desc =
   1196           memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc);
   1197       auto diff_weights_pd =
   1198           memory::primitive_desc(diff_weights_desc, cpu_engine);
   1199       auto diff_weights_m = memory(diff_weights_pd);
   1200 
   1201       auto bnrm_fwd_desc = batch_normalization_forward::desc(
   1202           prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_,
   1203           is_training_ ? use_scale_shift
   1204                        : (use_scale_shift | use_global_stats));
   1205       auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc(
   1206           bnrm_fwd_desc, cpu_engine);
   1207 
   1208       // Indices of output tensors
   1209       const size_t kDiffSrcIndex = 0;  // index of diff_src tensor
   1210 
   1211       // allocate diff_src tensor
   1212       MklDnnShape dnn_shape_diff_src;
   1213       TensorShape tf_shape_diff_src;
   1214       if (dnn_shape_src.IsMklTensor()) {
   1215         dnn_shape_diff_src.SetMklTensor(true);
   1216         auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc();
   1217         dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
   1218         dnn_shape_diff_src.SetElemType(MklDnnType<T>());
   1219         dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), src_dims,
   1220                                        format_m);
   1221         dnn_shape_diff_src.SetTfDimOrder(dnn_shape_src.GetDimension(),
   1222                                          tensor_format_);
   1223         tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
   1224       } else {
   1225         dnn_shape_diff_src.SetMklTensor(false);
   1226         tf_shape_diff_src = src_tensor.shape();
   1227       }
   1228       AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor,
   1229                                 tf_shape_diff_src, dnn_shape_diff_src);
   1230 
   1231       diff_src.SetUsrMem(src_md, diff_src_tensor);
   1232 
   1233       prop_kind pk = prop_kind::backward;
   1234       auto bnrm_bwd_desc = batch_normalization_backward::desc(
   1235           pk, diff_src.GetUsrMemDesc(), src.GetUsrMemDesc(), epsilon_,
   1236           /* for inference, specify use_global_stats
   1237              1. on fwd prop, use mean and variance
   1238                 provided as inputs
   1239              2. on bwd prop, mean and variance are
   1240                 considered as constants. Thus,
   1241                 reduce the amout of MKL computations
   1242           */
   1243           is_training_ ? use_scale_shift
   1244                        : (use_scale_shift | use_global_stats));
   1245       auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc(
   1246           bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd);
   1247 
   1248       auto bnrm_bwd_op = batch_normalization_backward(
   1249           bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(),
   1250           diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m);
   1251 
   1252       std::vector<primitive> net;
   1253       net.push_back(bnrm_bwd_op);
   1254       stream(stream::kind::eager).submit(net).wait();
   1255 
   1256       // allocate 4 output TF tensors
   1257       Tensor* diff_scale_tensor = nullptr;
   1258       Tensor* diff_shift_tensor = nullptr;
   1259       AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor,
   1260                         &diff_shift_tensor);
   1261 
   1262       // copy data: diff_scale and diff_shift
   1263       T* diff_weights_data_dnn =
   1264           reinterpret_cast<T*>(diff_weights_m.get_data_handle());
   1265       for (int i = 0; i < depth_; i++) {
   1266         diff_scale_tensor->flat<T>().data()[i] = diff_weights_data_dnn[i];
   1267         diff_shift_tensor->flat<T>().data()[i] =
   1268             diff_weights_data_dnn[i + depth_];
   1269       }
   1270     } catch (mkldnn::error& e) {
   1271       string error_msg = "Status: " + std::to_string(e.status) +
   1272                          ", message: " + string(e.message) + ", in file " +
   1273                          string(__FILE__) + ":" + std::to_string(__LINE__);
   1274       OP_REQUIRES_OK(
   1275           context,
   1276           errors::Aborted("Operation received an exception:", error_msg));
   1277     }
   1278   }
   1279 
   1280  private:
   1281   T epsilon_;
   1282   TensorFormat tensor_format_;
   1283   int depth_;  // batch normalization is done for per channel.
   1284   bool is_training_;
   1285 
   1286   void ExtractParams(OpKernelContext* context) {
   1287     const Tensor& input = MklGetInput(context, 0);
   1288     depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C'));
   1289   }
   1290 
   1291   void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src,
   1292                         TensorShape tf_shape_scale_shift,
   1293                         Tensor** diff_src_tensor) {
   1294     const size_t kDiffSrcIndex = 0;
   1295 
   1296     MklDnnShape dnn_shape_diff_src;
   1297     dnn_shape_diff_src.SetMklTensor(false);
   1298     AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor,
   1299                               tf_shape_src, dnn_shape_diff_src);
   1300     for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++)
   1301       (*diff_src_tensor)->flat<T>().data()[i] = 0;
   1302 
   1303     Tensor* diff_scale_tensor = nullptr;
   1304     Tensor* diff_shift_tensor = nullptr;
   1305     AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor,
   1306                       &diff_shift_tensor);
   1307   }
   1308 
   1309   void AllocateTFOutputs(OpKernelContext* context,
   1310                          TensorShape tf_shape_scale_shift,
   1311                          Tensor** diff_scale_tensor,
   1312                          Tensor** diff_shift_tensor) {
   1313     CHECK_NOTNULL(diff_scale_tensor);
   1314     CHECK_NOTNULL(diff_shift_tensor);
   1315 
   1316     const size_t kDiffScaleIndex = 1;
   1317     const size_t kDiffShiftIndex = 2;
   1318     const size_t kP1Index = 3;
   1319     const size_t kP2Index = 4;
   1320 
   1321     // separate out scale and shift grad and copy to individual tensors
   1322     MklDnnShape mkl_shape_diff_scale;
   1323     mkl_shape_diff_scale.SetMklTensor(false);
   1324     AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor,
   1325                               tf_shape_scale_shift, mkl_shape_diff_scale);
   1326     CHECK_NOTNULL(*diff_scale_tensor);
   1327     for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++)
   1328       (*diff_scale_tensor)->flat<T>().data()[i] = 0;
   1329 
   1330     MklDnnShape mkl_shape_diff_shift;
   1331     mkl_shape_diff_shift.SetMklTensor(false);
   1332     AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor,
   1333                               tf_shape_scale_shift, mkl_shape_diff_shift);
   1334     CHECK_NOTNULL(*diff_shift_tensor);
   1335     for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++)
   1336       (*diff_shift_tensor)->flat<T>().data()[i] = 0;
   1337 
   1338     // Placeholders for estimated_mean and estimated_variance, which are
   1339     // used for inference and thus not needed here for gradient computation.
   1340     Tensor *p1_tensor = nullptr, *p2_tensor = nullptr;
   1341     MklDnnShape mkl_shape_p;
   1342     mkl_shape_p.SetMklTensor(false);
   1343     AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}),
   1344                               mkl_shape_p);
   1345     AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}),
   1346                               mkl_shape_p);
   1347   }
   1348 
   1349   memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); }
   1350 };
   1351 
   1352 #endif
   1353 
   1354 #define REGISTER_MKL_CPU(T)                                         \
   1355   REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm")                \
   1356                               .Device(DEVICE_CPU)                   \
   1357                               .TypeConstraint<T>("T")               \
   1358                               .Label(mkl_op_registry::kMklOpLabel), \
   1359                           MklFusedBatchNormOp<CPUDevice, T>);
   1360 TF_CALL_float(REGISTER_MKL_CPU);
   1361 #undef REGISTER_MKL_CPU
   1362 
   1363 #define REGISTER_MKL_CPU(T)                                         \
   1364   REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad")            \
   1365                               .Device(DEVICE_CPU)                   \
   1366                               .TypeConstraint<T>("T")               \
   1367                               .Label(mkl_op_registry::kMklOpLabel), \
   1368                           MklFusedBatchNormGradOp<CPUDevice, T>);
   1369 TF_CALL_float(REGISTER_MKL_CPU);
   1370 #undef REGISTER_MKL_CPU
   1371 }  // namespace tensorflow
   1372 
   1373 #endif  // INTEL_MKL
   1374