Home | History | Annotate | Download | only in util
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
     17 #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
     18 #ifdef INTEL_MKL
     19 
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "mkl_dnn.h"
     24 #include "mkl_dnn_types.h"
     25 #include "mkl_service.h"
     26 #include "mkl_trans.h"
     27 #include "tensorflow/core/framework/op_kernel.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/graph/mkl_graph_util.h"
     31 #include "tensorflow/core/lib/core/errors.h"
     32 #include "tensorflow/core/lib/gtl/array_slice.h"
     33 #include "tensorflow/core/platform/logging.h"
     34 #include "tensorflow/core/platform/macros.h"
     35 #include "tensorflow/core/util/padding.h"
     36 #include "tensorflow/core/util/tensor_format.h"
     37 
     38 #ifndef INTEL_MKL_ML
     39 #include "mkldnn.hpp"
     40 
     41 using mkldnn::engine;
     42 using mkldnn::memory;
     43 using mkldnn::padding_kind;
     44 using mkldnn::primitive;
     45 using mkldnn::reorder;
     46 #endif
     47 
     48 // The file contains a number of utility classes and functions used by MKL
     49 // enabled kernels
     50 
     51 namespace tensorflow {
     52 
     53 // This class encapsulates all the meta data that is associated with an MKL
     54 // tensor. A tensor is an MKL tensor if it was created as the result of an
     55 // MKL operation, and did not go through a conversion to a standard
     56 // Tensorflow tensor.
     57 
     58 typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims;
     59 typedef enum {
     60   Dim_N = 0,
     61   Dim_C = 1,
     62   Dim_H = 2,
     63   Dim_W = 3,
     64   Dim_O = 0,
     65   Dim_I = 1
     66 } MklDnnDims;
     67 
     68 class MklShape {
     69  public:
     70   MklShape() {}
     71   TF_DISALLOW_COPY_AND_ASSIGN(MklShape);  // Cannot copy
     72 
     73   ~MklShape() {
     74     if (sizes_) delete[] sizes_;
     75     if (strides_) delete[] strides_;
     76     if (mklLayout_) CHECK_EQ(dnnLayoutDelete_F32(mklLayout_), E_SUCCESS);
     77     if (tfLayout_) CHECK_EQ(dnnLayoutDelete_F32(tfLayout_), E_SUCCESS);
     78     if (tf_to_mkl_dim_map_) delete[] tf_to_mkl_dim_map_;
     79   }
     80 
     81   const bool IsMklTensor() const { return isMklTensor_; }
     82 
     83   void SetMklTensor(const bool isMklTensor) { isMklTensor_ = isMklTensor; }
     84 
     85   void SetDimensions(const size_t dimension) { dimension_ = dimension; }
     86 
     87   void SetMklLayout(dnnLayout_t mklLayout) { mklLayout_ = mklLayout; }
     88 
     89   void SetMklLayout(const void* primitive, size_t resourceType) {
     90     CHECK_EQ(
     91         dnnLayoutCreateFromPrimitive_F32(&mklLayout_, (dnnPrimitive_t)primitive,
     92                                          (dnnResourceType_t)resourceType),
     93         E_SUCCESS);
     94   }
     95 
     96   void SetTfLayout(const size_t dimension, const size_t* sizes,
     97                    const size_t* strides) {
     98     dimension_ = dimension;
     99     if (dimension > 0) {  // MKl doesn't support zero dimension tensors
    100       sizes_ = new size_t[dimension];
    101       strides_ = new size_t[dimension];
    102 
    103       for (int ii = 0; ii < dimension; ii++) {
    104         sizes_[ii] = sizes[ii];
    105         strides_[ii] = strides[ii];
    106       }
    107       CHECK_EQ(dnnLayoutCreate_F32(&tfLayout_, dimension, sizes, strides),
    108                E_SUCCESS);
    109     }
    110   }
    111 
    112   // Default case - MKL dim ordering is opposite of TF dim ordering
    113   // MKL -> (DIMS-1)...0 where (DIMS-1) is outermost dim and 0 is innermost dim
    114   // TF  -> 0...(DIMS-1) where 0 is outermost dim and (DIMS-1) is innermost dim
    115   // For layers that rely on data_format semantics (conv, pooling etc.)
    116   // or operate only on certain dimensions (relu, concat, split etc.),
    117   // Mkl APIs might require us to reorder these dimensions. In such cases,
    118   // kernels should explicitly set this map
    119   void SetTfDimOrder(const size_t dimension) {
    120     CHECK(dimension == dimension_);
    121     if (tf_to_mkl_dim_map_ == nullptr) {
    122       tf_to_mkl_dim_map_ = new size_t[dimension];
    123     }
    124     for (size_t ii = 0; ii < dimension; ii++) {
    125       tf_to_mkl_dim_map_[ii] = dimension - (ii + 1);
    126     }
    127   }
    128 
    129   void SetTfDimOrder(const size_t dimension, const size_t* tf_to_mkl_dim_map) {
    130     CHECK(dimension == dimension_);
    131     if (tf_to_mkl_dim_map_ == nullptr) {
    132       tf_to_mkl_dim_map_ = new size_t[dimension];
    133     }
    134     for (size_t ii = 0; ii < dimension; ii++) {
    135       tf_to_mkl_dim_map_[ii] = tf_to_mkl_dim_map[ii];
    136     }
    137   }
    138 
    139   void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
    140     CHECK_EQ(dimension, 4);
    141     CHECK(dimension == dimension_);
    142     if (tf_to_mkl_dim_map_ == nullptr) {
    143       tf_to_mkl_dim_map_ = new size_t[dimension];
    144     }
    145     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDims::W;
    146     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDims::H;
    147     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDims::C;
    148     tf_to_mkl_dim_map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDims::N;
    149   }
    150 
    151   const dnnLayout_t GetMklLayout() const { return mklLayout_; }
    152   const dnnLayout_t GetTfLayout() const { return tfLayout_; }
    153   const dnnLayout_t GetCurLayout() const {
    154     return isMklTensor_ ? mklLayout_ : tfLayout_;
    155   }
    156   size_t GetDimension() const { return dimension_; }
    157   const size_t* GetSizes() const { return sizes_; }
    158   int64 dim_size(int index) const { return sizes_[index]; }
    159   int64 tf_dim_size(int index) const {
    160     return sizes_[tf_to_mkl_dim_map_[index]];
    161   }
    162   const size_t* GetStrides() const { return strides_; }
    163   const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
    164   size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
    165 
    166   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    167   // corresponds to MKL's Channel dimension.
    168   bool IsMklChannelDim(int d) const { return tf_dim_idx(d) == MklDims::C; }
    169   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    170   // corresponds to MKL's Batch dimension.
    171   bool IsMklBatchDim(int d) const { return tf_dim_idx(d) == MklDims::N; }
    172   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    173   // corresponds to MKL's Width dimension.
    174   bool IsMklWidthDim(int d) const { return tf_dim_idx(d) == MklDims::W; }
    175   // Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    176   // corresponds to MKL's Height dimension.
    177   bool IsMklHeightDim(int d) const { return tf_dim_idx(d) == MklDims::H; }
    178 
    179   // Check if the TF-Mkl dimension ordering map specifies if the input
    180   // tensor is in NCHW format.
    181   bool IsTensorInNCHWFormat() const {
    182     TensorFormat data_format = FORMAT_NCHW;
    183     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    184             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    185             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    186             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    187   }
    188 
    189   // Check if the TF-Mkl dimension ordering map specifies if the input
    190   // tensor is in NHWC format.
    191   bool IsTensorInNHWCFormat() const {
    192     TensorFormat data_format = FORMAT_NHWC;
    193     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    194             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    195             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    196             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    197   }
    198 
    199   void GetConvertedFlatData(dnnLayout_t targetLayout, void* input,
    200                             void* output) const {
    201     dnnLayout_t curLayout;
    202     if (isMklTensor_)
    203       curLayout = mklLayout_;
    204     else
    205       curLayout = tfLayout_;
    206     dnnPrimitive_t convert;
    207     CHECK_EQ(dnnConversionCreate_F32(&convert, curLayout, targetLayout),
    208              E_SUCCESS);
    209     CHECK_EQ(dnnConversionExecute_F32(convert, input, output), E_SUCCESS);
    210     CHECK_EQ(dnnDelete_F32(convert), E_SUCCESS);
    211   }
    212 
    213   // The following methods are used for serializing and de-serializing the
    214   // contents of the mklshape object.
    215   // The data is serialized in this order
    216   // isMklTensor_
    217   // dimension_
    218   // sizes_
    219   // strides_
    220   // mklLayout_
    221   // tfLayout_
    222   // tf_to_mkl_dim_map_
    223 
    224 #define SIZE_OF_MKL_DNN_BUF \
    225   (dnnLayoutSerializationBufferSize_F32())  // Size of buffer needed to
    226                                             // serialize dnn_layout pointer
    227 
    228   // Size of buffer to hold the serialized object, the size is computed as
    229   // follows sizeof(isMklTensor_) + sizeof(dimension_) + sizeof(sizes_) +
    230   // sizeof(strides_)
    231   // + sizeof(mklLayout_ buffer) + sizeof(tfLayout_ buffer)
    232   // + sizeof(tf_to_mkl_dim_map_)
    233 
    234 #define SIZE_OF_MKL_SERIAL_DATA(dims) \
    235   (2 * sizeof(size_t) + 3 * dims * sizeof(size_t) + 2 * SIZE_OF_MKL_DNN_BUF)
    236 
    237   // First we need to define some macro for offsets into the serial buffer where
    238   // different elements of Mklshape is written/read from
    239 
    240 #define IS_MKL_TENSOR_OFFSET 0
    241 // Location from start of buffer where isMklTensor_ is serialized
    242 #define DIMS_OFFSET \
    243   (IS_MKL_TENSOR_OFFSET + sizeof(size_t))  // Location of dimension_
    244 // Location of sizes. Note dim is not used here, left here
    245 // to make macros consistent.
    246 #define SIZES_OFFSET(dims) (DIMS_OFFSET + sizeof(size_t))
    247 #define STRIDES_OFFSET(dims) \
    248   (SIZES_OFFSET(dims) + dims * sizeof(size_t))  // Location of strides
    249 #define MKL_LAYOUT_OFFSET(dims) \
    250   (STRIDES_OFFSET(dims) + dims * sizeof(size_t))  // Location of mklLayout_
    251 #define TF_LAYOUT_OFFSET(dims) \
    252   (MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)  // Location of tfLayout_
    253 // Location of tf_to_mkl_dim_map_
    254 #define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
    255   (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
    256 
    257   // TODO(agramesh1) make sure to create a const to share with rewrite pass
    258   // for min size of MKL metadata tensor.
    259 
    260   void DeSerializeMklShape(const unsigned char* buf, size_t buf_size) {
    261     CHECK(buf_size >= sizeof(size_t)) << "Bufsize too small in DeSerialize";
    262     // Make sure buffer holds at least  isMklTensor_
    263     isMklTensor_ =
    264         *reinterpret_cast<const size_t*>(buf + IS_MKL_TENSOR_OFFSET) != 0;
    265 
    266     if (isMklTensor_) {  // If it is an MKL Tensor then read the rest
    267       dimension_ = *(reinterpret_cast<const size_t*>(buf + DIMS_OFFSET));
    268       CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
    269           << "Bufsize too small in DeSerialize";
    270       sizes_ = new size_t[dimension_];
    271       strides_ = new size_t[dimension_];
    272       tf_to_mkl_dim_map_ = new size_t[dimension_];
    273       for (int i = 0; i < dimension_; i++) {
    274         sizes_[i] =
    275             reinterpret_cast<const size_t*>(buf + SIZES_OFFSET(dimension_))[i];
    276         strides_[i] = reinterpret_cast<const size_t*>(
    277             buf + STRIDES_OFFSET(dimension_))[i];
    278         tf_to_mkl_dim_map_[i] = reinterpret_cast<const size_t*>(
    279             buf + TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i];
    280       }
    281       CHECK_EQ(dnnLayoutDeserialize_F32(&mklLayout_,
    282                                         buf + MKL_LAYOUT_OFFSET(dimension_)),
    283                E_SUCCESS);
    284       CHECK_EQ(dnnLayoutDeserialize_F32(&tfLayout_,
    285                                         buf + TF_LAYOUT_OFFSET(dimension_)),
    286                E_SUCCESS);
    287     }
    288   }
    289 
    290   void SerializeMklShape(unsigned char* buf, size_t buf_size) const {
    291     CHECK(buf_size >= SIZE_OF_MKL_SERIAL_DATA(dimension_))
    292         << "Bufsize too small to Serialize";
    293     *reinterpret_cast<size_t*>(buf + IS_MKL_TENSOR_OFFSET) =
    294         isMklTensor_ ? 1 : 0;
    295     if (isMklTensor_) {
    296       *(reinterpret_cast<size_t*>(buf + DIMS_OFFSET)) = dimension_;
    297       for (int i = 0; i < dimension_; i++) {
    298         reinterpret_cast<size_t*>(buf + SIZES_OFFSET(dimension_))[i] =
    299             sizes_[i];
    300         reinterpret_cast<size_t*>(buf + STRIDES_OFFSET(dimension_))[i] =
    301             strides_[i];
    302         reinterpret_cast<size_t*>(buf +
    303                                   TF_TO_MKL_DIM_MAP_OFFSET(dimension_))[i] =
    304             tf_to_mkl_dim_map_[i];
    305       }
    306       CHECK_EQ(dnnLayoutSerialize_F32(mklLayout_,
    307                                       buf + MKL_LAYOUT_OFFSET(dimension_)),
    308                E_SUCCESS);
    309       CHECK_EQ(
    310           dnnLayoutSerialize_F32(tfLayout_, buf + TF_LAYOUT_OFFSET(dimension_)),
    311           E_SUCCESS);
    312     }
    313   }
    314 
    315  private:
    316   bool isMklTensor_ =
    317       false;  // Flag to indicate if the tensor is an  MKL tensor or not
    318   dnnLayout_t mklLayout_ = nullptr;  // Pointer to the MKL layout
    319   dnnLayout_t tfLayout_ = nullptr;   // Pointer to layout of corresponding
    320   // Tensorflow tensor, used when conversion from MKL to standard tensor
    321   size_t dimension_ = 0;
    322   size_t* sizes_ = nullptr;    // Required by MKL for conversions
    323   size_t* strides_ = nullptr;  // Required by MKL for conversions
    324   size_t* tf_to_mkl_dim_map_ =
    325       nullptr;  // TF dimension corresponding to this MKL dimension
    326 };
    327 
    328 #ifndef INTEL_MKL_ML
    329 
    330 // Forward decl
    331 TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
    332 memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
    333 memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
    334                                         const memory::dims& strides,
    335                                         memory::data_type dtype);
    336 
    337 class MklDnnShape {
    338  private:
    339   typedef struct {
    340     /// Flag to indicate if the tensor is an  MKL tensor or not
    341     bool is_mkl_tensor_ = false;
    342     /// Number of dimensions in Tensorflow format
    343     size_t dimension_ = 0;
    344     /// Required by MKLDNN for conversions
    345     mkldnn_dims_t sizes_;  // Required by MKL for conversions
    346     memory::format tf_data_format_ = memory::format::format_undef;
    347     memory::data_type T_ = memory::data_type::data_undef;
    348     // MKL layout
    349     mkldnn_memory_desc_t mkl_md_;
    350     /// TF dimension corresponding to this MKL dimension
    351     mkldnn_dims_t map_;
    352   } MklShapeData;
    353   MklShapeData data_;
    354 
    355   typedef std::remove_extent<mkldnn_dims_t>::type mkldnn_dim_t;
    356 #define INVALID_DIM_SIZE -1
    357 
    358  public:
    359   MklDnnShape() {
    360     for (size_t i = 0; i < sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
    361          ++i) {
    362       data_.sizes_[i] = -1;
    363     }
    364     for (size_t i = 0; i < sizeof(data_.map_) / sizeof(data_.map_[0]); ++i) {
    365       data_.map_[i] = -1;
    366     }
    367   }
    368 
    369   ~MklDnnShape() {}
    370   TF_DISALLOW_COPY_AND_ASSIGN(MklDnnShape);  // Cannot copy
    371 
    372   /// Helper function to compare memory::desc objects for MklDnn.
    373   /// May be this should go into MklDnn directly.
    374   inline bool CompareMklDnnLayouts(const memory::desc& md1,
    375                                    const memory::desc& md2) const {
    376     mkldnn_memory_desc_t mdd1 = md1.data;
    377     mkldnn_memory_desc_t mdd2 = md2.data;
    378     const char* d1 = reinterpret_cast<const char*>(&mdd1);
    379     const char* d2 = reinterpret_cast<const char*>(&mdd2);
    380 
    381     size_t md_size = sizeof(mdd1);
    382     for (size_t i = 0; i < md_size; i++) {
    383       if (*d1++ != *d2++) {
    384         return false;
    385       }
    386     }
    387     return true;
    388   }
    389 
    390   /// Equality function for MklDnnShape objects
    391   /// @return true if both are equal; false otherwise.
    392   inline bool operator==(const MklDnnShape& input_shape) const {
    393     if (this->IsMklTensor() != input_shape.IsMklTensor()) {
    394       return false;
    395     }
    396 
    397     // If input tensors are in Mkl layout, then we check for dimensions and
    398     // sizes.
    399     if (this->IsMklTensor()) {
    400       return this->GetTfShape() == input_shape.GetTfShape() &&
    401              CompareMklDnnLayouts(this->GetMklLayout(),
    402                                   input_shape.GetMklLayout());
    403     }
    404 
    405     return true;
    406   }
    407 
    408   /// Equality operator for MklDnnShape and TFShape.
    409   /// Returns: true if TF shapes for both are the same, false otherwise
    410   inline bool operator==(const TensorShape& input_shape) const {
    411     if (!this->IsMklTensor()) {
    412       return false;
    413     }
    414 
    415     return this->GetTfShape() == input_shape;
    416   }
    417 
    418   inline const bool IsMklTensor() const { return data_.is_mkl_tensor_; }
    419   inline void SetMklTensor(bool is_mkl_tensor) {
    420     data_.is_mkl_tensor_ = is_mkl_tensor;
    421   }
    422 
    423   inline void SetDimensions(const size_t dimension) {
    424     data_.dimension_ = dimension;
    425   }
    426   inline size_t GetDimension(char dimension) const {
    427     int index = GetMklDnnTensorDimIndex(dimension);
    428     CHECK(index >= 0 && index < this->GetDimension())
    429         << "Invalid index from the dimension: " << index << ", " << dimension;
    430     return this->DimSize(index);
    431   }
    432 
    433   inline int32 GetMklDnnTensorDimIndex(char dimension) const {
    434     switch (dimension) {
    435       case 'N':
    436         return MklDnnDims::Dim_N;
    437       case 'C':
    438         return MklDnnDims::Dim_C;
    439       case 'H':
    440         return MklDnnDims::Dim_H;
    441       case 'W':
    442         return MklDnnDims::Dim_W;
    443       default:
    444         LOG(FATAL) << "Invalid dimension: " << dimension;
    445         return -1;  // Avoid compiler warning about missing return value
    446     }
    447   }
    448 
    449   inline size_t GetDimension() const { return data_.dimension_; }
    450   inline const int* GetSizes() const {
    451     return reinterpret_cast<const int*>(&data_.sizes_[0]);
    452   }
    453 
    454   // Returns an mkldnn::memory::dims object that contains the sizes of this
    455   // MklDnnShape object.
    456   inline memory::dims GetSizesAsMklDnnDims() const {
    457     memory::dims retVal;
    458     if (data_.is_mkl_tensor_) {
    459       size_t dimensions = sizeof(data_.sizes_) / sizeof(data_.sizes_[0]);
    460       for (size_t i = 0; i < dimensions; i++) {
    461         if (data_.sizes_[i] != INVALID_DIM_SIZE)
    462           retVal.push_back(data_.sizes_[i]);
    463       }
    464     } else {
    465       CHECK_EQ(data_.is_mkl_tensor_, true);
    466     }
    467     return retVal;
    468   }
    469 
    470   inline int64 DimSize(int index) const {
    471     CHECK_LT(index, sizeof(data_.sizes_) / sizeof(data_.sizes_[0]));
    472     return data_.sizes_[index];
    473   }
    474 
    475   /// Return TensorShape that describes the Tensorflow shape of the tensor
    476   /// represented by this MklShape.
    477   inline TensorShape GetTfShape() const {
    478     CHECK_EQ(data_.is_mkl_tensor_, true);
    479 
    480     std::vector<int32> shape(data_.dimension_, -1);
    481     if (data_.tf_data_format_ != memory::format::blocked) {
    482       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
    483         shape[idx] = data_.sizes_[TfDimIdx(idx)];
    484       }
    485     } else {
    486       // If Tensorflow shape is in Blocked format, then we don't have dimension
    487       // map for it. So we just create Tensorflow shape from sizes in the
    488       // specified order.
    489       for (size_t idx = 0; idx < data_.dimension_; ++idx) {
    490         shape[idx] = data_.sizes_[idx];
    491       }
    492     }
    493 
    494     TensorShape ts;
    495     bool ret = TensorShapeUtils::MakeShape(shape, &ts).ok();
    496     CHECK_EQ(ret, true);
    497     return ts;
    498   }
    499 
    500   inline void SetElemType(memory::data_type dt) { data_.T_ = dt; }
    501   inline const memory::data_type GetElemType() { return data_.T_; }
    502 
    503   inline void SetMklLayout(memory::primitive_desc* pd) {
    504     CHECK_NOTNULL(pd);
    505     data_.mkl_md_ = pd->desc().data;
    506   }
    507 
    508   inline void SetMklLayout(memory::desc* md) {
    509     CHECK_NOTNULL(md);
    510     data_.mkl_md_ = md->data;
    511   }
    512 
    513   inline const memory::desc GetMklLayout() const {
    514     return memory::desc(data_.mkl_md_);
    515   }
    516 
    517   inline memory::format GetTfDataFormat() const {
    518     return data_.tf_data_format_;
    519   }
    520   /// We don't create primitive_descriptor for TensorFlow layout now.
    521   /// We use lazy evaluation and create it only when needed. Input format can
    522   /// also be Blocked format.
    523   inline void SetTfLayout(size_t dims, const memory::dims& sizes,
    524                           memory::format format) {
    525     CHECK_EQ(dims, sizes.size());
    526     data_.dimension_ = dims;
    527     for (size_t ii = 0; ii < dims; ii++) {
    528       data_.sizes_[ii] = sizes[ii];
    529     }
    530     data_.tf_data_format_ = format;
    531     if (format != memory::format::blocked) {
    532       SetTfDimOrder(dims, format);
    533     }
    534   }
    535 
    536   inline const memory::desc GetTfLayout() const {
    537     memory::dims dims;
    538     for (size_t ii = 0; ii < data_.dimension_; ii++) {
    539       dims.push_back(data_.sizes_[ii]);
    540     }
    541 
    542     // Create Blocked memory desc if input TF format was set like that.
    543     if (data_.tf_data_format_ == memory::format::blocked) {
    544       auto strides = CalculateTFStrides(dims);
    545       return CreateBlockedMemDescHelper(dims, strides, data_.T_);
    546     } else {
    547       return memory::desc(dims, data_.T_, data_.tf_data_format_);
    548     }
    549   }
    550 
    551   inline const memory::desc GetCurLayout() const {
    552     return IsMklTensor() ? GetMklLayout() : GetTfLayout();
    553   }
    554 
    555   // nhasabni - I've removed SetTfDimOrder that was setting default order in
    556   // case of MKL-ML. We don't need a case of default dimension order because
    557   // when an operator that does not get data_format attribute gets all inputs
    558   // in Tensorflow format, it will produce output in Tensorflow format.
    559   inline void SetTfDimOrder(const size_t dimension, const mkldnn_dims_t map) {
    560     CHECK(dimension == data_.dimension_);
    561     for (size_t ii = 0; ii < dimension; ii++) {
    562       data_.map_[ii] = map[ii];
    563     }
    564   }
    565 
    566   inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
    567     // TODO(nhasabni): Why do we restrict this to 4D?
    568     CHECK_EQ(dimension, 4);
    569     CHECK(dimension == data_.dimension_);
    570     data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
    571     data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
    572     data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
    573     data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
    574   }
    575 
    576   inline void SetTfDimOrder(const size_t dimension, memory::format format) {
    577     TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
    578     SetTfDimOrder(dimension, data_format);
    579   }
    580 
    581   inline const mkldnn_dim_t* GetTfToMklDimMap() const { return &data_.map_[0]; }
    582   inline size_t TfDimIdx(int index) const { return data_.map_[index]; }
    583   inline int64 TfDimSize(int index) const {
    584     return data_.sizes_[TfDimIdx(index)];
    585   }
    586 
    587   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    588   /// corresponds to MKL's Channel dimension.
    589   inline bool IsMklChannelDim(int d) const {
    590     return TfDimIdx(d) == MklDnnDims::Dim_C;
    591   }
    592   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    593   /// corresponds to MKL's Batch dimension.
    594   inline bool IsMklBatchDim(int d) const {
    595     return TfDimIdx(d) == MklDnnDims::Dim_N;
    596   }
    597   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    598   /// corresponds to MKL's Width dimension.
    599   inline bool IsMklWidthDim(int d) const {
    600     return TfDimIdx(d) == MklDnnDims::Dim_W;
    601   }
    602   /// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
    603   /// corresponds to MKL's Height dimension.
    604   inline bool IsMklHeightDim(int d) const {
    605     return TfDimIdx(d) == MklDnnDims::Dim_H;
    606   }
    607 
    608   /// Check if the TF-Mkl dimension ordering map specifies if the input
    609   /// tensor is in NCHW format.
    610   inline bool IsTensorInNCHWFormat() const {
    611     TensorFormat data_format = FORMAT_NCHW;
    612     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    613             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    614             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    615             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    616   }
    617 
    618   /// Check if the TF-Mkl dimension ordering map specifies if the input
    619   /// tensor is in NHWC format.
    620   inline bool IsTensorInNHWCFormat() const {
    621     TensorFormat data_format = FORMAT_NHWC;
    622     return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
    623             IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
    624             IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
    625             IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
    626   }
    627 
    628   /// The following methods are used for serializing and de-serializing the
    629   /// contents of the mklshape object.
    630   /// The data is serialized in this order
    631   /// is_mkl_tensor_ : dimension_ : sizes_ : map_: format_ : T_ : mkl_pd_;
    632 
    633   /// Size of buffer to hold the serialized object, the size is computed by
    634   /// following above mentioned order
    635   inline size_t GetSerializeBufferSize() const { return sizeof(MklShapeData); }
    636 
    637   void SerializeMklDnnShape(unsigned char* buf, size_t buf_size) const {
    638     CHECK(buf_size >= GetSerializeBufferSize())
    639         << "Buffer size is too small to SerializeMklDnnShape";
    640     *reinterpret_cast<MklShapeData*>(buf) = data_;
    641   }
    642 
    643   void DeSerializeMklDnnShape(const unsigned char* buf, size_t buf_size) {
    644     // Make sure buffer holds at least is_mkl_tensor_.
    645     CHECK(buf_size >= sizeof(data_.is_mkl_tensor_))
    646         << "Buffer size is too small in DeSerializeMklDnnShape";
    647 
    648     const bool is_mkl_tensor = *reinterpret_cast<const bool*>(buf);
    649     if (is_mkl_tensor) {  // If it is an MKL Tensor then read the rest
    650       CHECK(buf_size >= GetSerializeBufferSize())
    651           << "Buffer size is too small in DeSerializeMklDnnShape";
    652       data_ = *reinterpret_cast<const MklShapeData*>(buf);
    653     }
    654   }
    655 };
    656 
    657 #endif
    658 
    659 // List of MklShape objects. Used in Concat/Split layers.
    660 
    661 typedef std::vector<MklShape> MklShapeList;
    662 
    663 #ifndef INTEL_MKL_ML
    664 typedef std::vector<MklDnnShape> MklDnnShapeList;
    665 #endif
    666 
    667 // Check if all tensors specified by MklShapes are MKL tensors.
    668 inline bool AreAllMklTensors(const MklShapeList& shapes) {
    669   for (auto& s : shapes) {
    670     if (!s.IsMklTensor()) {
    671       return false;
    672     }
    673   }
    674   return true;
    675 }
    676 
    677 #ifdef INTEL_MKL_ML
    678 template <typename T>
    679 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
    680                              const MklShape& mkl_shape) {
    681   Tensor output_tensor;
    682   TensorShape output_shape;
    683 
    684   for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
    685     // Outermost to innermost dimension
    686     output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
    687   }
    688 
    689   // Allocate output tensor.
    690   context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
    691 
    692   dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
    693   void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
    694   void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
    695 
    696   if (mkl_tensor.NumElements() != 0) {
    697     mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
    698   }
    699 
    700   return output_tensor;
    701 }
    702 #else
    703 template <typename T>
    704 inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
    705                              const MklDnnShape& mkl_shape) {
    706   Tensor output_tensor;
    707   TensorShape output_shape;
    708 
    709   TF_CHECK_OK(
    710       Status(error::Code::UNIMPLEMENTED, "Unimplemented conversion function"));
    711 
    712   return output_tensor;
    713 }
    714 #endif
    715 
    716 // Get the MKL shape from the second string tensor
    717 inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
    718   mklshape->DeSerializeMklShape(
    719       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    720           .flat<uint8>()
    721           .data(),
    722       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    723               .flat<uint8>()
    724               .size() *
    725           sizeof(uint8));
    726 }
    727 
    728 #ifndef INTEL_MKL_ML
    729 inline void GetMklShape(OpKernelContext* ctext, int n, MklDnnShape* mklshape) {
    730   mklshape->DeSerializeMklDnnShape(
    731       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    732           .flat<uint8>()
    733           .data(),
    734       ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
    735               .flat<uint8>()
    736               .size() *
    737           sizeof(uint8));
    738 }
    739 #endif
    740 
    741 // Gets the actual input
    742 inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
    743   return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
    744 }
    745 
    746 inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
    747                             OpInputList* input_tensors) {
    748   CHECK_NOTNULL(input_tensors);
    749   ctext->input_list(name, input_tensors);
    750 }
    751 
    752 #ifdef INTEL_MKL_ML
    753 
    754 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
    755                             MklShapeList* mkl_shapes) {
    756   OpInputList input_mkl_tensors;
    757   GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
    758 
    759   for (int i = 0; i < input_mkl_tensors.size(); i++) {
    760     (*mkl_shapes)[i].DeSerializeMklShape(
    761         input_mkl_tensors[i].flat<uint8>().data(),
    762         input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
    763   }
    764 }
    765 
    766 #else
    767 
    768 inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
    769                             MklDnnShapeList* mkl_shapes) {
    770   OpInputList input_mkl_tensors;
    771   GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
    772 
    773   for (int i = 0; i < input_mkl_tensors.size(); i++) {
    774     (*mkl_shapes)[i].DeSerializeMklDnnShape(
    775         input_mkl_tensors[i].flat<uint8>().data(),
    776         input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
    777   }
    778 }
    779 
    780 #endif
    781 
    782 #ifndef INTEL_MKL_ML
    783 /// Get shape of input tensor pointed by 'input_idx' in TensorShape format.
    784 /// If the input tensor is in MKL layout, then obtains TensorShape from
    785 /// MklShape.
    786 inline TensorShape GetTfShape(OpKernelContext* context, size_t input_idx) {
    787   // Sanity check.
    788   CHECK_NOTNULL(context);
    789   CHECK_LT(input_idx, context->num_inputs());
    790 
    791   MklDnnShape input_mkl_shape;
    792   GetMklShape(context, input_idx, &input_mkl_shape);
    793   if (input_mkl_shape.IsMklTensor()) {
    794     return input_mkl_shape.GetTfShape();
    795   } else {
    796     const Tensor& t = MklGetInput(context, input_idx);
    797     return t.shape();
    798   }
    799 }
    800 #endif
    801 
    802 // Allocate the second output tensor that will contain
    803 // the MKL shape serialized
    804 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
    805                                       const MklShape& mkl_shape) {
    806   Tensor* second_tensor = nullptr;
    807   TensorShape second_shape;
    808   second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
    809   OP_REQUIRES_OK(ctext, ctext->allocate_output(
    810                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
    811                             second_shape, &second_tensor));
    812   mkl_shape.SerializeMklShape(
    813       second_tensor->flat<uint8>().data(),
    814       second_tensor->flat<uint8>().size() * sizeof(uint8));
    815 }
    816 
    817 #ifndef INTEL_MKL_ML
    818 // Allocate the second output tensor that will contain
    819 // the MKL shape serialized
    820 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
    821                                       const MklDnnShape& mkl_shape) {
    822   Tensor* second_tensor = nullptr;
    823   TensorShape second_shape;
    824   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
    825   OP_REQUIRES_OK(ctext, ctext->allocate_output(
    826                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
    827                             second_shape, &second_tensor));
    828   mkl_shape.SerializeMklDnnShape(
    829       second_tensor->flat<uint8>().data(),
    830       second_tensor->flat<uint8>().size() * sizeof(uint8));
    831 }
    832 #endif
    833 
    834 // Allocate the output tensor, create a second output tensor that will contain
    835 // the MKL shape serialized
    836 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
    837                                       Tensor** output,
    838                                       const TensorShape& tf_shape,
    839                                       const MklShape& mkl_shape) {
    840   Tensor* second_tensor = nullptr;
    841   TensorShape second_shape;
    842   second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
    843   OP_REQUIRES_OK(
    844       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
    845                                     tf_shape, output));
    846   OP_REQUIRES_OK(ctext, ctext->allocate_output(
    847                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
    848                             second_shape, &second_tensor));
    849   mkl_shape.SerializeMklShape(
    850       second_tensor->flat<uint8>().data(),
    851       second_tensor->flat<uint8>().size() * sizeof(uint8));
    852 }
    853 
    854 #ifndef INTEL_MKL_ML
    855 // Allocate the output tensor, create a second output tensor that will contain
    856 // the MKL shape serialized
    857 inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
    858                                       Tensor** output,
    859                                       const TensorShape& tf_shape,
    860                                       const MklDnnShape& mkl_shape) {
    861   Tensor* second_tensor = nullptr;
    862   TensorShape second_shape;
    863   second_shape.AddDim(mkl_shape.GetSerializeBufferSize());
    864   OP_REQUIRES_OK(
    865       ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
    866                                     tf_shape, output));
    867   OP_REQUIRES_OK(ctext, ctext->allocate_output(
    868                             GetTensorMetaDataIndex(n, ctext->num_outputs()),
    869                             second_shape, &second_tensor));
    870   mkl_shape.SerializeMklDnnShape(
    871       second_tensor->flat<uint8>().data(),
    872       second_tensor->flat<uint8>().size() * sizeof(uint8));
    873 }
    874 #endif
    875 
    876 // Allocates a temp tensor and returns the data buffer for temporary storage.
    877 // Currently
    878 #ifndef INTEL_MKL_ML
    879 template <typename T>
    880 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
    881                            const memory::primitive_desc& pd, void** buf_out) {
    882   TensorShape tf_shape;
    883 
    884   tf_shape.AddDim(pd.get_size() / sizeof(T) + 1);
    885   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
    886                                                  tf_shape, tensor_out));
    887   *buf_out = static_cast<void*>(tensor_out->flat<T>().data());
    888 }
    889 #endif
    890 
    891 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
    892                            dnnLayout_t lt_buff, void** buf_out) {
    893   TensorShape tf_shape;
    894 
    895   tf_shape.AddDim(
    896       dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(lt_buff)) /
    897           sizeof(float) +
    898       1);
    899   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::v(),
    900                                                  tf_shape, tensor_out));
    901   *buf_out = static_cast<void*>(tensor_out->flat<float>().data());
    902 }
    903 
    904 template <typename T>
    905 inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
    906                            TensorShape tf_shape) {
    907   OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
    908                                                  tf_shape, tensor_out));
    909 }
    910 
    911 inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
    912                                 const size_t* sizes) {
    913   // MKL requires strides in NCHW
    914   if (data_format == FORMAT_NHWC) {
    915     strides[0] = sizes[2];
    916     strides[1] = sizes[0] * sizes[2];
    917     strides[2] = 1;
    918     strides[3] = sizes[0] * sizes[1] * sizes[2];
    919   } else {
    920     strides[0] = 1;
    921     strides[1] = sizes[0];
    922     strides[2] = sizes[0] * sizes[1];
    923     strides[3] = sizes[0] * sizes[1] * sizes[2];
    924   }
    925 }
    926 
    927 inline void MklSizesToTFSizes(OpKernelContext* context,
    928                               TensorFormat data_format_,
    929                               const MklShape& mkl_shape,
    930                               TensorShape* tf_shape) {
    931   size_t tf_dim = mkl_shape.GetDimension();
    932   const size_t* tf_sizes = mkl_shape.GetSizes();
    933 
    934   OP_REQUIRES(context, tf_dim == 4,
    935               errors::InvalidArgument("MKLSizesToTFSizes: size must be 4-dim"));
    936   std::vector<int32> sizes;
    937 
    938   sizes.push_back(tf_sizes[3]);
    939 
    940   if (data_format_ == FORMAT_NHWC) {
    941     sizes.push_back(tf_sizes[1]);
    942     sizes.push_back(tf_sizes[0]);
    943     sizes.push_back(tf_sizes[2]);
    944   } else {
    945     sizes.push_back(tf_sizes[2]);
    946     sizes.push_back(tf_sizes[1]);
    947     sizes.push_back(tf_sizes[0]);
    948   }
    949 
    950   OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
    951 }
    952 
    953 inline int32 GetMklTensorDimIndex(char dimension) {
    954   switch (dimension) {
    955     case 'N':
    956       return MklDims::N;
    957     case 'C':
    958       return MklDims::C;
    959     case 'H':
    960       return MklDims::H;
    961     case 'W':
    962       return MklDims::W;
    963     default:
    964       LOG(FATAL) << "Invalid dimension: " << dimension;
    965       return -1;  // Avoid compiler warning about missing return value
    966   }
    967 }
    968 
    969 inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
    970   int index = GetMklTensorDimIndex(dimension);
    971   CHECK(index >= 0 && index < mkl_shape.GetDimension())
    972       << "Invalid index from the dimension: " << index << ", " << dimension;
    973   return mkl_shape.dim_size(index);
    974 }
    975 
    976 inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
    977                                  int idx_out) {
    978   int num_inputs = context->num_inputs();
    979   int num_outputs = context->num_outputs();
    980   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
    981   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
    982   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
    983   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
    984 
    985   const Tensor& data = context->input(idx_data_in);
    986   const Tensor& meta = context->input(idx_meta_in);
    987   Tensor output(data.dtype());
    988   Tensor meta_output(meta.dtype());
    989 
    990   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
    991   CHECK(output.CopyFrom(data, data.shape()));
    992   CHECK(meta_output.CopyFrom(meta, meta.shape()));
    993   context->set_output(idx_data_out, output);
    994   context->set_output(idx_meta_out, meta_output);
    995 }
    996 
    997 #ifdef INTEL_MKL_ML
    998 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
    999                                          int idx_out,
   1000                                          const TensorShape& shape) {
   1001   int num_inputs = context->num_inputs();
   1002   int num_outputs = context->num_outputs();
   1003   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1004   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1005 
   1006   const Tensor& data = context->input(idx_data_in);
   1007   MklShape mkl_shape_output;
   1008   mkl_shape_output.SetMklTensor(false);
   1009   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
   1010   Tensor output(data.dtype());
   1011   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
   1012   CHECK(output.CopyFrom(data, shape));
   1013   context->set_output(idx_data_out, output);
   1014 }
   1015 #else
   1016 inline void CopyTfTensorInToOutWithShape(OpKernelContext* context, int idx_in,
   1017                                          int idx_out,
   1018                                          const TensorShape& shape) {
   1019   int num_inputs = context->num_inputs();
   1020   int num_outputs = context->num_outputs();
   1021   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1022   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1023 
   1024   const Tensor& data = context->input(idx_data_in);
   1025   MklDnnShape mkl_shape_output;
   1026   mkl_shape_output.SetMklTensor(false);
   1027   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
   1028   Tensor output(data.dtype());
   1029   // TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
   1030   CHECK(output.CopyFrom(data, shape));
   1031   context->set_output(idx_data_out, output);
   1032 }
   1033 #endif
   1034 
   1035 #ifdef INTEL_MKL_ML
   1036 
   1037 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
   1038                                    int idx_out) {
   1039   int num_inputs = context->num_inputs();
   1040   int num_outputs = context->num_outputs();
   1041   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1042   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1043 
   1044   MklShape mkl_shape_output;
   1045   mkl_shape_output.SetMklTensor(false);
   1046   AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
   1047   if (IsRefType(context->input_dtype(idx_data_in))) {
   1048     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1049   } else {
   1050     context->set_output(idx_data_out, context->input(idx_data_in));
   1051   }
   1052 }
   1053 
   1054 #else
   1055 
   1056 inline void ForwardTfTensorInToOut(OpKernelContext* context, int idx_in,
   1057                                    int idx_out) {
   1058   int num_inputs = context->num_inputs();
   1059   int num_outputs = context->num_outputs();
   1060   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1061   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1062 
   1063   MklDnnShape dnn_shape_output;
   1064   dnn_shape_output.SetMklTensor(false);
   1065   AllocateOutputSetMklShape(context, idx_out, dnn_shape_output);
   1066   if (IsRefType(context->input_dtype(idx_data_in))) {
   1067     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1068   } else {
   1069     context->set_output(idx_data_out, context->input(idx_data_in));
   1070   }
   1071 }
   1072 
   1073 #endif
   1074 
   1075 inline void ForwardMklTensorInToOut(OpKernelContext* context, int idx_in,
   1076                                     int idx_out) {
   1077   int num_inputs = context->num_inputs();
   1078   int num_outputs = context->num_outputs();
   1079   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1080   int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
   1081   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1082   int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
   1083 
   1084   if (IsRefType(context->input_dtype(idx_data_in))) {
   1085     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1086     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
   1087   } else {
   1088     context->set_output(idx_data_out, context->input(idx_data_in));
   1089     context->set_output(idx_meta_out, context->input(idx_meta_in));
   1090   }
   1091 }
   1092 
   1093 #ifndef INTEL_MKL_ML
   1094 inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context,
   1095                                                 int idx_in, int idx_out,
   1096                                                 const MklDnnShape& mkl_shape) {
   1097   int num_inputs = context->num_inputs();
   1098   int num_outputs = context->num_outputs();
   1099   int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
   1100   int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
   1101 
   1102   AllocateOutputSetMklShape(context, idx_out, mkl_shape);
   1103 
   1104   if (IsRefType(context->input_dtype(idx_data_in))) {
   1105     context->forward_ref_input_to_ref_output(idx_data_in, idx_data_out);
   1106   } else {
   1107     context->set_output(idx_data_out, context->input(idx_data_in));
   1108   }
   1109 }
   1110 #endif
   1111 
   1112 // Forward the MKL shape ONLY (used in elementwise and other ops where
   1113 // we call the eigen implementation and MKL shape is not used)
   1114 inline void ForwardMklMetaDataInToOut(OpKernelContext* context,
   1115                                       uint32 idx_data_in, uint32_t idx_data_out) {
   1116   uint32 idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs());
   1117   uint32 idx_meta_out =
   1118       GetTensorMetaDataIndex(idx_data_out, context->num_outputs());
   1119 
   1120   if (IsRefType(context->input_dtype(idx_data_in))) {
   1121     context->forward_ref_input_to_ref_output(idx_meta_in, idx_meta_out);
   1122   } else {
   1123     context->set_output(idx_meta_out, context->input(idx_meta_in));
   1124   }
   1125 }
   1126 
   1127 // Set a dummy MKL shape (called when the output is in TF format)
   1128 inline void SetDummyMklShapeOutput(OpKernelContext* context,
   1129                                    uint32 idx_data_out) {
   1130   MklShape mkl_shape_output;
   1131   mkl_shape_output.SetMklTensor(false);
   1132   AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output);
   1133 }
   1134 
   1135 #ifdef INTEL_MKL_ML
   1136 // We don't need these functions in MKLDNN. We have defined equality operator
   1137 // on MklDnnShape class directly.
   1138 
   1139 // Checks if the TF shape for both MKL tensors is the same or not
   1140 // Returns: true if both TF shapes are the same, false otherwise
   1141 inline bool MklCompareShapes(const MklShape* input_shape_0,
   1142                              const MklShape* input_shape_1) {
   1143   // Check for number of dimensions
   1144   if (input_shape_0->GetDimension() != input_shape_1->GetDimension()) {
   1145     return false;
   1146   }
   1147 
   1148   // Check size of each dimension
   1149   size_t ndims = input_shape_0->GetDimension();
   1150   for (size_t i = 0; i < ndims; i++) {
   1151     if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
   1152       return false;
   1153     }
   1154   }
   1155 
   1156   return true;
   1157 }
   1158 
   1159 // Checks if the TF shape for both tensors is the same or not
   1160 // Returns: true if TF shapes for both are the same, false otherwise
   1161 inline bool MklCompareShapes(const MklShape* input_shape_0,
   1162                              const TensorShape* input_shape_1) {
   1163   // Check for number of dimensions
   1164   if (input_shape_0->GetDimension() != input_shape_1->dims()) {
   1165     return false;
   1166   }
   1167 
   1168   // Check size of each dimension
   1169   size_t ndims = input_shape_0->GetDimension();
   1170   for (size_t i = 0; i < ndims; i++) {
   1171     if (input_shape_0->tf_dim_size(i) != input_shape_1->dim_size(i)) {
   1172       return false;
   1173     }
   1174   }
   1175 
   1176   return true;
   1177 }
   1178 
   1179 // Checks if the TF shape for both tensors is the same or not
   1180 // Returns: true if TF shapes for both are the same, false otherwise
   1181 inline bool MklCompareShapes(const TensorShape* input_shape_0,
   1182                              const MklShape* input_shape_1) {
   1183   return MklCompareShapes(input_shape_1, input_shape_0);
   1184 }
   1185 
   1186 // Checks if the TF shape for both tensors is the same or not
   1187 // Returns: true if TF shapes for both are the same, false otherwise
   1188 inline bool MklCompareShapes(const TensorShape* input_shape_0,
   1189                              const TensorShape* input_shape_1) {
   1190   // Check for number of dimensions
   1191   if (input_shape_0->dims() != input_shape_1->dims()) {
   1192     return false;
   1193   }
   1194 
   1195   // Check size of each dimension
   1196   size_t ndims = input_shape_0->dims();
   1197   for (size_t i = 0; i < ndims; i++) {
   1198     if (input_shape_0->dim_size(i) != input_shape_1->dim_size(i)) {
   1199       return false;
   1200     }
   1201   }
   1202 
   1203   return true;
   1204 }
   1205 #endif
   1206 
   1207 // These functions do not compile with MKL-DNN since mkl.h is missing.
   1208 // We may need to remove them later.
   1209 // TODO(intel_tf): Remove this routine when faster MKL layout conversion is
   1210 // out.
   1211 inline void MklNHWCToNCHW(const Tensor& input, Tensor** output) {
   1212   const float* buf_in = input.flat<float>().data();
   1213   float* buf_out = (*output)->flat<float>().data();
   1214 
   1215   int64 N = input.dim_size(0);
   1216   int64 H = input.dim_size(1);
   1217   int64 W = input.dim_size(2);
   1218   int64 C = input.dim_size(3);
   1219   int64 stride_n = H * W * C;
   1220 #pragma omp parallel for num_threads(16)
   1221   for (int64 n = 0; n < N; ++n) {
   1222     mkl_somatcopy('R', 'T', H * W, C, 1, buf_in + n * stride_n, C,
   1223                   buf_out + n * stride_n, H * W);
   1224   }
   1225 }
   1226 
   1227 inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
   1228   const float* buf_in = input.flat<float>().data();
   1229   float* buf_out = (*output)->flat<float>().data();
   1230 
   1231   int64 N = (*output)->dim_size(0);
   1232   int64 H = (*output)->dim_size(1);
   1233   int64 W = (*output)->dim_size(2);
   1234   int64 C = (*output)->dim_size(3);
   1235   int64 stride_n = H * W * C;
   1236 #pragma omp parallel for num_threads(16)
   1237   for (int64 n = 0; n < N; ++n) {
   1238     mkl_somatcopy('R', 'T', C, H * W, 1, buf_in + n * stride_n, H * W,
   1239                   buf_out + n * stride_n, C);
   1240   }
   1241 }
   1242 
   1243 // -------------------------------------------------------------------
   1244 
   1245 #ifndef INTEL_MKL_ML
   1246 
   1247 /// Return MKL-DNN data type (memory::data_type) for input type T
   1248 ///
   1249 /// @input None
   1250 /// @return memory::data_type corresponding to type T
   1251 template <typename T>
   1252 static memory::data_type MklDnnType();
   1253 
   1254 /// Instantiation for float type. Add similar instantiations for other
   1255 /// type if needed.
   1256 template <>
   1257 memory::data_type MklDnnType<float>() {
   1258   return memory::data_type::f32;
   1259 }
   1260 
   1261 /// Map TensorFlow's data format into MKL-DNN data format
   1262 ///
   1263 /// @input: TensorFlow data format
   1264 /// @return: memory::format corresponding to TensorFlow data format;
   1265 ///          Fails with an error if invalid data format.
   1266 inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
   1267   if (format == FORMAT_NHWC)
   1268     return memory::format::nhwc;
   1269   else if (format == FORMAT_NCHW)
   1270     return memory::format::nchw;
   1271   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
   1272   // Return to get rid of compiler warning
   1273   return memory::format::format_undef;
   1274 }
   1275 
   1276 /// Map MKL-DNN data format to TensorFlow's data format
   1277 ///
   1278 /// @input: memory::format
   1279 /// @return: Tensorflow data format corresponding to memory::format
   1280 ///          Fails with an error if invalid data format.
   1281 inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
   1282   if (format == memory::format::nhwc)
   1283     return FORMAT_NHWC;
   1284   else if (format == memory::format::nchw)
   1285     return FORMAT_NCHW;
   1286   TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
   1287 
   1288   // Return to prevent compiler warnings, otherwise TF_CHECK_OK will ensure
   1289   // that we don't come here.
   1290   return FORMAT_NHWC;
   1291 }
   1292 
   1293 /// Map TensorShape object into memory::dims required by MKL-DNN
   1294 ///
   1295 /// This function will simply map input TensorShape into MKL-DNN dims
   1296 /// naively. So it will preserve the order of dimensions. E.g., if
   1297 /// input tensor is in NHWC format, then dims will be in NHWC format
   1298 /// also.
   1299 ///
   1300 /// @input TensorShape object in shape
   1301 /// @return memory::dims corresponding to TensorShape
   1302 inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
   1303   memory::dims dims(shape.dims());
   1304   for (int d = 0; d < shape.dims(); ++d) {
   1305     dims[d] = shape.dim_size(d);
   1306   }
   1307   return dims;
   1308 }
   1309 
   1310 /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
   1311 ///
   1312 /// This function is a specific one than above function. It will map input
   1313 /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
   1314 /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
   1315 /// will be in NCHW format, and not in NHWC format.
   1316 ///
   1317 /// @input TensorShape object in shape
   1318 /// @return memory::dims in MKL-DNN required NCHW format
   1319 inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
   1320                                               TensorFormat format) {
   1321   // Check validity of format.
   1322   CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
   1323            memory::format::format_undef);
   1324 
   1325   int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
   1326   int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
   1327   int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
   1328   int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
   1329 
   1330   // MKL-DNN requires dimensions in NCHW format.
   1331   return memory::dims({n, c, h, w});
   1332 }
   1333 
   1334 /// Overloaded version of function above. Input parameters are
   1335 /// self-explanatory.
   1336 inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
   1337                                      TensorFormat format) {
   1338   // Check validity of format.
   1339   CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
   1340            memory::format::format_undef);
   1341 
   1342   int n = in_dims[GetTensorDimIndex(format, 'N')];
   1343   int c = in_dims[GetTensorDimIndex(format, 'C')];
   1344   int h = in_dims[GetTensorDimIndex(format, 'H')];
   1345   int w = in_dims[GetTensorDimIndex(format, 'W')];
   1346 
   1347   // MKL-DNN requires dimensions in NCHW format.
   1348   return memory::dims({n, c, h, w});
   1349 }
   1350 
   1351 /// Map MklDnn memory::dims object into TensorShape object.
   1352 ///
   1353 /// This function will simply map input shape in MKL-DNN memory::dims format
   1354 /// in Tensorflow's TensorShape object by perserving dimension order.
   1355 ///
   1356 /// @input MKL-DNN memory::dims object
   1357 /// @output TensorShape corresponding to memory::dims
   1358 inline TensorShape MklDnnDimsToTFShape(const memory::dims& dims) {
   1359   std::vector<int32> shape(dims.size(), -1);
   1360   for (int d = 0; d < dims.size(); d++) {
   1361     shape[d] = dims[d];
   1362   }
   1363 
   1364   TensorShape ret;
   1365   CHECK_EQ(TensorShapeUtils::MakeShape(shape, &ret).ok(), true);
   1366   return ret;
   1367 }
   1368 
   1369 /// Function to calculate strides given tensor shape in Tensorflow order
   1370 /// E.g., if dims_tf_order is {1, 2, 3, 4}, then as per Tensorflow convention,
   1371 /// dimesion with size 1 is outermost dimension; while dimension with size 4 is
   1372 /// innermost dimension. So strides for this tensor would be {4 * 3 * 2,
   1373 /// 4 * 3, 4, 1}, i.e., {24, 12, 4, 1}.
   1374 ///
   1375 /// @input Tensorflow shape in memory::dims type
   1376 /// @return memory::dims containing strides for the tensor.
   1377 inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) {
   1378   CHECK_GT(dims_tf_order.size(), 0);
   1379   memory::dims strides(dims_tf_order.size());
   1380   int last_dim_idx = dims_tf_order.size() - 1;
   1381   strides[last_dim_idx] = 1;
   1382   for (int d = last_dim_idx - 1; d >= 0; d--) {
   1383     strides[d] = strides[d + 1] * dims_tf_order[d + 1];
   1384   }
   1385   return strides;
   1386 }
   1387 
   1388 inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
   1389   // MKL-DNN only supports zero padding.
   1390   return padding_kind::zero;
   1391 }
   1392 
   1393 /// Helper function to create memory descriptor in Blocked format
   1394 ///
   1395 /// @input: Tensor dimensions
   1396 /// @input: strides corresponding to dimensions. One can use utility
   1397 ///         function such as CalculateTFStrides to compute strides
   1398 ///         for given dimensions.
   1399 /// @return: memory::desc object corresponding to blocked memory format
   1400 ///          for given dimensions and strides.
   1401 inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
   1402                                                const memory::dims& strides,
   1403                                                memory::data_type dtype) {
   1404   CHECK_EQ(dim.size(), strides.size());
   1405 
   1406   // We have to construct memory descriptor in a C style. This is not at all
   1407   // ideal but MKLDNN does not offer any API to construct descriptor in
   1408   // blocked format except a copy constructor that accepts
   1409   // mkldnn_memory_desc_t.
   1410   mkldnn_memory_desc_t md;
   1411   md.primitive_kind = mkldnn_memory;
   1412   md.ndims = dim.size();
   1413   md.format = mkldnn_blocked;
   1414   md.data_type = memory::convert_to_c(dtype);
   1415 
   1416   for (size_t i = 0; i < dim.size(); i++) {
   1417     md.layout_desc.blocking.block_dims[i] = 1;
   1418     md.layout_desc.blocking.strides[1][i] = 1;
   1419     md.layout_desc.blocking.strides[0][i] = strides[i];
   1420     md.layout_desc.blocking.padding_dims[i] = dim[i];
   1421     md.layout_desc.blocking.offset_padding_to_data[i] = 0;
   1422     md.dims[i] = dim[i];
   1423   }
   1424   md.layout_desc.blocking.offset_padding = 0;
   1425 
   1426   return memory::desc(md);
   1427 }
   1428 
   1429 /*
   1430  * Class to represent all the resources corresponding to a tensor in TensorFlow
   1431  * that are required to execute an operation (such as Convolution).
   1432  */
   1433 template <typename T>
   1434 class MklDnnData {
   1435  private:
   1436   /// MKL-DNN memory primitive for input user memory
   1437   memory* user_memory_;
   1438 
   1439   /// MKL-DNN memory primitive in case input or output reorder is needed.
   1440   memory* reorder_memory_;
   1441 
   1442   /// Operations memory descriptor
   1443   memory::desc* op_md_;
   1444 
   1445   /// CPU engine on which operation will be executed
   1446   const engine* cpu_engine_;
   1447 
   1448  public:
   1449   explicit MklDnnData(const engine* e)
   1450       : user_memory_(nullptr),
   1451         reorder_memory_(nullptr),
   1452         op_md_(nullptr),
   1453         cpu_engine_(e) {}
   1454 
   1455   ~MklDnnData() {
   1456     cpu_engine_ = nullptr;  // We don't own this.
   1457     delete (user_memory_);
   1458     delete (reorder_memory_);
   1459     delete (op_md_);
   1460   }
   1461 
   1462   inline void* GetTensorBuffer(const Tensor* tensor) const {
   1463     CHECK_NOTNULL(tensor);
   1464     return const_cast<void*>(
   1465         static_cast<const void*>(tensor->flat<T>().data()));
   1466   }
   1467 
   1468   /// Set user memory primitive using specified dimensions, memory format and
   1469   /// data_buffer. Function automatically uses element data type by using
   1470   /// input type T used for creating call object.
   1471   ///
   1472   /// In a nutshell, function allows user to describe the input tensor to
   1473   /// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
   1474   /// memory format HWIO, and the buffer that contains actual values is
   1475   /// pointed by data_buffer.
   1476   inline void SetUsrMem(const memory::dims& dim, memory::format fm,
   1477                         void* data_buffer = nullptr) {
   1478     auto md = memory::desc(dim, MklDnnType<T>(), fm);
   1479     SetUsrMem(md, data_buffer);
   1480   }
   1481 
   1482   inline void SetUsrMem(const memory::dims& dim, memory::format fm,
   1483                         const Tensor* tensor) {
   1484     CHECK_NOTNULL(tensor);
   1485     SetUsrMem(dim, fm, GetTensorBuffer(tensor));
   1486   }
   1487 
   1488   /// Helper function to create memory descriptor in Blocked format
   1489   ///
   1490   /// @input: Tensor dimensions
   1491   /// @input: strides corresponding to dimensions. One can use utility
   1492   ///         function such as CalculateTFStrides to compute strides
   1493   ///         for given dimensions.
   1494   /// @return: memory::desc object corresponding to blocked memory format
   1495   ///          for given dimensions and strides.
   1496   static inline memory::desc CreateBlockedMemDesc(const memory::dims& dim,
   1497                                                   const memory::dims& strides) {
   1498     return CreateBlockedMemDescHelper(dim, strides, MklDnnType<T>());
   1499   }
   1500 
   1501   /// A version of SetUsrMem call that allows user to create memory in blocked
   1502   /// format. So in addition to accepting dimensions, it also accepts strides.
   1503   /// This allows user to create memory for tensor in a format that is not
   1504   /// supported by MKLDNN. E.g., MKLDNN does not support tensor format for 6
   1505   /// dimensional tensor as a native format. But by using blocked format, a user
   1506   /// can create memory for 6D tensor.
   1507   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
   1508                         void* data_buffer = nullptr) {
   1509     CHECK_EQ(dim.size(), strides.size());
   1510     auto blocked_md = MklDnnData<T>::CreateBlockedMemDesc(dim, strides);
   1511     SetUsrMem(blocked_md, data_buffer);
   1512   }
   1513 
   1514   inline void SetUsrMem(const memory::dims& dim, const memory::dims& strides,
   1515                         const Tensor* tensor) {
   1516     CHECK_NOTNULL(tensor);
   1517     SetUsrMem(dim, strides, GetTensorBuffer(tensor));
   1518   }
   1519 
   1520   /// A version of function to set user memory primitive that accepts memory
   1521   /// descriptor directly, instead of accepting dimensions and format. This
   1522   /// function is more generic that the one above, but the function above is
   1523   /// sufficient in most cases.
   1524   inline void SetUsrMem(const memory::desc& md, void* data_buffer = nullptr) {
   1525     auto pd = memory::primitive_desc(md, *cpu_engine_);
   1526     SetUsrMem(pd, data_buffer);
   1527   }
   1528 
   1529   /// A version of SetUsrMem with memory descriptor and tensor
   1530   inline void SetUsrMem(const memory::desc& md, const Tensor* tensor) {
   1531     CHECK_NOTNULL(tensor);
   1532     SetUsrMem(md, GetTensorBuffer(tensor));
   1533   }
   1534 
   1535   /// A version of function to set user memory primitive that accepts primitive
   1536   /// descriptor directly, instead of accepting dimensions and format. This
   1537   /// function is more generic that the one above, but the function above is
   1538   /// sufficient in most cases.
   1539   inline void SetUsrMem(const memory::primitive_desc& pd,
   1540                         void* data_buffer = nullptr) {
   1541     CHECK_NOTNULL(cpu_engine_);
   1542     // TODO(nhasabni): can we remove dynamic memory allocation?
   1543     if (data_buffer) {
   1544       user_memory_ = new memory(pd, data_buffer);
   1545     } else {
   1546       user_memory_ = new memory(pd);
   1547     }
   1548   }
   1549 
   1550   /// A version of SetUsrMem with primitive descriptor and tensor
   1551   inline void SetUsrMem(const memory::primitive_desc& pd,
   1552                         const Tensor* tensor) {
   1553     CHECK_NOTNULL(tensor);
   1554     SetUsrMem(pd, GetTensorBuffer(tensor));
   1555   }
   1556 
   1557   /// Get function for user memory primitive.
   1558   inline const memory* GetUsrMem() const { return user_memory_; }
   1559 
   1560   /// Get function for primitive descriptor of user memory primitive.
   1561   inline const memory::primitive_desc GetUsrMemPrimDesc() const {
   1562     CHECK_NOTNULL(user_memory_);
   1563     return user_memory_->get_primitive_desc();
   1564   }
   1565 
   1566   /// Get function for descriptor of user memory.
   1567   inline memory::desc GetUsrMemDesc() {
   1568     // This is ugly. Why MKL-DNN does not provide desc() method of const type??
   1569     const memory::primitive_desc pd = GetUsrMemPrimDesc();
   1570     return const_cast<memory::primitive_desc*>(&pd)->desc();
   1571   }
   1572 
   1573   /// Get function for data buffer of user memory primitive.
   1574   inline void* GetUsrMemDataHandle() const {
   1575     CHECK_NOTNULL(user_memory_);
   1576     return user_memory_->get_data_handle();
   1577   }
   1578 
   1579   /// Set function for data buffer of user memory primitive.
   1580   inline void* SetUsrMemDataHandle(void* data_buffer) {
   1581     CHECK_NOTNULL(user_memory_);
   1582     CHECK_NOTNULL(data_buffer);
   1583     return user_memory_->set_data_handle(data_buffer);
   1584   }
   1585 
   1586   /// Set function for data buffer of user memory primitive.
   1587   inline void SetUsrMemDataHandle(const Tensor* tensor) {
   1588     CHECK_NOTNULL(user_memory_);
   1589     CHECK_NOTNULL(tensor);
   1590     user_memory_->set_data_handle(GetTensorBuffer(tensor));
   1591   }
   1592 
   1593   /// Get the memory primitive for input and output of an op. If inputs
   1594   /// to an op require reorders, then this function returns memory primitive
   1595   /// for reorder. Otherwise, it will return memory primitive for user memory.
   1596   ///
   1597   /// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
   1598   /// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
   1599   /// required for I and F (say I_r is reorder primitive for I; F_r is reorder
   1600   /// primitive for F), then we need I_r and F_r to perform Conv2D.
   1601   inline const memory& GetOpMem() const {
   1602     return reorder_memory_ ? *reorder_memory_ : *user_memory_;
   1603   }
   1604 
   1605   /// Set memory descriptor of an operation in terms of dimensions and memory
   1606   /// format. E.g., For Conv2D, the dimensions would be same as user dimensions
   1607   /// but memory::format would be mkldnn::any because we want MKL-DNN to choose
   1608   /// best layout/format for given input dimensions.
   1609   inline void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
   1610     // TODO(nhasabni): can we remove dynamic memory allocation?
   1611     op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
   1612   }
   1613 
   1614   /// Get function for memory descriptor for an operation
   1615   inline const memory::desc& GetOpMemDesc() const { return *op_md_; }
   1616 
   1617   /// Predicate that checks if we need to reorder user's memory into memory
   1618   /// pointed by op_pd.
   1619   ///
   1620   /// @input: op_pd - memory primitive descriptor of the given input of an
   1621   ///               operation
   1622   /// @return: true in case reorder of input is needed; false, otherwise.
   1623   inline bool IsReorderNeeded(const memory::primitive_desc& op_pd) const {
   1624     CHECK_NOTNULL(user_memory_);
   1625     return op_pd != user_memory_->get_primitive_desc();
   1626   }
   1627 
   1628   /// Predicate that checks if we need to reorder user's memory into memory
   1629   /// based on the provided format.
   1630   ///
   1631   /// @input: target_format - memory format of the given input of an
   1632   ///               operation
   1633   /// @return: true in case reorder of input is needed; false, otherwise.
   1634   inline bool IsReorderNeeded(const memory::format& target_format) const {
   1635     CHECK_NOTNULL(user_memory_);
   1636     return target_format !=
   1637            user_memory_->get_primitive_desc().desc().data.format;
   1638   }
   1639 
   1640   /// Function to create a reorder from memory pointed by from to memory pointed
   1641   /// by to. Returns created primitive.
   1642   inline primitive CreateReorder(const memory* from, const memory* to) const {
   1643     CHECK_NOTNULL(from);
   1644     CHECK_NOTNULL(to);
   1645     return reorder(*from, *to);
   1646   }
   1647 
   1648   /// Function to handle input reordering
   1649   ///
   1650   /// Check if we need to reorder this input of an operation.
   1651   /// Return true and allocate reorder memory primitive if reorder is needed.
   1652   /// Otherwise, return false and do not allocate reorder memory primitive.
   1653   ///
   1654   /// To check if reorder is needed, this function compares memory primitive
   1655   /// descriptor of an operation (op_pd) for the given input with the
   1656   /// user-specified memory primitive descriptor.
   1657   ///
   1658   /// @input: op_pd - memory primitive descriptor of the given input of an
   1659   ///               operation
   1660   /// @input: net - net to which to add reorder primitive in case it is needed.
   1661   /// @return: true in case reorder of input is needed; false, otherwise.
   1662   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1663                                   std::vector<primitive>* net) {
   1664     CHECK_NOTNULL(net);
   1665     CHECK_NOTNULL(user_memory_);
   1666     if (IsReorderNeeded(op_pd)) {
   1667       // TODO(nhasabni): can we remove dynamic memory allocation?
   1668       reorder_memory_ = new memory(op_pd);
   1669       net->push_back(CreateReorder(user_memory_, reorder_memory_));
   1670       return true;
   1671     }
   1672     return false;
   1673   }
   1674 
   1675   /// Overloaded version of above function that accepts memory buffer
   1676   /// where output of reorder needs to be stored.
   1677   ///
   1678   /// @input: op_pd - memory primitive descriptor of the given input of an
   1679   ///               operation
   1680   /// @reorder_data_handle - memory buffer where output of reorder needs to be
   1681   ///                        stored. Primitive does not check if buffer is
   1682   ///                        enough size to write.
   1683   /// @input: net - net to which to add reorder primitive in case it is needed.
   1684   /// @return: true in case reorder of input is needed; false, otherwise.
   1685   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1686                                   void* reorder_data_handle,
   1687                                   std::vector<primitive>* net) {
   1688     CHECK_NOTNULL(net);
   1689     CHECK_NOTNULL(reorder_data_handle);
   1690     CHECK_NOTNULL(user_memory_);
   1691     if (IsReorderNeeded(op_pd)) {
   1692       // TODO(nhasabni): can we remove dynamic memory allocation?
   1693       reorder_memory_ = new memory(op_pd, reorder_data_handle);
   1694       net->push_back(CreateReorder(user_memory_, reorder_memory_));
   1695       return true;
   1696     }
   1697     return false;
   1698   }
   1699 
   1700   /// Another overloaded version of CheckReorderToOpMem that accepts Tensor
   1701   /// where output of reorder needs to be stored.
   1702   ///
   1703   /// @input: op_pd - memory primitive descriptor of the given input of an
   1704   ///               operation
   1705   /// @reorder_tensor - Tensor whose buffer is to be used to store output of
   1706   ///                   reorder. Primitive does not check if buffer is
   1707   ///                   enough size to write.
   1708   /// @input: net - net to which to add reorder primitive in case it is needed.
   1709   /// @return: true in case reorder of input is needed; false, otherwise.
   1710   inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
   1711                                   Tensor* reorder_tensor,
   1712                                   std::vector<primitive>* net) {
   1713     CHECK_NOTNULL(net);
   1714     CHECK_NOTNULL(reorder_tensor);
   1715     return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), net);
   1716   }
   1717 
   1718   /// Function to handle output reorder
   1719   ///
   1720   /// This function performs very similar functionality as input reordering
   1721   /// function above. The only difference is that this function does not add
   1722   /// reorder primitive to the net. The reason for this is: the reorder
   1723   /// primitive for output needs to be added to the list only after operation
   1724   /// has executed. But we need to prepare a temporary buffer in case output
   1725   /// reorder is needed. And this temporary buffer will hold the output of
   1726   /// an operation before it is fed to reorder primitive.
   1727   ///
   1728   /// @input memory primitive descriptor for the given output of an operation
   1729   /// @return: true in case reorder of output is needed; false, otherwise.
   1730   inline bool PrepareReorderToUserMemIfReq(
   1731       const memory::primitive_desc& op_pd) {
   1732     CHECK_NOTNULL(user_memory_);
   1733     if (IsReorderNeeded(op_pd)) {
   1734       // TODO(nhasabni): can we remove dynamic memory allocation?
   1735       reorder_memory_ = new memory(op_pd);
   1736       return true;
   1737     }
   1738     return false;
   1739   }
   1740 
   1741   /// Function to actually insert reorder primitive in the net
   1742   ///
   1743   /// This function completes remaining part of output reordering. It inserts
   1744   /// a reordering primitive from the temporary buffer that holds the output
   1745   /// to the user-specified output buffer.
   1746   ///
   1747   /// @input: net - net to which to add reorder primitive
   1748   inline void InsertReorderToUserMem(std::vector<primitive>* net) {
   1749     CHECK_NOTNULL(net);
   1750     CHECK_NOTNULL(user_memory_);
   1751     CHECK_NOTNULL(reorder_memory_);
   1752     net->push_back(CreateReorder(reorder_memory_, user_memory_));
   1753   }
   1754 };
   1755 
   1756 #endif  // INTEL_MKL_ML
   1757 
   1758 }  // namespace tensorflow
   1759 #endif  // INTEL_MKL
   1760 #endif  // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
   1761