Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.
     17 
     18 #ifdef INTEL_MKL
     19 
     20 #include <algorithm>
     21 #include <vector>
     22 
     23 #include "tensorflow/core/framework/numeric_op.h"
     24 #include "tensorflow/core/framework/op_kernel.h"
     25 #include "tensorflow/core/framework/register_types.h"
     26 #include "tensorflow/core/framework/tensor.h"
     27 #include "tensorflow/core/framework/tensor_shape.h"
     28 #include "tensorflow/core/framework/tensor_slice.h"
     29 #include "tensorflow/core/kernels/conv_grad_ops.h"
     30 #include "tensorflow/core/kernels/mkl_conv_ops.h"
     31 #include "tensorflow/core/kernels/ops_util.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/lib/gtl/array_slice.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 #include "tensorflow/core/util/padding.h"
     37 #include "tensorflow/core/util/tensor_format.h"
     38 #include "tensorflow/core/util/use_cudnn.h"
     39 #include "tensorflow/core/util/work_sharder.h"
     40 
     41 #include "mkl_dnn.h"
     42 #include "mkl_dnn_types.h"
     43 #include "tensorflow/core/util/mkl_util.h"
     44 
     45 #ifndef INTEL_MKL_ML
     46 #include "mkldnn.hpp"
     47 
     48 using mkldnn::convolution_backward_weights;
     49 using mkldnn::memory;
     50 using mkldnn::prop_kind;
     51 using mkldnn::stream;
     52 #endif
     53 
     54 namespace tensorflow {
     55 
     56 typedef Eigen::ThreadPoolDevice CPUDevice;
     57 
     58 #ifdef INTEL_MKL_ML
     59 
     60 template <typename Device, class T>
     61 class MklConv2DCustomBackpropFilterOp : public OpKernel {
     62  public:
     63   explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
     64       : OpKernel(context) {
     65     string data_format;
     66     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     67     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     68                 errors::InvalidArgument("Invalid data format"));
     69 
     70     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
     71     int stride_n = GetTensorDim(strides_, data_format_, 'N');
     72     int stride_c = GetTensorDim(strides_, data_format_, 'C');
     73     OP_REQUIRES(
     74         context, (stride_n == 1 && stride_c == 1),
     75         errors::InvalidArgument("Current implementation does not yet support "
     76                                 "strides in the batch and depth dimensions."));
     77     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
     78   }
     79 
     80   void Compute(OpKernelContext* context) override {
     81     MklConv2DGradFilterOpContext mkl_context;
     82     const Tensor& input = MklGetInput(context, 0);
     83     GetMklShape(context, 0, &(mkl_context.input_shape));
     84     bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
     85 
     86     const Tensor& filter_sizes = MklGetInput(context, 1);
     87 
     88     const Tensor& out_backprop = MklGetInput(context, 2);
     89     GetMklShape(context, 2, &(mkl_context.out_backprop_shape));
     90     bool out_backprop_in_mkl_format =
     91         mkl_context.out_backprop_shape.IsMklTensor();
     92 
     93     TensorShape input_shape, filter_shape, out_backprop_shape;
     94 
     95     OP_REQUIRES(
     96         context, TensorShapeUtils::IsVector(filter_sizes.shape()),
     97         errors::InvalidArgument(
     98             "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, "
     99             "not ",
    100             filter_sizes.dims()));
    101     OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    102                                 filter_sizes.vec<int32>(), &filter_shape));
    103 
    104     ConvBackpropDimensions backprop_dims;
    105 
    106     // Generate shape for input if input is in MKL format.
    107     if (input_in_mkl_format) {
    108       OP_REQUIRES(context, mkl_context.input_shape.GetDimension() == 4,
    109                   errors::InvalidArgument(
    110                       "Conv2DCustomBackpropFilter: input size must be 4-dim"));
    111 
    112       MklSizesToTFSizes(context, data_format_, mkl_context.input_shape,
    113                         &input_shape);
    114     } else {
    115       input_shape = input.shape();
    116     }
    117 
    118     // Generate shape for outback prop if input is in MKL format.
    119     if (out_backprop_in_mkl_format) {
    120       OP_REQUIRES(
    121           context, mkl_context.out_backprop_shape.GetDimension() == 4,
    122           errors::InvalidArgument(
    123               "Conv2DCustomBackpropFilter: outbackprop size must be 4-dim"));
    124 
    125       MklSizesToTFSizes(context, data_format_, mkl_context.out_backprop_shape,
    126                         &out_backprop_shape);
    127     } else {
    128       out_backprop_shape = out_backprop.shape();
    129     }
    130 
    131     OP_REQUIRES_OK(context,
    132                    ConvBackpropComputeDimensions(
    133                        "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2,
    134                        input_shape, filter_shape, out_backprop_shape, strides_,
    135                        padding_, data_format_, &backprop_dims));
    136 
    137     int64 pad_top, pad_bottom;
    138     int64 pad_left, pad_right;
    139     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    140                                 backprop_dims.spatial_dims[0].input_size,
    141                                 backprop_dims.spatial_dims[0].filter_size,
    142                                 backprop_dims.spatial_dims[0].stride, padding_,
    143                                 &backprop_dims.spatial_dims[0].output_size,
    144                                 &pad_top, &pad_bottom));
    145     OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
    146                                 backprop_dims.spatial_dims[1].input_size,
    147                                 backprop_dims.spatial_dims[1].filter_size,
    148                                 backprop_dims.spatial_dims[1].stride, padding_,
    149                                 &backprop_dims.spatial_dims[1].output_size,
    150                                 &pad_left, &pad_right));
    151 
    152     // Create MKL primitives for convolution filter grad
    153     mkl_context.in_dims = input_in_mkl_format
    154                               ? mkl_context.input_shape.GetDimension()
    155                               : input.dims();
    156     mkl_context.out_dims = out_backprop_in_mkl_format
    157                                ? mkl_context.out_backprop_shape.GetDimension()
    158                                : out_backprop.dims();
    159     mkl_context.in_sizes[0] =
    160         static_cast<size_t>(backprop_dims.spatial_dims[1].input_size);
    161     mkl_context.in_sizes[1] =
    162         static_cast<size_t>(backprop_dims.spatial_dims[0].input_size);
    163     mkl_context.in_sizes[2] = static_cast<size_t>(backprop_dims.in_depth);
    164     mkl_context.in_sizes[3] = static_cast<size_t>(backprop_dims.batch_size);
    165     mkl_context.out_sizes[0] =
    166         static_cast<size_t>(backprop_dims.spatial_dims[1].output_size);
    167     mkl_context.out_sizes[1] =
    168         static_cast<size_t>(backprop_dims.spatial_dims[0].output_size);
    169     mkl_context.out_sizes[2] = static_cast<size_t>(backprop_dims.out_depth);
    170     mkl_context.out_sizes[3] = static_cast<size_t>(backprop_dims.batch_size);
    171     mkl_context.input_offsets[0] = static_cast<int>(-pad_left);
    172     mkl_context.input_offsets[1] = static_cast<int>(-pad_top);
    173     mkl_context.conv_strides[0] =
    174         static_cast<size_t>(backprop_dims.spatial_dims[1].stride);
    175     mkl_context.conv_strides[1] =
    176         static_cast<size_t>(backprop_dims.spatial_dims[0].stride);
    177 
    178     GetStridesFromSizes(data_format_, mkl_context.in_strides,
    179                         mkl_context.in_sizes);
    180     GetStridesFromSizes(data_format_, mkl_context.out_strides,
    181                         mkl_context.out_sizes);
    182 
    183     // MKL understands dimensions in 0, 1, 2, and 3 indices denotes
    184     // filter cols, rows, input channels, and output depth/channels.
    185     mkl_context.filter_dims = 4;
    186     mkl_context.filter_sizes[0] = backprop_dims.spatial_dims[1].filter_size;
    187     mkl_context.filter_sizes[1] = backprop_dims.spatial_dims[0].filter_size;
    188     mkl_context.filter_sizes[2] = backprop_dims.in_depth;
    189     mkl_context.filter_sizes[3] = backprop_dims.out_depth;
    190 
    191     // We want filter grad to be in TF format, so
    192     // make the strides accordingly to reflect this fact.
    193     // Note TF filter layout : (rows, cols, in_depth, out_depth),
    194     // while row is the innermost dimension.
    195     mkl_context.filter_strides[0] =
    196         backprop_dims.out_depth * backprop_dims.in_depth;
    197     mkl_context.filter_strides[1] = backprop_dims.out_depth *
    198                                     backprop_dims.in_depth *
    199                                     backprop_dims.spatial_dims[1].filter_size;
    200     mkl_context.filter_strides[2] = backprop_dims.out_depth;
    201     mkl_context.filter_strides[3] = 1;
    202 
    203     mkl_context.conv_strides[0] = backprop_dims.spatial_dims[1].stride;
    204     mkl_context.conv_strides[1] = backprop_dims.spatial_dims[0].stride;
    205 
    206     // Create convolution-grad-filter primitive
    207     CHECK_EQ(dnnConvolutionCreateBackwardFilter_F32(
    208                  &mkl_context.prim_conv_bwdfilter, nullptr,
    209                  dnnAlgorithmConvolutionDirect, mkl_context.in_dims,
    210                  mkl_context.in_sizes, mkl_context.out_sizes,
    211                  mkl_context.filter_sizes, mkl_context.conv_strides,
    212                  mkl_context.input_offsets, dnnBorderZeros),
    213              E_SUCCESS);
    214 
    215     // Create the layouts for entities in received context.
    216     mkl_context.MklCreateInputLayouts(context);
    217 
    218     // Mkl needs the entities in its native format.
    219     // So create temporary tensors along with buffers to
    220     // convert the received entities.
    221     Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor;
    222     // This preparation sets (1) dnnResourceSrc (2) dnnResourceDiffDst
    223     mkl_context.MklPrepareInputs(context, &mkl_tmp_input_buf_tensor,
    224                                  &mkl_tmp_out_backprop_buf_tensor);
    225 
    226     // Final conv-grad-filter should be in TF layout.
    227     Tensor* grad_filter;
    228     mkl_context.grad_filter_shape.SetMklTensor(false);
    229     mkl_context.grad_filter_shape.SetTfLayout(mkl_context.filter_dims,
    230                                               mkl_context.filter_sizes,
    231                                               mkl_context.filter_strides);
    232     AllocateOutputSetMklShape(context, 0, &grad_filter, filter_shape,
    233                               mkl_context.grad_filter_shape);
    234 
    235     // Need to set member variable for TF layout
    236     mkl_context.lt_grad_filter = mkl_context.grad_filter_shape.GetTfLayout();
    237 
    238     // MKL conv-grad-filter might produce grad in its internal layout
    239     Tensor mkl_tmp_grad_filter_buf_tensor;
    240     // This preparation sets conversion primitive if required
    241     // and allocates temporary tensor and its buffer without doing conversions.
    242     // Also sets (3) dnnResourceDiffFilter accordingly
    243     mkl_context.MklPrepareGradFilter(context, grad_filter,
    244                                      &mkl_tmp_grad_filter_buf_tensor);
    245 
    246     // After setting all the required dnnResources, ready for execution!
    247     CHECK_EQ(
    248         dnnExecute_F32(mkl_context.prim_conv_bwdfilter, mkl_context.conv_res),
    249         E_SUCCESS);
    250 
    251     // Convert grad-filter to TF layout
    252     if (mkl_context.convert_bwdfilter != nullptr) {
    253       void* mkl_buf_convert_grad_filter =
    254           const_cast<void*>(static_cast<const void*>(
    255               mkl_tmp_grad_filter_buf_tensor.flat<T>().data()));
    256       void* mkl_buf_grad_filter = const_cast<void*>(
    257           static_cast<const void*>(grad_filter->flat<T>().data()));
    258       CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_bwdfilter,
    259                                         mkl_buf_convert_grad_filter,
    260                                         mkl_buf_grad_filter),
    261                E_SUCCESS);
    262     }
    263 
    264     mkl_context.MklCleanup();
    265   }
    266 
    267  private:
    268   typedef struct {
    269     int in_dims;
    270     size_t in_sizes[4];
    271     size_t in_strides[4];
    272     int out_dims;
    273     size_t out_sizes[4];
    274     size_t out_strides[4];
    275     int filter_dims;
    276     size_t filter_sizes[4];
    277     size_t filter_strides[4];
    278     int input_offsets[2];
    279     size_t conv_strides[2];
    280     MklShape input_shape, grad_filter_shape, out_backprop_shape;
    281     dnnPrimitive_t prim_conv_bwdfilter = nullptr;
    282     dnnPrimitive_t convert_bwdfilter = nullptr;
    283     dnnLayout_t lt_input = nullptr;
    284     dnnLayout_t lt_grad_filter = nullptr;
    285     dnnLayout_t lt_out_backprop = nullptr;
    286     void* conv_res[dnnResourceNumber];
    287 
    288     void MklCleanup() {
    289       // Cleanup member layouts and primitives except "lt_grad_filter_"
    290       // which points to MklShape's TFLayout
    291       bool input_in_mkl_format = input_shape.IsMklTensor();
    292       bool out_backprop_in_mkl_format = out_backprop_shape.IsMklTensor();
    293       if (!input_in_mkl_format) dnnLayoutDelete_F32(lt_input);
    294       if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(lt_out_backprop);
    295       if (convert_bwdfilter != nullptr) dnnDelete_F32(convert_bwdfilter);
    296       dnnDelete_F32(prim_conv_bwdfilter);
    297     }
    298 
    299     // Create MKL dnnLayout_t objects for tensors coming into the layer
    300     void MklCreateInputLayouts(OpKernelContext* context) {
    301       bool input_in_mkl_format = input_shape.IsMklTensor();
    302       if (input_in_mkl_format) {
    303         lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
    304       } else {
    305         CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
    306                  E_SUCCESS);
    307       }
    308 
    309       bool out_backprop_in_mkl_format = out_backprop_shape.IsMklTensor();
    310       if (out_backprop_in_mkl_format) {
    311         lt_out_backprop =
    312             static_cast<dnnLayout_t>(out_backprop_shape.GetCurLayout());
    313       } else {
    314         CHECK_EQ(dnnLayoutCreate_F32(&lt_out_backprop, out_dims, out_sizes,
    315                                      out_strides),
    316                  E_SUCCESS);
    317       }
    318     }
    319 
    320     // Compare incoming tensor layouts with MKL preferred layouts and convert
    321     // data to the preferred layout if necessary
    322     void MklPrepareInputs(OpKernelContext* context,
    323                           Tensor* mkl_tmp_input_buf_tensor,
    324                           Tensor* mkl_tmp_out_backprop_buf_tensor) {
    325       bool mkl_convert_input, mkl_convert_out_backprop;
    326       dnnPrimitive_t mkl_prim_convert_input, mkl_prim_convert_out_backprop;
    327       dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop;
    328       void *mkl_buf_convert_input, *mkl_buf_convert_out_backprop;
    329 
    330       mkl_prim_convert_input = nullptr;
    331       mkl_prim_convert_out_backprop = nullptr;
    332       mkl_lt_internal_input = nullptr;
    333       mkl_lt_internal_out_backprop = nullptr;
    334       mkl_buf_convert_input = nullptr;
    335       mkl_buf_convert_out_backprop = nullptr;
    336 
    337       // Compare with internal layouts and convert if needed
    338       const Tensor& input = MklGetInput(context, 0);
    339       void* mkl_buf_input =
    340           const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
    341       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    342                    &mkl_lt_internal_input, prim_conv_bwdfilter, dnnResourceSrc),
    343                E_SUCCESS);
    344       mkl_convert_input =
    345           !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input);
    346       if (mkl_convert_input) {
    347         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input,
    348                                          mkl_lt_internal_input),
    349                  E_SUCCESS);
    350         AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
    351                        &mkl_buf_convert_input);
    352         CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
    353                                           mkl_buf_convert_input),
    354                  E_SUCCESS);
    355         dnnDelete_F32(mkl_prim_convert_input);
    356       }
    357       dnnLayoutDelete_F32(mkl_lt_internal_input);
    358 
    359       conv_res[dnnResourceSrc] =
    360           (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
    361 
    362       const Tensor& out_backprop = MklGetInput(context, 2);
    363       void* mkl_buf_out_backprop = const_cast<void*>(
    364           static_cast<const void*>(out_backprop.flat<T>().data()));
    365 
    366       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
    367                                                 prim_conv_bwdfilter,
    368                                                 dnnResourceDiffDst),
    369                E_SUCCESS);
    370       mkl_convert_out_backprop =
    371           !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop);
    372       if (mkl_convert_out_backprop) {
    373         CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
    374                                          lt_out_backprop,
    375                                          mkl_lt_internal_out_backprop),
    376                  E_SUCCESS);
    377         AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor,
    378                        lt_out_backprop, &mkl_buf_convert_out_backprop);
    379         CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
    380                                           mkl_buf_out_backprop,
    381                                           mkl_buf_convert_out_backprop),
    382                  E_SUCCESS);
    383         dnnDelete_F32(mkl_prim_convert_out_backprop);
    384       }
    385       dnnLayoutDelete_F32(mkl_lt_internal_out_backprop);
    386 
    387       conv_res[dnnResourceDiffDst] = (mkl_convert_out_backprop)
    388                                          ? mkl_buf_convert_out_backprop
    389                                          : mkl_buf_out_backprop;
    390     }
    391 
    392     void MklPrepareGradFilter(OpKernelContext* context, Tensor* grad_filter,
    393                               Tensor* mkl_tmp_grad_filter_buf_tensor) {
    394       bool mkl_convert_grad_filter;
    395       dnnLayout_t mkl_lt_internal_grad_filter = nullptr;
    396       void* mkl_buf_convert_grad_filter = nullptr;
    397       void* mkl_buf_grad_filter = const_cast<void*>(
    398           static_cast<const void*>(grad_filter->flat<T>().data()));
    399       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_grad_filter,
    400                                                 prim_conv_bwdfilter,
    401                                                 dnnResourceDiffFilter),
    402                E_SUCCESS);
    403       mkl_convert_grad_filter =
    404           !dnnLayoutCompare_F32(mkl_lt_internal_grad_filter, lt_grad_filter);
    405       if (mkl_convert_grad_filter) {
    406         CHECK_EQ(dnnConversionCreate_F32(&convert_bwdfilter,
    407                                          mkl_lt_internal_grad_filter,
    408                                          lt_grad_filter),
    409                  E_SUCCESS);
    410         AllocTmpBuffer(context, mkl_tmp_grad_filter_buf_tensor,
    411                        mkl_lt_internal_grad_filter,
    412                        &mkl_buf_convert_grad_filter);
    413       }
    414       dnnLayoutDelete_F32(mkl_lt_internal_grad_filter);
    415 
    416       conv_res[dnnResourceDiffFilter] = (mkl_convert_grad_filter)
    417                                             ? mkl_buf_convert_grad_filter
    418                                             : mkl_buf_grad_filter;
    419     }
    420   } MklConv2DGradFilterOpContext;
    421 
    422   std::vector<int32> strides_;
    423   Padding padding_;
    424   TensorFormat data_format_;
    425 };
    426 
    427 #define REGISTER_MKL_FILTER_KERNELS(T)                              \
    428   REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter")          \
    429                               .Device(DEVICE_CPU)                   \
    430                               .TypeConstraint<T>("T")               \
    431                               .Label(mkl_op_registry::kMklOpLabel), \
    432                           MklConv2DCustomBackpropFilterOp<CPUDevice, T>);
    433 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
    434 #undef REGISTER_MKL_FILTER_KERNELS
    435 
    436 #else
    437 
    438 template <typename Device, class T, bool biasEnabled>
    439 class MklConv2DCustomBackpropFilterOp
    440     : public MklConv2DBackpropCommonOp<Device, T> {
    441  public:
    442   explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
    443       : MklConv2DBackpropCommonOp<Device, T>(context) {}
    444   ~MklConv2DCustomBackpropFilterOp() {}
    445 
    446  private:
    447   void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
    448                          const MklDnnShape& filter_mkl_shape,
    449                          const MklDnnShape& obp_mkl_shape) {
    450     CHECK(!filter_mkl_shape.IsMklTensor())
    451         << "Conv2DBackpropFilter: filter should not be in MKL Layout";
    452   }
    453 
    454   size_t GetInputTensorIndexWithSizes() { return 1; /* filter index */ }
    455 
    456   TensorShape MakeInputTfShape(OpKernelContext* context,
    457                                const Tensor& input_tensor) {
    458     size_t input_idx = 0;
    459     return GetTfShape(context, input_idx);
    460   }
    461 
    462   TensorShape MakeFilterTfShape(OpKernelContext* context,
    463                                 const Tensor& filter_tensor) {
    464     TensorShape filter_tf_shape;
    465     CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true);
    466     CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(),
    467                                          &filter_tf_shape)
    468                  .ok(),
    469              true);
    470     return filter_tf_shape;
    471   }
    472 
    473   TensorShape GetOutputTfShape(const TensorShape& input_shape,
    474                                const TensorShape& filter_shape,
    475                                const TensorShape& outbprop_shape) {
    476     // Shape of output of Conv2DBackpropFilter is same as shape of filter.
    477     return filter_shape;
    478   }
    479 
    480   const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
    481                                     const memory::dims& fwd_filter_dims) {
    482     // Shape of output of Conv2DBackpropFilter is same as shape of filter.
    483     return fwd_filter_dims;
    484   }
    485 
    486   memory::format GetOutputFormat(const memory::format data_format) {
    487     // Output layout is Tensorflow's filter layout (HWIO).
    488     return memory::format::hwio;
    489   }
    490 
    491   void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine,
    492                        const convolution_forward::primitive_desc& conv_fwd_pd,
    493                        MklDnnData<T>* input, MklDnnData<T>* filter,
    494                        MklDnnData<T>* outbackprop, MklDnnData<T>* output,
    495                        Tensor** output_tensor, const memory::dims& strides,
    496                        const memory::dims& padding_l,
    497                        const memory::dims& padding_r, padding_kind padding,
    498                        const memory::dims& bwd_output_dims,
    499                        memory::format bwd_output_format) {
    500     CHECK_NOTNULL(context);
    501     CHECK_NOTNULL(input);
    502     CHECK_NOTNULL(filter);
    503     CHECK_NOTNULL(outbackprop);
    504     CHECK_NOTNULL(output);
    505     CHECK_NOTNULL(output_tensor);
    506 
    507     MklDnnData<T>* bias_grad = nullptr;
    508     int depth = 0;
    509     if (biasEnabled) {
    510       // Data structure for bias_grad
    511       bias_grad = new MklDnnData<T>(&cpu_engine);
    512       TensorShape obp_tf_shape = GetTfShape(context, 2);
    513       depth = (MklConv2DBackpropCommonOp<Device, T>::GetTFDataFormat() ==
    514                FORMAT_NCHW)
    515                   ? obp_tf_shape.dim_size(1)
    516                   : obp_tf_shape.dim_size(3);
    517       memory::dims bias_grad_dims = {depth};
    518       bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x);
    519     }
    520 
    521     // Create convolution backward weights primitive.
    522     auto bwd_desc =
    523         (biasEnabled && (bias_grad != nullptr))
    524             ? convolution_backward_weights::desc(
    525                   convolution_direct, input->GetOpMemDesc(),
    526                   output->GetOpMemDesc(), bias_grad->GetOpMemDesc(),
    527                   outbackprop->GetOpMemDesc(), strides, padding_l, padding_r,
    528                   padding)
    529             : convolution_backward_weights::desc(
    530                   convolution_direct, input->GetOpMemDesc(),
    531                   output->GetOpMemDesc(), outbackprop->GetOpMemDesc(), strides,
    532                   padding_l, padding_r, padding);
    533 
    534     auto bwd_pd = convolution_backward_weights::primitive_desc(
    535         bwd_desc, cpu_engine, conv_fwd_pd);
    536 
    537     // Allocate output tensor.
    538     AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format,
    539                          output_tensor);
    540 
    541     CHECK_NOTNULL(*output_tensor);
    542     // Set buffer handle using allocated output tensor.
    543     output->SetUsrMemDataHandle(*output_tensor);
    544 
    545     if (biasEnabled && (bias_grad != nullptr)) {
    546       // Allocate bias_grad tensor
    547       TensorShape bias_grad_shape({depth});
    548       Tensor* bias_grad_tensor = nullptr;
    549       AllocateBiasGradTensor(context, bias_grad_shape, &bias_grad_tensor);
    550       memory::dims bias_grad_dims = {depth};
    551       // Since Bias is 1D, we use format::x from MKLDNN to represent it.
    552       auto bias_grad_md =
    553           memory::desc({bias_grad_dims}, MklDnnType<T>(), memory::format::x);
    554       bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor);
    555       bias_grad->SetUsrMemDataHandle(bias_grad_tensor);
    556     }
    557 
    558     if (biasEnabled && (bias_grad != nullptr)) {
    559       PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output, bias_grad);
    560     } else {
    561       PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output);
    562     }
    563   }
    564 
    565   // Allocate output tensor.
    566   void AllocateOutputTensor(
    567       OpKernelContext* context,
    568       const convolution_backward_weights::primitive_desc& conv_pd,
    569       const memory::dims& output_dims_mkl_order,
    570       memory::format output_tf_format, Tensor** output_tensor) {
    571     CHECK_NOTNULL(output_tensor);
    572 
    573     // For BackpropFilter, we convert the output tensor back in Tensorflow
    574     // layout. Because typically, BackpropFilter is the last operator in the
    575     // graph that emit filter gradient that is provided to ApplyGradient
    576     // method to update the filter. But it may be possible to eliminate this
    577     // by forwarding filter in MKL layout if we support ApplyGradient method
    578     // for MKL layout propagation.
    579     MklDnnShape output_mkl_shape;
    580     output_mkl_shape.SetMklTensor(false);
    581     // output_dims_mkl_order is in OIHW format.
    582     // Allocate shape of TF tensor in HWIO format.
    583     TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H],
    584                                  output_dims_mkl_order[MklDnnDims::Dim_W],
    585                                  output_dims_mkl_order[MklDnnDims::Dim_I],
    586                                  output_dims_mkl_order[MklDnnDims::Dim_O]});
    587     AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
    588                               output_mkl_shape);
    589   }
    590 
    591   // Allocate tensor for bias grad
    592   void AllocateBiasGradTensor(OpKernelContext* context,
    593                               const TensorShape& bias_grad_shape,
    594                               Tensor** bias_grad_tensor) {
    595     CHECK_NOTNULL(bias_grad_tensor);
    596 
    597     MklDnnShape bias_grad_mkl_shape;
    598     bias_grad_mkl_shape.SetMklTensor(false);
    599     AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape,
    600                               bias_grad_mkl_shape);
    601   }
    602 
    603   // Prepare and execute net - checks for input and output reorders.
    604   void PrepareAndExecutePrimitive(
    605       const convolution_backward_weights::primitive_desc& conv_pd,
    606       MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output,
    607       MklDnnData<T>* bias_grad = nullptr) {
    608     // Create reorders between user layout and MKL layout if it is needed and
    609     // add it to the net before convolution.
    610     std::vector<primitive> net;
    611     input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
    612     obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
    613 
    614     // For BackpropFilter, we convert the output tensor back in Tensorflow
    615     // layout.
    616     bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
    617         conv_pd.diff_weights_primitive_desc());
    618 
    619     if (biasEnabled && (bias_grad != nullptr)) {
    620       net.push_back(convolution_backward_weights(
    621           conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem(),
    622           bias_grad->GetOpMem()));
    623     } else {
    624       net.push_back(convolution_backward_weights(
    625           conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem()));
    626     }
    627 
    628     if (output_reorder_required) {
    629       output->InsertReorderToUserMem(&net);
    630     }
    631 
    632     stream(stream::kind::eager).submit(net).wait();
    633   }
    634 };
    635 
    636 #define REGISTER_MKL_FILTER_KERNELS(T)                                   \
    637   REGISTER_KERNEL_BUILDER(                                               \
    638       Name("_MklConv2DBackpropFilter")                                   \
    639           .Device(DEVICE_CPU)                                            \
    640           .TypeConstraint<T>("T")                                        \
    641           .Label(mkl_op_registry::kMklOpLabel),                          \
    642       MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>);             \
    643   REGISTER_KERNEL_BUILDER(                                               \
    644       Name("_MklConv2DBackpropFilterWithBias")                           \
    645           .Device(DEVICE_CPU)                                            \
    646           .TypeConstraint<T>("T")                                        \
    647           .Label(mkl_op_registry::kMklOpLabel),                          \
    648       MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>);              \
    649   REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
    650                               .Device(DEVICE_CPU)                        \
    651                               .TypeConstraint<T>("T")                    \
    652                               .Label(mkl_op_registry::kMklOpLabel),      \
    653                           MklDummyOp<CPUDevice, T>);
    654 
    655 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
    656 #undef REGISTER_MKL_FILTER_KERNELS
    657 
    658 #endif  // INTEL_MKL_ML
    659 
    660 }  // namespace tensorflow
    661 
    662 #endif  // INTEL_MKL
    663