Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
     17 // layout and primitives, use MKL dnn primitives to compute convolution backward
     18 // input
     19 
     20 #ifdef INTEL_MKL
     21 
     22 #define USE_EIGEN_TENSOR
     23 #define EIGEN_USE_THREADS
     24 #include <algorithm>
     25 #include <vector>
     26 #include "mkl_dnn.h"
     27 #include "mkl_dnn_types.h"
     28 #include "tensorflow/core/framework/numeric_op.h"
     29 #include "tensorflow/core/framework/op_kernel.h"
     30 #include "tensorflow/core/framework/register_types.h"
     31 #include "tensorflow/core/framework/tensor.h"
     32 #include "tensorflow/core/framework/tensor_shape.h"
     33 #include "tensorflow/core/framework/tensor_slice.h"
     34 #include "tensorflow/core/kernels/conv_grad_ops.h"
     35 #include "tensorflow/core/kernels/mkl_conv_ops.h"
     36 #include "tensorflow/core/kernels/ops_util.h"
     37 #include "tensorflow/core/lib/core/errors.h"
     38 #include "tensorflow/core/lib/gtl/array_slice.h"
     39 #include "tensorflow/core/platform/logging.h"
     40 #include "tensorflow/core/platform/macros.h"
     41 #include "tensorflow/core/util/mkl_util.h"
     42 #include "tensorflow/core/util/padding.h"
     43 #include "tensorflow/core/util/tensor_format.h"
     44 #include "tensorflow/core/util/use_cudnn.h"
     45 #include "tensorflow/core/util/work_sharder.h"
     46 
     47 #ifndef INTEL_MKL_ML
     48 #include "mkldnn.hpp"
     49 
     50 using mkldnn::convolution_backward_data;
     51 using mkldnn::prop_kind;
     52 using mkldnn::stream;
     53 #endif
     54 
     55 namespace tensorflow {
     56 
     57 typedef Eigen::ThreadPoolDevice CPUDevice;
     58 
     59 #ifdef INTEL_MKL_ML
     60 
     61 template <typename Device, class T>
     62 class MklConv2DCustomBackpropInputOp : public OpKernel {
     63  public:
     64   ~MklConv2DCustomBackpropInputOp() {}
     65   explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
     66       : OpKernel(context) {
     67     string dataformat;
     68     OP_REQUIRES_OK(context, context->GetAttr("data_format", &dataformat));
     69     OP_REQUIRES(context, FormatFromString(dataformat, &data_format),
     70                 errors::InvalidArgument("Invalid data format"));
     71     OP_REQUIRES_OK(context, context->GetAttr("strides", &strides));
     72     int stride_n = GetTensorDim(strides, data_format, 'N');
     73     int stride_c = GetTensorDim(strides, data_format, 'C');
     74     OP_REQUIRES(
     75         context, (stride_n == 1 && stride_c == 1),
     76         errors::InvalidArgument("Current implementation does not yet support "
     77                                 "strides in the batch and depth dimensions."));
     78 
     79     OP_REQUIRES_OK(context, context->GetAttr("padding", &padding));
     80   }
     81 
     82   void Compute(OpKernelContext* context) override {
     83     MklConvBackInputOpContext mkl_context;
     84     const Tensor& input = MklGetInput(context, 0);
     85     const Tensor& filter = MklGetInput(context, 1);
     86 
     87     GetMklShape(context, 1, &(mkl_context.filter_shape));
     88     bool filter_in_mkl_format = mkl_context.filter_shape.IsMklTensor();
     89 
     90     const Tensor& out_backprop = MklGetInput(context, 2);
     91     GetMklShape(context, 2, &(mkl_context.outback_shape));
     92     bool outback_in_mkl_format = mkl_context.outback_shape.IsMklTensor();
     93 
     94     TensorShape input_shape, filter_shape, outback_shape;
     95 
     96     // Generate input shape.
     97     OP_REQUIRES(
     98         context, TensorShapeUtils::IsVector(input.shape()),
     99         errors::InvalidArgument(
    100             "Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
    101             input.dims()));
    102     OP_REQUIRES_OK(
    103         context, TensorShapeUtils::MakeShape(input.vec<int32>(), &input_shape));
    104 
    105     // Generate shape for filter prop if input is in MKL format.
    106     if (filter_in_mkl_format) {
    107       OP_REQUIRES(context, mkl_context.filter_shape.GetDimension() == 4,
    108                   errors::InvalidArgument(
    109                       "Conv2DCustomBackpropInput: size must be 4-dim"));
    110 
    111       const int64* filter_sizes =
    112           (const int64*)mkl_context.filter_shape.GetSizes();
    113       const int64 filter_dims = mkl_context.filter_shape.GetDimension();
    114 
    115       OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
    116                                   filter_sizes, filter_dims, &filter_shape));
    117     } else {
    118       filter_shape = filter.shape();
    119     }
    120 
    121     // Generate shape for outback prop if input is in MKL format.
    122     if (outback_in_mkl_format) {
    123       OP_REQUIRES(context, mkl_context.outback_shape.GetDimension() == 4,
    124                   errors::InvalidArgument(
    125                       "Conv2DCustomBackpropInput: size must be 4-dim"));
    126 
    127       MklSizesToTFSizes(context, data_format, mkl_context.outback_shape,
    128                         &outback_shape);
    129     } else {
    130       outback_shape = out_backprop.shape();
    131     }
    132 
    133     ConvBackpropDimensions dims;
    134     OP_REQUIRES_OK(
    135         context,
    136         ConvBackpropComputeDimensions(
    137             "Conv2DCustomBackpropInput", /*num_spatial_dims=*/2, input_shape,
    138             filter_shape, outback_shape, strides, padding, data_format, &dims));
    139 
    140     int64 pad_top, pad_bottom;
    141     int64 pad_left, pad_right;
    142     OP_REQUIRES_OK(
    143         context,
    144         GetWindowedOutputSizeVerbose(
    145             dims.spatial_dims[0].input_size, dims.spatial_dims[0].filter_size,
    146             dims.spatial_dims[0].stride, padding,
    147             &dims.spatial_dims[0].output_size, &pad_top, &pad_bottom));
    148     OP_REQUIRES_OK(
    149         context,
    150         GetWindowedOutputSizeVerbose(
    151             dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
    152             dims.spatial_dims[1].stride, padding,
    153             &dims.spatial_dims[1].output_size, &pad_left, &pad_right));
    154 
    155     mkl_context.in_dims = 4;
    156 
    157     mkl_context.in_sizes[0] =
    158         static_cast<size_t>(dims.spatial_dims[1].input_size);
    159     mkl_context.in_sizes[1] =
    160         static_cast<size_t>(dims.spatial_dims[0].input_size);
    161     mkl_context.in_sizes[2] = static_cast<size_t>(dims.in_depth);
    162     mkl_context.in_sizes[3] = static_cast<size_t>(dims.batch_size);
    163 
    164     mkl_context.out_sizes[0] =
    165         static_cast<size_t>(dims.spatial_dims[1].output_size);
    166     mkl_context.out_sizes[1] =
    167         static_cast<size_t>(dims.spatial_dims[0].output_size);
    168     mkl_context.out_sizes[2] = static_cast<size_t>(dims.out_depth);
    169     mkl_context.out_sizes[3] = static_cast<size_t>(dims.batch_size);
    170 
    171     mkl_context.input_offset[0] = static_cast<int>(-pad_left);
    172     mkl_context.input_offset[1] = static_cast<int>(-pad_top);
    173 
    174     mkl_context.conv_strides[0] =
    175         static_cast<size_t>(dims.spatial_dims[1].stride);
    176     mkl_context.conv_strides[1] =
    177         static_cast<size_t>(dims.spatial_dims[0].stride);
    178 
    179     GetStridesFromSizes(data_format, mkl_context.out_strides,
    180                         mkl_context.out_sizes);
    181     GetStridesFromSizes(data_format, mkl_context.in_strides,
    182                         mkl_context.in_sizes);
    183 
    184     mkl_context.filter_size[0] = dims.spatial_dims[1].filter_size;
    185     mkl_context.filter_size[1] = dims.spatial_dims[0].filter_size;
    186     mkl_context.filter_size[2] = dims.in_depth;
    187     mkl_context.filter_size[3] = dims.out_depth;
    188 
    189     mkl_context.filter_stride[0] =
    190         mkl_context.filter_size[2] * mkl_context.filter_size[3];
    191     mkl_context.filter_stride[1] = mkl_context.filter_size[2] *
    192                                    mkl_context.filter_size[0] *
    193                                    mkl_context.filter_size[3];
    194     mkl_context.filter_stride[2] = mkl_context.filter_size[3];
    195     mkl_context.filter_stride[3] = 1;
    196 
    197     CHECK_EQ(
    198         dnnConvolutionCreateBackwardData_F32(
    199             &mkl_context.prim_bwddata, NULL, dnnAlgorithmConvolutionDirect,
    200             mkl_context.in_dims, mkl_context.in_sizes, mkl_context.out_sizes,
    201             mkl_context.filter_size, mkl_context.conv_strides,
    202             mkl_context.input_offset, dnnBorderZeros),
    203         E_SUCCESS);
    204 
    205     // Allocate output tensor and shape
    206     TensorShape mkl_out_shape;
    207     MklShape mklOutputShape;
    208     mklOutputShape.SetMklTensor(true);
    209     mklOutputShape.SetMklLayout(mkl_context.prim_bwddata, dnnResourceDiffSrc);
    210     mklOutputShape.SetTfLayout(mkl_context.in_dims, mkl_context.in_sizes,
    211                                mkl_context.in_strides);
    212     // MKL might change the dimension ordering.
    213     // Create mapping to recover the original TF dimension order
    214     mklOutputShape.SetTfDimOrder(mkl_context.in_dims, data_format);
    215 
    216     Tensor* in_backprop = nullptr;
    217     mkl_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
    218                              mklOutputShape.GetMklLayout())) /
    219                          sizeof(T));
    220     AllocateOutputSetMklShape(context, 0, &in_backprop, mkl_out_shape,
    221                               mklOutputShape);
    222 
    223     mkl_context.conv_res[dnnResourceDiffSrc] =
    224         static_cast<void*>(const_cast<T*>(in_backprop->flat<T>().data()));
    225 
    226     mkl_context.MklCreateInputLayouts(context);
    227     Tensor mkl_tmp_outbackprop_buf_tensor, mkl_tmp_filter_buf_tensor;
    228     mkl_context.MklPrepareConvolutionInputs(
    229         context, &mkl_tmp_outbackprop_buf_tensor, &mkl_tmp_filter_buf_tensor);
    230 
    231     CHECK_EQ(dnnExecute_F32(mkl_context.prim_bwddata, mkl_context.conv_res),
    232              E_SUCCESS);
    233     mkl_context.MklCleanup();
    234   }
    235 
    236  private:
    237   typedef struct {
    238     int in_dims;
    239     size_t in_sizes[4];
    240     size_t in_strides[4];
    241     size_t out_sizes[4];
    242     size_t out_strides[4];
    243     int input_offset[2];
    244     size_t filter_size[4];
    245     size_t filter_stride[4];
    246     size_t conv_strides[2];
    247     MklShape filter_shape, outback_shape;
    248     dnnPrimitive_t prim_bwddata;
    249     void* conv_res[dnnResourceNumber];
    250     dnnLayout_t lt_filter, lt_outbackprop;
    251 
    252     // Create MKL dnnLayout_t objects for tensors coming into the layer
    253     void MklCreateInputLayouts(OpKernelContext* context) {
    254       bool filter_in_mkl_format = filter_shape.IsMklTensor();
    255       bool outback_in_mkl_format = outback_shape.IsMklTensor();
    256       if (filter_in_mkl_format) {
    257         lt_filter = (dnnLayout_t)filter_shape.GetCurLayout();
    258       } else {
    259         CHECK_EQ(dnnLayoutCreate_F32(&lt_filter, in_dims, filter_size,
    260                                      filter_stride),
    261                  E_SUCCESS);
    262       }
    263 
    264       if (outback_in_mkl_format) {
    265         lt_outbackprop = (dnnLayout_t)outback_shape.GetCurLayout();
    266       } else {
    267         CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop, in_dims, out_sizes,
    268                                      out_strides),
    269                  E_SUCCESS);
    270       }
    271     }
    272 
    273     // Compare incoming input tensor layouts with MKL preferred layouts and
    274     // convert data to the preferred layout if necessary
    275     void MklPrepareConvolutionInputs(OpKernelContext* context,
    276                                      Tensor* mkl_tmp_outbackprop_buf_tensor,
    277                                      Tensor* mkl_tmp_filter_buf_tensor) {
    278       dnnPrimitive_t mkl_convert_filter = nullptr,
    279                      mkl_convert_outbackprop = nullptr;
    280       void *mkl_filter_buf = nullptr, *mkl_outbackprop_buf = nullptr;
    281       dnnLayout_t mkl_lt_filter_internal = nullptr,
    282                   mkl_lt_outbackprop_internal = nullptr;
    283       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
    284                    &mkl_lt_filter_internal, prim_bwddata, dnnResourceFilter),
    285                E_SUCCESS);
    286 
    287       const Tensor& filter = MklGetInput(context, 1);
    288 
    289       CHECK_EQ(
    290           dnnLayoutCreateFromPrimitive_F32(&mkl_lt_outbackprop_internal,
    291                                            prim_bwddata, dnnResourceDiffDst),
    292           E_SUCCESS);
    293       if (!dnnLayoutCompare_F32(mkl_lt_filter_internal, lt_filter)) {
    294         // Create conversion primitive
    295         CHECK_EQ(dnnConversionCreate_F32(&mkl_convert_filter, lt_filter,
    296                                          mkl_lt_filter_internal),
    297                  E_SUCCESS);
    298 
    299         AllocTmpBuffer(context, mkl_tmp_filter_buf_tensor,
    300                        mkl_lt_filter_internal, &mkl_filter_buf);
    301         CHECK_EQ(
    302             dnnConversionExecute_F32(
    303                 mkl_convert_filter,
    304                 static_cast<void*>(const_cast<T*>(filter.flat<T>().data())),
    305                 mkl_filter_buf),
    306             E_SUCCESS);
    307 
    308         // Assign filter buf to resources[] for convolution.
    309         conv_res[dnnResourceFilter] = mkl_filter_buf;
    310         dnnDelete_F32(mkl_convert_filter);
    311       } else {
    312         // If we do not need any layout conversion for filter, then
    313         // we directly assign input filter to resources[].
    314         conv_res[dnnResourceFilter] =
    315             static_cast<void*>(const_cast<T*>(filter.flat<T>().data()));
    316       }
    317       dnnLayoutDelete_F32(mkl_lt_filter_internal);
    318       const Tensor& out_backprop = MklGetInput(context, 2);
    319       // --
    320       // We do similar steps as above for outputbackprop.
    321       if (!dnnLayoutCompare_F32(mkl_lt_outbackprop_internal, lt_outbackprop)) {
    322         CHECK_EQ(
    323             dnnConversionCreate_F32(&mkl_convert_outbackprop, lt_outbackprop,
    324                                     mkl_lt_outbackprop_internal),
    325             E_SUCCESS);
    326         AllocTmpBuffer(context, mkl_tmp_outbackprop_buf_tensor,
    327                        mkl_lt_outbackprop_internal, &mkl_outbackprop_buf);
    328 
    329         CHECK_EQ(dnnConversionExecute_F32(mkl_convert_outbackprop,
    330                                           static_cast<void*>(const_cast<T*>(
    331                                               out_backprop.flat<T>().data())),
    332                                           mkl_outbackprop_buf),
    333                  E_SUCCESS);
    334 
    335         conv_res[dnnResourceDiffDst] = mkl_outbackprop_buf;
    336         dnnDelete_F32(mkl_convert_outbackprop);
    337       } else {
    338         conv_res[dnnResourceDiffDst] =
    339             static_cast<void*>(const_cast<T*>(out_backprop.flat<T>().data()));
    340       }
    341       dnnLayoutDelete_F32(mkl_lt_outbackprop_internal);
    342     }
    343 
    344     // Cleanup member layouts and primitives
    345     void MklCleanup() {
    346       bool filter_in_mkl_format = filter_shape.IsMklTensor();
    347       bool outback_in_mkl_format = outback_shape.IsMklTensor();
    348       if (!filter_in_mkl_format) dnnLayoutDelete_F32(lt_filter);
    349       if (!outback_in_mkl_format) dnnLayoutDelete_F32(lt_outbackprop);
    350       dnnDelete_F32(prim_bwddata);
    351     }
    352   } MklConvBackInputOpContext;
    353 
    354   std::vector<int32> strides;
    355   Padding padding;
    356   TensorFormat data_format;
    357 };
    358 
    359 #else
    360 
    361 template <typename Device, class T>
    362 class MklConv2DCustomBackpropInputOp
    363     : public MklConv2DBackpropCommonOp<Device, T> {
    364  public:
    365   explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
    366       : MklConv2DBackpropCommonOp<Device, T>(context) {}
    367   ~MklConv2DCustomBackpropInputOp() {}
    368 
    369  private:
    370   const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0,
    371             kInputIndex_OutBackProp = 2;
    372   void ValidateMklShapes(const MklDnnShape& input_mkl_shape,
    373                          const MklDnnShape& filter_mkl_shape,
    374                          const MklDnnShape& obp_mkl_shape) {
    375     // Tensor that feeds to 'Input' slot of BackpropInput is always just a shape
    376     // of the Tensor and never an actual tensor. So it will never be in MKL
    377     // layout.
    378     CHECK(!input_mkl_shape.IsMklTensor())
    379         << "Conv2DBackpropInput: input should not be in MKL Layout";
    380   }
    381 
    382   size_t GetInputTensorIndexWithSizes() { return kInputIndex_InputSizes; }
    383 
    384   TensorShape MakeInputTfShape(OpKernelContext* context,
    385                                const Tensor& input_tensor) {
    386     TensorShape input_tf_shape;
    387     CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
    388     CHECK_EQ(
    389         TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape)
    390             .ok(),
    391         true);
    392     return input_tf_shape;
    393   }
    394 
    395   TensorShape MakeFilterTfShape(OpKernelContext* context,
    396                                 const Tensor& filter_tensor) {
    397     return GetTfShape(context, kInputIndex_Filter);
    398   }
    399 
    400   TensorShape GetOutputTfShape(const TensorShape& input_shape,
    401                                const TensorShape& filter_shape,
    402                                const TensorShape& outbprop_shape) {
    403     // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
    404     return input_shape;
    405   }
    406 
    407   const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
    408                                     const memory::dims& fwd_filter_dims) {
    409     // Output Shape of Conv2DBackpropInput is same as shape of Conv2D 'input'.
    410     return fwd_input_dims;
    411   }
    412 
    413   memory::format GetOutputFormat(const memory::format data_format) {
    414     // Output layout is Tensorflow's layout in data format order.
    415     return data_format;
    416   }
    417 
    418   void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine,
    419                        const convolution_forward::primitive_desc& conv_fwd_pd,
    420                        MklDnnData<T>* input, MklDnnData<T>* filter,
    421                        MklDnnData<T>* outbackprop, MklDnnData<T>* output,
    422                        Tensor** output_tensor, const memory::dims& strides,
    423                        const memory::dims& padding_l,
    424                        const memory::dims& padding_r, padding_kind padding,
    425                        const memory::dims& bwd_output_dims,
    426                        memory::format bwd_output_format) {
    427     CHECK_NOTNULL(context);
    428     CHECK_NOTNULL(input);
    429     CHECK_NOTNULL(filter);
    430     CHECK_NOTNULL(outbackprop);
    431     CHECK_NOTNULL(output);
    432     CHECK_NOTNULL(output_tensor);
    433 
    434     // Create convolution backward data primitive.
    435     auto bwd_desc = convolution_backward_data::desc(
    436         convolution_direct, output->GetOpMemDesc(), filter->GetOpMemDesc(),
    437         outbackprop->GetOpMemDesc(), strides, padding_l, padding_r, padding);
    438 
    439     auto bwd_pd = convolution_backward_data::primitive_desc(
    440         bwd_desc, cpu_engine, conv_fwd_pd);
    441 
    442     // Allocate output tensor in TensorFlow and MKL layout.
    443     AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format,
    444                          output_tensor);
    445     CHECK_NOTNULL(*output_tensor);
    446     // Set buffer handle using allocated output tensor.
    447     output->SetUsrMemDataHandle(*output_tensor);
    448 
    449     PrepareAndExecutePrimitive(bwd_pd, filter, outbackprop, output);
    450   }
    451 
    452   // Allocate output tensor.
    453   void AllocateOutputTensor(
    454       OpKernelContext* context,
    455       const convolution_backward_data::primitive_desc& conv_pd,
    456       const memory::dims& output_dims_mkl_order,
    457       memory::format output_tf_format, Tensor** output_tensor) {
    458     CHECK_NOTNULL(output_tensor);
    459 
    460     // Output primitive descriptor for backward data is diff_src.
    461     auto dst_pd = conv_pd.diff_src_primitive_desc();
    462 
    463     // Allocate shape of Mkl tensor.
    464     MklDnnShape output_mkl_shape;
    465     output_mkl_shape.SetMklTensor(true);
    466     output_mkl_shape.SetMklLayout(&dst_pd);
    467     output_mkl_shape.SetElemType(MklDnnType<T>());
    468     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    469                                  output_dims_mkl_order, output_tf_format);
    470 
    471     // Allocate shape of TF tensor.
    472     TensorShape output_tf_shape;
    473     output_tf_shape.AddDim(dst_pd.get_size() / sizeof(T));
    474 
    475     AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
    476                               output_mkl_shape);
    477   }
    478 
    479   // Prepare and execute net - checks for input and output reorders.
    480   void PrepareAndExecutePrimitive(
    481       const convolution_backward_data::primitive_desc& conv_pd,
    482       MklDnnData<T>* filter, MklDnnData<T>* obp, MklDnnData<T>* output) {
    483     // Create reorders between user layout and MKL layout if it is needed and
    484     // add it to the net before convolution.
    485     std::vector<primitive> net;
    486     filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
    487     obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
    488 
    489     net.push_back(convolution_backward_data(
    490         conv_pd, obp->GetOpMem(), filter->GetOpMem(), output->GetOpMem()));
    491 
    492     stream(stream::kind::eager).submit(net).wait();
    493   }
    494 };
    495 
    496 #endif  // INTEL_MKL_ML
    497 
    498 #define REGISTER_MKL_CPU_KERNELS(T)                                 \
    499   REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput")           \
    500                               .Device(DEVICE_CPU)                   \
    501                               .TypeConstraint<T>("T")               \
    502                               .Label(mkl_op_registry::kMklOpLabel), \
    503                           MklConv2DCustomBackpropInputOp<CPUDevice, T>);
    504 
    505 TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
    506 #undef REGISTER_MKL_CPU_KERNELS
    507 
    508 }  // namespace tensorflow
    509 #endif  // INTEL_MKL
    510