Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
     17 #define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
     18 
     19 #ifdef INTEL_MKL
     20 #include <string>
     21 #include <vector>
     22 #include "tensorflow/core/util/mkl_util.h"
     23 #include "tensorflow/core/util/padding.h"
     24 
     25 #ifndef INTEL_MKL_ML
     26 #include "mkldnn.hpp"
     27 using mkldnn::memory;
     28 using mkldnn::pooling_backward;
     29 using mkldnn::pooling_forward;
     30 using mkldnn::stream;
     31 #endif
     32 
     33 namespace tensorflow {
     34 
     35 typedef Eigen::ThreadPoolDevice CPUDevice;
     36 
     37 struct MklPoolParameters {
     38   int depth;
     39 
     40   int tensor_in_cols;
     41   int tensor_in_rows;
     42   int tensor_in_batch;
     43 
     44   int window_rows;
     45   int window_cols;
     46   int depth_window;
     47 
     48   int row_stride;
     49   int col_stride;
     50   int depth_stride;
     51 
     52   int64 out_height;
     53   int64 out_width;
     54   int out_depth;
     55 
     56   int64 pad_left;
     57   int64 pad_right;
     58   int64 pad_top;
     59   int64 pad_bottom;
     60   int pad_depth;
     61 
     62   TensorFormat data_format;
     63   MklPoolParameters()
     64       : depth(0),
     65         tensor_in_cols(0),
     66         tensor_in_rows(0),
     67         tensor_in_batch(0),
     68         window_rows(0),
     69         window_cols(0),
     70         depth_window(0),
     71         row_stride(0),
     72         col_stride(0),
     73         depth_stride(0),
     74         out_height(0),
     75         out_width(0),
     76         out_depth(0),
     77         pad_left(0),
     78         pad_right(0),
     79         pad_top(0),
     80         pad_bottom(0),
     81         pad_depth(0),
     82         data_format(TensorFormat::FORMAT_NCHW) {}
     83 
     84   // Updates context->status if there is an invalid input.
     85   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
     86             const std::vector<int32>& stride, Padding padding,
     87             TensorFormat data_format, const TensorShape& tensor_in_shape);
     88 #ifdef INTEL_MKL_ML
     89   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
     90             const std::vector<int32>& stride, Padding padding,
     91             TensorFormat data_format, const MklShape* mkl_in_shape);
     92 #else
     93   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
     94             const std::vector<int32>& stride, Padding padding,
     95             TensorFormat data_format, const MklDnnShape* mkl_in_shape);
     96 #endif
     97 
     98  private:
     99   // Common initialization for TensorFlow and MKL formats
    100   void Init(OpKernelContext* context, const std::vector<int32>& ksize,
    101             const std::vector<int32>& stride, Padding padding,
    102             TensorFormat data_format);
    103 };
    104 
    105 #ifndef INTEL_MKL_ML
    106 
    107 template <class T>
    108 class MklPoolingOpBase : public OpKernel {
    109  public:
    110   explicit MklPoolingOpBase(OpKernelConstruction* context)
    111       : OpKernel(context), workspace_enabled_(false) {
    112     string data_format;
    113     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
    114     OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_),
    115                 errors::InvalidArgument("Invalid data format"));
    116     this->data_format_mkldnn_ =
    117         TFDataFormatToMklDnnDataFormat(this->data_format_tf_);
    118     OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_));
    119     OP_REQUIRES(context, this->ksize_.size() == 4,
    120                 errors::InvalidArgument("Sliding window ksize field must "
    121                                         "specify 4 dimensions"));
    122     OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_));
    123     OP_REQUIRES(context, this->stride_.size() == 4,
    124                 errors::InvalidArgument("Sliding window strides field must "
    125                                         "specify 4 dimensions"));
    126     OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_));
    127     OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1,
    128                 errors::Unimplemented("Pooling is not yet supported on the "
    129                                       "batch dimension."));
    130 
    131     // We may not get this attribute for this node if it does not go through
    132     // graph rewrite pass. So we do not check for error while retrieving this
    133     // attribute value.
    134     context->GetAttr("workspace_enabled", &this->workspace_enabled_);
    135   }
    136   void Compute(OpKernelContext* context) override = 0;
    137 
    138  protected:
    139   // Calculate output shape of pooling op in MKL-DNN and TensorFlow order.
    140   // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
    141   // NHWC or NCHW format depending on data format. Function expects
    142   // output height and output width to have already been int32
    143   // bounds-checked
    144   void GetOutputDims(const MklPoolParameters& mkl_pool_params,
    145                      memory::dims* output_dims_mkl_order) {
    146     // MKL-DNN always needs output in NCHW format.
    147     *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
    148                               mkl_pool_params.out_depth,
    149                               static_cast<int>(mkl_pool_params.out_height),
    150                               static_cast<int>(mkl_pool_params.out_width)};
    151   }
    152 
    153   void InitMklPoolParameters(OpKernelContext* context,
    154                              MklPoolParameters* pool_params,
    155                              const MklDnnShape& original_input_mkl_shape,
    156                              const TensorShape& input_tensor_shape) {
    157     if (!original_input_mkl_shape.IsMklTensor()) {
    158       pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
    159                         this->data_format_tf_, input_tensor_shape);
    160     } else {
    161       pool_params->Init(context, this->ksize_, this->stride_, this->padding_,
    162                         this->data_format_tf_, &original_input_mkl_shape);
    163     }
    164   }
    165 
    166   // Checks to make sure that the memory we need to allocate
    167   // is a multiple of sizeof(T)
    168   // returns the number of elements
    169   size_t GetNumTElements(const memory::primitive_desc& pd) {
    170     size_t num_bytes = pd.get_size();
    171     size_t ret_val = num_bytes / sizeof(T);
    172     if (num_bytes % sizeof(T) != 0) {
    173       ret_val++;
    174     }
    175     return ret_val;
    176   }
    177 
    178   std::vector<int32> ksize_;
    179   std::vector<int32> stride_;
    180   Padding padding_;
    181   TensorFormat data_format_tf_;
    182   memory::format data_format_mkldnn_;
    183   bool workspace_enabled_;
    184 };
    185 
    186 template <class T>
    187 class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
    188  public:
    189   explicit MklPoolingForwardOpBase<T>(OpKernelConstruction* context)
    190       : MklPoolingOpBase<T>(context) {}
    191   void Compute(OpKernelContext* context) override = 0;
    192 
    193  protected:
    194   void ConfigureInput(OpKernelContext* context,
    195                       const MklDnnShape& input_mkl_shape,
    196                       const Tensor& input_tensor,
    197                       MklPoolParameters* pool_params,
    198                       MklDnnData<T>* dnn_data_input) {
    199     CHECK_NOTNULL(pool_params);
    200     CHECK_NOTNULL(dnn_data_input);
    201     TensorShape input_tensor_shape = input_tensor.shape();
    202     memory::desc input_md =
    203         input_mkl_shape.IsMklTensor()
    204             ? input_mkl_shape.GetMklLayout()
    205             : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
    206                                                      this->data_format_tf_),
    207                            MklDnnType<T>(), this->data_format_mkldnn_);
    208     dnn_data_input->SetUsrMem(input_md, &input_tensor);
    209     this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
    210                                 input_tensor_shape);
    211   }
    212 
    213   void AllocateOutputTensor(
    214       OpKernelContext* context,
    215       const pooling_forward::primitive_desc& pool_fwd_prim_desc,
    216       const memory::dims output_dims_mkl_order,
    217       const memory::format& output_tf_format, Tensor** output_tensor) {
    218     CHECK_NOTNULL(output_tensor);
    219     memory::primitive_desc dst_pd = pool_fwd_prim_desc.dst_primitive_desc();
    220 
    221     MklDnnShape output_mkl_shape;
    222     output_mkl_shape.SetMklTensor(true);
    223     output_mkl_shape.SetMklLayout(&dst_pd);
    224     output_mkl_shape.SetElemType(MklDnnType<T>());
    225     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    226                                  output_dims_mkl_order, output_tf_format);
    227     TensorShape output_tf_shape;
    228 
    229     // only allocate enough space for the elements we need.
    230     output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
    231     AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
    232                               output_tf_shape, output_mkl_shape);
    233     CHECK_NOTNULL(*output_tensor);
    234   }
    235 
    236   void PrepareAndExecuteNet(
    237       const pooling_forward::primitive_desc& pool_fwd_desc,
    238       const MklDnnData<T>* src, MklDnnData<T>* dst,
    239       MklDnnData<uint8>* wksp = nullptr) {
    240     std::vector<primitive> net;
    241 
    242     // Create pooling primitive and add it to net
    243     if (wksp != nullptr) {
    244       net.push_back(pooling_forward(pool_fwd_desc, src->GetOpMem(),
    245                                     dst->GetOpMem(), wksp->GetOpMem()));
    246     } else {
    247       net.push_back(
    248           pooling_forward(pool_fwd_desc, src->GetOpMem(), dst->GetOpMem()));
    249     }
    250     stream(stream::kind::eager).submit(net).wait();
    251   }
    252 
    253   void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
    254                         const MklDnnShape& input_mkl_shape) {
    255     if (!input_mkl_shape.IsMklTensor()) {
    256       OP_REQUIRES(context, input_tensor.dims() == 4,
    257                   errors::InvalidArgument("Input must be 4-dimensional"));
    258     } else {
    259       OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4,
    260                   errors::InvalidArgument("Input shape must be "
    261                                           "4-dimensional"));
    262     }
    263   }
    264   // .Input("value: T")
    265   // .Output("output: T")
    266   const int kInputTensorIndexInput = 0;
    267   const int kOutputTensorIndexOutput = 0;
    268 };  // MklPoolingForwardBaseOp
    269 
    270 template <class T>
    271 class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
    272  public:
    273   explicit MklPoolingBackwardOpBase<T>(OpKernelConstruction* context)
    274       : MklPoolingOpBase<T>(context) {}
    275   void Compute(OpKernelContext* context) override = 0;
    276 
    277  protected:
    278   const int kOutputTensorIndexOutput = 0;
    279 
    280   void AllocateOutputTensor(
    281       OpKernelContext* context,
    282       const pooling_backward::primitive_desc& pool_bkwd_prim_desc,
    283       const memory::dims output_dims_mkl_order,
    284       const memory::format& output_tf_format, Tensor** output_tensor) {
    285     CHECK_NOTNULL(output_tensor);
    286     memory::primitive_desc dst_pd =
    287         pool_bkwd_prim_desc.diff_src_primitive_desc();
    288     MklDnnShape output_mkl_shape;
    289     output_mkl_shape.SetMklTensor(true);
    290     output_mkl_shape.SetMklLayout(&dst_pd);
    291     output_mkl_shape.SetElemType(MklDnnType<T>());
    292     output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(),
    293                                  output_dims_mkl_order, output_tf_format);
    294 
    295     TensorShape output_tf_shape;
    296     output_tf_shape.AddDim(this->GetNumTElements(dst_pd));
    297     AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor,
    298                               output_tf_shape, output_mkl_shape);
    299     CHECK_NOTNULL(*output_tensor);
    300   }
    301 
    302   void PrepareAndExecuteNet(
    303       const pooling_backward::primitive_desc& pool_bkwd_desc,
    304       MklDnnData<T>* input_gradient_diff_dst, MklDnnData<T>* output_diff_src,
    305       const memory::primitive_desc& target_diff_dst_pd,
    306       const MklDnnData<uint8>* workspace = nullptr) {
    307     std::vector<primitive> net;
    308 
    309     // If the input gradient isn't in the same format as the output
    310     // reorder it to the same format as the output
    311     input_gradient_diff_dst->CheckReorderToOpMem(target_diff_dst_pd, &net);
    312 
    313     // Create pooling primitive and add it to net
    314     if (nullptr == workspace) {
    315       net.push_back(pooling_backward(pool_bkwd_desc,
    316                                      input_gradient_diff_dst->GetOpMem(),
    317                                      output_diff_src->GetOpMem()));
    318     } else {
    319       net.push_back(
    320           pooling_backward(pool_bkwd_desc, input_gradient_diff_dst->GetOpMem(),
    321                            workspace->GetOpMem(), output_diff_src->GetOpMem()));
    322     }
    323     stream(stream::kind::eager).submit(net).wait();
    324   }
    325 
    326   // Max Pooling and Avg Pooling have slightly different implementations
    327   // Takes the Tensor containing original input data and the original
    328   // mkl Dnn Shape and populates other data
    329   memory::desc ConfigureOriginalInput(
    330       OpKernelContext* context, const Tensor& tensor_original_input_shape,
    331       const MklDnnShape& original_input_mkl_shape,
    332       memory::dims* original_input_dims_nchw, MklPoolParameters* pool_params,
    333       const TensorShape& input_tensor_shape) {
    334     CHECK_NOTNULL(original_input_dims_nchw);
    335     CHECK_NOTNULL(pool_params);
    336     this->InitMklPoolParameters(context, pool_params, original_input_mkl_shape,
    337                                 input_tensor_shape);
    338 
    339     *original_input_dims_nchw =
    340         original_input_mkl_shape.IsMklTensor()
    341             ? original_input_mkl_shape.GetSizesAsMklDnnDims()
    342             : TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
    343                                         this->data_format_tf_);
    344 
    345     return original_input_mkl_shape.IsMklTensor()
    346                ? original_input_mkl_shape.GetMklLayout()
    347                : memory::desc(*original_input_dims_nchw, MklDnnType<T>(),
    348                               this->data_format_mkldnn_);
    349   }
    350 
    351   memory::desc ConfigureOriginalOutput(
    352       const MklPoolParameters& pool_params,
    353       const MklDnnShape& original_output_mkl_shape,
    354       memory::dims output_dims_mkl_order) {
    355     this->GetOutputDims(pool_params, &output_dims_mkl_order);
    356 
    357     return original_output_mkl_shape.IsMklTensor()
    358                ? original_output_mkl_shape.GetMklLayout()
    359                : memory::desc(output_dims_mkl_order, MklDnnType<T>(),
    360                               this->data_format_mkldnn_);
    361   }
    362 
    363   memory::desc ConfigureInputGradient(
    364       const MklDnnShape& input_gradient_mkl_shape,
    365       const Tensor& input_gradient_tensor,
    366       MklDnnData<T>* input_gradient_dnn_data,
    367       const memory::desc& original_output_md) {
    368     // Configure the gradient as is
    369     memory::desc original_input_grad_md =
    370         input_gradient_mkl_shape.IsMklTensor()
    371             ? input_gradient_mkl_shape.GetMklLayout()
    372             : memory::desc(
    373                   TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
    374                                             this->data_format_tf_),
    375                   MklDnnType<T>(), this->data_format_mkldnn_);
    376 
    377     input_gradient_dnn_data->SetUsrMem(original_input_grad_md,
    378                                        &input_gradient_tensor);
    379 
    380     // Check to see if input grad diff dst is in the right format
    381     // Create a new memory descriptor with the same shape as the
    382     // original, but the format of the other tensors.
    383     memory::format original_output_format =
    384         static_cast<memory::format>(original_output_md.data.format);
    385     bool grad_reorder_needed =
    386         input_gradient_dnn_data->IsReorderNeeded(original_output_format);
    387     memory::dims diff_dst_dims =
    388         input_gradient_mkl_shape.IsMklTensor()
    389             ? input_gradient_mkl_shape.GetSizesAsMklDnnDims()
    390             : TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
    391                                         this->data_format_tf_);
    392     memory::desc target_diff_dst_md =
    393         memory::desc(diff_dst_dims, MklDnnType<T>(), original_output_format);
    394 
    395     return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md;
    396   }
    397 };
    398 #endif  // INTEL_MKL_ML
    399 
    400 //-------------------------------------------------------------------
    401 // Utility functions
    402 
    403 typedef struct {
    404   size_t in_dim;
    405   size_t in_sizes[4];
    406   size_t in_strides[4];
    407   size_t out_sizes[4];
    408   size_t out_strides[4];
    409   int in_offset[4];
    410   size_t kernel_stride[2];
    411   size_t kernel_size[2];
    412 } MklPoolingOpParams;
    413 
    414 // Transfers the right parameters for pooling to the op parameters
    415 // Updates context->status if there is an invalid input.
    416 void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format,
    417                         const MklPoolParameters& params,
    418                         MklPoolingOpParams* mkl_params);
    419 }  // namespace tensorflow
    420 
    421 #endif  // INTEL_MKL
    422 #endif  // TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_
    423