Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
     17 #define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
     18 
     19 #include <limits>
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/framework/numeric_op.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/tensor_slice.h"
     29 #include "tensorflow/core/kernels/bounds_check.h"
     30 #include "tensorflow/core/kernels/conv_grad_ops.h"
     31 #include "tensorflow/core/kernels/ops_util.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/lib/gtl/array_slice.h"
     34 #include "tensorflow/core/lib/strings/numbers.h"
     35 #include "tensorflow/core/lib/strings/str_util.h"
     36 #include "tensorflow/core/platform/logging.h"
     37 #include "tensorflow/core/platform/macros.h"
     38 #include "tensorflow/core/util/padding.h"
     39 #include "tensorflow/core/util/tensor_format.h"
     40 
     41 #include "tensorflow/core/util/mkl_util.h"
     42 
     43 #ifndef INTEL_MKL_ML
     44 #include "mkldnn.hpp"
     45 
     46 using mkldnn::prop_kind;
     47 using mkldnn::stream;
     48 
     49 using mkldnn::convolution_direct;
     50 using mkldnn::convolution_forward;
     51 #endif
     52 
     53 namespace tensorflow {
     54 
     55 #ifndef INTEL_MKL_ML
     56 
     57 class MklDnnConvUtil {
     58  protected:
     59   OpKernelContext* context_;  // We don't own this.
     60   std::vector<int32> strides_;
     61   Padding padding_;
     62   TensorFormat data_format_;
     63 
     64  public:
     65   MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
     66                  Padding pad, TensorFormat fm)
     67       : context_(context), strides_(strides), padding_(pad), data_format_(fm) {}
     68 
     69   virtual ~MklDnnConvUtil() { context_ = nullptr; }
     70 
     71   // Calculate Convolution strides
     72   virtual inline void GetStridesInMklOrder(memory::dims* strides) {
     73     // For now we take the stride from the second and third dimensions only
     74     // (we do not support striding on the batch or depth dimension).
     75     CHECK_NOTNULL(strides);
     76     int stride_rows = GetTensorDim(strides_, data_format_, 'H');
     77     int stride_cols = GetTensorDim(strides_, data_format_, 'W');
     78     *strides = {stride_rows, stride_cols};
     79   }
     80 
     81   // Calculate Convolution input size in MKL-DNN order. MKL-DNN
     82   // requires input in NCHW format. Function does not return anything.
     83   // But errors arising from sanity checks are returned in context's
     84   // status.
     85   virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
     86                                              memory::dims* input_dims) {
     87 #define CHECK_BOUNDS(val, err_msg)                                     \
     88   do {                                                                 \
     89     OP_REQUIRES(context_,                                              \
     90                 FastBoundsCheck(val, std::numeric_limits<int>::max()), \
     91                 errors::InvalidArgument(err_msg));                     \
     92   } while (0)
     93 
     94     CHECK_NOTNULL(input_dims);
     95 
     96     // Input channel
     97     int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
     98     int input_depth = static_cast<int>(input_depth_raw);
     99 
    100     // Input rows/height
    101     int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
    102     CHECK_BOUNDS(input_rows_raw, "Input rows too large");
    103     int input_rows = static_cast<int>(input_rows_raw);
    104 
    105     // Input columns/width
    106     int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
    107     CHECK_BOUNDS(input_cols_raw, "Input cols too large");
    108     int input_cols = static_cast<int>(input_cols_raw);
    109 
    110     // Input batch
    111     int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
    112     CHECK_BOUNDS(input_batch_raw, "Input batch too large");
    113     int input_batch = static_cast<int>(input_batch_raw);
    114 
    115 #undef CHECK_BOUNDS
    116 
    117     // MKL-DNN always requires input in NCHW format.
    118     std::vector<int> mkldnn_sizes(4, -1);
    119     mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
    120     mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
    121     mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
    122     mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
    123 
    124     *input_dims = mkldnn_sizes;
    125   }
    126 
    127   // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
    128   // requires filter in OIHW format. Function does not return anything.
    129   // But errors arising from sanity checks are returned in context's
    130   // status.
    131   //
    132   // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
    133   // requires filter in OIHW format. Function does not return anything.
    134   // But errors arising from sanity checks are returned in context's
    135   // status. This function differs from GetConvFilterSizeInMklOrder in
    136   // parameter for input - it accepts src_shape since Convolution Backward
    137   // Input gets shape of input tensor rather than actual tensor (Convolution
    138   // forward gets actual tensor as input).
    139   //
    140   // TODO(nhasabni): Add similar function for input and filter in MklShape.
    141   virtual inline void GetFilterSizeInMklOrder(const TensorShape& input_shape,
    142                                               const TensorShape& filter_shape,
    143                                               memory::dims* filter_dims) {
    144     CHECK_NOTNULL(filter_dims);
    145 
    146     OP_REQUIRES(context_, filter_shape.dims() == 4,
    147                 errors::InvalidArgument("filter must be 4-dimensional: ",
    148                                         filter_shape.DebugString()));
    149 
    150     for (int i = 0; i < 3; i++) {
    151       OP_REQUIRES(context_,
    152                   FastBoundsCheck(filter_shape.dim_size(i),
    153                                   std::numeric_limits<int>::max()),
    154                   errors::InvalidArgument("filter too large"));
    155     }
    156 
    157     int input_depth = GetTensorDim(input_shape, data_format_, 'C');
    158 
    159     OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
    160                 errors::InvalidArgument(
    161                     "input and filter must have the same depth: ", input_depth,
    162                     " vs ", filter_shape.dim_size(2)));
    163 
    164     // TF filter is always in (rows, cols, in_depth, out_depth) order.
    165     int filter_rows = static_cast<int>(filter_shape.dim_size(0));
    166     int filter_cols = static_cast<int>(filter_shape.dim_size(1));
    167     int in_depth = static_cast<int>(filter_shape.dim_size(2));
    168     int out_depth = static_cast<int>(filter_shape.dim_size(3));
    169 
    170     // MKL-DNN always needs filter in OIHW format.
    171     // OIHW = (out_depth, in_depth, rows, cols)
    172     std::vector<int> mkldnn_sizes(4, -1);
    173     mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
    174     mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
    175     mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
    176     mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
    177 
    178     *filter_dims = mkldnn_sizes;
    179   }
    180 
    181   // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
    182   // requires filter in OIHW format. Function does not return anything.
    183   // But errors arising from sanity checks are returned in context's
    184   // status.
    185   virtual inline void GetFilterSizeInMklOrder(size_t src_index,
    186                                               size_t filter_index,
    187                                               memory::dims* filter_dims) {
    188     CHECK_NOTNULL(filter_dims);
    189     GetFilterSizeInMklOrder(GetTfShape(context_, src_index),
    190                             GetTfShape(context_, filter_index), filter_dims);
    191   }
    192 
    193   // Calculate Bias size for 2D Convolution. Function does not return
    194   // anything, but sets error in context status.
    195   virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
    196                                             memory::dims* bias_dims) {
    197     const Tensor& bias = MklGetInput(context_, bias_index);
    198     OP_REQUIRES(context_, bias.dims() == 1,
    199                 errors::InvalidArgument("bias must be 1-dimensional: ",
    200                                         bias.shape().DebugString()));
    201 
    202     *bias_dims = {static_cast<int>(bias.dim_size(0))};
    203   }
    204 
    205   // Function to calculate output and padding size for 2D convolution.
    206   //
    207   // Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
    208   // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
    209   // NHWC or NCHW format depending on data format. Function also calculates
    210   // left, right, top and bottom pads. Function does not return any status -
    211   // status is returned via context status.
    212   //
    213   // TODO(nhasabni): Add similar function for input and filter in MklShape.
    214   virtual inline void GetOutputAndPadSizeInMklOrder(
    215       const TensorShape& input_shape, const TensorShape& filter_shape,
    216       const memory::dims& strides, memory::dims* output_dims_tf_order,
    217       memory::dims* output_dims_mkl_order, memory::dims* pad_l,
    218       memory::dims* pad_r) {
    219     CHECK_NOTNULL(output_dims_tf_order);
    220     CHECK_NOTNULL(output_dims_mkl_order);
    221     CHECK_NOTNULL(pad_l);
    222     CHECK_NOTNULL(pad_r);
    223 
    224     int input_rows = GetTensorDim(input_shape, data_format_, 'H');
    225     int input_cols = GetTensorDim(input_shape, data_format_, 'W');
    226 
    227     // The first dimension for filter is rows/height.
    228     int filter_rows = filter_shape.dim_size(0);
    229     // The second dimension for filter is cols/width.
    230     int filter_cols = filter_shape.dim_size(1);
    231 
    232     // Stride is vector of 2 elements: {s_r, s_c}
    233     int stride_rows = strides[0];
    234     int stride_cols = strides[1];
    235 
    236     // Output batch is same as input batch.
    237     int out_batch = GetTensorDim(input_shape, data_format_, 'N');
    238     // Output depth is same as last dimension for filter.
    239     int out_depth = filter_shape.dim_size(3);
    240 
    241     int64 out_rows = 0, out_cols = 0;
    242     int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
    243 
    244     OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
    245                                  input_rows, filter_rows, stride_rows, padding_,
    246                                  &out_rows, &pad_top, &pad_bottom));
    247     OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
    248                                  input_cols, filter_cols, stride_cols, padding_,
    249                                  &out_cols, &pad_left, &pad_right));
    250 
    251     // Tensorflow output is in data_format order. (NHWC or NCHW)
    252     TensorShape out_shape =
    253         ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
    254     *output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
    255 
    256     // MKL-DNN always needs output in NCHW format.
    257     std::vector<int> mkldnn_sizes(4, -1);
    258     mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
    259     mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
    260     mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
    261     mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
    262     *output_dims_mkl_order = mkldnn_sizes;
    263 
    264     // Now handle padding. MKL-DNN uses asymetric padding.
    265     *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
    266     *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
    267   }
    268 
    269   // Calculate output and pad size of forward Convolution operator.
    270   // See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
    271   //
    272   // Function does not return anything, but sets error in context status.
    273   inline void GetOutputAndPadSizeInMklOrder(
    274       size_t src_index, size_t filter_index, const memory::dims& strides,
    275       memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
    276       memory::dims* pad_l, memory::dims* pad_r) {
    277     CHECK_NOTNULL(output_dims_tf_order);
    278     CHECK_NOTNULL(output_dims_mkl_order);
    279     CHECK_NOTNULL(pad_l);
    280     CHECK_NOTNULL(pad_r);
    281 
    282     auto input_tf_shape = GetTfShape(context_, src_index);
    283     auto filter_tf_shape = GetTfShape(context_, filter_index);
    284 
    285     OP_REQUIRES(context_, input_tf_shape.dims() == 4,
    286                 errors::InvalidArgument("input must be 4-dimensional",
    287                                         input_tf_shape.DebugString()));
    288 
    289     GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape, strides,
    290                                   output_dims_tf_order, output_dims_mkl_order,
    291                                   pad_l, pad_r);
    292   }
    293 
    294   // Wrapper function to calculate input, filter, and output sizes of
    295   // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
    296   // Function also calculates output shape in Tensorflow order. Additionally, it
    297   // also calculates strides and paddings for 2D Convolution.
    298   //
    299   // Function does not return anything, but sets error in context status.
    300   inline void GetConvFwdSizesInMklOrder(
    301       const TensorShape& input_shape, const TensorShape& filter_shape,
    302       memory::dims* input_dims, memory::dims* filter_dims,
    303       memory::dims* strides, memory::dims* output_dims_tf_order,
    304       memory::dims* output_dims_mkl_order, memory::dims* pad_l,
    305       memory::dims* pad_r) {
    306     CHECK_NOTNULL(input_dims);
    307     CHECK_NOTNULL(filter_dims);
    308     CHECK_NOTNULL(strides);
    309     CHECK_NOTNULL(output_dims_tf_order);
    310     CHECK_NOTNULL(output_dims_mkl_order);
    311     CHECK_NOTNULL(pad_l);
    312     CHECK_NOTNULL(pad_r);
    313 
    314     GetInputSizeInMklOrder(input_shape, input_dims);
    315     if (!context_->status().ok()) return;
    316     GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims);
    317     if (!context_->status().ok()) return;
    318     GetStridesInMklOrder(strides);
    319     GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides,
    320                                   output_dims_tf_order, output_dims_mkl_order,
    321                                   pad_l, pad_r);
    322     if (!context_->status().ok()) return;
    323   }
    324 };
    325 
    326 /////////////////////////////////////////////////////////////////////
    327 ///  Common class that implements Conv2DBackpropFilter and Input
    328 /////////////////////////////////////////////////////////////////////
    329 
    330 template <typename Device, class T>
    331 class MklConv2DBackpropCommonOp : public OpKernel {
    332  public:
    333   ~MklConv2DBackpropCommonOp() {}
    334   explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context)
    335       : OpKernel(context) {
    336     string data_format_str;
    337     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
    338     OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
    339                 errors::InvalidArgument("Invalid data format"));
    340     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    341     int stride_n = GetTensorDim(strides_, data_format_, 'N');
    342     int stride_c = GetTensorDim(strides_, data_format_, 'C');
    343     OP_REQUIRES(
    344         context, (stride_n == 1 && stride_c == 1),
    345         errors::InvalidArgument("Current implementation does not yet support "
    346                                 "strides in the batch and depth dimensions."));
    347 
    348     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    349   }
    350 
    351   void Compute(OpKernelContext* context) override {
    352     try {
    353       auto cpu_engine = engine(engine::cpu, 0);
    354 
    355       // Prepare common tensors for Conv2DBackpropInput and
    356       // Conv2DBackpropFilter.
    357       MklDnnData<T> input(&cpu_engine);
    358       MklDnnData<T> filter(&cpu_engine);
    359       MklDnnData<T> outbackprop(&cpu_engine);
    360       MklDnnData<T> output(&cpu_engine);
    361 
    362       // Input tensors
    363       const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
    364       const Tensor& input_tensor = MklGetInput(context, kInputIdx);
    365       const Tensor& filter_tensor = MklGetInput(context, kFilterIdx);
    366       const Tensor& outbprop_tensor = MklGetInput(context, kOutbpropIdx);
    367 
    368       MklDnnShape input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape;
    369       GetMklShape(context, kInputIdx, &input_mkl_shape);
    370       GetMklShape(context, kFilterIdx, &filter_mkl_shape);
    371       GetMklShape(context, kOutbpropIdx, &outbprop_mkl_shape);
    372       // Allow operator-specific sanity checking of shapes.
    373       ValidateMklShapes(input_mkl_shape, filter_mkl_shape, outbprop_mkl_shape);
    374 
    375       // Allow operator-specific generation of shapes.
    376       // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
    377       // tensor containing shape of filter. So filter.shape() is not
    378       // a correct way to get filter shape. These operator-specific calls
    379       // allow this class to handle this case.
    380       TensorShape input_tf_shape = MakeInputTfShape(context, input_tensor);
    381       TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor);
    382       TensorShape outbprop_tf_shape = GetTfShape(context, kOutbpropIdx);
    383 
    384       // Corner cases: output with 0 elements and 0 batch size.
    385       Tensor* output_tensor = nullptr;
    386       if (input_tf_shape.num_elements() == 0 ||
    387           filter_tf_shape.num_elements() == 0 ||
    388           outbprop_tf_shape.num_elements() == 0) {
    389         MklDnnShape output_mkl_shape;
    390         output_mkl_shape.SetMklTensor(false);
    391         TensorShape output_tf_shape = GetOutputTfShape(
    392             input_tf_shape, filter_tf_shape, outbprop_tf_shape);
    393         const int kOutputIdx = 0;
    394         AllocateOutputSetMklShape(context, kOutputIdx, &output_tensor,
    395                                   output_tf_shape, output_mkl_shape);
    396         CHECK_NOTNULL(output_tensor);
    397 
    398         // if output tensor has more than 0 elements, we need to 0 them out.
    399         for (size_t i = 0; i < output_tf_shape.num_elements(); ++i) {
    400           output_tensor->flat<T>().data()[i] = 0;
    401         }
    402 
    403         return;
    404       }
    405 
    406       // By default, all dims are in MKL order. Only dims in TF order
    407       // are those with prefix tf_order.
    408       memory::dims outbprop_dims, fwd_input_dims, fwd_filter_dims;
    409       memory::dims padding_l, padding_r, strides, fwd_output_dims;
    410       memory::dims fwd_output_dims_tf_order;
    411 
    412       // Get forward convolution parameters.
    413       MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
    414       conv_utl.GetConvFwdSizesInMklOrder(
    415           input_tf_shape, filter_tf_shape, &fwd_input_dims, &fwd_filter_dims,
    416           &strides, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_l,
    417           &padding_r);
    418       if (!context->status().ok()) return;
    419 
    420       // Create Convolution forward descriptor since Convolution backward
    421       // API needs it. For that, we first need to create input, filter
    422       // and output memory descriptors.
    423       auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
    424       // If input is in MKL layout, then simply grab input layout; otherwise,
    425       // construct input TF layout. For TF layout, although input shape
    426       // required is in MKL-DNN order, the layout is Tensorflow's layout
    427       // (NHWC or NCHW depending on data format).
    428       auto fwd_input_md =
    429           input_mkl_shape.IsMklTensor()
    430               ? input_mkl_shape.GetMklLayout()
    431               : memory::desc(fwd_input_dims, MklDnnType<T>(), tf_fmt);
    432       // If filter is in MKL layout, then simply grab filter layout; otherwise
    433       // construct filter in TF layout. For TF layout, filter is in HWIO format.
    434       auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
    435                                ? filter_mkl_shape.GetMklLayout()
    436                                : memory::desc(fwd_filter_dims, MklDnnType<T>(),
    437                                               memory::format::hwio);
    438       // Tensorflow Output of Conv2D is in data_format order.
    439       auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(), tf_fmt);
    440       auto fwd_desc = convolution_forward::desc(
    441           prop_kind::forward, convolution_direct, fwd_input_md, fwd_filter_md,
    442           fwd_out_md, strides, padding_l, padding_r,
    443           TFPaddingToMklDnnPadding(padding_));
    444       auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
    445 
    446       // Create memory for user data. Describe how the inputs and outputs of
    447       // Convolution look like. Also specify buffers containing actual input
    448       // and output data.
    449 
    450       // Since this is a common class for both Conv2DBackpropFilter and
    451       // Conv2DBackpropInput, we skip SetUsrMem call for input tensor (for
    452       // Conv2DBackpropInput) and for filter tensor (for
    453       // conv2DBackpropFilter) depending on which tensor is int32 type.
    454       size_t input_with_sizes = GetInputTensorIndexWithSizes();
    455       if (input_with_sizes != kInputIdx) {
    456         // Shape of Conv2DBackpropFilter's input is same as Conv2D input.
    457         input.SetUsrMem(fwd_input_md, &input_tensor);
    458       } else if (input_with_sizes != kFilterIdx) {
    459         // Shape of Conv2DBackpropInput's filter is same as Conv2D filter.
    460         filter.SetUsrMem(fwd_filter_md, &filter_tensor);
    461       }
    462 
    463       conv_utl.GetInputSizeInMklOrder(outbprop_tf_shape, &outbprop_dims);
    464       if (!context->status().ok()) return;
    465       if (outbprop_mkl_shape.IsMklTensor()) {
    466         // If outbackprop is in Mkl layout, then simply grab it.
    467         auto outbprop_md = outbprop_mkl_shape.GetMklLayout();
    468         outbackprop.SetUsrMem(outbprop_md, &outbprop_tensor);
    469       } else {
    470         // If outbackprop is in TensorFlow layout, then we need to create memory
    471         // descriptor for it. Outbackprop shape is data format order.
    472         outbackprop.SetUsrMem(outbprop_dims, tf_fmt, &outbprop_tensor);
    473       }
    474 
    475       // Operator specific call to get output shape and data_format.
    476       auto bwd_output_dims = GetOutputDims(fwd_input_dims, fwd_filter_dims);
    477       auto bwd_output_format = GetOutputFormat(tf_fmt);
    478       output.SetUsrMem(bwd_output_dims, bwd_output_format);
    479 
    480       // Create memory descriptors for convolution data w/ no specified format.
    481       input.SetOpMemDesc(fwd_input_dims, memory::format::any);
    482       filter.SetOpMemDesc(fwd_filter_dims, memory::format::any);
    483       outbackprop.SetOpMemDesc(outbprop_dims, memory::format::any);
    484       output.SetOpMemDesc(bwd_output_dims, memory::format::any);
    485 
    486       // Operator-specific call to create and execute primitive.
    487       CreatePrimitive(context, cpu_engine, fwd_pd, &input, &filter,
    488                       &outbackprop, &output, &output_tensor, strides, padding_l,
    489                       padding_r, TFPaddingToMklDnnPadding(padding_),
    490                       bwd_output_dims, bwd_output_format);
    491     } catch (mkldnn::error& e) {
    492       string error_msg = "Status: " + std::to_string(e.status) +
    493                          ", message: " + string(e.message) + ", in file " +
    494                          string(__FILE__) + ":" + std::to_string(__LINE__);
    495       OP_REQUIRES_OK(
    496           context,
    497           errors::Aborted("Operation received an exception:", error_msg));
    498     }
    499   }
    500 
    501   /// Pure virtual function to allow operator to check for validity of input
    502   /// shapes. Function asserts that input shapes are valid.
    503   virtual void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
    504                                  const MklDnnShape& filter_mkl_shape,
    505                                  const MklDnnShape& outbprop_mkl_shape) = 0;
    506 
    507   /// Operator-specific function that returns index of input that is
    508   /// representing input sizes. For Conv2DBackpropFilter it returns 1 since
    509   /// filter for this operator is filter shape. For Conv2DBackpropInput it
    510   /// returns 0 (for input).
    511   virtual size_t GetInputTensorIndexWithSizes() = 0;
    512 
    513   /// Get TensorFlow shape of input tensor.
    514   virtual TensorShape MakeInputTfShape(OpKernelContext* context,
    515                                        const Tensor& input_tensor) = 0;
    516 
    517   /// Get TensorFlow shape of filter tensor.
    518   virtual TensorShape MakeFilterTfShape(OpKernelContext* context,
    519                                         const Tensor& filter_tensor) = 0;
    520 
    521   /// Get the TensorFlow shape of output tensor.
    522   virtual TensorShape GetOutputTfShape(const TensorShape& input_shape,
    523                                        const TensorShape& filter_shape,
    524                                        const TensorShape& outbprop_shape) = 0;
    525 
    526   /// Get shape of output in MKL-DNN order. Computes shape of output from
    527   /// input shape (fwd_input_dims) and filter shape (fwd_filter_dims).
    528   virtual const memory::dims& GetOutputDims(
    529       const memory::dims& fwd_input_dims,
    530       const memory::dims& fwd_filter_dims) = 0;
    531 
    532   /// Get data_format of output in MKL-DNN order. If output data format is
    533   /// same as input data format, then it simply returns value of data_format
    534   /// parameter as it is.
    535   virtual memory::format GetOutputFormat(const memory::format data_format) = 0;
    536 
    537   /// Create and execute the primitive storing output in the output_tensor.
    538   virtual void CreatePrimitive(
    539       OpKernelContext* context, const engine& cpu_engine,
    540       const convolution_forward::primitive_desc& conv_fwd_pd,
    541       MklDnnData<T>* input, MklDnnData<T>* filter, MklDnnData<T>* outbackprop,
    542       MklDnnData<T>* output, Tensor** output_tensor,
    543       const memory::dims& strides, const memory::dims& padding_l,
    544       const memory::dims& padding_r, padding_kind padding,
    545       const memory::dims& bwd_output_dims,
    546       memory::format bwd_output_format) = 0;
    547 
    548   // Get the data_format {NCHW, NHWC}
    549   TensorFormat GetTFDataFormat() { return data_format_; }
    550 
    551  private:
    552   std::vector<int32> strides_;
    553   Padding padding_;
    554   TensorFormat data_format_;
    555 };
    556 #endif  // INTEL_MKL_ML
    557 
    558 /////////////////////////////////////////////////////////////////////
    559 ///  Dummy Mkl op that is just used for operators that are intermediate
    560 ///  output of node fusion in the graph
    561 /////////////////////////////////////////////////////////////////////
    562 
    563 template <typename Device, typename T>
    564 class MklDummyOp : public OpKernel {
    565  public:
    566   ~MklDummyOp() {}
    567 
    568   explicit MklDummyOp(OpKernelConstruction* context) : OpKernel(context) {}
    569 
    570   void Compute(OpKernelContext* context) override {
    571     TF_CHECK_OK(
    572         errors::Unimplemented("This is a dummy op."
    573                               "It should not have been invoked."));
    574   }
    575 };
    576 
    577 }  // namespace tensorflow
    578 
    579 #endif  // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
    580