Home | History | Annotate | Download | only in util
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
     17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
     18 #ifdef INTEL_MKL
     19 
     20 #include <list>
     21 #include <memory>
     22 #include <string>
     23 #include <unordered_map>
     24 #include <utility>
     25 #include <vector>
     26 
     27 #if defined(INTEL_MKL_ML_ONLY) || defined(INTEL_MKL_DNN_ONLY)
     28 #ifndef INTEL_MKL
     29 #error "INTEL_MKL_{ML,DNN}_ONLY require INTEL_MKL"
     30 #endif
     31 #endif
     32 
     33 #if defined(INTEL_MKL_ML_ONLY) && defined(INTEL_MKL_DNN_ONLY)
     34 #error "at most one of INTEL_MKL_ML_ONLY and INTEL_MKL_DNN_ONLY may be defined"
     35 #endif
     36 
     37 #ifdef INTEL_MKL_ML_ONLY
     38 #error "Please use INTEL MKL DNN (the default option for --config=mkl)."
     39 #endif
     40 
     41 #ifdef INTEL_MKL_ML_ONLY
     42 #include "mkl_dnn.h"
     43 #include "mkl_dnn_types.h"
     44 #include "mkl_service.h"
     45 #include "mkl_trans.h"
     46 #endif
     47 
     48 #include "tensorflow/core/framework/op_kernel.h"
     49 #include "tensorflow/core/framework/tensor.h"
     50 #include "tensorflow/core/framework/tensor_shape.h"
     51 #include "tensorflow/core/graph/mkl_graph_util.h"
     52 #include "tensorflow/core/lib/core/errors.h"
     53 #include "tensorflow/core/lib/gtl/array_slice.h"
     54 #include "tensorflow/core/platform/cpu_info.h"
     55 #include "tensorflow/core/platform/logging.h"
     56 #include "tensorflow/core/platform/macros.h"
     57 #include "tensorflow/core/util/env_var.h"
     58 #include "tensorflow/core/util/padding.h"
     59 #include "tensorflow/core/util/tensor_format.h"
     60 
     61 #ifndef INTEL_MKL_ML_ONLY
     62 #include "mkldnn.hpp"
     63 #include "tensorflow/core/lib/core/stringpiece.h"
     64 
     65 using mkldnn::engine;
     66 using mkldnn::memory;
     67 using mkldnn::padding_kind;
     68 using mkldnn::primitive;
     69 using mkldnn::reorder;
     70 #endif
     71 
     72 #ifdef _WIN32
     73 typedef unsigned int uint;
     74 #endif
     75 
     76 namespace tensorflow {
     77 
     78 // The file contains a number of utility classes and functions used by MKL
     79 // enabled kernels
     80 
     81 // This class encapsulates all the meta data that is associated with an MKL
     82 // tensor. A tensor is an MKL tensor if it was created as the result of an
     83 // MKL operation, and did not go through a conversion to a standard
     84 // Tensorflow tensor.
     85 
     86 // For use with MKL ML, has been deprecated
     87 typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims;
     88 
     89 // The dimensions order that MKL-DNN internally uses for 2D activations
     90 // [Batch, Channel, Height, Width] and
     91 // for 2D filters [Out_Channel, In_Channel, Height, Width].
     92 typedef enum {
     93   Dim_N = 0,
     94   Dim_C = 1,
     95   Dim_H = 2,
     96   Dim_W = 3,
     97   Dim_O = 0,
     98   Dim_I = 1
     99 } MklDnnDims;
    100 
    101 // The dimensions order that MKL-DNN internally uses for 3D activations
    102 // [Batch, Channel, Depth, Height, Width] and
    103 // for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
    104 typedef enum {
    105   Dim3d_N = 0,
    106   Dim3d_C = 1,
    107   Dim3d_D = 2,
    108   Dim3d_H = 3,
    109   Dim3d_W = 4,
    110   Dim3d_O = 0,
    111   Dim3d_I = 1
    112 } MklDnnDims3D;
    113 
    114 // Enum for the order of dimensions of a TF 2D filter with shape [filter_height,
    115 // filter_width, in_channels, out_channels]
    116 typedef enum {
    117   TF_2DFILTER_DIM_H = 0,
    118   TF_2DFILTER_DIM_W = 1,
    119   TF_2DFILTER_DIM_I = 2,
    120   TF_2DFILTER_DIM_O = 3
    121 } TFFilterDims2d;
    122 
    123 // Enum for the order of dimensions of a TF 3D filter with shape [filter_depth,
    124 // filter_height, filter_width, in_channels, out_channels]
    125 typedef enum {
    126   TF_3DFILTER_DIM_P = 0,
    127   TF_3DFILTER_DIM_H = 1,
    128   TF_3DFILTER_DIM_W = 2,
    129   TF_3DFILTER_DIM_I = 3,
    130   TF_3DFILTER_DIM_O = 4
    131 } TFFilterDims3d;
    132 
    133 // The dimensions order that MKL-DNN requires for the filter in a grouped
    134 // convolution (2D only)
    135 typedef enum {
    136   MKL_GROUP_FILTER_DIM_G = 0,
    137   MKL_GROUP_FILTER_DIM_O = 1,
    138   MKL_GROUP_FILTER_DIM_I = 2,
    139   MKL_GROUP_FILTER_DIM_H = 3,
    140   MKL_GROUP_FILTER_DIM_W = 4
    141 } MklDnnFilterGroupDims;
    142 
    143 //Enum used to templatize MklOp kernel implementations
    144 // that support both fp32 and int8 versions.
    145 enum class MklQuantization {
    146   QUANTIZED_VERSION,
    147   FP_VERSION,
    148 };
    149 
    150 static const int kSmallBatchSize = 32;
    151 
    152 #ifdef INTEL_MKL_ML_ONLY
    153 class MklShape {
    154  public:
    155   MklShape() {}
    156   TF_DISALLOW_COPY_AND_ASSIGN(MklShape);  // Cannot copy
    157 
    158   ~MklShape() {
    159     if (sizes_) delete[] sizes_;
    160     if (strides_) delete[] strides_;
    161     if (mklLayout_) CHECK_EQ(dnnLayoutDelete_F32(mklLayout_), E_SUCCESS);
    162     if (tfLayout_) CHECK_EQ(dnnLayoutDelete_F32(tfLayout_), E_SUCCESS);
    163     if (tf_to_mkl_dim_map_) delete[] tf_to_mkl_dim_map_;
    164   }
    165 
    166   const bool IsMklTensor() const { return isMklTensor_; }
    167 
    168   void SetMklTensor(const bool isMklTensor) { isMklTensor_ = isMklTensor; }
    169 
    170   void SetDimensions(const size_t dimension) { dimension_ = dimension; }
    171 
    172   void SetMklLayout(dnnLayout_t mklLayout) { mklLayout_ = mklLayout; }
    173 
    174   void SetMklLayout(const void* primitive, size_t resourceType) {
    175     CHECK_EQ(
    176         dnnLayoutCreateFromPrimitive_F32(&mklLayout_, (dnnPrimitive_t)primitive,
    177                                          (dnnResourceType_t)resourceType),
    178         E_SUCCESS);
    179   }
    180 
    181   void SetTfLayout(const size_t dimension, const size_t* sizes,
    182                    const size_t* strides) {
    183     dimension_ = dimension;
    184     if (dimension > 0) {  // MKl doesn't support zero dimension tensors
    185       sizes_ = new size_t[dimension];
    186       strides_ = new size_t[dimension];
    187 
    188       for (int ii = 0; ii < dimension; ii++) {
    189         sizes_[ii] = sizes[ii];
    190         strides_[ii] = strides[ii];
    191       }
    192       CHECK_EQ(dnnLayoutCreate_F32(&tfLayout_, dimension, sizes, strides),
    193                E_SUCCESS);
    194     }
    195   }
    196 
    197   // Default case - MKL dim ordering is opposite of TF dim ordering
    198   // MKL -> (DIMS-1)...0 where (DIMS-1) is outermost dim and 0 is innermost dim
    199   // TF  -> 0...(DIMS-1) where 0 is outermost dim and (DIMS-1) is innermost dim
    200   // For layers that rely on data_format semantics (conv, pooling etc.)
    201   // or operate only on certain dimensions (relu, concat, split etc.),
    202   // Mkl APIs might require us to reorder these dimensions. In such cases,
    203   // kernels should explicitly set this map
    204   void SetTfDimOrder(const size_t dimension) {
    205     CHECK(dimension == dimension_);
    206     if (tf_to_mkl_dim_map_ == nullptr) {
    207       tf_to_mkl_dim_map_ = new size_t[dimension];
    208     }
    209     for (size_t ii = 0; ii < dimension; ii++) {
    210       tf_to_mkl_dim_map_[ii] = dimension - (ii + 1);
    211     }
    212   }
    213 
    214   void SetTfDimOrder(const size_t dimension, const size_t* tf_to_mkl_dim_map) {
    215     CHECK(dimension == dimension_);
    216     if (tf_to_mkl_dim_map_ == nullptr) {
    217       tf_to_mkl_dim_map_ = new size_t[dimension];
    218     }
    219     for (size_t ii = 0; ii < dimension; ii++) {
    220       tf_to_mkl_dim_map_[ii] = tf_to_mkl_dim_map[ii];
    221     }
    222   }
    223 
    224   void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
    225     CHECK_EQ(dimension, 4);
    226     CHECK(dimension == dimension_);
    227     if (tf_to_mkl_dim_map_ == nullptr) {
    228       tf_to_mkl_dim_map_ = new size_t[dimension];
    229     }
    230     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDims::W;
    231     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDims::H;
    232     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDims::C;
    233     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDims::N;
    234   }
    235 
    236   const dnnLayout_t GetMklLayout() const { return mklLayout_; }
    237   const dnnLayout_t GetTfLayout() const { return tfLayout_; }
    238   const dnnLayout_t GetCurLayout() const {
    239     return isMklTensor_ ? mklLayout_ : tfLayout_;
    240   }
    241   size_t GetDimension() const { return dimension_; }
    242   const size_t* GetSizes() const { return sizes_; }
    243   int64 dim_size(int index) const { return sizes_[index]; }
    244   int64 tf_dim_size(int index) const {
    245     return sizes_[tf_to_mkl_dim_map_[index]];
    246   }
    247   const size_t* GetStrides() const { return strides_; }
    248   const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
    249   size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
    250 
    251   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    252   // corresponds to MKL's Channel dimension.
    253   bool IsMklChannelDim(int d) const { return tf_dim_idx(d) == MklDims::C; }
    254   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    255   // corresponds to MKL's Batch dimension.
    256   bool IsMklBatchDim(int d) const { return tf_dim_idx(d) == MklDims::N; }
    257   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    258   // corresponds to MKL's Width dimension.
    259   bool IsMklWidthDim(int d) const { return tf_dim_idx(d) == MklDims::W; }
    260   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    261   // corresponds to MKL's Height dimension.
    262   bool IsMklHeightDim(int d) const { return tf_dim_idx(d) == MklDims::H; }
    263 
    264   // Check if the TF-Mkl dimension ordering map specifies if the input
    265   // tensor is in NCHW format.
    266   bool IsTensorInNCHWFormat() const {
    267     TensorFormat data_format = FORMAT_NCHW;
    268     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    269             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    270             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    271             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    272   }
    273 
    274   // Check if the TF-Mkl dimension ordering map specifies if the input
    275   // tensor is in NHWC format.
    276   bool IsTensorInNHWCFormat() const {
    277     TensorFormat data_format = FORMAT_NHWC;
    278     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    279             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    280             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    281             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    282   }
    283 
    284   void GetConvertedFlatData(dnnLayout_t targetLayout, void* input,
    285                             void* output) const {
    286     dnnLayout_t curLayout;
    287     if (isMklTensor_)
    288       curLayout = mklLayout_;
    289     else
    290       curLayout = tfLayout_;
    291     dnnPrimitive_t convert;
    292     CHECK_EQ(dnnConversionCreate_F32(&convert, curLayout, targetLayout),
    293              E_SUCCESS);
    294     CHECK_EQ(dnnConversionExecute_F32(convert, input, output), E_SUCCESS);
    295     CHECK_EQ(dnnDelete_F32(convert), E_SUCCESS);
    296   }
    297 
    298   // The following methods are used for serializing and de-serializing the
    299   // contents of the mklshape object.
    300   // The data is serialized in this order
    301   // isMklTensor_
    302   // dimension_
    303   // sizes_
    304   // strides_
    305   // mklLayout_
    306   // tfLayout_
    307   // tf_to_mkl_dim_map_
    308 
    309 #define SIZE_OF_MKL_DNN_BUF \
    310   (dnnLayoutSerializationBufferSize_F32())  // Size of buffer needed to
    311                                             // serialize dnn_layout pointer
    312 
    313   // Size of buffer to hold the serialized object, the size is computed as
    314   // follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) +
    315   // sizeof(strides_)
    316   // + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer)
    317   // + sizeof(tf_to_mkl_dim_map_)
    318 
    319 #define SIZE_OF_MKL_SERIAL_DATA(dims) \
    320   (2 * sizeof(size_t) + 3 * dims * sizeof(size_t) + 2 * SIZE_OF_MKL_DNN_BUF)
    321 
    322   // First we need to define some macro for offsets into the serial buffer where
    323   // different elements of Mklshape is written/read from
    324 
    325 #define IS_MKL_TENSOR_OFFSET 0
    326 // Location from start of buffer where isMklTensor_ is serialized
    327 #define DIMS_OFFSET \
    328   (IS_MKL_TENSOR_OFFSET + sizeof(size_t))  // Location of dimension_
    329 // Location of sizes. Note dim is not used here, left here
    330 // to make macros consistent.
    331 #define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
    332 #define STRIDES_OFFSET(dims) \
    333   (SIZES_OFFSET(dims) + dims * sizeof(size_t))  // Location of strides
    334 #define MKL_LAYOUT_OFFSET(dims) \
    335   (STRIDES_OFFSET(dims) + dims * sizeof(size_t))  // Location of mklLayout_
    336 #define TF_LAYOUT_OFFSET(dims) \
    337   (MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)  // Location of tfLayout_
    338 // Location of tf_to_mkl_dim_map_
    339 #define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
    340   (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
    341 
    342   // TODO(agramesh1) make sure to create a const to share with rewrite pass
    343   // for min size of MKL metadata tensor.
    344 
    345   void DeSerializeMklShape(const unsigned char* buf, size_t buf_size) {
    346     CHECK(buf_size >= sizeof(size_t)) << "Bufsize too small in DeSerialize";
    347     // Make sure buffer holds at least  isMklTensor_
    348     isMklTensor_ =
    349         *reinterpret_cast<const size_t*>(buf + IS_MKL_TENSOR_OFFSET) != 0;
    350 
    351     if (isMklTensor_) {  // If it is an MKL Tensor then read the rest
    352       dimension_ = *(reinterpret_cast<const size_t*>(buf + DIMS_OFFSET));
    353       CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
    354           << "Bufsize too small in DeSerialize";
    355       sizes_ = new size_t[dimension_];
    356       strides_ = new size_t[dimension_];
    357       tf_to_mkl_dim_map_ = new size_t[dimension_];
    358       for (int i = 0; i < dimension_; i++) {
    359         sizes_[i] =
    360             reinterpret_cast<const size_t*>(buf + SIZES_OFFSET(dimension_))[i];
    361         strides_[i] = reinterpret_cast<const size_t*>(
    362             buf + STRIDES_OFFSET(dimension_))[i];
    363         tf_to_mkl_dim_map_[i] = reinterpret_cast<const size_t*>(
    364             buf + TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i];
    365       }
    366       CHECK_EQ(dnnLayoutDeserialize_F32(&mklLayout_,
    367                                         buf + MKL_LAYOUT_OFFSET(dimension_)),
    368                E_SUCCESS);
    369       CHECK_EQ(dnnLayoutDeserialize_F32(&tfLayout_,
    370                                         buf + TF_LAYOUT_OFFSET(dimension_)),
    371                E_SUCCESS);
    372     }
    373   }
    374 
    375   void SerializeMklShape(unsigned char* buf, size_t buf_size) const {
    376     CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
    377         << "Bufsize too small to Serialize";
    378     *reinterpret_cast<size_t*>(buf + IS_MKL_TENSOR_OFFSET) =
    379         isMklTensor_ ? 1 : 0;
    380     if (isMklTensor_) {
    381       *(reinterpret_cast<size_t*>(buf + DIMS_OFFSET)) = dimension_;
    382       for (int i = 0; i < dimension_; i++) {
    383         reinterpret_cast<size_t*>(buf + SIZES_OFFSET(dimension_))[i] =
    384             sizes_[i];
    385         reinterpret_cast<size_t*>(buf + STRIDES_OFFSET(dimension_))[i] =
    386             strides_[i];
    387         reinterpret_cast<size_t*>(buf +
    388                                   TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i] =
    389             tf_to_mkl_dim_map_[i];
    390       }
    391       CHECK_EQ(dnnLayoutSerialize_F32(mklLayout_,
    392                                       buf + MKL_LAYOUT_OFFSET(dimension_)),
    393                E_SUCCESS);
    394       CHECK_EQ(
    395           dnnLayoutSerialize_F32(tfLayout_, buf + TF_LAYOUT_OFFSET(dimension_)),
    396           E_SUCCESS);
    397     }
    398   }
    399 
    400  private:
    401   bool isMklTensor_ =
    402       false;  // Flag to indicate if the tensor is an  MKL tensor or not
    403   dnnLayout_t mklLayout_ = nullptr;  // Pointer to the MKL layout
    404   dnnLayout_t tfLayout_ = nullptr;   // Pointer to layout of corresponding
    405   // Tensorflow tensor, used when conversion from MKL to standard tensor
    406   size_t dimension_ = 0;
    407   size_t* sizes_ = nullptr;    // Required by MKL for conversions
    408   size_t* strides_ = nullptr;  // Required by MKL for conversions
    409   size_t* tf_to_mkl_dim_map_ =
    410       nullptr;  // TF dimension corresponding to this MKL dimension
    411 };
    412 
    413 #else
    414 
    415 // Forward decl
    416 TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
    417 TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
    418 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
    419 memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
    420                                         const memory::dims& strides,
    421                                         memory::data_type dtype);
    422 
    423 class MklDnnShape {
    424  private:
    425   typedef struct {
    426     /// Flag to indicate if the tensor is an  MKL tensor or not
    427     bool is_mkl_tensor_ = false;
    428     /// Number of dimensions in Tensorflow format
    429     size_t dimension_ = 0;
    430     /// Required by MKLDNN for conversions
    431     mkldnn_dims_t sizes_;  // Required by MKL for conversions
    432     memory::format tf_data_format_ = memory::format::format_undef;
    433     memory::data_type T_ = memory::data_type::data_undef;
    434     // MKL layout
    435     mkldnn_memory_desc_t mkl_md_;
    436     /// TF dimension corresponding to this MKL dimension
    437     mkldnn_dims_t map_;
    438   } MklShapeData;
    439   MklShapeData data_;
    440 
    441   typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
    442 #define INVALID_DIM_SIZE -1
    443 
    444  public:
    445   MklDnnShape() {
    446     for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
    447          ++i) {
    448       data_.sizes_[i] = -1;
    449     }
    450     for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
    451       data_.map_[i] = -1;
    452     }
    453   }
    454 
    455   ~MklDnnShape() {}
    456   TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape);  // Cannot copy
    457 
    458   /// Helper function to compare memory::desc objects for MklDnn.
    459   /// May be this should go into MklDnn directly.
    460   inline bool CompareMklDnnLayouts(const memory::desc& md1,
    461                                    const memory::desc& md2) const {
    462     mkldnn_memory_desc_t mdd1 = md1.data;
    463     mkldnn_memory_desc_t mdd2 = md2.data;
    464     const char* d1 = reinterpret_cast<const char*>(&mdd1);
    465     const char* d2 = reinterpret_cast<const char*>(&mdd2);
    466 
    467     size_t md_size = sizeof(mdd1);
    468     for (size_t i = 0; i < md_size; i++) {
    469       if (*d1++ != *d2++) {
    470         return false;
    471       }
    472     }
    473     return true;
    474   }
    475 
    476   /// Equality function for MklDnnShape objects
    477   /// @return true if both are equal; false otherwise.
    478   inline bool operator==(const MklDnnShape& input_shape) const {
    479     if (this->IsMklTensor() != input_shape.IsMklTensor()) {
    480       return false;
    481     }
    482 
    483     // If input tensors are in Mkl layout, then we check for dimensions and
    484     // sizes.
    485     if (this->IsMklTensor()) {
    486       return this->GetTfShape() == input_shape.GetTfShape() &&
    487              CompareMklDnnLayouts(this->GetMklLayout(),
    488                                   input_shape.GetMklLayout());
    489     }
    490 
    491     return true;
    492   }
    493 
    494   /// Equality operator for MklDnnShape and TFShape.
    495   /// Returns: true if TF shapes for both are the same, false otherwise
    496   inline bool operator==(const TensorShape& input_shape) const {
    497     if (!this->IsMklTensor()) {
    498       return false;
    499     }
    500 
    501     return this->GetTfShape() == input_shape;
    502   }
    503 
    504   inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
    505   inline void SetMklTensor(bool is_mkl_tensor) {
    506     data_.is_mkl_tensor_ = is_mkl_tensor;
    507   }
    508 
    509   inline void SetDimensions(const size_t dimension) {
    510     data_.dimension_ = dimension;
    511   }
    512   inline size_t GetDimension(char dimension) const {
    513     int index = GetMklDnnTensorDimIndex(dimension);
    514     CHECK(index >= 0 && index < this->GetDimension())
    515         << "Invalid index from the dimension: " << index << ", " << dimension;
    516     return this->DimSize(index);
    517   }
    518 
    519   inline size_t GetDimension3D(char dimension) const {
    520     int index = GetMklDnnTensor3DDimIndex(dimension);
    521     CHECK(index >= 0 && index < this->GetDimension())
    522         << "Invalid index from the dimension: " << index << ", " << dimension;
    523     return this->DimSize(index);
    524   }
    525 
    526   inline int32 GetMklDnnTensorDimIndex(char dimension) const {
    527     switch (dimension) {
    528       case 'N':
    529         return MklDnnDims::Dim_N;
    530       case 'C':
    531         return MklDnnDims::Dim_C;
    532       case 'H':
    533         return MklDnnDims::Dim_H;
    534       case 'W':
    535         return MklDnnDims::Dim_W;
    536       default:
    537         LOG(FATAL) << "Invalid dimension: " << dimension;
    538         return -1;  // Avoid compiler warning about missing return value
    539     }
    540   }
    541 
    542   inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
    543     switch (dimension) {
    544       case 'N':
    545         return MklDnnDims3D::Dim3d_N;
    546       case 'C':
    547         return MklDnnDims3D::Dim3d_C;
    548       case 'D':
    549         return MklDnnDims3D::Dim3d_D;
    550       case 'H':
    551         return MklDnnDims3D::Dim3d_H;
    552       case 'W':
    553         return MklDnnDims3D::Dim3d_W;
    554       default:
    555         LOG(FATAL) << "Invalid dimension: " << dimension;
    556         return -1;  // Avoid compiler warning about missing return value
    557     }
    558   }
    559 
    560   inline size_t GetDimension() const { return data_.dimension_; }
    561   inline const int* GetSizes() const {
    562     return reinterpret_cast<const int*>(&data_.sizes_[0]);
    563   }
    564 
    565   // Returns an mkldnn::memory::dims object that contains the sizes of this
    566   // MklDnnShape object.
    567   inline memory::dims GetSizesAsMklDnnDims() const {
    568     memory::dims retVal;
    569     if (data_.is_mkl_tensor_) {
    570       size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
    571       for (size_t i = 0; i < dimensions; i++) {
    572         if (data_.sizes_[i] != INVALID_DIM_SIZE)
    573           retVal.push_back(data_.sizes_[i]);
    574       }
    575     } else {
    576       CHECK_EQ(data_.is_mkl_tensor_, true);
    577     }
    578     return retVal;
    579   }
    580 
    581   inline int64 DimSize(int index) const {
    582     CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
    583     return data_.sizes_[index];
    584   }
    585 
    586   /// Return TensorShape that describes the Tensorflow shape of the tensor
    587   /// represented by this MklShape.
    588   inline TensorShape GetTfShape() const {
    589     CHECK_EQ(data_.is_mkl_tensor_, true);
    590 
    591     std::vector<int32> shape(data_.dimension_, -1);
    592     if (data_.tf_data_format_ != memory::format::blocked) {
    593       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
    594         shape[idx] = data_.sizes_[TfDimIdx(idx)];
    595       }
    596     } else {
    597       // If Tensorflow shape is in Blocked format, then we don't have dimension
    598       // map for it. So we just create Tensorflow shape from sizes in the
    599       // specified order.
    600       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
    601         shape[idx] = data_.sizes_[idx];
    602       }
    603     }
    604 
    605     TensorShape ts;
    606     bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
    607     CHECK_EQ(ret, true);
    608     return ts;
    609   }
    610 
    611   inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
    612   inline const memory::data_type GetElemType() { return data_.T_; }
    613 
    614   inline void SetMklLayout(memory::primitive_desc* pd) {
    615     CHECK_NOTNULL(pd);
    616     data_.mkl_md_ = pd->desc().data;
    617   }
    618 
    619   inline void SetMklLayout(memory::desc* md) {
    620     CHECK_NOTNULL(md);
    621     data_.mkl_md_ = md->data;
    622   }
    623 
    624   inline const memory::desc GetMklLayout() const {
    625     return memory::desc(data_.mkl_md_);
    626   }
    627 
    628   inline memory::format GetTfDataFormat() const {
    629     return data_.tf_data_format_;
    630   }
    631   /// We don't create primitive_descriptor for TensorFlow layout now.
    632   /// We use lazy evaluation and create it only when needed. Input format can
    633   /// also be Blocked format.
    634   inline void SetTfLayout(size_t dims, const memory::dims& sizes,
    635                           memory::format format) {
    636     CHECK_EQ(dims, sizes.size());
    637     data_.dimension_ = dims;
    638     for (size_t ii = 0; ii < dims; ii++) {
    639       data_.sizes_[ii] = sizes[ii];
    640     }
    641     data_.tf_data_format_ = format;
    642     if (format != memory::format::blocked) {
    643       SetTfDimOrder(dims, format);
    644     }
    645   }
    646 
    647   inline const memory::desc GetTfLayout() const {
    648     memory::dims dims;
    649     for (size_t ii = 0; ii < data_.dimension_; ii++) {
    650       dims.push_back(data_.sizes_[ii]);
    651     }
    652 
    653     // Create Blocked memory desc if input TF format was set like that.
    654     if (data_.tf_data_format_ == memory::format::blocked) {
    655       auto strides = CalculateTFStrides(dims);
    656       return CreateBlockedMemDescHelper(dims, strides, data_.T_);
    657     } else {
    658       return memory::desc(dims, data_.T_, data_.tf_data_format_);
    659     }
    660   }
    661 
    662   inline const memory::desc GetCurLayout() const {
    663     return IsMklTensor() ? GetMklLayout() : GetTfLayout();
    664   }
    665 
    666   // nhasabni - I've removed SetTfDimOrder that was setting default order in
    667   // case of MKL-ML. We don't need a case of default dimension order because
    668   // when an operator that does not get data_format attribute gets all inputs
    669   // in Tensorflow format, it will produce output in Tensorflow format.
    670   inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
    671     CHECK(dimension == data_.dimension_);
    672     for (size_t ii = 0; ii < dimension; ii++) {
    673       data_.map_[ii] = map[ii];
    674     }
    675   }
    676 
    677   inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
    678     if (dimension == 5) {
    679       CHECK(dimension == data_.dimension_);
    680       data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
    681           MklDnnDims3D::Dim3d_D;
    682       data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
    683           MklDnnDims3D::Dim3d_H;
    684       data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
    685           MklDnnDims3D::Dim3d_W;
    686       data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
    687           MklDnnDims3D::Dim3d_C;
    688       data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
    689           MklDnnDims3D::Dim3d_N;
    690     } else {
    691       CHECK_EQ(dimension, 4);
    692       CHECK(dimension == data_.dimension_);
    693       data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
    694       data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
    695       data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
    696       data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
    697     }
    698   }
    699 
    700   inline void SetTfDimOrder(const size_t dimension, memory::format format) {
    701     TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
    702     SetTfDimOrder(dimension, data_format);
    703   }
    704 
    705   inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
    706   inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
    707   inline int64 TfDimSize(int index) const {
    708     return data_.sizes_[TfDimIdx(index)];
    709   }
    710 
    711   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    712   /// corresponds to MKL's Channel dimension.
    713   inline bool IsMklChannelDim(int d) const {
    714     return TfDimIdx(d) == MklDnnDims::Dim_C;
    715   }
    716   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    717   /// corresponds to MKL's Batch dimension.
    718   inline bool IsMklBatchDim(int d) const {
    719     return TfDimIdx(d) == MklDnnDims::Dim_N;
    720   }
    721   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    722   /// corresponds to MKL's Width dimension.
    723   inline bool IsMklWidthDim(int d) const {
    724     return TfDimIdx(d) == MklDnnDims::Dim_W;
    725   }
    726   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    727   /// corresponds to MKL's Height dimension.
    728   inline bool IsMklHeightDim(int d) const {
    729     return TfDimIdx(d) == MklDnnDims::Dim_H;
    730   }
    731 
    732   /// Check if the TF-Mkl dimension ordering map specifies if the input
    733   /// tensor is in NCHW format.
    734   inline bool IsTensorInNCHWFormat() const {
    735     TensorFormat data_format = FORMAT_NCHW;
    736     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    737             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    738             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    739             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    740   }
    741 
    742   /// Check if the TF-Mkl dimension ordering map specifies if the input
    743   /// tensor is in NHWC format.
    744   inline bool IsTensorInNHWCFormat() const {
    745     TensorFormat data_format = FORMAT_NHWC;
    746     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    747             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    748             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    749             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    750   }
    751 
    752   /// The following methods are used for serializing and de-serializing the
    753   /// contents of the mklshape object.
    754   /// The data is serialized in this order
    755   /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
    756 
    757   /// Size of buffer to hold the serialized object, the size is computed by
    758   /// following above mentioned order
    759   inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
    760 
    761   void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
    762     CHECK(buf_size >= GetSerializeBufferSize())
    763         << "Buffer size is too small to SerializeMklDnnShape";
    764     *reinterpret_cast<MklShapeData*>(buf) = data_;
    765   }
    766 
    767   void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
    768     // Make sure buffer holds at least is_mkl_tensor_.
    769     CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
    770         << "Buffer size is too small in DeSerializeMklDnnShape";
    771 
    772     const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
    773     if (is_mkl_tensor) {  // If it is an MKL Tensor then read the rest
    774       CHECK(buf_size >= GetSerializeBufferSize())
    775           << "Buffer size is too small in DeSerializeMklDnnShape";
    776       data_ = *reinterpret_cast<const MklShapeData*>(buf);
    777     }
    778   }
    779 };
    780 
    781 #endif
    782 
    783 // List of MklShape objects. Used in Concat/Split layers.
    784 
    785 #ifndef INTEL_MKL_ML_ONLY
    786 typedef std::vector<MklDnnShape> MklDnnShapeList;
    787 #else
    788 typedef std::vector<MklShape> MklShapeList;
    789 #endif
    790 
    791 #ifdef INTEL_MKL_ML_ONLY
    792 // Check if all tensors specified by MklShapes are MKL tensors.
    793 inline bool AreAllMklTensors(const MklShapeList& shapes) {
    794   for (auto& s : shapes) {
    795     if (!s.IsMklTensor()) {
    796       return false;
    797     }
    798   }
    799   return true;
    800 }
    801 
    802 template <typename T>
    803 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
    804                              const MklShape& mkl_shape) {
    805   Tensor output_tensor;
    806   TensorShape output_shape;
    807 
    808   for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
    809     // Outermost to innermost dimension
    810     output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
    811   }
    812 
    813   // Allocate output tensor.
    814   context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
    815 
    816   dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
    817   void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
    818   void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
    819 
    820   if (mkl_tensor.NumElements() != 0) {
    821     mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
    822   }
    823 
    824   return output_tensor;
    825 }
    826 #else
    827 using mkldnn::stream;
    828 template <typename T>
    829 class MklDnnData;
    830 
    831 template <typename T>
    832 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
    833                              const MklDnnShape& mkl_shape) {
    834   Tensor output_tensor;
    835   try {
    836     if (!mkl_shape.IsMklTensor())
    837       return mkl_tensor;  // return input since it is already TF tensor
    838 
    839     TensorShape output_shape = mkl_shape.GetTfShape();
    840 
    841     // Allocate output tensor.
    842     context->allocate_temp(DataTypeToEnum<T>::v(), output_shape,
    843                            &output_tensor);
    844 
    845     auto cpu_engine = engine(engine::cpu, 0);
    846     MklDnnData<T> input(&cpu_engine);
    847 
    848     // Get Mkl layout of input tensor.
    849     auto input_mkl_md = mkl_shape.GetMklLayout();
    850     auto output_tf_md = mkl_shape.GetTfLayout();
    851     auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine);
    852     input.SetUsrMem(input_mkl_md, &mkl_tensor);
    853 
    854     // reorder
    855     if (input.IsReorderNeeded(output_tf_pd)) {
    856       std::vector<primitive> net;
    857       CHECK_EQ(input.CheckReorderToOpMem(output_tf_pd, &output_tensor, &net),
    858                true);
    859       stream(stream::kind::eager).submit(net).wait();
    860     } else {
    861       // If not, just forward input tensor to output tensor.
    862       CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape));
    863     }
    864   } catch (mkldnn::error& e) {
    865     string error_msg = "Status: " + std::to_string(e.status) +
    866                        ", message: " + string(e.message) + ", in file " +
    867                        string(__FILE__) + ":" + std::to_string(__LINE__);
    868     LOG(FATAL) << "Operation received an exception: " << error_msg;
    869   }
    870   return output_tensor;
    871 }
    872 #endif
    873 
    874 // Get the MKL shape from the second string tensor
    875 #ifdef INTEL_MKL_ML_ONLY
    876 inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
    877   mklshape->DeSerializeMklShape(
    878       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    879           .flat<uint8>()
    880           .data(),
    881       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    882               .flat<uint8>()
    883               .size() *
    884           sizeof(uint8));
    885 }
    886 #else
    887 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
    888   mklshape->DeSerializeMklDnnShape(
    889       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    890           .flat<uint8>()
    891           .data(),
    892       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    893               .flat<uint8>()
    894               .size() *
    895           sizeof(uint8));
    896 }
    897 #endif
    898 
    899 // Gets the actual input
    900 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
    901   return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
    902 }
    903 
    904 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
    905                             OpInputList* input_tensors) {
    906   CHECK_NOTNULL(input_tensors);
    907   ctext->input_list(name, input_tensors);
    908 }
    909 
    910 #ifdef INTEL_MKL_ML_ONLY
    911 
    912 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
    913                             MklShapeList* mkl_shapes) {
    914   OpInputList input_mkl_tensors;
    915   GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
    916 
    917   for (int i = 0; i < input_mkl_tensors.size(); i++) {
    918     (*mkl_shapes)[i].DeSerializeMklShape(
    919         input_mkl_tensors[i].flat<uint8>().data(),
    920         input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
    921   }
    922 }
    923 
    924 #else
    925 
    926 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
    927                             MklDnnShapeList* mkl_shapes) {
    928   OpInputList input_mkl_tensors;
    929   GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
    930 
    931   for (int i = 0; i < input_mkl_tensors.size(); i++) {
    932     (*mkl_shapes)[i].DeSerializeMklDnnShape(
    933         input_mkl_tensors[i].flat<uint8>().data(),
    934         input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
    935   }
    936 }
    937 
    938 #endif
    939 
    940 #ifndef INTEL_MKL_ML_ONLY
    941 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
    942 /// If the input tensor is in MKL layout, then obtains TensorShape from
    943 /// MklShape.
    944 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
    945   // Sanity check.
    946   CHECK_NOTNULL(context);
    947   CHECK_LT(input_idx, context->num_inputs());
    948 
    949   MklDnnShape input_mkl_shape;
    950   GetMklShape(context, input_idx, &input_mkl_shape);
    951   if (input_mkl_shape.IsMklTensor()) {
    952     return input_mkl_shape.GetTfShape();
    953   } else {
    954     const Tensor& t = MklGetInput(context, input_idx);
    955     return t.shape();
    956   }
    957 }
    958 #endif
    959 
    960 #ifdef INTEL_MKL_ML_ONLY
    961 // Allocate the second output tensor that will contain
    962 // the MKL shape serialized
    963 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
    964                                       const MklShape& mkl_shape) {
    965   Tensor* second_tensor = nullptr;
    966   TensorShape second_shape;
    967   second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
    968   OP_REQUIRES_OK(ctext, ctext->allocate_output(
    969                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
    970                             second_shape, &second_tensor));
    971   mkl_shape.SerializeMklShape(
    972       second_tensor->flat<uint8>().data(),
    973       second_tensor->flat<uint8>().size() * sizeof(uint8));
    974 }
    975 
    976 #else
    977 // Allocate the second output tensor that will contain
    978 // the MKL shape serialized
    979 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
    980                                       const MklDnnShape& mkl_shape) {
    981   Tensor* second_tensor = nullptr;
    982   TensorShape second_shape;
    983   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
    984   OP_REQUIRES_OK(ctext, ctext->allocate_output(
    985                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
    986                             second_shape, &second_tensor));
    987   mkl_shape.SerializeMklDnnShape(
    988       second_tensor->flat<uint8>().data(),
    989       second_tensor->flat<uint8>().size() * sizeof(uint8));
    990 }
    991 #endif
    992 
    993 #ifdef INTEL_MKL_ML_ONLY
    994 // Allocate the output tensor, create a second output tensor that will contain
    995 // the MKL shape serialized
    996 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
    997                                       Tensor** output,
    998                                       const TensorShape& tf_shape,
    999                                       const MklShape& mkl_shape) {
   1000   Tensor* second_tensor = nullptr;
   1001   TensorShape second_shape;
   1002   second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
   1003   OP_REQUIRES_OK(
   1004       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
   1005                                     tf_shape, output));
   1006   OP_REQUIRES_OK(ctext, ctext->allocate_output(
   1007                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
   1008                             second_shape, &second_tensor));
   1009   mkl_shape.SerializeMklShape(
   1010       second_tensor->flat<uint8>().data(),
   1011       second_tensor->flat<uint8>().size() * sizeof(uint8));
   1012 }
   1013 
   1014 #else
   1015 // Allocate the output tensor, create a second output tensor that will contain
   1016 // the MKL shape serialized
   1017 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
   1018                                       Tensor** output,
   1019                                       const TensorShape& tf_shape,
   1020                                       const MklDnnShape& mkl_shape) {
   1021   Tensor* second_tensor = nullptr;
   1022   TensorShape second_shape;
   1023   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
   1024   OP_REQUIRES_OK(
   1025       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
   1026                                     tf_shape, output));
   1027   OP_REQUIRES_OK(ctext, ctext->allocate_output(
   1028                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
   1029                             second_shape, &second_tensor));
   1030   mkl_shape.SerializeMklDnnShape(
   1031       second_tensor->flat<uint8>().data(),
   1032       second_tensor->flat<uint8>().size() * sizeof(uint8));
   1033 }
   1034 #endif
   1035 
   1036 // Allocates a temp tensor and returns the data buffer for temporary storage.
   1037 // Currently
   1038 #ifndef INTEL_MKL_ML_ONLY
   1039 template <typename T>
   1040 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
   1041                            const memory::primitive_desc& pd, void** buf_out) {
   1042   TensorShape tf_shape;
   1043 
   1044   tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
   1045   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
   1046                                                  tf_shape, tensor_out));
   1047   *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
   1048 }
   1049 #else
   1050 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
   1051                            dnnLayout_t lt_buff, void** buf_out) {
   1052   TensorShape tf_shape;
   1053 
   1054   tf_shape.AddDim(
   1055       dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(lt_buff)) /
   1056           sizeof(float) +
   1057       1);
   1058   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::v(),
   1059                                                  tf_shape, tensor_out));
   1060   *buf_out = static_cast<void*>(tensor_out->flat<float>().data());
   1061 }
   1062 
   1063 #endif
   1064 template <typename T>
   1065 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
   1066                            TensorShape tf_shape) {
   1067   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
   1068                                                  tf_shape, tensor_out));
   1069 }
   1070 
   1071 inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
   1072                                 const size_t* sizes) {
   1073   // MKL requires strides in NCHW
   1074   if (data_format == FORMAT_NHWC) {
   1075     strides[0] = sizes[2];
   1076     strides[1] = sizes[0] * sizes[2];
   1077     strides[2] = 1;
   1078     strides[3] = sizes[0] * sizes[1] * sizes[2];
   1079   } else {
   1080     strides[0] = 1;
   1081     strides[1] = sizes[0];
   1082     strides[2] = sizes[0] * sizes[1];
   1083     strides[3] = sizes[0] * sizes[1] * sizes[2];
   1084   }
   1085 }
   1086 
   1087 #ifdef INTEL_MKL_ML_ONLY
   1088 inline void MklSizesToTFSizes(OpKernelContext* context,
   1089                               TensorFormat data_format_,
   1090                               const MklShape& mkl_shape,
   1091                               TensorShape* tf_shape) {
   1092   size_t tf_dim = mkl_shape.GetDimension();
   1093   const size_t* tf_sizes = mkl_shape.GetSizes();
   1094 
   1095   OP_REQUIRES(context, tf_dim == 4,
   1096               errors::InvalidArgument("MKLSizesToTFSizes: size must be 4-dim"));
   1097   std::vector<int32> sizes;
   1098 
   1099   sizes.push_back(tf_sizes[3]);
   1100 
   1101   if (data_format_ == FORMAT_NHWC) {
   1102     sizes.push_back(tf_sizes[1]);
   1103     sizes.push_back(tf_sizes[0]);
   1104     sizes.push_back(tf_sizes[2]);
   1105   } else {
   1106     sizes.push_back(tf_sizes[2]);
   1107     sizes.push_back(tf_sizes[1]);
   1108     sizes.push_back(tf_sizes[0]);
   1109   }
   1110 
   1111   OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
   1112 }
   1113 #endif
   1114 
   1115 inline int32 GetMklTensorDimIndex(char dimension) {
   1116   switch (dimension) {
   1117     case 'N':
   1118       return MklDims::N;
   1119     case 'C':
   1120       return MklDims::C;
   1121     case 'H':
   1122       return MklDims::H;
   1123     case 'W':
   1124       return MklDims::W;
   1125     default:
   1126       LOG(FATAL) << "Invalid dimension: " << dimension;
   1127       return -1;  // Avoid compiler warning about missing return value
   1128   }
   1129 }
   1130 
   1131 #ifdef INTEL_MKL_ML_ONLY
   1132 inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
   1133   int index = GetMklTensorDimIndex(dimension);
   1134   CHECK(index >= 0 && index < mkl_shape.GetDimension())
   1135       << "Invalid index from the dimension: " << index << ", " << dimension;
   1136   return mkl_shape.dim_size(index);
   1137 }
   1138 #endif
   1139 
   1140 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
   1141                                  int idx_out) {
   1142   int num_inputs = context->num_inputs();
   1143   int num_outputs = context->num_outputs();
   1144   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1145   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
   1146   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1147   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
   1148 
   1149   const Tensor& data = context->input(idx_data_in);
   1150   const Tensor& meta = context->input(idx_meta_in);
   1151   Tensor output(data.dtype());
   1152   Tensor meta_output(meta.dtype());
   1153 
   1154   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
   1155   CHECK(output.CopyFrom(data, data.shape()));
   1156   CHECK(meta_output.CopyFrom(meta, meta.shape()));
   1157   context->set_output(idx_data_out, output);
   1158   context->set_output(idx_meta_out, meta_output);
   1159 }
   1160 
   1161 #ifdef INTEL_MKL_ML_ONLY
   1162 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
   1163                                          int idx_out,
   1164                                          const TensorShape& shape) {
   1165   int num_inputs = context->num_inputs();
   1166   int num_outputs = context->num_outputs();
   1167   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1168   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1169 
   1170   const Tensor& data = context->input(idx_data_in);
   1171   MklShape mkl_shape_output;
   1172   mkl_shape_output.SetMklTensor(false);
   1173   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
   1174   Tensor output(data.dtype());
   1175   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
   1176   CHECK(output.CopyFrom(data, shape));
   1177   context->set_output(idx_data_out, output);
   1178 }
   1179 #else
   1180 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
   1181                                          int idx_out,
   1182                                          const TensorShape& shape) {
   1183   int num_inputs = context->num_inputs();
   1184   int num_outputs = context->num_outputs();
   1185   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1186   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1187 
   1188   const Tensor& data = context->input(idx_data_in);
   1189   MklDnnShape mkl_shape_output;
   1190   mkl_shape_output.SetMklTensor(false);
   1191   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
   1192   Tensor output(data.dtype());
   1193   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
   1194   CHECK(output.CopyFrom(data, shape));
   1195   context->set_output(idx_data_out, output);
   1196 }
   1197 #endif
   1198 
   1199 #ifdef INTEL_MKL_ML_ONLY
   1200 
   1201 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
   1202                                    int idx_out) {
   1203   int num_inputs = context->num_inputs();
   1204   int num_outputs = context->num_outputs();
   1205   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1206   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1207 
   1208   MklShape mkl_shape_output;
   1209   mkl_shape_output.SetMklTensor(false);
   1210   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
   1211   if (IsRefType(context->input_dtype(idx_data_in))) {
   1212     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1213   } else {
   1214     context->set_output(idx_data_out, context->input(idx_data_in));
   1215   }
   1216 }
   1217 
   1218 #else
   1219 
   1220 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
   1221                                    int idx_out) {
   1222   int num_inputs = context->num_inputs();
   1223   int num_outputs = context->num_outputs();
   1224   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1225   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1226 
   1227   MklDnnShape dnn_shape_output;
   1228   dnn_shape_output.SetMklTensor(false);
   1229   AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
   1230   if (IsRefType(context->input_dtype(idx_data_in))) {
   1231     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1232   } else {
   1233     context->set_output(idx_data_out, context->input(idx_data_in));
   1234   }
   1235 }
   1236 
   1237 #endif
   1238 
   1239 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
   1240                                     int idx_out) {
   1241   int num_inputs = context->num_inputs();
   1242   int num_outputs = context->num_outputs();
   1243   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1244   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
   1245   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1246   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
   1247 
   1248   if (IsRefType(context->input_dtype(idx_data_in))) {
   1249     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1250     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
   1251   } else {
   1252     context->set_output(idx_data_out, context->input(idx_data_in));
   1253     context->set_output(idx_meta_out, context->input(idx_meta_in));
   1254   }
   1255 }
   1256 
   1257 #ifndef INTEL_MKL_ML_ONLY
   1258 // Set a dummy MKLDNN shape (called when the output is in TF format)
   1259 inline void SetDummyMklDnnShapeOutput(OpKernelContext* context,
   1260                                       uint32 idx_data_out) {
   1261   MklDnnShape mkl_shape_output;
   1262   mkl_shape_output.SetMklTensor(false);
   1263   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
   1264 }
   1265 
   1266 inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
   1267                                                 int idx_in, int idx_out,
   1268                                                 const MklDnnShape& mkl_shape) {
   1269   int num_inputs = context->num_inputs();
   1270   int num_outputs = context->num_outputs();
   1271   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1272   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1273 
   1274   AllocateOutputSetMklShape(context, idx_out, mkl_shape);
   1275 
   1276   if (IsRefType(context->input_dtype(idx_data_in))) {
   1277     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1278   } else {
   1279     context->set_output(idx_data_out, context->input(idx_data_in));
   1280   }
   1281 }
   1282 #endif
   1283 
   1284 // Forward the MKL shape ONLY (used in elementwise and other ops where
   1285 // we call the eigen implementation and MKL shape is not used)
   1286 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
   1287                                       uint32 idx_data_in,
   1288                                       uint32_t idx_data_out) {
   1289   uint32 idx_meta_in =
   1290       GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
   1291   uint32 idx_meta_out =
   1292       GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
   1293 
   1294   if (IsRefType(context->input_dtype(idx_data_in))) {
   1295     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
   1296   } else {
   1297     context->set_output(idx_meta_out, context->input(idx_meta_in));
   1298   }
   1299 }
   1300 
   1301 #ifdef INTEL_MKL_ML_ONLY
   1302 // Set a dummy MKL shape (called when the output is in TF format)
   1303 inline void SetDummyMklShapeOutput(OpKernelContext* context,
   1304                                    uint32 idx_data_out) {
   1305   MklShape mkl_shape_output;
   1306   mkl_shape_output.SetMklTensor(false);
   1307   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
   1308 }
   1309 // We don't need these functions in MKLDNN. We have defined equality operator
   1310 // on MklDnnShape class directly.
   1311 
   1312 // Checks if the TF shape for both MKL tensors is the same or not
   1313 // Returns: true if both TF shapes are the same, false otherwise
   1314 inline bool MklCompareShapes(const MklShape* input_shape_0,
   1315                              const MklShape* input_shape_1) {
   1316   // Check for number of dimensions
   1317   if (input_shape_0->GetDimension() != input_shape_1->GetDimension()) {
   1318     return false;
   1319   }
   1320 
   1321   // Check size of each dimension
   1322   size_t ndims = input_shape_0->GetDimension();
   1323   for (size_t i = 0; i < ndims; i++) {
   1324     if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
   1325       return false;
   1326     }
   1327   }
   1328 
   1329   return true;
   1330 }
   1331 
   1332 // Checks if the TF shape for both tensors is the same or not
   1333 // Returns: true if TF shapes for both are the same, false otherwise
   1334 inline bool MklCompareShapes(const MklShape* input_shape_0,
   1335                              const TensorShape* input_shape_1) {
   1336   // Check for number of dimensions
   1337   if (input_shape_0->GetDimension() != input_shape_1->dims()) {
   1338     return false;
   1339   }
   1340 
   1341   // Check size of each dimension
   1342   size_t ndims = input_shape_0->GetDimension();
   1343   for (size_t i = 0; i < ndims; i++) {
   1344     if (input_shape_0->tf_dim_size(i) != input_shape_1->dim_size(i)) {
   1345       return false;
   1346     }
   1347   }
   1348 
   1349   return true;
   1350 }
   1351 
   1352 // Checks if the TF shape for both tensors is the same or not
   1353 // Returns: true if TF shapes for both are the same, false otherwise
   1354 inline bool MklCompareShapes(const TensorShape* input_shape_0,
   1355                              const MklShape* input_shape_1) {
   1356   return MklCompareShapes(input_shape_1, input_shape_0);
   1357 }
   1358 
   1359 // Checks if the TF shape for both tensors is the same or not
   1360 // Returns: true if TF shapes for both are the same, false otherwise
   1361 inline bool MklCompareShapes(const TensorShape* input_shape_0,
   1362                              const TensorShape* input_shape_1) {
   1363   // Check for number of dimensions
   1364   if (input_shape_0->dims() != input_shape_1->dims()) {
   1365     return false;
   1366   }
   1367 
   1368   // Check size of each dimension
   1369   size_t ndims = input_shape_0->dims();
   1370   for (size_t i = 0; i < ndims; i++) {
   1371     if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
   1372       return false;
   1373     }
   1374   }
   1375 
   1376   return true;
   1377 }
   1378 
   1379 // These functions do not compile with MKL-DNN since mkl.h is missing.
   1380 // We may need to remove them later.
   1381 // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
   1382 // out.
   1383 inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
   1384   const float* buf_in = input.flat<float>().data();
   1385   float* buf_out = (*output)->flat<float>().data();
   1386 
   1387   int64 N = input.dim_size(0);
   1388   int64 H = input.dim_size(1);
   1389   int64 W = input.dim_size(2);
   1390   int64 C = input.dim_size(3);
   1391   int64 stride_n = H * W * C;
   1392 #pragma omp parallel for num_threads(16)
   1393   for (int64 n = 0; n < N; ++n) {
   1394     mkl_somatcopy('R', 'T', H * W, C, 1, buf_in + n * stride_n, C,
   1395                   buf_out + n * stride_n, H * W);
   1396   }
   1397 }
   1398 
   1399 inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
   1400   const float* buf_in = input.flat<float>().data();
   1401   float* buf_out = (*output)->flat<float>().data();
   1402 
   1403   int64 N = (*output)->dim_size(0);
   1404   int64 H = (*output)->dim_size(1);
   1405   int64 W = (*output)->dim_size(2);
   1406   int64 C = (*output)->dim_size(3);
   1407   int64 stride_n = H * W * C;
   1408 #pragma omp parallel for num_threads(16)
   1409   for (int64 n = 0; n < N; ++n) {
   1410     mkl_somatcopy('R', 'T', C, H * W, 1, buf_in + n * stride_n, H * W,
   1411                   buf_out + n * stride_n, C);
   1412   }
   1413 }
   1414 
   1415 #endif
   1416 // -------------------------------------------------------------------
   1417 
   1418 #ifndef INTEL_MKL_ML_ONLY
   1419 
   1420 /// Return MKL-DNN data type (memory::data_type) for input type T
   1421 ///
   1422 /// @input None
   1423 /// @return memory::data_type corresponding to type T
   1424 template <typename T>
   1425 static memory::data_type MklDnnType();
   1426 
   1427 /// Instantiation for float type. Add similar instantiations for other
   1428 /// type if needed.
   1429 template <>
   1430 memory::data_type MklDnnType<float>() {
   1431   return memory::data_type::f32;
   1432 }
   1433 template <>
   1434 memory::data_type MklDnnType<quint8>() {
   1435   return memory::data_type::u8;
   1436 }
   1437 template <>
   1438 memory::data_type MklDnnType<qint8>() {
   1439   return memory::data_type::s8;
   1440 }
   1441 template <>
   1442 memory::data_type MklDnnType<qint32>() {
   1443   return memory::data_type::s32;
   1444 }
   1445 
   1446 /// Map TensorFlow's data format into MKL-DNN 3D data format
   1447 /// @input: TensorFlow data format
   1448 /// @return: memory::format corresponding to TensorFlow data format;
   1449 ///          Fails with an error if invalid data format.
   1450 inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
   1451   if (format == FORMAT_NHWC)
   1452     return memory::format::ndhwc;
   1453   else if (format == FORMAT_NCHW)
   1454     return memory::format::ncdhw;
   1455   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
   1456   return memory::format::format_undef;
   1457 }
   1458 
   1459 /// Map TensorFlow's data format into MKL-DNN data format
   1460 ///
   1461 /// @input: TensorFlow data format
   1462 /// @return: memory::format corresponding to TensorFlow data format;
   1463 ///          Fails with an error if invalid data format.
   1464 inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
   1465   if (format == FORMAT_NHWC)
   1466     return memory::format::nhwc;
   1467   else if (format == FORMAT_NCHW)
   1468     return memory::format::nchw;
   1469   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
   1470   return memory::format::format_undef;
   1471 }
   1472 
   1473 /// Map MKL-DNN data format to TensorFlow's data format
   1474 ///
   1475 /// @input: memory::format
   1476 /// @return: Tensorflow data format corresponding to memory::format
   1477 ///          Fails with an error if invalid data format.
   1478 inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
   1479   if (format == memory::format::nhwc || format == memory::format::ndhwc)
   1480     return FORMAT_NHWC;
   1481   else if (format == memory::format::nchw || format == memory::format::ncdhw)
   1482     return FORMAT_NCHW;
   1483   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
   1484 
   1485   // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
   1486   // that we don't come here.
   1487   return FORMAT_NHWC;
   1488 }
   1489 
   1490 /// Map TensorShape object into memory::dims required by MKL-DNN
   1491 ///
   1492 /// This function will simply map input TensorShape into MKL-DNN dims
   1493 /// naively. So it will preserve the order of dimensions. E.g., if
   1494 /// input tensor is in NHWC format, then dims will be in NHWC format
   1495 /// also.
   1496 ///
   1497 /// @input TensorShape object in shape
   1498 /// @return memory::dims corresponding to TensorShape
   1499 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
   1500   memory::dims dims(shape.dims());
   1501   for (int d = 0; d < shape.dims(); ++d) {
   1502     dims[d] = shape.dim_size(d);
   1503   }
   1504   return dims;
   1505 }
   1506 
   1507 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
   1508 ///
   1509 /// This function is a specific one than above function. It will map input
   1510 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
   1511 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
   1512 /// will be in NCHW format, and not in NHWC format.
   1513 ///
   1514 /// @input TensorShape object in shape
   1515 /// @return memory::dims in MKL-DNN required NCHW format
   1516 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
   1517                                               TensorFormat format) {
   1518   // Check validity of format.
   1519   CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
   1520            memory::format::format_undef);
   1521 
   1522   int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
   1523   int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
   1524   int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
   1525   int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
   1526 
   1527   // MKL-DNN requires dimensions in NCHW format.
   1528   return memory::dims({n, c, h, w});
   1529 }
   1530 
   1531 inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
   1532                                                TensorFormat format) {
   1533   // Check validity of format.
   1534   CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
   1535            memory::format::format_undef);
   1536 
   1537   int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
   1538   int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
   1539   int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
   1540   int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
   1541   int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
   1542 
   1543   // MKL-DNN requires dimensions in NCDHW format.
   1544   return memory::dims({n, c, d, h, w});
   1545 }
   1546 
   1547 /// Overloaded version of function above. Input parameters are
   1548 /// self-explanatory.
   1549 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
   1550                                      TensorFormat format) {
   1551   // Check validity of format.
   1552   CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
   1553            memory::format::format_undef);
   1554 
   1555   int n = in_dims[GetTensorDimIndex(format, 'N')];
   1556   int c = in_dims[GetTensorDimIndex(format, 'C')];
   1557   int h = in_dims[GetTensorDimIndex(format, 'H')];
   1558   int w = in_dims[GetTensorDimIndex(format, 'W')];
   1559 
   1560   // MKL-DNN requires dimensions in NCHW format.
   1561   return memory::dims({n, c, h, w});
   1562 }
   1563 
   1564 /// Map MklDnn memory::dims object into TensorShape object.
   1565 ///
   1566 /// This function will simply map input shape in MKL-DNN memory::dims format
   1567 /// in Tensorflow's TensorShape object by preserving dimension order.
   1568 ///
   1569 /// @input MKL-DNN memory::dims object
   1570 /// @output TensorShape corresponding to memory::dims
   1571 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
   1572   std::vector<int32> shape(dims.size(), -1);
   1573   for (int d = 0; d < dims.size(); d++) {
   1574     shape[d] = dims[d];
   1575   }
   1576 
   1577   TensorShape ret;
   1578   CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
   1579   return ret;
   1580 }
   1581 
   1582 /// Function to calculate strides given tensor shape in Tensorflow order
   1583 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
   1584 /// dimension with size 1 is outermost dimension; while dimension with size 4 is
   1585 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
   1586 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
   1587 ///
   1588 /// @input Tensorflow shape in memory::dims type
   1589 /// @return memory::dims containing strides for the tensor.
   1590 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
   1591   CHECK_GT(dims_tf_order.size(), 0);
   1592   memory::dims strides(dims_tf_order.size());
   1593   int last_dim_idx = dims_tf_order.size() - 1;
   1594   strides[last_dim_idx] = 1;
   1595   for (int d = last_dim_idx - 1; d >= 0; d--) {
   1596     strides[d] = strides[d + 1] * dims_tf_order[d + 1];
   1597   }
   1598   return strides;
   1599 }
   1600 
   1601 inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
   1602   // MKL-DNN only supports zero padding.
   1603   return padding_kind::zero;
   1604 }
   1605 
   1606 /// Helper function to create memory descriptor in Blocked format
   1607 ///
   1608 /// @input: Tensor dimensions
   1609 /// @input: strides corresponding to dimensions. One can use utility
   1610 ///         function such as CalculateTFStrides to compute strides
   1611 ///         for given dimensions.
   1612 /// @return: memory::desc object corresponding to blocked memory format
   1613 ///          for given dimensions and strides.
   1614 inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
   1615                                                const memory::dims& strides,
   1616                                                memory::data_type dtype) {
   1617   CHECK_EQ(dim.size(), strides.size());
   1618 
   1619   // We have to construct memory descriptor in a C style. This is not at all
   1620   // ideal but MKLDNN does not offer any API to construct descriptor in
   1621   // blocked format except a copy constructor that accepts
   1622   // mkldnn_memory_desc_t.
   1623   mkldnn_memory_desc_t md;
   1624   md.primitive_kind = mkldnn_memory;
   1625   md.ndims = dim.size();
   1626   md.format = mkldnn_blocked;
   1627   md.data_type = memory::convert_to_c(dtype);
   1628 
   1629   for (size_t i = 0; i < dim.size(); i++) {
   1630     md.layout_desc.blocking.block_dims[i] = 1;
   1631     md.layout_desc.blocking.strides[1][i] = 1;
   1632     md.layout_desc.blocking.strides[0][i] = strides[i];
   1633     md.layout_desc.blocking.padding_dims[i] = dim[i];
   1634     md.layout_desc.blocking.offset_padding_to_data[i] = 0;
   1635     md.dims[i] = dim[i];
   1636   }
   1637   md.layout_desc.blocking.offset_padding = 0;
   1638 
   1639   return memory::desc(md);
   1640 }
   1641 
   1642 template <typename T>
   1643 inline primitive FindOrCreateReorder(const memory* from, const memory* to);
   1644 /*
   1645  * Class to represent all the resources corresponding to a tensor in TensorFlow
   1646  * that are required to execute an operation (such as Convolution).
   1647  */
   1648 template <typename T>
   1649 class MklDnnData {
   1650  private:
   1651   /// MKL-DNN memory primitive for input user memory
   1652   memory* user_memory_;
   1653 
   1654   /// MKL-DNN memory primitive in case input or output reorder is needed.
   1655   memory* reorder_memory_;
   1656 
   1657   /// Operations memory descriptor
   1658   memory::desc* op_md_;
   1659   // flat to indicate if data is 3D or not.
   1660   bool bIs3D;
   1661   /// Operations temp buffer
   1662   void* allocated_buffer_;
   1663   /// CPU engine on which operation will be executed
   1664   const engine* cpu_engine_;
   1665 
   1666  public:
   1667   explicit MklDnnData(const engine* e)
   1668       : user_memory_(nullptr),
   1669         reorder_memory_(nullptr),
   1670         op_md_(nullptr),
   1671         allocated_buffer_(nullptr),
   1672         cpu_engine_(e) {}
   1673 
   1674   ~MklDnnData() {
   1675     if (allocated_buffer_ != nullptr) {
   1676       cpu_allocator()->DeallocateRaw(allocated_buffer_);
   1677     }
   1678     cpu_engine_ = nullptr;  // We don't own this.
   1679     delete (user_memory_);
   1680     delete (reorder_memory_);
   1681     delete (op_md_);
   1682   }
   1683 
   1684   inline void* GetTensorBuffer(const Tensor* tensor) const {
   1685     CHECK_NOTNULL(tensor);
   1686     return const_cast<void*>(
   1687         static_cast<const void*>(tensor->flat<T>().data()));
   1688   }
   1689 
   1690   void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
   1691 
   1692   bool GetIs3D() { return bIs3D; }
   1693 
   1694   /// Set user memory primitive using specified dimensions, memory format and
   1695   /// data_buffer. Function automatically uses element data type by using
   1696   /// input type T used for creating call object.
   1697   ///
   1698   /// In a nutshell, function allows user to describe the input tensor to
   1699   /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
   1700   /// memory format HWIO, and the buffer that contains actual values is
   1701   /// pointed by data_buffer.
   1702   inline void SetUsrMem(const memory::dims& dim, memory::format fm,
   1703                         void* data_buffer = nullptr) {
   1704     auto md = memory::desc(dim, MklDnnType<T>(), fm);
   1705     SetUsrMem(md, data_buffer);
   1706   }
   1707 
   1708   inline void SetUsrMem(const memory::dims& dim, memory::format fm,
   1709                         const Tensor* tensor) {
   1710     CHECK_NOTNULL(tensor);
   1711     SetUsrMem(dim, fm, GetTensorBuffer(tensor));
   1712   }
   1713 
   1714   /// Helper function to create memory descriptor in Blocked format
   1715   ///
   1716   /// @input: Tensor dimensions
   1717   /// @input: strides corresponding to dimensions. One can use utility
   1718   ///         function such as CalculateTFStrides to compute strides
   1719   ///         for given dimensions.
   1720   /// @return: memory::desc object corresponding to blocked memory format
   1721   ///          for given dimensions and strides.
   1722   static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
   1723                                                   const memory::dims& strides) {
   1724     return CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>());
   1725   }
   1726 
   1727   /// A version of SetUsrMem call that allows user to create memory in blocked
   1728   /// format. So in addition to accepting dimensions, it also accepts strides.
   1729   /// This allows user to create memory for tensor in a format that is not
   1730   /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
   1731   /// dimensional tensor as a native format. But by using blocked format, a user
   1732   /// can create memory for 6D tensor.
   1733   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
   1734                         void* data_buffer = nullptr) {
   1735     CHECK_EQ(dim.size(), strides.size());
   1736     auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
   1737     SetUsrMem(blocked_md, data_buffer);
   1738   }
   1739 
   1740   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
   1741                         const Tensor* tensor) {
   1742     CHECK_NOTNULL(tensor);
   1743     SetUsrMem(dim, strides, GetTensorBuffer(tensor));
   1744   }
   1745 
   1746   /// A version of function to set user memory primitive that accepts memory
   1747   /// descriptor directly, instead of accepting dimensions and format. This
   1748   /// function is more generic that the one above, but the function above is
   1749   /// sufficient in most cases.
   1750   inline void SetUsrMem(const memory::desc& md, void* data_buffer = nullptr) {
   1751     auto pd = memory::primitive_desc(md, *cpu_engine_);
   1752     SetUsrMem(pd, data_buffer);
   1753   }
   1754 
   1755   /// A version of SetUsrMem with memory descriptor and tensor
   1756   inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
   1757     CHECK_NOTNULL(tensor);
   1758     SetUsrMem(md, GetTensorBuffer(tensor));
   1759   }
   1760 
   1761   /// A version of function to set user memory primitive that accepts primitive
   1762   /// descriptor directly, instead of accepting dimensions and format. This
   1763   /// function is more generic that the one above, but the function above is
   1764   /// sufficient in most cases.
   1765   inline void SetUsrMem(const memory::primitive_desc& pd,
   1766                         void* data_buffer = nullptr) {
   1767     CHECK_NOTNULL(cpu_engine_);
   1768     if (user_memory_) delete user_memory_;
   1769     // TODO(nhasabni): can we remove dynamic memory allocation?
   1770     if (data_buffer) {
   1771       user_memory_ = new memory(pd, data_buffer);
   1772     } else {
   1773       user_memory_ = new memory(pd);
   1774     }
   1775   }
   1776 
   1777   /// A version of SetUsrMem with primitive descriptor and tensor
   1778   inline void SetUsrMem(const memory::primitive_desc& pd,
   1779                         const Tensor* tensor) {
   1780     CHECK_NOTNULL(tensor);
   1781     SetUsrMem(pd, GetTensorBuffer(tensor));
   1782   }
   1783 
   1784   /// Get function for user memory primitive.
   1785   inline const memory* GetUsrMem() const { return user_memory_; }
   1786 
   1787   /// Get function for primitive descriptor of user memory primitive.
   1788   inline const memory::primitive_desc GetUsrMemPrimDesc() const {
   1789     CHECK_NOTNULL(user_memory_);
   1790     return user_memory_->get_primitive_desc();
   1791   }
   1792 
   1793   /// Get function for descriptor of user memory.
   1794   inline memory::desc GetUsrMemDesc() {
   1795     // This is ugly. Why MKL-DNN does not provide desc() method of const type??
   1796     const memory::primitive_desc pd = GetUsrMemPrimDesc();
   1797     return const_cast<memory::primitive_desc*>(&pd)->desc();
   1798   }
   1799 
   1800   /// Get function for data buffer of user memory primitive.
   1801   inline void* GetUsrMemDataHandle() const {
   1802     CHECK_NOTNULL(user_memory_);
   1803     return user_memory_->get_data_handle();
   1804   }
   1805 
   1806   /// Set function for data buffer of user memory primitive.
   1807   inline void SetUsrMemDataHandle(void* data_buffer) {
   1808     CHECK_NOTNULL(user_memory_);
   1809     CHECK_NOTNULL(data_buffer);
   1810     user_memory_->set_data_handle(data_buffer);
   1811   }
   1812 
   1813   /// Set function for data buffer of user memory primitive.
   1814   inline void SetUsrMemDataHandle(const Tensor* tensor) {
   1815     CHECK_NOTNULL(user_memory_);
   1816     CHECK_NOTNULL(tensor);
   1817     user_memory_->set_data_handle(GetTensorBuffer(tensor));
   1818   }
   1819 
   1820   /// allocate function for data buffer
   1821   inline void AllocateBuffer(size_t size) {
   1822     const int64 kMemoryAlginment = 64;  // For AVX512 memory alignment.
   1823     allocated_buffer_ = cpu_allocator()->AllocateRaw(kMemoryAlginment, size);
   1824   }
   1825 
   1826   inline void* GetAllocatedBuffer() { return allocated_buffer_; }
   1827 
   1828   /// Get the memory primitive for input and output of an op. If inputs
   1829   /// to an op require reorders, then this function returns memory primitive
   1830   /// for reorder. Otherwise, it will return memory primitive for user memory.
   1831   ///
   1832   /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
   1833   /// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
   1834   /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
   1835   /// primitive for F), then we need I_r and F_r to perform Conv2D.
   1836   inline const memory& GetOpMem() const {
   1837     return reorder_memory_ ? *reorder_memory_ : *user_memory_;
   1838   }
   1839 
   1840   /// Set memory descriptor of an operation in terms of dimensions and memory
   1841   /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
   1842   /// but memory::format would be mkldnn::any because we want MKL-DNN to choose
   1843   /// best layout/format for given input dimensions.
   1844   inline void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
   1845     // TODO(nhasabni): can we remove dynamic memory allocation?
   1846     op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
   1847   }
   1848 
   1849   /// Get function for memory descriptor for an operation
   1850   inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
   1851 
   1852   /// Predicate that checks if we need to reorder user's memory into memory
   1853   /// pointed by op_pd.
   1854   ///
   1855   /// @input: op_pd - memory primitive descriptor of the given input of an
   1856   ///               operation
   1857   /// @return: true in case reorder of input is needed; false, otherwise.
   1858   inline bool IsReorderNeeded(const memory::primitive_desc& op_pd) const {
   1859     CHECK_NOTNULL(user_memory_);
   1860     return op_pd != user_memory_->get_primitive_desc();
   1861   }
   1862 
   1863   /// Predicate that checks if we need to reorder user's memory into memory
   1864   /// based on the provided format.
   1865   ///
   1866   /// @input: target_format - memory format of the given input of an
   1867   ///               operation
   1868   /// @return: true in case reorder of input is needed; false, otherwise.
   1869   inline bool IsReorderNeeded(const memory::format& target_format) const {
   1870     CHECK_NOTNULL(user_memory_);
   1871     return target_format !=
   1872            user_memory_->get_primitive_desc().desc().data.format;
   1873   }
   1874 
   1875   /// Function to create a reorder from memory pointed by from to memory pointed
   1876   /// by to. Returns created primitive.
   1877   inline primitive CreateReorder(const memory* from, const memory* to) const {
   1878     CHECK_NOTNULL(from);
   1879     CHECK_NOTNULL(to);
   1880     return reorder(*from, *to);
   1881   }
   1882 
   1883   /// Function to handle input reordering
   1884   ///
   1885   /// Check if we need to reorder this input of an operation.
   1886   /// Return true and allocate reorder memory primitive if reorder is needed.
   1887   /// Otherwise, return false and do not allocate reorder memory primitive.
   1888   ///
   1889   /// To check if reorder is needed, this function compares memory primitive
   1890   /// descriptor of an operation (op_pd) for the given input with the
   1891   /// user-specified memory primitive descriptor.
   1892   ///
   1893   /// @input: op_pd - memory primitive descriptor of the given input of an
   1894   ///               operation
   1895   /// @input: net - net to which to add reorder primitive in case it is needed.
   1896   /// @return: true in case reorder of input is needed; false, otherwise.
   1897   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1898                                   std::vector<primitive>* net) {
   1899     CHECK_NOTNULL(net);
   1900     CHECK_NOTNULL(user_memory_);
   1901     if (IsReorderNeeded(op_pd)) {
   1902       // TODO(nhasabni): can we remove dynamic memory allocation?
   1903       reorder_memory_ = new memory(op_pd);
   1904       net->push_back(CreateReorder(user_memory_, reorder_memory_));
   1905       return true;
   1906     }
   1907     return false;
   1908   }
   1909 
   1910   /// TODO: this is a faster path with reorder primitive cache compared with
   1911   /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
   1912   /// slow path in the future
   1913   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) {
   1914     CHECK_NOTNULL(user_memory_);
   1915     if (IsReorderNeeded(op_pd)) {
   1916       // TODO(nhasabni): can we remove dynamic memory allocation?
   1917       // primitive reuse don't allow two same reorder prim in
   1918       // one stream, so submit it immediately
   1919       reorder_memory_ = new memory(op_pd);
   1920       std::vector<primitive> net;
   1921       net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
   1922       stream(stream::kind::eager).submit(net).wait();
   1923       return true;
   1924     }
   1925     return false;
   1926   }
   1927 
   1928   /// Overloaded version of above function that accepts memory buffer
   1929   /// where output of reorder needs to be stored.
   1930   ///
   1931   /// @input: op_pd - memory primitive descriptor of the given input of an
   1932   ///               operation
   1933   /// @reorder_data_handle - memory buffer where output of reorder needs to be
   1934   ///                        stored. Primitive does not check if buffer is
   1935   ///                        enough size to write.
   1936   /// @input: net - net to which to add reorder primitive in case it is needed.
   1937   /// @return: true in case reorder of input is needed; false, otherwise.
   1938   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1939                                   void* reorder_data_handle,
   1940                                   std::vector<primitive>* net) {
   1941     CHECK_NOTNULL(net);
   1942     CHECK_NOTNULL(reorder_data_handle);
   1943     CHECK_NOTNULL(user_memory_);
   1944     if (IsReorderNeeded(op_pd)) {
   1945       // TODO(nhasabni): can we remove dynamic memory allocation?
   1946       reorder_memory_ = new memory(op_pd, reorder_data_handle);
   1947       net->push_back(CreateReorder(user_memory_, reorder_memory_));
   1948       return true;
   1949     }
   1950     return false;
   1951   }
   1952 
   1953   /// TODO: this is a faster path with reorder primitive cache compared with
   1954   /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
   1955   /// slow path in the future
   1956   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1957                                   void* reorder_data_handle) {
   1958     CHECK_NOTNULL(reorder_data_handle);
   1959     CHECK_NOTNULL(user_memory_);
   1960     if (IsReorderNeeded(op_pd)) {
   1961       // TODO(nhasabni): can we remove dynamic memory allocation?
   1962       // primitive reuse don't allow two same reorder prim in
   1963       // one stream, so submit it immediately
   1964       std::vector<primitive> net;
   1965       reorder_memory_ = new memory(op_pd, reorder_data_handle);
   1966       net.push_back(FindOrCreateReorder<T>(user_memory_, reorder_memory_));
   1967       stream(stream::kind::eager).submit(net).wait();
   1968       return true;
   1969     }
   1970     return false;
   1971   }
   1972 
   1973   /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
   1974   /// where output of reorder needs to be stored.
   1975   ///
   1976   /// @input: op_pd - memory primitive descriptor of the given input of an
   1977   ///               operation
   1978   /// @reorder_tensor - Tensor whose buffer is to be used to store output of
   1979   ///                   reorder. Primitive does not check if buffer is
   1980   ///                   enough size to write.
   1981   /// @input: net - net to which to add reorder primitive in case it is needed.
   1982   /// @return: true in case reorder of input is needed; false, otherwise.
   1983   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1984                                   Tensor* reorder_tensor,
   1985                                   std::vector<primitive>* net) {
   1986     CHECK_NOTNULL(net);
   1987     CHECK_NOTNULL(reorder_tensor);
   1988     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
   1989   }
   1990 
   1991   /// TODO: this is a faster path with reorder primitive cache compared with
   1992   /// CheckReorderToOpMem(..., std::vector<primitive>* net), will remove
   1993   /// slow path in the future
   1994   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1995                                   Tensor* reorder_tensor) {
   1996     CHECK_NOTNULL(reorder_tensor);
   1997     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
   1998   }
   1999 
   2000   /// Function to handle output reorder
   2001   ///
   2002   /// This function performs very similar functionality as input reordering
   2003   /// function above. The only difference is that this function does not add
   2004   /// reorder primitive to the net. The reason for this is: the reorder
   2005   /// primitive for output needs to be added to the list only after operation
   2006   /// has executed. But we need to prepare a temporary buffer in case output
   2007   /// reorder is needed. And this temporary buffer will hold the output of
   2008   /// an operation before it is fed to reorder primitive.
   2009   ///
   2010   /// @input memory primitive descriptor for the given output of an operation
   2011   /// @return: true in case reorder of output is needed; false, otherwise.
   2012   inline bool PrepareReorderToUserMemIfReq(
   2013       const memory::primitive_desc& op_pd) {
   2014     CHECK_NOTNULL(user_memory_);
   2015     if (IsReorderNeeded(op_pd)) {
   2016       // TODO(nhasabni): can we remove dynamic memory allocation?
   2017       reorder_memory_ = new memory(op_pd);
   2018       return true;
   2019     }
   2020     return false;
   2021   }
   2022 
   2023   /// Function to actually insert reorder primitive in the net
   2024   ///
   2025   /// This function completes remaining part of output reordering. It inserts
   2026   /// a reordering primitive from the temporary buffer that holds the output
   2027   /// to the user-specified output buffer.
   2028   ///
   2029   /// @input: net - net to which to add reorder primitive
   2030   inline void InsertReorderToUserMem(std::vector<primitive>* net) {
   2031     CHECK_NOTNULL(net);
   2032     CHECK_NOTNULL(user_memory_);
   2033     CHECK_NOTNULL(reorder_memory_);
   2034     net->push_back(CreateReorder(reorder_memory_, user_memory_));
   2035   }
   2036 
   2037   /// TODO: this is a faster path with reorder primitive cache compared with
   2038   ///       InsertReorderToUserMem(std::vector<primitive>* net), will remove
   2039   ///       slow path in the future
   2040   inline void InsertReorderToUserMem() {
   2041     CHECK_NOTNULL(user_memory_);
   2042     CHECK_NOTNULL(reorder_memory_);
   2043     // primitive reuse don't allow two same reorder prim in
   2044     // one stream, so submit it immediately
   2045     std::vector<primitive> net;
   2046     net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
   2047     stream(stream::kind::eager).submit(net).wait();
   2048   }
   2049 };
   2050 
   2051 /// Base class for operations with reuse of primitives
   2052 ///
   2053 class MklPrimitive {
   2054  public:
   2055   virtual ~MklPrimitive() {}
   2056 
   2057   // Dummy data which MKL DNN never operates on
   2058   unsigned char* DummyData = nullptr;
   2059 };
   2060 
   2061 const mkldnn::memory::dims NONE_DIMS = {};
   2062 
   2063 //
   2064 // LRUCache is a class which implements LRU (Least Recently Used) cache.
   2065 // The implementation is similar to that of
   2066 //    tensorflow/core/platform/cloud/expiring_lru_cache.h
   2067 // without its thread-safe part because the cache is supposed to be
   2068 // used as thread local (for instance, MklPrimitive caching).
   2069 //
   2070 // The LRU list maintains objects in chronological order based on
   2071 // creation time, with the least recently accessed object at the
   2072 // tail of LRU list, while the most recently accessed object
   2073 // at the head of LRU list.
   2074 //
   2075 // This class is used to maintain an upper bound on the total number of
   2076 // cached items. When the cache reaches its capacity, the LRU item will
   2077 // be removed and replaced by a new one from SetOp call.
   2078 //
   2079 template <typename T>
   2080 class LRUCache {
   2081  public:
   2082   explicit LRUCache(size_t capacity) {
   2083     capacity_ = capacity;
   2084     Clear();
   2085   }
   2086 
   2087   T* GetOp(const string& key) {
   2088     auto it = cache_.find(key);
   2089     if (it == cache_.end()) {
   2090       return nullptr;
   2091     }
   2092 
   2093     // Move to the front of LRU list as the most recently accessed.
   2094     lru_list_.erase(it->second.lru_iterator);
   2095     lru_list_.push_front(it->first);
   2096     it->second.lru_iterator = lru_list_.begin();
   2097     return it->second.op;
   2098   }
   2099 
   2100   void SetOp(const string& key, T* op) {
   2101     if (lru_list_.size() >= capacity_) {
   2102       Delete();
   2103     }
   2104 
   2105     // Insert an entry to the front of the LRU list
   2106     lru_list_.push_front(key);
   2107     Entry entry(op, lru_list_.begin());
   2108     cache_.emplace(std::make_pair(key, std::move(entry)));
   2109   }
   2110 
   2111   void Clear() {
   2112     if (lru_list_.empty()) return;
   2113 
   2114     // Clean up the cache
   2115     cache_.clear();
   2116     lru_list_.clear();
   2117   }
   2118 
   2119  private:
   2120   struct Entry {
   2121     // The entry's value.
   2122     T* op;
   2123 
   2124     // A list iterator pointing to the entry's position in the LRU list.
   2125     std::list<string>::iterator lru_iterator;
   2126 
   2127     // Constructor
   2128     Entry(T* op, std::list<string>::iterator it) {
   2129       this->op = op;
   2130       this->lru_iterator = it;
   2131     }
   2132 
   2133     // Move construcctor
   2134     Entry(Entry&& source) noexcept
   2135         : lru_iterator(std::move(source.lru_iterator)) {
   2136       op = std::move(source.op);
   2137       source.op = std::forward<T*>(nullptr);
   2138     }
   2139 
   2140     // Destructor
   2141     ~Entry() {
   2142       if (op != nullptr) delete op;
   2143     }
   2144   };
   2145 
   2146   // Remove the least recently accessed entry from LRU list, which
   2147   // is the tail of lru_list_. Update cache_ correspondingly.
   2148   bool Delete() {
   2149     if (lru_list_.empty()) return false;
   2150     string key = lru_list_.back();
   2151     lru_list_.pop_back();
   2152     cache_.erase(key);
   2153     return true;
   2154   }
   2155 
   2156   // Cache capacity
   2157   size_t capacity_;
   2158 
   2159   // The cache, a map from string key to a LRU entry.
   2160   std::unordered_map<string, Entry> cache_;
   2161 
   2162   // The LRU list of entries.
   2163   // The front of the list contains the key of the most recently accessed
   2164   // entry, while the back of the list is the least recently accessed entry.
   2165   std::list<string> lru_list_;
   2166 };
   2167 
   2168 template <typename T>
   2169 class MklPrimitiveFactory {
   2170  public:
   2171   MklPrimitiveFactory() {}
   2172 
   2173   ~MklPrimitiveFactory() {}
   2174 
   2175   MklPrimitive* GetOp(const string& key) {
   2176     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
   2177     return lru_cache.GetOp(key);
   2178   }
   2179 
   2180   void SetOp(const string& key, MklPrimitive* op) {
   2181     auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
   2182     lru_cache.SetOp(key, op);
   2183   }
   2184 
   2185   /// Function to decide whether HW has AVX512 or AVX2
   2186   /// For those legacy device(w/o AVX512 and AVX2),
   2187   /// MKL-DNN GEMM will be used.
   2188   static inline bool IsLegacyPlatform() {
   2189     return (!port::TestCPUFeature(port::CPUFeature::AVX512F) &&
   2190             !port::TestCPUFeature(port::CPUFeature::AVX2));
   2191   }
   2192 
   2193   /// Fuction to check whether primitive memory optimization is enabled
   2194   static inline bool IsPrimitiveMemOptEnabled() {
   2195     bool is_primitive_mem_opt_enabled = true;
   2196     TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
   2197                                    &is_primitive_mem_opt_enabled));
   2198     return is_primitive_mem_opt_enabled;
   2199   }
   2200 
   2201  private:
   2202   static inline LRUCache<MklPrimitive>& GetLRUCache() {
   2203     static const int kCapacity = 1024;  // cache capacity
   2204     static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
   2205     return lru_cache_;
   2206   }
   2207 };
   2208 
   2209 // utility class for creating keys of MKL primitive pool.
   2210 class FactoryKeyCreator {
   2211  public:
   2212   FactoryKeyCreator() { key_.reserve(kMaxKeyLength); }
   2213 
   2214   ~FactoryKeyCreator() {}
   2215 
   2216   void AddAsKey(const string& str) { Append(str); }
   2217 
   2218   void AddAsKey(const mkldnn::memory::dims& dims) {
   2219     for (unsigned int i = 0; i < dims.size(); i++) {
   2220       AddAsKey<int>(dims[i]);
   2221     }
   2222   }
   2223 
   2224   template <typename T>
   2225   void AddAsKey(const T data) {
   2226     auto buffer = reinterpret_cast<const char*>(&data);
   2227     Append(StringPiece(buffer, sizeof(T)));
   2228   }
   2229 
   2230   string GetKey() { return key_; }
   2231 
   2232  private:
   2233   string key_;
   2234   const char delimiter = 'x';
   2235   const int kMaxKeyLength = 256;
   2236   void Append(StringPiece s) {
   2237     key_.append(string(s));
   2238     key_.append(1, delimiter);
   2239   }
   2240 };
   2241 
   2242 static inline memory::format get_desired_format(int channel,
   2243                                                 bool is_2d = true) {
   2244   memory::format fmt_desired = memory::format::any;
   2245 
   2246   if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
   2247     fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
   2248   } else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
   2249              (channel % 8) == 0) {
   2250     fmt_desired = is_2d ? memory::format::nChw8c
   2251                         : memory::format::ncdhw;  // no avx2 support for 3d yet.
   2252   } else {
   2253     fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
   2254   }
   2255   return fmt_desired;
   2256 }
   2257 
   2258 class MklReorderPrimitive : public MklPrimitive {
   2259  public:
   2260   explicit MklReorderPrimitive(const memory* from, const memory* to) {
   2261     Setup(from, to);
   2262   }
   2263   ~MklReorderPrimitive() {}
   2264 
   2265   std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
   2266 
   2267   void SetMemory(const memory* from, const memory* to) {
   2268     context_.src_mem->set_data_handle(from->get_data_handle());
   2269     context_.dst_mem->set_data_handle(to->get_data_handle());
   2270   }
   2271 
   2272  private:
   2273   struct ReorderContext {
   2274     std::shared_ptr<mkldnn::memory> src_mem;
   2275     std::shared_ptr<mkldnn::memory> dst_mem;
   2276     std::shared_ptr<primitive> reorder_prim;
   2277     ReorderContext()
   2278         : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
   2279   } context_;
   2280 
   2281   engine cpu_engine_ = engine(engine::cpu, 0);
   2282 
   2283   void Setup(const memory* from, const memory* to) {
   2284     context_.src_mem.reset(new memory(
   2285         {from->get_primitive_desc().desc(), cpu_engine_}, DummyData));
   2286     context_.dst_mem.reset(
   2287         new memory({to->get_primitive_desc().desc(), cpu_engine_}, DummyData));
   2288     context_.reorder_prim = std::make_shared<mkldnn::reorder>(
   2289         reorder(*context_.src_mem, *context_.dst_mem));
   2290   }
   2291 };
   2292 
   2293 template <typename T>
   2294 class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
   2295  public:
   2296   static MklReorderPrimitive* Get(const memory* from, const memory* to) {
   2297     auto reorderPrim = static_cast<MklReorderPrimitive*>(
   2298         MklReorderPrimitiveFactory<T>::GetInstance().GetReorder(from, to));
   2299     if (reorderPrim == nullptr) {
   2300       reorderPrim = new MklReorderPrimitive(from, to);
   2301       MklReorderPrimitiveFactory<T>::GetInstance().SetReorder(from, to,
   2302                                                               reorderPrim);
   2303     }
   2304     reorderPrim->SetMemory(from, to);
   2305     return reorderPrim;
   2306   }
   2307 
   2308   static MklReorderPrimitiveFactory& GetInstance() {
   2309     static MklReorderPrimitiveFactory instance_;
   2310     return instance_;
   2311   }
   2312 
   2313  private:
   2314   MklReorderPrimitiveFactory() {}
   2315   ~MklReorderPrimitiveFactory() {}
   2316 
   2317   static string CreateKey(const memory* from, const memory* to) {
   2318     string prefix = "reorder";
   2319     FactoryKeyCreator key_creator;
   2320     auto const& from_desc = from->get_primitive_desc().desc().data;
   2321     auto const& to_desc = to->get_primitive_desc().desc().data;
   2322     const int KIdxFirstStride = 0;
   2323     memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
   2324     memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
   2325     memory::dims from_strides(
   2326         from_desc.layout_desc.blocking.strides[KIdxFirstStride],
   2327         &from_desc.layout_desc.blocking
   2328              .strides[KIdxFirstStride][from_desc.ndims]);
   2329     memory::dims to_strides(
   2330         to_desc.layout_desc.blocking.strides[KIdxFirstStride],
   2331         &to_desc.layout_desc.blocking.strides[KIdxFirstStride][to_desc.ndims]);
   2332     key_creator.AddAsKey(prefix);
   2333     key_creator.AddAsKey(static_cast<int>(from_desc.format));
   2334     key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
   2335     key_creator.AddAsKey(from_dims);
   2336     key_creator.AddAsKey(from_strides);
   2337     key_creator.AddAsKey(static_cast<int>(to_desc.format));
   2338     key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
   2339     key_creator.AddAsKey(to_dims);
   2340     key_creator.AddAsKey(to_strides);
   2341     return key_creator.GetKey();
   2342   }
   2343 
   2344   MklPrimitive* GetReorder(const memory* from, const memory* to) {
   2345     string key = CreateKey(from, to);
   2346     return this->GetOp(key);
   2347   }
   2348 
   2349   void SetReorder(const memory* from, const memory* to, MklPrimitive* op) {
   2350     string key = CreateKey(from, to);
   2351     this->SetOp(key, op);
   2352   }
   2353 };
   2354 
   2355 /// Fuction to find(or create) a reorder from memory pointed by
   2356 /// from to memory pointed by to, it will created primitive or
   2357 /// get primitive from pool if it is cached.
   2358 /// Returns the primitive.
   2359 template <typename T>
   2360 inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
   2361   CHECK_NOTNULL(from);
   2362   CHECK_NOTNULL(to);
   2363   MklReorderPrimitive* reorder_prim =
   2364       MklReorderPrimitiveFactory<T>::Get(from, to);
   2365   return *reorder_prim->GetPrimitive();
   2366 }
   2367 
   2368 // utility function to determine if it is conv 1x1 and stride != 1
   2369 // for purpose of temporarily disabling primitive reuse
   2370 inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
   2371                                 memory::dims strides) {
   2372   if (filter_dims.size() != 4 || strides.size() != 2) return false;
   2373 
   2374   return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
   2375           ((strides[0] != 1) || (strides[1] != 1)));
   2376 }
   2377 
   2378 #endif  // INTEL_MKL_DNN
   2379 
   2380 }  // namespace tensorflow
   2381 #endif  // INTEL_MKL
   2382 #endif  // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
   2383