Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 #ifdef INTEL_MKL
     18 
     19 #include <string.h>
     20 #include <map>
     21 #include <string>
     22 #include <vector>
     23 
     24 #include "tensorflow/core/framework/numeric_op.h"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_shape.h"
     29 #include "tensorflow/core/framework/tensor_slice.h"
     30 #include "tensorflow/core/kernels/bounds_check.h"
     31 #include "tensorflow/core/kernels/mkl_conv_ops.h"
     32 #include "tensorflow/core/kernels/ops_util.h"
     33 #include "tensorflow/core/lib/core/errors.h"
     34 #include "tensorflow/core/lib/gtl/array_slice.h"
     35 #include "tensorflow/core/lib/strings/numbers.h"
     36 #include "tensorflow/core/lib/strings/str_util.h"
     37 #include "tensorflow/core/platform/logging.h"
     38 #include "tensorflow/core/platform/macros.h"
     39 #include "tensorflow/core/util/padding.h"
     40 #include "tensorflow/core/util/tensor_format.h"
     41 
     42 #include "tensorflow/core/util/mkl_util.h"
     43 
     44 #ifndef INTEL_MKL_ML
     45 
     46 #include "mkldnn.hpp"
     47 
     48 using mkldnn::prop_kind;
     49 using mkldnn::stream;
     50 
     51 using mkldnn::convolution_direct;
     52 using mkldnn::convolution_forward;
     53 #else
     54 #include "mkl_dnn.h"
     55 #include "mkl_dnn_types.h"
     56 #endif
     57 
     58 namespace tensorflow {
     59 
     60 typedef Eigen::ThreadPoolDevice CPUDevice;
     61 
     62 // MKL-DNN is now default. MKL-ML must be specified explicitly.
     63 #ifdef INTEL_MKL_ML
     64 
     65 template <typename Device, typename T, bool biasEnabled>
     66 class MklConv2DOp : public OpKernel {
     67  public:
     68   ~MklConv2DOp() {}
     69 
     70   explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
     71     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
     72     string data_format;
     73     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     74     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     75                 errors::InvalidArgument("Invalid data format"));
     76     OP_REQUIRES(context, strides_.size() == 4,
     77                 errors::InvalidArgument("Sliding window strides field must "
     78                                         "specify 4 dimensions"));
     79 
     80     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
     81     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
     82     OP_REQUIRES(
     83         context, stride_n == 1 && stride_c == 1,
     84         errors::InvalidArgument("Current implementation does not yet support "
     85                                 "strides in the batch and depth dimensions."));
     86     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     87   }
     88 
     89   void Compute(OpKernelContext* context) override {
     90     MklConv2DOpContext mkl_context;
     91     const Tensor& input = MklGetInput(context, 0);
     92     GetMklShape(context, 0, &(mkl_context.input_shape));
     93     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
     94 
     95     const Tensor& filter = MklGetInput(context, 1);
     96     MklShape mkl_filter_shape;
     97     GetMklShape(context, 1, &mkl_filter_shape);
     98     CHECK(!mkl_filter_shape.IsMklTensor())
     99         << "Conv filter should not be in MKL Layout";
    100 
    101     if (biasEnabled) {
    102       const Tensor& bias = MklGetInput(context, 2);
    103       OP_REQUIRES(context, bias.dims() == 1,
    104                   errors::InvalidArgument("bias must be 1-dimensional: ",
    105                                           bias.shape().DebugString()));
    106     }
    107 
    108     if (!input_in_mkl_format) {
    109       OP_REQUIRES(context, input.dims() == 4,
    110                   errors::InvalidArgument("input must be 4-dimensional",
    111                                           input.shape().DebugString()));
    112     }
    113 
    114     OP_REQUIRES(context, filter.dims() == 4,
    115                 errors::InvalidArgument("filter must be 4-dimensional: ",
    116                                         filter.shape().DebugString()));
    117 
    118     for (int i = 0; i < 3; i++) {
    119       OP_REQUIRES(
    120           context,
    121           FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
    122           errors::InvalidArgument("filter too large"));
    123     }
    124 
    125     const int64 input_depth =
    126         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
    127                             : GetTensorDim(input, data_format_, 'C');
    128     OP_REQUIRES(context, input_depth == filter.dim_size(2),
    129                 errors::InvalidArgument(
    130                     "input and filter must have the same depth: ", input_depth,
    131                     " vs ", filter.dim_size(2)));
    132     // The last dimension for filter is out_depth.
    133     const int out_depth = static_cast<int>(filter.dim_size(3));
    134 
    135     // The second dimension for input is rows/height.
    136     // The first dimension for filter is rows/height.
    137     const int64 input_rows_raw =
    138         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
    139                             : GetTensorDim(input, data_format_, 'H');
    140     OP_REQUIRES(
    141         context,
    142         FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
    143         errors::InvalidArgument("Input rows too large"));
    144     const int input_rows = static_cast<int>(input_rows_raw);
    145     const int filter_rows = static_cast<int>(filter.dim_size(0));
    146 
    147     // The third dimension for input is columns/width.
    148     // The second dimension for filter is columns/width.
    149     const int64 input_cols_raw =
    150         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
    151                             : GetTensorDim(input, data_format_, 'W');
    152     OP_REQUIRES(
    153         context,
    154         FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
    155         errors::InvalidArgument("Input cols too large"));
    156     const int input_cols = static_cast<int>(input_cols_raw);
    157     const int filter_cols = static_cast<int>(filter.dim_size(1));
    158 
    159     // The first dimension for input is batch.
    160     const int64 input_batch_raw =
    161         input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
    162                             : GetTensorDim(input, data_format_, 'N');
    163     OP_REQUIRES(
    164         context,
    165         FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()),
    166         errors::InvalidArgument("batch is too large"));
    167     const int batch = static_cast<int>(input_batch_raw);
    168 
    169     // For now we take the stride from the second and third dimensions only (we
    170     // do not support striding on the batch or depth dimension).
    171     const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
    172     const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
    173 
    174     int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
    175     OP_REQUIRES_OK(context,
    176                    GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
    177                                          padding_, &out_rows, &pad_rows));
    178     OP_REQUIRES_OK(context,
    179                    GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
    180                                          padding_, &out_cols, &pad_cols));
    181     TensorShape out_shape =
    182         ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
    183 
    184     // Output tensor is of the following dimensions:
    185     // [ in_batch, out_rows, out_cols, out_depth ]
    186     Tensor* output = nullptr;
    187 
    188     // If there is nothing to compute, return.
    189     if (out_shape.num_elements() == 0) {
    190       // Nothing to do, allocate output tensor and return
    191       MklShape mkl_output_mkl_shape;
    192       mkl_output_mkl_shape.SetMklTensor(false);
    193       AllocateOutputSetMklShape(context, 0, &output, input.shape(),
    194                                 mkl_output_mkl_shape);
    195       return;
    196     }
    197 
    198     if (batch == 0) {
    199       // Nothing to do, allocate output tensor and return
    200       MklShape mkl_output_mkl_shape;
    201       mkl_output_mkl_shape.SetMklTensor(false);
    202       AllocateOutputSetMklShape(context, 0, &output, input.shape(),
    203                                 mkl_output_mkl_shape);
    204       return;
    205     }
    206 
    207     // Create MKL convolution primitives
    208     mkl_context.in_dims = input_in_mkl_format
    209                               ? mkl_context.input_shape.GetDimension()
    210                               : input.dims();
    211     mkl_context.filter_dims = filter.dims();
    212 
    213     mkl_context.in_sizes[MklDims::W] = static_cast<size_t>(input_cols);
    214     mkl_context.in_sizes[MklDims::H] = static_cast<size_t>(input_rows);
    215     mkl_context.in_sizes[MklDims::C] = static_cast<size_t>(input_depth);
    216     mkl_context.in_sizes[MklDims::N] = static_cast<size_t>(batch);
    217 
    218     mkl_context.out_sizes[MklDims::W] = static_cast<size_t>(out_cols);
    219     mkl_context.out_sizes[MklDims::H] = static_cast<size_t>(out_rows);
    220     mkl_context.out_sizes[MklDims::C] = static_cast<size_t>(out_depth);
    221     mkl_context.out_sizes[MklDims::N] = static_cast<size_t>(batch);
    222 
    223     mkl_context.input_offset[0] = static_cast<int>(-pad_cols);
    224     mkl_context.input_offset[1] = static_cast<int>(-pad_rows);
    225 
    226     mkl_context.conv_stride[0] = static_cast<size_t>(stride_cols);
    227     mkl_context.conv_stride[1] = static_cast<size_t>(stride_rows);
    228 
    229     GetStridesFromSizes(data_format_, mkl_context.out_strides,
    230                         mkl_context.out_sizes);
    231     GetStridesFromSizes(data_format_, mkl_context.in_strides,
    232                         mkl_context.in_sizes);
    233 
    234     // TF filter dimension order (out_depth, in_depth, cols, rows) ->
    235     // MKL filter dimension order (out_depth, in_depth, rows, cols)
    236     mkl_context.filter_sizes[0] = filter.dim_size(1);  // cols
    237     mkl_context.filter_sizes[1] = filter.dim_size(0);  // rows
    238     mkl_context.filter_sizes[2] = filter.dim_size(2);  // in_depth
    239     mkl_context.filter_sizes[3] = filter.dim_size(3);  // out_depth
    240 
    241     // TF filter layout - (rows, cols, in_depth, out_depth)
    242     mkl_context.filter_strides[0] =
    243         filter.dim_size(2) * filter.dim_size(3);  // cols
    244     mkl_context.filter_strides[1] =
    245         filter.dim_size(1) * filter.dim_size(2) * filter.dim_size(3);  // rows
    246     mkl_context.filter_strides[2] = filter.dim_size(3);  // in_depth
    247     mkl_context.filter_strides[3] = 1;                   // out_depth
    248 
    249     if (biasEnabled) {
    250       const Tensor& bias = MklGetInput(context, 2);
    251       mkl_context.bias_sizes[0] = {static_cast<size_t>(bias.dim_size(0))};
    252       mkl_context.bias_strides[0] = {1};
    253     }
    254 
    255     // Create Convolution Primitive
    256     if (biasEnabled) {
    257       CHECK_EQ(
    258           dnnConvolutionCreateForwardBias_F32(
    259               &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect,
    260               mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes,
    261               mkl_context.filter_sizes, mkl_context.conv_stride,
    262               mkl_context.input_offset, dnnBorderZeros),
    263           E_SUCCESS);
    264     } else {
    265       CHECK_EQ(
    266           dnnConvolutionCreateForward_F32(
    267               &mkl_context.prim_fwd, nullptr, dnnAlgorithmConvolutionDirect,
    268               mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes,
    269               mkl_context.filter_sizes, mkl_context.conv_stride,
    270               mkl_context.input_offset, dnnBorderZeros),
    271           E_SUCCESS);
    272     }
    273 
    274     TensorShape mkl_output_tf_shape;
    275     MklShape mkl_output_mkl_shape;
    276     mkl_output_mkl_shape.SetMklTensor(true);
    277     mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd, dnnResourceDst);
    278     mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes,
    279                                      mkl_context.out_strides);
    280     // MKL might change the dimension ordering
    281     // Create mapping to recover the original TF dimension order
    282     mkl_output_mkl_shape.SetTfDimOrder(mkl_context.in_dims, data_format_);
    283 
    284     mkl_output_tf_shape.AddDim(
    285         dnnLayoutGetMemorySize_F32(
    286             static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
    287         sizeof(T));
    288     AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
    289                               mkl_output_mkl_shape);
    290     // Filter output to be used in the backprop_input
    291     TensorShape mkl_filter_output_tf_shape;
    292     MklShape mkl_filter_output_mkl_shape;
    293     mkl_filter_output_mkl_shape.SetMklTensor(true);
    294     mkl_filter_output_mkl_shape.SetMklLayout(mkl_context.prim_fwd,
    295                                              dnnResourceFilter);
    296 
    297     size_t filter_sizes[4] = {filter.dim_size(0), filter.dim_size(1),
    298                               filter.dim_size(2), filter.dim_size(3)};
    299     mkl_filter_output_mkl_shape.SetTfLayout(filter.dims(), filter_sizes,
    300                                             mkl_context.filter_strides);
    301 
    302     mkl_filter_output_mkl_shape.SetTfDimOrder(mkl_context.filter_dims,
    303                                               data_format_);
    304     mkl_filter_output_tf_shape.AddDim(
    305         dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    306             mkl_filter_output_mkl_shape.GetMklLayout())) /
    307         sizeof(T));
    308     AllocateOutputSetMklShape(context, 1, &mkl_context.output_filter,
    309                               mkl_filter_output_tf_shape,
    310                               mkl_filter_output_mkl_shape);
    311 
    312     mkl_context.conv_res[dnnResourceDst] =
    313         static_cast<void*>(output->flat<T>().data());
    314 
    315     mkl_context.MklCreateInputLayouts(context);
    316 
    317     // Temp tensor used to allocate tmp buffers
    318     Tensor mkl_tmp_input_buf_tensor, mkl_tmp_filter_buf_tensor,
    319         mkl_tmp_bias_buf_tensor;
    320     mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf_tensor,
    321                                             &mkl_tmp_filter_buf_tensor,
    322                                             &mkl_tmp_bias_buf_tensor);
    323 
    324     // Execute convolution
    325     CHECK_EQ(dnnExecute_F32(mkl_context.prim_fwd, mkl_context.conv_res),
    326              E_SUCCESS);
    327 
    328     mkl_context.MklCleanup();
    329   }
    330 
    331  private:
    332   typedef struct {
    333     int in_dims;
    334     size_t in_sizes[4];
    335     size_t in_strides[4];
    336     size_t out_sizes[4];
    337     size_t out_strides[4];
    338     int filter_dims;
    339     size_t filter_sizes[4];
    340     size_t filter_strides[4];
    341     size_t bias_sizes[1];
    342     size_t bias_strides[1];
    343     int input_offset[2];
    344     size_t conv_stride[2];
    345     MklShape input_shape;
    346     dnnPrimitive_t prim_fwd;
    347     void* conv_res[dnnResourceNumber];
    348     dnnLayout_t lt_filter, lt_bias, lt_input;
    349     Tensor* output_filter = nullptr;
    350 
    351     // Create MKL dnnLayout_t objects for tensors coming into the layer
    352     void MklCreateInputLayouts(OpKernelContext* context) {
    353       bool input_in_mkl_format = input_shape.IsMklTensor();
    354       if (input_in_mkl_format) {
    355         lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
    356       } else {
    357         CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
    358                  E_SUCCESS);
    359       }
    360 
    361       CHECK_EQ(dnnLayoutCreate_F32(&lt_filter, filter_dims, filter_sizes,
    362                                    filter_strides),
    363                E_SUCCESS);
    364 
    365       if (biasEnabled) {
    366         CHECK_EQ(dnnLayoutCreate_F32(&lt_bias, 1, bias_sizes, bias_strides),
    367                  E_SUCCESS);
    368       }
    369     }
    370 
    371     // Compare incoming tensor layouts with MKL preferred layouts and convert
    372     // data to the preferred layout if necessary
    373     void MklPrepareConvolutionInputs(OpKernelContext* context,
    374                                      Tensor* mkl_tmp_input_buf_tensor,
    375                                      Tensor* mkl_tmp_filter_buf_tensor,
    376                                      Tensor* mkl_tmp_bias_buf_tensor) {
    377       bool mkl_convert_input, mkl_convert_filter, mkl_convert_bias;
    378       dnnPrimitive_t mkl_prim_convert_filter, mkl_prim_convert_bias,
    379           mkl_prim_convert_input;
    380       dnnLayout_t mkl_lt_internal_filter, mkl_lt_internal_bias,
    381           mkl_lt_internal_input;
    382       void *mkl_buf_convert_input, *mkl_buf_convert_filter,
    383           *mkl_buf_convert_bias;
    384       mkl_prim_convert_filter = nullptr;
    385       mkl_prim_convert_bias = nullptr;
    386       mkl_prim_convert_input = nullptr;
    387       mkl_lt_internal_filter = nullptr;
    388       mkl_lt_internal_bias = nullptr;
    389       mkl_lt_internal_input = nullptr;
    390       mkl_buf_convert_input = nullptr;
    391       mkl_buf_convert_filter = nullptr;
    392       mkl_buf_convert_bias = nullptr;
    393 
    394       // Compare with internal layouts and convert if needed
    395       const Tensor& input = MklGetInput(context, 0);
    396       void* mkl_buf_input =
    397           const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
    398       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
    399                                                 prim_fwd, dnnResourceSrc),
    400                E_SUCCESS);
    401       mkl_convert_input =
    402           !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
    403       if (mkl_convert_input) {
    404         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
    405                                          mkl_lt_internal_input),
    406                  E_SUCCESS);
    407         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
    408                        &mkl_buf_convert_input);
    409         CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
    410                                           mkl_buf_convert_input),
    411                  E_SUCCESS);
    412         dnnDelete_F32(mkl_prim_convert_input);
    413       }
    414       dnnLayoutDelete_F32(mkl_lt_internal_input);
    415 
    416       conv_res[dnnResourceSrc] =
    417           (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
    418 
    419       const Tensor& filter = MklGetInput(context, 1);
    420       void* mkl_buf_filter =
    421           const_cast<void*>(static_cast<const void*>(filter.flat<T>().data()));
    422       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_filter,
    423                                                 prim_fwd, dnnResourceFilter),
    424                E_SUCCESS);
    425       mkl_convert_filter =
    426           !dnnLayoutCompare_F32(mkl_lt_internal_filter, lt_filter);
    427       if (mkl_convert_filter) {
    428         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_filter, lt_filter,
    429                                          mkl_lt_internal_filter),
    430                  E_SUCCESS);
    431 
    432         mkl_buf_convert_filter = const_cast<void*>(
    433             static_cast<const void*>(output_filter->flat<T>().data()));
    434 
    435         CHECK_EQ(
    436             dnnConversionExecute_F32(mkl_prim_convert_filter, mkl_buf_filter,
    437                                      mkl_buf_convert_filter),
    438             E_SUCCESS);
    439         dnnDelete_F32(mkl_prim_convert_filter);
    440       }
    441       dnnLayoutDelete_F32(mkl_lt_internal_filter);
    442 
    443       conv_res[dnnResourceFilter] =
    444           (mkl_convert_filter) ? mkl_buf_convert_filter : mkl_buf_filter;
    445 
    446       if (biasEnabled) {
    447         const Tensor& bias = MklGetInput(context, 2);
    448         void* mkl_buf_bias =
    449             const_cast<void*>(static_cast<const void*>(bias.flat<T>().data()));
    450         CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_bias,
    451                                                   prim_fwd, dnnResourceBias),
    452                  E_SUCCESS);
    453         mkl_convert_bias = !dnnLayoutCompare_F32(mkl_lt_internal_bias, lt_bias);
    454         if (mkl_convert_bias) {
    455           CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_bias, lt_bias,
    456                                            mkl_lt_internal_bias),
    457                    E_SUCCESS);
    458           AllocTmpBuffer(context, mkl_tmp_bias_buf_tensor, mkl_lt_internal_bias,
    459                          &mkl_buf_convert_bias);
    460           CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_bias, mkl_buf_bias,
    461                                             mkl_buf_convert_bias),
    462                    E_SUCCESS);
    463           dnnDelete_F32(mkl_prim_convert_bias);
    464         }
    465         dnnLayoutDelete_F32(mkl_lt_internal_bias);
    466 
    467         conv_res[dnnResourceBias] =
    468             (mkl_convert_bias) ? mkl_buf_convert_bias : mkl_buf_bias;
    469       }
    470     }
    471 
    472     void MklCleanup() {
    473       bool input_in_mkl_format = input_shape.IsMklTensor();
    474       dnnDelete_F32(prim_fwd);
    475       if (!input_in_mkl_format) dnnLayoutDelete_F32(lt_input);
    476       dnnLayoutDelete_F32(lt_filter);
    477       if (biasEnabled) dnnLayoutDelete_F32(lt_bias);
    478     }
    479   } MklConv2DOpContext;
    480 
    481   std::vector<int32> strides_;
    482   Padding padding_;
    483   TensorFormat data_format_;
    484 };
    485 
    486 #else
    487 
    488 template <typename Device, typename T, bool biasEnabled>
    489 class MklConv2DOp : public OpKernel {
    490  public:
    491   ~MklConv2DOp() {}
    492 
    493   explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
    494     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
    495     string data_format;
    496     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    497     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
    498                 errors::InvalidArgument("Invalid data format"));
    499     OP_REQUIRES(context, strides_.size() == 4,
    500                 errors::InvalidArgument("Sliding window strides field must "
    501                                         "specify 4 dimensions"));
    502 
    503     const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
    504     const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
    505     OP_REQUIRES(
    506         context, stride_n == 1 && stride_c == 1,
    507         errors::InvalidArgument("Current implementation does not yet support "
    508                                 "strides in the batch and depth dimensions."));
    509     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
    510   }
    511 
    512   void Compute(OpKernelContext* context) override {
    513     try {
    514       auto cpu_engine = engine(engine::cpu, 0);
    515 
    516       // Input tensors
    517       const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src);
    518       const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
    519 
    520       MklDnnShape src_mkl_shape, filter_mkl_shape;
    521       GetMklShape(context, kInputIndex_Src, &src_mkl_shape);
    522       GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape);
    523       OP_REQUIRES(context, filter_mkl_shape.IsMklTensor() == false,
    524                   errors::InvalidArgument("Filter should not be in "
    525                                           "Mkl Layout"));
    526 
    527       MklDnnData<T> src(&cpu_engine);
    528       MklDnnData<T> filter(&cpu_engine);
    529       MklDnnData<T> output(&cpu_engine);
    530 
    531       memory::dims src_dims, filter_dims, padding_l, padding_r, strides;
    532       memory::dims output_dims_tf_order, output_dims_mkl_order;
    533 
    534       // Get shapes of input tensors in MKL-DNN order
    535       MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
    536       auto src_tf_shape = GetTfShape(context, kInputIndex_Src);
    537       auto filter_tf_shape = GetTfShape(context, kInputIndex_Filter);
    538       conv_utl.GetConvFwdSizesInMklOrder(
    539           src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
    540           &output_dims_tf_order, &output_dims_mkl_order, &padding_l,
    541           &padding_r);
    542       if (!context->status().ok()) return;
    543 
    544       // Check for corner case - if there is nothing to compute, return.
    545       TensorShape output_tf_shape = MklDnnDimsToTFShape(output_dims_tf_order);
    546 
    547       // Corner cases: output with 0 elements and 0 batch size.
    548       Tensor* output_tensor = nullptr;
    549       if (output_tf_shape.num_elements() == 0 || output_dims_tf_order[0] == 0) {
    550         // TODO(jbobba): Verify correctness here
    551         //               Need semantics for Null MKL tensor
    552         MklDnnShape output_mkl_shape;
    553         output_mkl_shape.SetMklTensor(false);
    554         AllocateOutputSetMklShape(context, kOutputIndex_Dst, &output_tensor,
    555                                   src_tf_shape, output_mkl_shape);
    556 
    557         // MklConv2D also outputs converted filter as 2nd output of Conv2D.
    558         filter_mkl_shape.SetMklTensor(false);
    559         Tensor* output_filter_tensor = nullptr;
    560         AllocateOutputSetMklShape(context, kOutputIndex_Filter,
    561                                   &output_filter_tensor, filter_tf_shape,
    562                                   filter_mkl_shape);
    563         return;
    564       }
    565 
    566       // Create memory for user data.
    567       // Describe how the inputs and outputs of Convolution look like. Also
    568       // specify buffers containing actual input and output data.
    569       auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
    570       // If input is in MKL layout, then simply grab input layout; otherwise,
    571       // construct input Tf layout. For TF layout, although input shape
    572       // (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
    573       // layout (NHWC or NCHW depending on data format).
    574       auto src_md = src_mkl_shape.IsMklTensor()
    575                         ? src_mkl_shape.GetMklLayout()
    576                         : memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
    577       src.SetUsrMem(src_md, &src_tensor);
    578       // Although filter shape (filter_dims) required is in MKL-DNN order,
    579       // the layout is Tensorflow's layout (HWIO).
    580       auto filter_md = filter_mkl_shape.IsMklTensor()  // Should NEVER be true
    581                            ? filter_mkl_shape.GetMklLayout()
    582                            : memory::desc(filter_dims, MklDnnType<T>(),
    583                                           memory::format::hwio);
    584       filter.SetUsrMem(filter_md, &filter_tensor);
    585 
    586       // Set output shape (output_dims) required in MKL-DNN order.
    587       // Currently, we set output layout as Tensorflow's layout (NHWC or NCHW
    588       // depending on data format). But later we propagate Mkl layout of the
    589       // output to the next op directly.
    590       output.SetUsrMem(output_dims_mkl_order, tf_fmt);
    591 
    592       // Create memory descriptors for convolution data w/ no specified format.
    593       src.SetOpMemDesc(src_dims, memory::format::any);
    594       filter.SetOpMemDesc(filter_dims, memory::format::any);
    595       output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
    596 
    597       // If bias is enabled, then do the same steps as above for bias.
    598       if (biasEnabled) {
    599         MklDnnData<T> bias(&cpu_engine);
    600         memory::dims bias_size;
    601         conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_size);
    602         const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
    603         bias.SetUsrMem(bias_size, memory::format::x, &bias_tensor);
    604         bias.SetOpMemDesc(bias_size, memory::format::any);
    605 
    606         // Create convolution primitive with Bias.
    607         auto conv_desc = convolution_forward::desc(
    608             prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
    609             filter.GetOpMemDesc(), bias.GetOpMemDesc(), output.GetOpMemDesc(),
    610             strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
    611 
    612         auto conv_prim_desc =
    613             convolution_forward::primitive_desc(conv_desc, cpu_engine);
    614         AllocateOutputTensor(context, conv_prim_desc, output_dims_mkl_order,
    615                              tf_fmt, &output_tensor);
    616         // Set data handle for output.
    617         output.SetUsrMemDataHandle(output_tensor);
    618 
    619         Tensor* filter_out_tensor = nullptr;
    620         AllocateFilterOutputTensor(context, conv_prim_desc,
    621                                    TFShapeToMklDnnDims(filter_tf_shape),
    622                                    &filter_out_tensor);
    623 
    624         PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output,
    625                              filter_out_tensor);
    626       } else {
    627         // Create convolution primitive without Bias.
    628         auto conv_desc = convolution_forward::desc(
    629             prop_kind::forward, convolution_direct, src.GetOpMemDesc(),
    630             filter.GetOpMemDesc(), output.GetOpMemDesc(), strides, padding_l,
    631             padding_r, TFPaddingToMklDnnPadding(padding_));
    632 
    633         auto conv_prim_desc =
    634             convolution_forward::primitive_desc(conv_desc, cpu_engine);
    635         AllocateOutputTensor(context, conv_prim_desc, output_dims_mkl_order,
    636                              tf_fmt, &output_tensor);
    637         // Set data handle for output.
    638         output.SetUsrMemDataHandle(output_tensor);
    639 
    640         Tensor* filter_out_tensor = nullptr;
    641         AllocateFilterOutputTensor(context, conv_prim_desc,
    642                                    TFShapeToMklDnnDims(filter_tf_shape),
    643                                    &filter_out_tensor);
    644         PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output,
    645                              filter_out_tensor);
    646       }
    647     } catch (mkldnn::error& e) {
    648       string error_msg = "Status: " + std::to_string(e.status) +
    649                          ", message: " + std::string(e.message) + ", in file " +
    650                          std::string(__FILE__) + ":" + std::to_string(__LINE__);
    651       OP_REQUIRES_OK(
    652           context,
    653           errors::Aborted("Operation received an exception:", error_msg));
    654     }
    655   }
    656 
    657  private:
    658   std::vector<int32> strides_;
    659   Padding padding_;
    660   TensorFormat data_format_;
    661   const int kInputIndex_Src = 0, kInputIndex_Filter = 1, kInputIndex_Bias = 2;
    662   const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1;
    663 
    664   // Allocate output tensor.
    665   void AllocateOutputTensor(
    666       OpKernelContext* context,
    667       const convolution_forward::primitive_desc& conv_prim_desc,
    668       const memory::dims& output_dims_mkl_order,
    669       memory::format output_tf_format, Tensor** output_tensor) {
    670     CHECK_NOTNULL(output_tensor);
    671     auto dst_pd = conv_prim_desc.dst_primitive_desc();
    672 
    673     // Allocate shape of Mkl tensor.
    674     MklDnnShape output_mkl_shape;
    675     output_mkl_shape.SetMklTensor(true);
    676     output_mkl_shape.SetMklLayout(&dst_pd);
    677     output_mkl_shape.SetElemType(MklDnnType<T>());
    678     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    679                                  output_dims_mkl_order, output_tf_format);
    680 
    681     // Allocate shape of TF tensor.
    682     TensorShape output_tf_shape;
    683     output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
    684 
    685     AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
    686                               output_tf_shape, output_mkl_shape);
    687   }
    688 
    689   // Allocate output tensor.
    690   void AllocateFilterOutputTensor(
    691       OpKernelContext* context,
    692       const convolution_forward::primitive_desc& conv_prim_desc,
    693       const memory::dims& filter_dims_tf_order, Tensor** filter_tensor) {
    694     CHECK_NOTNULL(filter_tensor);
    695     auto filter_pd = conv_prim_desc.weights_primitive_desc();
    696 
    697     // Allocate shape of Mkl tensor.
    698     MklDnnShape filter_mkl_shape;
    699     filter_mkl_shape.SetMklTensor(true);
    700     filter_mkl_shape.SetMklLayout(&filter_pd);
    701     filter_mkl_shape.SetElemType(MklDnnType<T>());
    702 
    703     // The format of the filter is actually OIhw8i8o, but TF doesn't support
    704     // this format. Just use format::blocked for now because the layout
    705     // is stored in the MKL data.
    706     filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(),
    707                                  filter_dims_tf_order, memory::format::blocked);
    708 
    709     // Allocate the data space for the filter to propagate as TF tensor.
    710     TensorShape filter_tf_shape;
    711     filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(T)));
    712 
    713     AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor,
    714                               filter_tf_shape, filter_mkl_shape);
    715   }
    716 
    717   // Prepare and execute net - checks for input and output reorders.
    718   void PrepareAndExecuteNet(
    719       const convolution_forward::primitive_desc& conv_prim_desc,
    720       MklDnnData<T>* src, MklDnnData<T>* filter, MklDnnData<T>* bias,
    721       MklDnnData<T>* output, Tensor* filter_out_tensor) {
    722     CHECK_NOTNULL(filter_out_tensor);
    723 
    724     // Create reorders between user layout and MKL layout if it is needed and
    725     // add it to the net before convolution. No need to check for output
    726     // reorder as we propagate output layout to the next layer.
    727     std::vector<primitive> net;
    728     src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
    729 
    730     // rather than re-order to a temp buffer, reorder directly to the
    731     // filter output tensor
    732     filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(),
    733                                 filter->GetTensorBuffer(filter_out_tensor),
    734                                 &net);
    735 
    736     // Create convolution primitive and add it to net.
    737     if (bias) {
    738       CHECK_EQ(biasEnabled, true);
    739       net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
    740                                         filter->GetOpMem(), bias->GetOpMem(),
    741                                         output->GetOpMem()));
    742     } else {
    743       CHECK_EQ(biasEnabled, false);
    744       net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
    745                                         filter->GetOpMem(),
    746                                         output->GetOpMem()));
    747     }
    748 
    749     stream(stream::kind::eager).submit(net).wait();
    750   }
    751 };
    752 
    753 #endif
    754 
    755 #define REGISTER_MKL_CPU(T)                                         \
    756   REGISTER_KERNEL_BUILDER(Name("_MklConv2D")                        \
    757                               .Device(DEVICE_CPU)                   \
    758                               .TypeConstraint<T>("T")               \
    759                               .Label(mkl_op_registry::kMklOpLabel), \
    760                           MklConv2DOp<CPUDevice, T, false>);        \
    761   REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias")                \
    762                               .Device(DEVICE_CPU)                   \
    763                               .TypeConstraint<T>("T")               \
    764                               .Label(mkl_op_registry::kMklOpLabel), \
    765                           MklConv2DOp<CPUDevice, T, true>);         \
    766   REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias")          \
    767                               .Device(DEVICE_CPU)                   \
    768                               .TypeConstraint<T>("T")               \
    769                               .Label(mkl_op_registry::kMklOpLabel), \
    770                           MklDummyOp<CPUDevice, T>);
    771 
    772 TF_CALL_float(REGISTER_MKL_CPU);
    773 
    774 }  // namespace tensorflow
    775 #endif  // INTEL_MKL
    776