Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 // See docs in ../ops/nn_ops.cc.This opkernel uses MKL library, create MKL
     17 // layout and primitives, use MKL dnn primitives to compute convolution backward
     18 // bias.
     19 
     20 #ifdef INTEL_MKL
     21 
     22 #define USE_EIGEN_TENSOR
     23 #define EIGEN_USE_THREADS
     24 
     25 #include "tensorflow/core/framework/numeric_op.h"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_shape.h"
     30 #include "tensorflow/core/framework/tensor_slice.h"
     31 #include "tensorflow/core/kernels/ops_util.h"
     32 #include "tensorflow/core/lib/core/errors.h"
     33 #include "tensorflow/core/lib/gtl/array_slice.h"
     34 #include "tensorflow/core/platform/logging.h"
     35 #include "tensorflow/core/platform/macros.h"
     36 #include "tensorflow/core/util/padding.h"
     37 #include "tensorflow/core/util/tensor_format.h"
     38 #include "tensorflow/core/util/use_cudnn.h"
     39 #include "tensorflow/core/util/work_sharder.h"
     40 
     41 #include "mkl_dnn.h"
     42 #include "mkl_dnn_types.h"
     43 #include "tensorflow/core/util/mkl_util.h"
     44 
     45 namespace tensorflow {
     46 
     47 typedef Eigen::ThreadPoolDevice CPUDevice;
     48 
     49 template <typename Device, class T>
     50 class MklConv2DCustomBackpropBiasOp : public OpKernel {
     51  public:
     52   explicit MklConv2DCustomBackpropBiasOp(OpKernelConstruction* context)
     53       : OpKernel(context) {
     54     string data_format;
     55     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
     56     OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
     57                 errors::InvalidArgument("Invalid data format"));
     58   }
     59   ~MklConv2DCustomBackpropBiasOp() {}
     60 
     61   void Compute(OpKernelContext* context) override {
     62     MklConvBackBiasOpContext mkl_context;
     63     const Tensor& input = MklGetInput(context, 0);
     64     GetMklShape(context, 0, &mkl_context.input_shape);
     65     bool input_is_mkl = mkl_context.input_shape.IsMklTensor();
     66 
     67     if (input_is_mkl) {
     68       OP_REQUIRES(
     69           context, mkl_context.input_shape.GetDimension() == 4,
     70           errors::InvalidArgument("Input tensor must be 4-dimensional"));
     71     } else {
     72       OP_REQUIRES(context, input.dims() == 4,
     73                   errors::InvalidArgument("input must be 4-dimensional",
     74                                           input.shape().DebugString()));
     75     }
     76 
     77     if (input_is_mkl) {
     78       mkl_context.c_size = mkl_context.input_shape.GetSizes()[MklDims::C];
     79     } else if (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW) {
     80       mkl_context.c_size = GetTensorDim(input, data_format_, 'C');
     81     } else {
     82       errors::InvalidArgument("Unknown format ",
     83                               " Format must be either NCHW or NHWC. ");
     84     }
     85     TensorShape output_shape{mkl_context.c_size};
     86 
     87     Tensor* bias_backprop = nullptr;
     88     MklShape output_mkl_shape;
     89     output_mkl_shape.SetMklTensor(false);
     90     AllocateOutputSetMklShape(context, 0, &bias_backprop, output_shape,
     91                               output_mkl_shape);
     92 
     93     mkl_context.in_dims = 4;
     94 
     95     if (input_is_mkl) {  // get the shape from the mkl shape
     96       mkl_context.in_sizes[MklDims::W] =
     97           mkl_context.input_shape.GetSizes()[MklDims::W];
     98       mkl_context.in_sizes[MklDims::H] =
     99           mkl_context.input_shape.GetSizes()[MklDims::H];
    100       mkl_context.in_sizes[MklDims::C] =
    101           mkl_context.input_shape.GetSizes()[MklDims::C];
    102       mkl_context.in_sizes[MklDims::N] =
    103           mkl_context.input_shape.GetSizes()[MklDims::N];
    104     } else {
    105       mkl_context.in_sizes[MklDims::W] = GetTensorDim(input, data_format_, 'W');
    106       mkl_context.in_sizes[MklDims::H] = GetTensorDim(input, data_format_, 'H');
    107       mkl_context.in_sizes[MklDims::C] = GetTensorDim(input, data_format_, 'C');
    108       mkl_context.in_sizes[MklDims::N] = GetTensorDim(input, data_format_, 'N');
    109       GetStridesFromSizes(data_format_, mkl_context.in_strides,
    110                           mkl_context.in_sizes);
    111     }
    112 
    113     mkl_context.out_sizes[0] = mkl_context.c_size;
    114     mkl_context.out_strides[0] = 1;
    115 
    116     CHECK_EQ(
    117         dnnConvolutionCreateBackwardBias_F32(
    118             &mkl_context.prim_conv_bwdbias, NULL, dnnAlgorithmConvolutionDirect,
    119             mkl_context.in_dims, mkl_context.in_sizes),
    120         E_SUCCESS);
    121 
    122     mkl_context.MklCreateInputLayouts(context);
    123 
    124     Tensor mkl_tmp_input_buf, mkl_tmp_outbackprop_buf;
    125     mkl_context.MklPrepareConvolutionInputs(context, &mkl_tmp_input_buf);
    126     mkl_context.MklPrepareConvolutionOutputs(context, &mkl_tmp_outbackprop_buf,
    127                                              bias_backprop);
    128 
    129     CHECK_EQ(
    130         dnnExecute_F32(mkl_context.prim_conv_bwdbias, mkl_context.conv_res),
    131         E_SUCCESS);
    132     if (mkl_context.should_convert_output) {
    133       CHECK_EQ(dnnConversionExecute_F32(
    134                    mkl_context.convert_outbackprop, mkl_context.outbackprop_buf,
    135                    static_cast<void*>(bias_backprop->flat<T>().data())),
    136                E_SUCCESS);
    137     }
    138     // deletes layouts
    139     mkl_context.MklCleanup();
    140   }
    141 
    142  private:
    143   typedef struct {
    144     int in_dims;
    145     int c_size;
    146     size_t in_sizes[4];
    147     size_t in_strides[4];
    148     size_t out_sizes[1];
    149     size_t out_strides[1];
    150     size_t filter_sizes[4];
    151     size_t filter_strides[4];
    152     int input_offset[2];
    153     size_t conv_stride[2];
    154     MklShape input_shape;
    155     dnnPrimitive_t prim_conv_bwdbias;
    156     void* conv_res[dnnResourceNumber];
    157     dnnLayout_t lt_input, lt_outbackprop;
    158     bool should_convert_output;
    159     dnnPrimitive_t convert_outbackprop;
    160     void* outbackprop_buf;
    161 
    162     // Create MKL dnnLayout_t objects for tensors coming into the layer
    163     void MklCreateInputLayouts(OpKernelContext* context) {
    164       bool input_is_mkl = input_shape.IsMklTensor();
    165 
    166       CHECK_EQ(dnnLayoutCreate_F32(&lt_outbackprop, 1, out_sizes, out_strides),
    167                E_SUCCESS);
    168       if (input_is_mkl) {
    169         lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
    170       } else {
    171         CHECK_EQ(dnnLayoutCreate_F32(&lt_input, in_dims, in_sizes, in_strides),
    172                  E_SUCCESS);
    173       }
    174     }
    175 
    176     // Compare incoming output tensor layouts with MKL preferred layouts and
    177     // convert data to the preferred layout if necessary
    178     void MklPrepareConvolutionOutputs(OpKernelContext* context,
    179                                       Tensor* mkl_tmp_outbackprop_buf,
    180                                       Tensor* bias_backprop) {
    181       dnnLayout_t mkl_prim_internal_outbackprop = nullptr;
    182       CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_prim_internal_outbackprop,
    183                                                 prim_conv_bwdbias,
    184                                                 dnnResourceDiffBias),
    185                E_SUCCESS);
    186       should_convert_output =
    187           !dnnLayoutCompare_F32(lt_outbackprop, mkl_prim_internal_outbackprop);
    188       if (should_convert_output) {
    189         CHECK_EQ(dnnConversionCreate_F32(&convert_outbackprop,
    190                                          mkl_prim_internal_outbackprop,
    191                                          lt_outbackprop),
    192                  E_SUCCESS);
    193         AllocTmpBuffer(context, mkl_tmp_outbackprop_buf,
    194                        mkl_prim_internal_outbackprop, &outbackprop_buf);
    195         conv_res[dnnResourceDiffBias] = outbackprop_buf;
    196       } else {
    197         conv_res[dnnResourceDiffBias] =
    198             static_cast<void*>(const_cast<T*>(bias_backprop->flat<T>().data()));
    199       }
    200 
    201       dnnLayoutDelete_F32(mkl_prim_internal_outbackprop);
    202     }
    203 
    204     // Compare incoming input tensor layouts with MKL preferred layouts and
    205     // convert data to the preferred layout if necessary
    206     void MklPrepareConvolutionInputs(OpKernelContext* context,
    207                                      Tensor* mkl_tmp_input_buf) {
    208       dnnLayout_t mkl_prim_internal_input = nullptr;
    209       dnnPrimitive_t mkl_convert_input = nullptr;
    210       void* input_buf = nullptr;
    211       const Tensor& input = MklGetInput(context, 0);
    212 
    213       CHECK_EQ(
    214           dnnLayoutCreateFromPrimitive_F32(
    215               &mkl_prim_internal_input, prim_conv_bwdbias, dnnResourceDiffDst),
    216           E_SUCCESS);
    217 
    218       if (!dnnLayoutCompare_F32(lt_input, mkl_prim_internal_input)) {
    219         CHECK_EQ(dnnConversionCreate_F32(&mkl_convert_input, lt_input,
    220                                          mkl_prim_internal_input),
    221                  E_SUCCESS);
    222         AllocTmpBuffer(context, mkl_tmp_input_buf, mkl_prim_internal_input,
    223                        &input_buf);
    224         CHECK_EQ(dnnConversionExecute_F32(
    225                      mkl_convert_input,
    226                      static_cast<void*>(const_cast<T*>(input.flat<T>().data())),
    227                      input_buf),
    228                  E_SUCCESS);
    229         conv_res[dnnResourceDiffDst] = input_buf;
    230         dnnDelete_F32(mkl_convert_input);
    231       } else {
    232         conv_res[dnnResourceDiffDst] =
    233             static_cast<void*>(const_cast<T*>(input.flat<T>().data()));
    234       }
    235 
    236       dnnLayoutDelete_F32(mkl_prim_internal_input);
    237     }
    238 
    239     // Cleanup member layouts and primitives
    240     void MklCleanup() {
    241       bool input_is_mkl = input_shape.IsMklTensor();
    242       if (!input_is_mkl) dnnLayoutDelete_F32(lt_input);
    243       dnnLayoutDelete_F32(lt_outbackprop);
    244 
    245       if (should_convert_output) dnnDelete_F32(convert_outbackprop);
    246       dnnDelete_F32(prim_conv_bwdbias);
    247     }
    248   } MklConvBackBiasOpContext;
    249 
    250   TensorFormat data_format_;
    251   TF_DISALLOW_COPY_AND_ASSIGN(MklConv2DCustomBackpropBiasOp);
    252 };
    253 
    254 #define REGISTER_CPU_KERNELS(T)                                     \
    255   REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias")    \
    256                               .Device(DEVICE_CPU)                   \
    257                               .TypeConstraint<T>("T")               \
    258                               .Label(mkl_op_registry::kMklOpLabel), \
    259                           MklConv2DCustomBackpropBiasOp<CPUDevice, T>);
    260 
    261 TF_CALL_float(REGISTER_CPU_KERNELS);
    262 #undef REGISTER_CPU_KERNELS
    263 } /* namespace tensorflow */
    264 #endif /* INTEL_MKL */
    265