Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 #ifdef INTEL_MKL
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/numeric_op.h"
     21 #include "tensorflow/core/framework/op_kernel.h"
     22 #include "tensorflow/core/framework/register_types.h"
     23 #include "tensorflow/core/framework/tensor.h"
     24 #include "tensorflow/core/lib/core/errors.h"
     25 
     26 #include "mkl_dnn.h"
     27 #include "mkl_dnn_types.h"
     28 #include "tensorflow/core/platform/default/logging.h"
     29 #include "tensorflow/core/util/mkl_util.h"
     30 
     31 #ifndef INTEL_MKL_ML
     32 #include "mkldnn.hpp"
     33 
     34 using mkldnn::algorithm;
     35 using mkldnn::eltwise_elu;
     36 using mkldnn::eltwise_relu;
     37 using mkldnn::eltwise_tanh;
     38 using mkldnn::prop_kind;
     39 using mkldnn::relu_backward;
     40 using mkldnn::relu_forward;
     41 using mkldnn::stream;
     42 #endif
     43 
     44 namespace tensorflow {
     45 
     46 typedef Eigen::ThreadPoolDevice CPUDevice;
     47 
     48 struct MklReluHelpers {
     49   static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
     50                                      const Tensor& a) {
     51     OP_REQUIRES(context, a.IsSameSize(g),
     52                 errors::InvalidArgument("g and a must be the same size"));
     53   }
     54   static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
     55                                const Tensor& a) {
     56     ValidateSameSizeHelper(context, g, a);
     57     return context->status().ok();
     58   }
     59 };
     60 
     61 #ifdef INTEL_MKL_ML
     62 
     63 template <typename Device, typename T>
     64 class MklReluOp : public OpKernel {
     65  public:
     66   ~MklReluOp() {}
     67 
     68   explicit MklReluOp(OpKernelConstruction* context) : OpKernel(context) {}
     69 
     70   void Compute(OpKernelContext* context) override {
     71     MklReluOpContext mkl_context;
     72 
     73     const Tensor& input = MklGetInput(context, 0);
     74     GetMklShape(context, 0, &mkl_context.input_shape);
     75     void* user_i = static_cast<void*>(const_cast<T*>(input.flat<T>().data()));
     76     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
     77 
     78     if (!input_in_mkl_format && !input.dims()) {  // handle the case of a scalar
     79       const TensorShape& o_shape = input.shape();
     80       Tensor* out_tensor = nullptr;
     81       mkl_context.output_shape.SetMklTensor(false);
     82       AllocateOutputSetMklShape(context, 0, &out_tensor, o_shape,
     83                                 mkl_context.output_shape);
     84       void* out_o = static_cast<void*>(out_tensor->flat<T>().data());
     85       (static_cast<T*>(out_o))[0] =
     86           std::max((static_cast<T*>(user_i))[0], static_cast<T>(0));
     87       return;
     88     }
     89 
     90     // Generate size, stride for input if input is in MKL format.
     91     if (input_in_mkl_format) {
     92       mkl_context.in_dims = mkl_context.input_shape.GetDimension();
     93       mkl_context.in_sizes = new size_t[mkl_context.in_dims];
     94       mkl_context.in_strides = new size_t[mkl_context.in_dims];
     95       for (int i = 0; i < mkl_context.in_dims; i++) {
     96         mkl_context.in_sizes[i] = mkl_context.input_shape.GetSizes()[i];
     97         mkl_context.in_strides[i] = mkl_context.input_shape.GetStrides()[i];
     98       }
     99     } else {
    100       mkl_context.in_dims = input.dims();
    101       mkl_context.in_sizes = new size_t[mkl_context.in_dims];
    102       mkl_context.in_strides = new size_t[mkl_context.in_dims];
    103       for (int i = 0; i < mkl_context.in_dims; i++) {
    104         mkl_context.in_sizes[i] = input.dim_size((mkl_context.in_dims - 1) - i);
    105       }
    106       mkl_context.in_strides[0] = 1;
    107       for (int i = 1; i < mkl_context.in_dims; i++) {
    108         mkl_context.in_strides[i] =
    109             mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1];
    110       }
    111     }
    112 
    113     float negative_slope = 0.0;
    114     mkl_context.MklCreateInputLayouts(context);
    115     CHECK_EQ(dnnReLUCreateForward_F32(&mkl_context.prim_relu_fwd, NULL,
    116                                       mkl_context.lt_input, negative_slope),
    117              E_SUCCESS);
    118 
    119     Tensor* output = nullptr;
    120 
    121     if (input_in_mkl_format) {
    122       TensorShape tf_shape;
    123       mkl_context.output_shape.SetMklTensor(true);
    124       mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_fwd,
    125                                             dnnResourceDst);
    126       mkl_context.output_shape.SetTfLayout(
    127           mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
    128       mkl_context.output_shape.SetTfDimOrder(
    129           mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap());
    130       tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    131                           mkl_context.output_shape.GetMklLayout())) /
    132                       sizeof(T));
    133       AllocateOutputSetMklShape(context, 0, &output, tf_shape,
    134                                 mkl_context.output_shape);
    135     } else {
    136       const TensorShape& o_shape = input.shape();
    137       mkl_context.output_shape.SetMklTensor(false);
    138       AllocateOutputSetMklShape(context, 0, &output, o_shape,
    139                                 mkl_context.output_shape);
    140     }
    141 
    142     void* user_o = static_cast<void*>(const_cast<T*>(output->flat<T>().data()));
    143 
    144     mkl_context.relu_res[dnnResourceDst] = user_o;
    145     mkl_context.relu_res[dnnResourceSrc] = user_i;
    146     CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_fwd, mkl_context.relu_res),
    147              E_SUCCESS);
    148     mkl_context.MklCleanup();
    149   }
    150 
    151  private:
    152   typedef struct {
    153     int in_dims;
    154     size_t* in_sizes;
    155     size_t* in_strides;
    156     MklShape input_shape, output_shape;
    157     dnnPrimitive_t prim_relu_fwd = nullptr;
    158     void* relu_res[dnnResourceNumber];
    159     dnnLayout_t lt_input = nullptr;
    160 
    161     void MklCleanup() {
    162       bool input_in_mkl_format = input_shape.IsMklTensor();
    163       if (!input_in_mkl_format) {
    164         dnnLayoutDelete_F32(lt_input);
    165         free(in_sizes);
    166         free(in_strides);
    167       }
    168       dnnDelete_F32(prim_relu_fwd);
    169     }
    170 
    171     void MklCreateInputLayouts(OpKernelContext* context) {
    172       bool input_in_mkl_format = input_shape.IsMklTensor();
    173       if (!input_in_mkl_format) {
    174         CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
    175                  E_SUCCESS);
    176       } else {
    177         lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
    178       }
    179     }
    180   } MklReluOpContext;
    181 };
    182 
    183 template <typename Device, typename T>
    184 class MklReluGradOp : public OpKernel {
    185  public:
    186   ~MklReluGradOp() {}
    187 
    188   explicit MklReluGradOp(OpKernelConstruction* context) : OpKernel(context) {}
    189 
    190   void Compute(OpKernelContext* context) override;
    191 
    192  private:
    193   typedef struct {
    194     int in_dims;
    195     size_t* in_sizes;
    196     size_t* in_strides;
    197     MklShape input_shape, grad_shape, output_shape;
    198     void* relu_res[dnnResourceNumber];
    199     dnnPrimitive_t prim_relu_bwd;
    200     dnnLayout_t lt_input, lt_grad;
    201 
    202     void MklPrepareReluGradInputs(OpKernelContext* context,
    203                                   Tensor* mkl_tmp_input_buf_tensor) {
    204       const Tensor& g = MklGetInput(context, 0);
    205       const Tensor& a = MklGetInput(context, 1);
    206       void* buf_input = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
    207       void* mkl_buffer_convert = nullptr;
    208 
    209       dnnPrimitive_t cv_input_to_grad = nullptr;
    210 
    211       // if input and grad are not in the same layout,
    212       // do a conversion between them.
    213       if (!dnnLayoutCompare_F32(lt_input, lt_grad)) {
    214         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_grad,
    215                        &mkl_buffer_convert);
    216         CHECK_EQ(dnnConversionCreate_F32(&cv_input_to_grad, lt_input, lt_grad),
    217                  E_SUCCESS);
    218         CHECK_EQ(dnnConversionExecute_F32(cv_input_to_grad, buf_input,
    219                                           mkl_buffer_convert),
    220                  E_SUCCESS);
    221         relu_res[dnnResourceSrc] = mkl_buffer_convert;
    222         dnnDelete_F32(cv_input_to_grad);
    223       } else {
    224         relu_res[dnnResourceSrc] = buf_input;
    225       }
    226 
    227       void* buf_grad = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
    228       relu_res[dnnResourceDiffDst] = buf_grad;
    229     }
    230 
    231     void MklCreateInputLayouts(OpKernelContext* context) {
    232       bool grad_is_mkl = grad_shape.IsMklTensor();
    233       bool input_is_mkl = input_shape.IsMklTensor();
    234       if (!input_is_mkl) {
    235         CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
    236                  E_SUCCESS);
    237       } else {
    238         lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
    239       }
    240 
    241       if (!grad_is_mkl) {
    242         CHECK_EQ(dnnLayoutCreate_F32(&lt_grad, in_dims, in_sizes, in_strides),
    243                  E_SUCCESS);
    244       } else {
    245         lt_grad = static_cast<dnnLayout_t>(grad_shape.GetCurLayout());
    246       }
    247     }
    248 
    249     void MklCleanup() {
    250       bool grad_is_mkl = grad_shape.IsMklTensor();
    251       bool input_is_mkl = input_shape.IsMklTensor();
    252       dnnDelete_F32(prim_relu_bwd);
    253       if (!input_is_mkl) {
    254         dnnLayoutDelete_F32(lt_input);
    255         free(in_sizes);
    256         free(in_strides);
    257       }
    258       if (!grad_is_mkl) {
    259         dnnLayoutDelete_F32(lt_grad);
    260       }
    261     }
    262   } MklReluGradOpContext;
    263 };
    264 
    265 template <typename Device, typename T>
    266 void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
    267   MklReluGradOpContext mkl_context;
    268   const Tensor& g = MklGetInput(context, 0);
    269   const Tensor& a = MklGetInput(context, 1);
    270 
    271   void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
    272   void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
    273 
    274   GetMklShape(context, 0, &mkl_context.grad_shape);
    275   GetMklShape(context, 1, &mkl_context.input_shape);
    276 
    277   bool grad_is_mkl = mkl_context.grad_shape.IsMklTensor();
    278   bool input_is_mkl = mkl_context.input_shape.IsMklTensor();
    279   if (!input_is_mkl && !grad_is_mkl &&
    280       !MklReluHelpers::ValidateSameSize(context, g, a))
    281     return;
    282   Tensor* output = nullptr;
    283 
    284   if (!input_is_mkl && !grad_is_mkl && !a.dims()) {
    285     // handle the scalar case
    286     const TensorShape& g_shape = g.shape();
    287     mkl_context.output_shape.SetMklTensor(false);
    288     AllocateOutputSetMklShape(context, 0, &output, g_shape,
    289                               mkl_context.output_shape);
    290 
    291     void* out_o = static_cast<void*>(output->flat<T>().data());
    292     (static_cast<T*>(out_o))[0] =
    293         (static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0);
    294     return;
    295   }
    296 
    297   // generate size, stride for input if input/grad is in mkl format.
    298   if (grad_is_mkl || input_is_mkl) {
    299     const MklShape* tmp_mkl_shape =
    300         (grad_is_mkl) ? &mkl_context.grad_shape : &mkl_context.input_shape;
    301 
    302     mkl_context.in_dims = tmp_mkl_shape->GetDimension();
    303     mkl_context.in_strides = new size_t[mkl_context.in_dims];
    304     mkl_context.in_sizes = new size_t[mkl_context.in_dims];
    305     for (int i = 0; i < mkl_context.in_dims; i++) {
    306       mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
    307       mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
    308     }
    309   } else {
    310     mkl_context.in_dims = g.dims();
    311     mkl_context.in_strides = new size_t[mkl_context.in_dims];
    312     mkl_context.in_sizes = new size_t[mkl_context.in_dims];
    313 
    314     for (int i = 0; i < mkl_context.in_dims; i++) {
    315       mkl_context.in_sizes[i] = g.dim_size((mkl_context.in_dims - 1) - i);
    316     }
    317     mkl_context.in_strides[0] = 1;
    318     for (int i = 1; i < mkl_context.in_dims; i++) {
    319       mkl_context.in_strides[i] =
    320           mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1];
    321     }
    322   }
    323 
    324   mkl_context.MklCreateInputLayouts(context);
    325   float negative_slope = 0.0;
    326   CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
    327                                      mkl_context.lt_grad, mkl_context.lt_grad,
    328                                      negative_slope),
    329            E_SUCCESS);
    330   Tensor mkl_tmp_input_buf_tensor;
    331   mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_input_buf_tensor);
    332 
    333   if (input_is_mkl ||
    334       grad_is_mkl) { /*if  grad or input are mkl leave it in mkl*/
    335     TensorShape tf_shape;
    336     mkl_context.output_shape.SetMklTensor(true);
    337     mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_bwd,
    338                                           dnnResourceDiffSrc);
    339     mkl_context.output_shape.SetTfLayout(
    340         mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
    341     // if input_is_mkl or grad_is_mkl, then we copy strides and sizes from mkl
    342     // shape of one that is in mkl layout.
    343     if (grad_is_mkl == true) {
    344       mkl_context.output_shape.SetTfDimOrder(
    345           mkl_context.in_dims, mkl_context.grad_shape.GetTfToMklDimMap());
    346     } else {
    347       mkl_context.output_shape.SetTfDimOrder(
    348           mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap());
    349     }
    350 
    351     tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    352                         mkl_context.output_shape.GetMklLayout())) /
    353                     sizeof(T));
    354     AllocateOutputSetMklShape(context, 0, &output, tf_shape,
    355                               mkl_context.output_shape);
    356   } else {
    357     const TensorShape& o_shape = g.shape();
    358     mkl_context.output_shape.SetMklTensor(false);
    359     AllocateOutputSetMklShape(context, 0, &output, o_shape,
    360                               mkl_context.output_shape);
    361   }
    362 
    363   mkl_context.relu_res[dnnResourceDiffSrc] =
    364       static_cast<void*>(output->flat<T>().data());
    365 
    366   CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, mkl_context.relu_res),
    367            E_SUCCESS);
    368   mkl_context.MklCleanup();
    369 }
    370 
    371 #else  // INTEL_MKL_ML
    372 
    373 template <typename Device, typename T, algorithm alg_kind>
    374 class MklReluOpBase : public OpKernel {
    375  public:
    376   ~MklReluOpBase() {}
    377 
    378   explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) {}
    379 
    380   virtual void Compute_Scalar(OpKernelContext* context) = 0;
    381 
    382   void Compute(OpKernelContext* context) override {
    383     try {
    384       auto cpu_engine = engine(engine::cpu, 0);
    385       const size_t src_index = 0;  // index of src input tensor
    386       const size_t dst_index = 0;  // index of dst output tensor
    387       const Tensor& src_tensor = MklGetInput(context, src_index);
    388       MklDnnShape dnn_shape_src;
    389       GetMklShape(context, src_index, &dnn_shape_src);
    390 
    391       Tensor* dst_tensor = nullptr;
    392       if (src_tensor.dims() == 0) {
    393         Compute_Scalar(context);
    394         return;
    395       }
    396 
    397       // Create relu primitive.
    398       MklDnnData<T> src(&cpu_engine);
    399       MklDnnData<T> dst(&cpu_engine);
    400 
    401       // Set DNN primitive - src
    402       memory::desc src_md({}, memory::data_undef, memory::format_undef);
    403       if (dnn_shape_src.IsMklTensor()) {
    404         src_md = dnn_shape_src.GetMklLayout();
    405       } else {
    406         auto src_dims = TFShapeToMklDnnDims(src_tensor.shape());
    407         auto src_strides = CalculateTFStrides(src_dims);
    408         // Create blocked memory descriptor
    409         src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
    410       }
    411       src.SetUsrMem(src_md, &src_tensor);
    412 
    413       T alpha = 0, beta = 0;
    414       std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
    415       auto relu_fwd_desc = relu_forward::desc(
    416           prop_kind::forward_training,
    417           // Operator memory descriptor is same as user memory descriptor.
    418           alg_kind, src.GetUsrMemDesc(), alpha, beta);
    419       relu_fwd_pd.reset(
    420           new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine));
    421 
    422       // allocate dst tensor
    423       MklDnnShape dnn_shape_dst;
    424       TensorShape tf_shape_dst;
    425       if (dnn_shape_src.IsMklTensor()) {
    426         dnn_shape_dst.SetMklTensor(true);
    427         auto dst_pd = relu_fwd_pd->dst_primitive_desc();
    428         dnn_shape_dst.SetMklLayout(&dst_pd);
    429         dnn_shape_dst.SetElemType(MklDnnType<T>());
    430         dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
    431                                   dnn_shape_src.GetSizesAsMklDnnDims(),
    432                                   dnn_shape_src.GetTfDataFormat());
    433         tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T));
    434       } else {
    435         dnn_shape_dst.SetMklTensor(false);
    436         tf_shape_dst = src_tensor.shape();
    437       }
    438       AllocateOutputSetMklShape(context, dst_index, &dst_tensor, tf_shape_dst,
    439                                 dnn_shape_dst);
    440 
    441       // Destination memory descriptor is same as source memory descriptor.
    442       auto dst_md = src_md;
    443       dst.SetUsrMem(dst_md, dst_tensor);
    444 
    445       // execute net
    446       std::vector<primitive> net;
    447       auto relu_fwd =
    448           relu_forward(*relu_fwd_pd, src.GetOpMem(), dst.GetOpMem());
    449       net.push_back(relu_fwd);
    450       stream(stream::kind::eager).submit(net).wait();
    451     } catch (mkldnn::error& e) {
    452       string error_msg = "Status: " + std::to_string(e.status) +
    453                          ", message: " + string(e.message) + ", in file " +
    454                          string(__FILE__) + ":" + std::to_string(__LINE__);
    455       OP_REQUIRES_OK(
    456           context,
    457           errors::Aborted("Operation received an exception:", error_msg));
    458     }
    459   }
    460 };
    461 
    462 template <typename Device, typename T, algorithm alg_kind>
    463 class MklReluGradOpBase : public OpKernel {
    464  public:
    465   ~MklReluGradOpBase() {}
    466 
    467   explicit MklReluGradOpBase(OpKernelConstruction* context)
    468       : OpKernel(context) {}
    469 
    470   virtual void Compute_Scalar(OpKernelContext* context) = 0;
    471 
    472   void Compute(OpKernelContext* context) {
    473     try {
    474       auto cpu_engine = engine(engine::cpu, 0);
    475       MklDnnData<T> src(&cpu_engine);
    476       MklDnnData<T> diff_dst(&cpu_engine);
    477       MklDnnData<T> diff_src(&cpu_engine);
    478 
    479       const size_t diff_dst_index = 0;  // index of diff_dst input tensor
    480       const size_t src_index = 1;       // index of src input tensor
    481       const size_t diff_src_index = 0;  // index of diff_src output tensor
    482 
    483       const Tensor& src_tensor = MklGetInput(context, src_index);
    484       const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
    485       Tensor* diff_src_tensor = nullptr;
    486 
    487       MklDnnShape dnn_shape_src, dnn_shape_diff_dst;
    488       GetMklShape(context, src_index, &dnn_shape_src);
    489       GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
    490 
    491       int src_dims_size = src_tensor.dims();
    492       if (src_dims_size == 0) {
    493         Compute_Scalar(context);
    494         return;
    495       }
    496 
    497       // Set DNN primitives for src & diff_dst
    498       memory::desc src_md({}, memory::data_undef, memory::format_undef);
    499       memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
    500 
    501       // For creating Sum primitive, we need to ensure that all inputs are in
    502       // same format. What that means is if we have a mixed input case - where
    503       // one input is in Tensorflow format and one input is in MKL format -,
    504       // then we need to ensure that all inputs are in same format for
    505       // primitive construction. For performance reason, we say that all inputs
    506       // are in MKL format in such case, and insert reorder for input that is
    507       // in Tensorflow format into MKL format. On the other hand, if both the
    508       // inputs are in MKL format or both are in Tensorflow format, then we
    509       // dont need reorder.
    510       if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
    511         // If both the inputs are in Tensorflow format, we create blocked memory
    512         // descriptor.
    513         auto src_dims = TFShapeToMklDnnDims(src_tensor.shape());
    514         auto src_strides = CalculateTFStrides(src_dims);
    515         src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
    516         diff_dst_md = src_md;
    517       } else if (dnn_shape_src.IsMklTensor() &&
    518                  !dnn_shape_diff_dst.IsMklTensor()) {
    519         // If one input is in MKL format and other is in Tensorflow, then
    520         // create respective descriptors describing the actual case. For input
    521         // in Mkl format, we just get Mkl layout from MklDnnShape. For input in
    522         // Tensorflow format, we create memory descriptor using data format.
    523         src_md = dnn_shape_src.GetMklLayout();
    524 
    525         memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
    526         auto src_tf_data_format =
    527             MklDnnDataFormatToTFDataFormat(src_mkl_data_format);
    528         auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
    529                                                        src_tf_data_format);
    530         diff_dst_md =
    531             memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format);
    532       } else if (!dnn_shape_src.IsMklTensor() &&
    533                  dnn_shape_diff_dst.IsMklTensor()) {
    534         // Same comment as above.
    535         diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
    536 
    537         memory::format diff_dst_mkl_data_format =
    538             dnn_shape_diff_dst.GetTfDataFormat();
    539         auto diff_dst_tf_data_format =
    540             MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
    541         auto src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
    542                                                   diff_dst_tf_data_format);
    543         src_md =
    544             memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
    545       } else {
    546         // If both the inputs are in MKL format, we use Mkl layout of the input
    547         // tensors.
    548         src_md = dnn_shape_src.GetMklLayout();
    549         diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
    550       }
    551 
    552       src.SetUsrMem(src_md, &src_tensor);
    553       diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
    554 
    555       // As per comment above, we tell MKLDNN that both the inputs are in same
    556       // format. So we set common memory descriptor in MKL format, if any of the
    557       // inputs are in MKL format. Let's get memory descriptor that we will use
    558       // for both the inputs.
    559       memory::desc common_md({}, memory::data_undef, memory::format_undef);
    560       if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
    561         common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
    562       } else {
    563         // Since both the inputs are in Tensorflow format, and have
    564         // same shape, we can get memory descriptor from any input.
    565         common_md = src_md;
    566       }
    567 
    568       T alpha = 0, beta = 0;
    569       std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
    570       auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training,
    571                                               alg_kind, src_md, alpha, beta);
    572       relu_fwd_pd.reset(
    573           new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine));
    574       auto relu_bwd_desc =
    575           relu_backward::desc(alg_kind, common_md, common_md, alpha, beta);
    576       auto relu_bwd_pd = relu_backward::primitive_desc(
    577           relu_bwd_desc, cpu_engine, *relu_fwd_pd);
    578 
    579       // allocate diff_src tensor
    580       MklDnnShape dnn_shape_diff_src;
    581       TensorShape tf_shape_diff_src;
    582       if (dnn_shape_src.IsMklTensor()) {
    583         dnn_shape_diff_src.SetMklTensor(true);
    584         auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc();
    585         dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
    586         dnn_shape_diff_src.SetElemType(MklDnnType<T>());
    587         dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(),
    588                                        dnn_shape_src.GetSizesAsMklDnnDims(),
    589                                        dnn_shape_src.GetTfDataFormat());
    590         tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
    591       } else {
    592         dnn_shape_diff_src.SetMklTensor(false);
    593         tf_shape_diff_src = src_tensor.shape();
    594       }
    595       AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
    596                                 tf_shape_diff_src, dnn_shape_diff_src);
    597 
    598       // diff_src memory descriptor is same as memory descriptor for both
    599       // inputs.
    600       diff_src.SetUsrMem(common_md, diff_src_tensor);
    601 
    602       PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst);
    603     } catch (mkldnn::error& e) {
    604       string error_msg = "Status: " + std::to_string(e.status) +
    605                          ", message: " + string(e.message) + ", in file " +
    606                          string(__FILE__) + ":" + std::to_string(__LINE__);
    607       OP_REQUIRES_OK(
    608           context,
    609           errors::Aborted("Operation received an exception:", error_msg));
    610     }
    611   }
    612 
    613   void PrepareAndExecuteNet(const relu_backward::primitive_desc& relu_prim_desc,
    614                             MklDnnData<T>* src, MklDnnData<T>* diff_src,
    615                             MklDnnData<T>* diff_dst) {
    616     std::vector<primitive> net;
    617 
    618     // Check if we need to reorder original input tensors into common_md layout
    619     // that we set for primitive creation. diff_src_primitive_desc is same as
    620     // common_md.
    621     src->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), &net);
    622     diff_dst->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(),
    623                                   &net);
    624 
    625     net.push_back(relu_backward(relu_prim_desc, src->GetOpMem(),
    626                                 diff_dst->GetOpMem(), diff_src->GetOpMem()));
    627     stream(stream::kind::eager).submit(net).wait();
    628   }
    629 };
    630 
    631 template <typename Device, typename T>
    632 class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
    633  public:
    634   ~MklReluOp() {}
    635 
    636   explicit MklReluOp(OpKernelConstruction* context)
    637       : MklReluOpBase<Device, T, eltwise_relu>(context) {}
    638 
    639   virtual void Compute_Scalar(OpKernelContext* context) {
    640     const size_t src_index = 0;  // index of src input tensor
    641     const size_t dst_index = 0;  // index of dst output tensor
    642     const Tensor& src_tensor = MklGetInput(context, src_index);
    643     MklDnnShape dnn_shape_src;
    644     GetMklShape(context, src_index, &dnn_shape_src);
    645 
    646     Tensor* dst_tensor = nullptr;
    647     void* user_i =
    648         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
    649     MklDnnShape dnn_shape_dst;
    650     dnn_shape_dst.SetMklTensor(false);
    651     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
    652                               src_tensor.shape(), dnn_shape_dst);
    653     void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
    654     (static_cast<T*>(out_o))[0] =
    655         std::max((static_cast<T*>(user_i))[0], static_cast<T>(0));
    656     return;
    657   }
    658 };
    659 
    660 template <typename Device, typename T>
    661 class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
    662  public:
    663   ~MklReluGradOp() {}
    664 
    665   explicit MklReluGradOp(OpKernelConstruction* context)
    666       : MklReluGradOpBase<Device, T, eltwise_relu>(context) {}
    667 
    668   virtual void Compute_Scalar(OpKernelContext* context) {
    669     const size_t diff_dst_index = 0;  // index of diff_dst input tensor
    670     const size_t src_index = 1;       // index of src input tensor
    671     const size_t diff_src_index = 0;  // index of diff_src output tensor
    672     const Tensor& src_tensor = MklGetInput(context, src_index);
    673     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
    674     Tensor* diff_src_tensor = nullptr;
    675 
    676     MklDnnShape dnn_shape_diff_dst;
    677     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
    678 
    679     MklDnnShape dnn_shape_diff_src;
    680     dnn_shape_diff_src.SetMklTensor(false);
    681     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
    682                               diff_dst_tensor.shape(), dnn_shape_diff_src);
    683     void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data());
    684     void* user_i =
    685         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
    686     void* user_g =
    687         static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
    688     (static_cast<T*>(out_o))[0] =
    689         (static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0);
    690     return;
    691   }
    692 };
    693 
    694 template <typename Device, typename T>
    695 class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
    696  public:
    697   ~MklEluOp() {}
    698 
    699   explicit MklEluOp(OpKernelConstruction* context)
    700       : MklReluOpBase<Device, T, eltwise_elu>(context) {}
    701 
    702   virtual void Compute_Scalar(OpKernelContext* context) {
    703     const size_t src_index = 0;  // index of src input tensor
    704     const size_t dst_index = 0;  // index of dst output tensor
    705     const Tensor& src_tensor = MklGetInput(context, src_index);
    706     MklDnnShape dnn_shape_src;
    707     GetMklShape(context, src_index, &dnn_shape_src);
    708 
    709     Tensor* dst_tensor = nullptr;
    710     void* user_i =
    711         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
    712     MklDnnShape dnn_shape_dst;
    713     dnn_shape_dst.SetMklTensor(false);
    714     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
    715                               src_tensor.shape(), dnn_shape_dst);
    716     void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
    717     // return exp(feature) - 1 if feature > 0; feature otherwise
    718     T feature = (static_cast<T*>(user_i))[0];
    719     if (feature < 0)
    720       (static_cast<T*>(out_o))[0] = std::exp(feature);
    721     else
    722       (static_cast<T*>(out_o))[0] = feature;
    723     return;
    724   }
    725 };
    726 
    727 template <typename Device, typename T>
    728 class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
    729  public:
    730   ~MklEluGradOp() {}
    731 
    732   explicit MklEluGradOp(OpKernelConstruction* context)
    733       : MklReluGradOpBase<Device, T, eltwise_elu>(context) {}
    734 
    735   virtual void Compute_Scalar(OpKernelContext* context) {
    736     const size_t diff_dst_index = 0;  // index of diff_dst input tensor
    737     const size_t src_index = 1;       // index of src input tensor
    738     const size_t diff_src_index = 0;  // index of diff_src output tensor
    739     const Tensor& src_tensor = MklGetInput(context, src_index);
    740     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
    741     Tensor* diff_src_tensor = nullptr;
    742 
    743     MklDnnShape dnn_shape_diff_dst;
    744     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
    745 
    746     MklDnnShape dnn_shape_diff_src;
    747     dnn_shape_diff_src.SetMklTensor(false);
    748     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
    749                               diff_dst_tensor.shape(), dnn_shape_diff_src);
    750     void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data());
    751     void* user_i =
    752         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
    753     void* user_g =
    754         static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
    755     // gradient of elu(x) = 1 if x > 0; elu(x) + 1 otherwise
    756     T feature = (static_cast<T*>(user_i))[0];
    757     if (feature > 0) {
    758       (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0];
    759     } else {
    760       T elu = std::exp(feature) - 1;
    761       (static_cast<T*>(out_o))[0] = (static_cast<T*>(user_g))[0] * (elu + 1);
    762     }
    763   }
    764 };
    765 
    766 template <typename Device, typename T>
    767 class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
    768  public:
    769   ~MklTanhOp() {}
    770 
    771   explicit MklTanhOp(OpKernelConstruction* context)
    772       : MklReluOpBase<Device, T, eltwise_tanh>(context) {}
    773 
    774   virtual void Compute_Scalar(OpKernelContext* context) {
    775     const size_t src_index = 0;  // index of src input tensor
    776     const size_t dst_index = 0;  // index of dst output tensor
    777     const Tensor& src_tensor = MklGetInput(context, src_index);
    778     MklDnnShape dnn_shape_src;
    779     GetMklShape(context, src_index, &dnn_shape_src);
    780 
    781     Tensor* dst_tensor = nullptr;
    782     void* user_i =
    783         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
    784     MklDnnShape dnn_shape_dst;
    785     dnn_shape_dst.SetMklTensor(false);
    786     AllocateOutputSetMklShape(context, dst_index, &dst_tensor,
    787                               src_tensor.shape(), dnn_shape_dst);
    788     void* out_o = static_cast<void*>(dst_tensor->flat<T>().data());
    789     // tanh(x) = (e^x - e^(-x))/ (e^x + e^(-x))
    790     T feature = (static_cast<T*>(user_i))[0];
    791     T e1 = std::exp(feature);
    792     T e2 = std::exp(-feature);
    793     (static_cast<T*>(out_o))[0] = (e1 - e2) / (e1 + e2);
    794     return;
    795   }
    796 };
    797 
    798 template <typename Device, typename T>
    799 class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
    800  public:
    801   ~MklTanhGradOp() {}
    802 
    803   explicit MklTanhGradOp(OpKernelConstruction* context)
    804       : MklReluGradOpBase<Device, T, eltwise_tanh>(context) {}
    805 
    806   virtual void Compute_Scalar(OpKernelContext* context) {
    807     const size_t diff_dst_index = 0;  // index of diff_dst input tensor
    808     const size_t src_index = 1;       // index of src input tensor
    809     const size_t diff_src_index = 0;  // index of diff_src output tensor
    810     const Tensor& src_tensor = MklGetInput(context, src_index);
    811     const Tensor& diff_dst_tensor = MklGetInput(context, diff_dst_index);
    812     Tensor* diff_src_tensor = nullptr;
    813 
    814     MklDnnShape dnn_shape_diff_dst;
    815     GetMklShape(context, diff_dst_index, &dnn_shape_diff_dst);
    816 
    817     MklDnnShape dnn_shape_diff_src;
    818     dnn_shape_diff_src.SetMklTensor(false);
    819     AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor,
    820                               diff_dst_tensor.shape(), dnn_shape_diff_src);
    821     void* out_o = static_cast<void*>(diff_src_tensor->flat<T>().data());
    822     void* user_i =
    823         static_cast<void*>(const_cast<T*>(src_tensor.flat<T>().data()));
    824     // gradient of tanh(x) = 1 - tanh(x)^2
    825     T feature = (static_cast<T*>(user_i))[0];
    826     T e1 = std::exp(feature);
    827     T e2 = std::exp(-feature);
    828     T tanh = (e1 - e2) / (e1 + e2);
    829     void* user_g =
    830         static_cast<void*>(const_cast<T*>(diff_dst_tensor.flat<T>().data()));
    831     (static_cast<T*>(out_o))[0] =
    832         (static_cast<T*>(user_g))[0] * (1 - tanh * tanh);
    833   }
    834 };
    835 
    836 #endif
    837 
    838 // register dnn kernels for supported operations and supported types
    839 #define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type)             \
    840   REGISTER_KERNEL_BUILDER(Name("_MklRelu")                          \
    841                               .Device(DEVICE_CPU)                   \
    842                               .TypeConstraint<type>("T")            \
    843                               .Label(mkl_op_registry::kMklOpLabel), \
    844                           MklReluOp<CPUDevice, type>);              \
    845   REGISTER_KERNEL_BUILDER(Name("_MklReluGrad")                      \
    846                               .Device(DEVICE_CPU)                   \
    847                               .TypeConstraint<type>("T")            \
    848                               .Label(mkl_op_registry::kMklOpLabel), \
    849                           MklReluGradOp<CPUDevice, type>);
    850 TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
    851 
    852 #ifndef INTEL_MKL_ML
    853 
    854 // register dnn kernels for supported operations and supported types
    855 #define REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES(type)              \
    856   REGISTER_KERNEL_BUILDER(Name("_MklElu")                           \
    857                               .Device(DEVICE_CPU)                   \
    858                               .TypeConstraint<type>("T")            \
    859                               .Label(mkl_op_registry::kMklOpLabel), \
    860                           MklEluOp<CPUDevice, type>);               \
    861   REGISTER_KERNEL_BUILDER(Name("_MklEluGrad")                       \
    862                               .Device(DEVICE_CPU)                   \
    863                               .TypeConstraint<type>("T")            \
    864                               .Label(mkl_op_registry::kMklOpLabel), \
    865                           MklEluGradOp<CPUDevice, type>);
    866 TF_CALL_float(REGISTER_ELU_MKL_SUPPORTED_KERNELS_TYPES);
    867 
    868 #define REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES(type)             \
    869   REGISTER_KERNEL_BUILDER(Name("_MklTanh")                          \
    870                               .Device(DEVICE_CPU)                   \
    871                               .TypeConstraint<type>("T")            \
    872                               .Label(mkl_op_registry::kMklOpLabel), \
    873                           MklTanhOp<CPUDevice, type>);              \
    874   REGISTER_KERNEL_BUILDER(Name("_MklTanhGrad")                      \
    875                               .Device(DEVICE_CPU)                   \
    876                               .TypeConstraint<type>("T")            \
    877                               .Label(mkl_op_registry::kMklOpLabel), \
    878                           MklTanhGradOp<CPUDevice, type>);
    879 TF_CALL_float(REGISTER_TANH_MKL_SUPPORTED_KERNELS_TYPES);
    880 
    881 #endif
    882 
    883 }  // namespace tensorflow
    884 
    885 #endif  // INTEL_MKL
    886