1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ 17 #define TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ 18 19 #ifdef INTEL_MKL 20 #include <string> 21 #include <vector> 22 #include "tensorflow/core/util/mkl_util.h" 23 #include "tensorflow/core/util/padding.h" 24 25 #ifndef INTEL_MKL_ML 26 #include "mkldnn.hpp" 27 using mkldnn::memory; 28 using mkldnn::pooling_backward; 29 using mkldnn::pooling_forward; 30 using mkldnn::stream; 31 #endif 32 33 namespace tensorflow { 34 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 37 struct MklPoolParameters { 38 int depth; 39 40 int tensor_in_cols; 41 int tensor_in_rows; 42 int tensor_in_batch; 43 44 int window_rows; 45 int window_cols; 46 int depth_window; 47 48 int row_stride; 49 int col_stride; 50 int depth_stride; 51 52 int64 out_height; 53 int64 out_width; 54 int out_depth; 55 56 int64 pad_left; 57 int64 pad_right; 58 int64 pad_top; 59 int64 pad_bottom; 60 int pad_depth; 61 62 TensorFormat data_format; 63 MklPoolParameters() 64 : depth(0), 65 tensor_in_cols(0), 66 tensor_in_rows(0), 67 tensor_in_batch(0), 68 window_rows(0), 69 window_cols(0), 70 depth_window(0), 71 row_stride(0), 72 col_stride(0), 73 depth_stride(0), 74 out_height(0), 75 out_width(0), 76 out_depth(0), 77 pad_left(0), 78 pad_right(0), 79 pad_top(0), 80 pad_bottom(0), 81 pad_depth(0), 82 data_format(TensorFormat::FORMAT_NCHW) {} 83 84 // Updates context->status if there is an invalid input. 85 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 86 const std::vector<int32>& stride, Padding padding, 87 TensorFormat data_format, const TensorShape& tensor_in_shape); 88 #ifdef INTEL_MKL_ML 89 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 90 const std::vector<int32>& stride, Padding padding, 91 TensorFormat data_format, const MklShape* mkl_in_shape); 92 #else 93 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 94 const std::vector<int32>& stride, Padding padding, 95 TensorFormat data_format, const MklDnnShape* mkl_in_shape); 96 #endif 97 98 private: 99 // Common initialization for TensorFlow and MKL formats 100 void Init(OpKernelContext* context, const std::vector<int32>& ksize, 101 const std::vector<int32>& stride, Padding padding, 102 TensorFormat data_format); 103 }; 104 105 #ifndef INTEL_MKL_ML 106 107 template <class T> 108 class MklPoolingOpBase : public OpKernel { 109 public: 110 explicit MklPoolingOpBase(OpKernelConstruction* context) 111 : OpKernel(context), workspace_enabled_(false) { 112 string data_format; 113 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 114 OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_), 115 errors::InvalidArgument("Invalid data format")); 116 this->data_format_mkldnn_ = 117 TFDataFormatToMklDnnDataFormat(this->data_format_tf_); 118 OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_)); 119 OP_REQUIRES(context, this->ksize_.size() == 4, 120 errors::InvalidArgument("Sliding window ksize field must " 121 "specify 4 dimensions")); 122 OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_)); 123 OP_REQUIRES(context, this->stride_.size() == 4, 124 errors::InvalidArgument("Sliding window strides field must " 125 "specify 4 dimensions")); 126 OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_)); 127 OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1, 128 errors::Unimplemented("Pooling is not yet supported on the " 129 "batch dimension.")); 130 131 // We may not get this attribute for this node if it does not go through 132 // graph rewrite pass. So we do not check for error while retrieving this 133 // attribute value. 134 context->GetAttr("workspace_enabled", &this->workspace_enabled_); 135 } 136 void Compute(OpKernelContext* context) override = 0; 137 138 protected: 139 // Calculate output shape of pooling op in MKL-DNN and TensorFlow order. 140 // MKL-DNN uses NCHW for output order. But TensorFlow output will be in 141 // NHWC or NCHW format depending on data format. Function expects 142 // output height and output width to have already been int32 143 // bounds-checked 144 void GetOutputDims(const MklPoolParameters& mkl_pool_params, 145 memory::dims* output_dims_mkl_order) { 146 // MKL-DNN always needs output in NCHW format. 147 *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch, 148 mkl_pool_params.out_depth, 149 static_cast<int>(mkl_pool_params.out_height), 150 static_cast<int>(mkl_pool_params.out_width)}; 151 } 152 153 void InitMklPoolParameters(OpKernelContext* context, 154 MklPoolParameters* pool_params, 155 const MklDnnShape& original_input_mkl_shape, 156 const TensorShape& input_tensor_shape) { 157 if (!original_input_mkl_shape.IsMklTensor()) { 158 pool_params->Init(context, this->ksize_, this->stride_, this->padding_, 159 this->data_format_tf_, input_tensor_shape); 160 } else { 161 pool_params->Init(context, this->ksize_, this->stride_, this->padding_, 162 this->data_format_tf_, &original_input_mkl_shape); 163 } 164 } 165 166 // Checks to make sure that the memory we need to allocate 167 // is a multiple of sizeof(T) 168 // returns the number of elements 169 size_t GetNumTElements(const memory::primitive_desc& pd) { 170 size_t num_bytes = pd.get_size(); 171 size_t ret_val = num_bytes / sizeof(T); 172 if (num_bytes % sizeof(T) != 0) { 173 ret_val++; 174 } 175 return ret_val; 176 } 177 178 std::vector<int32> ksize_; 179 std::vector<int32> stride_; 180 Padding padding_; 181 TensorFormat data_format_tf_; 182 memory::format data_format_mkldnn_; 183 bool workspace_enabled_; 184 }; 185 186 template <class T> 187 class MklPoolingForwardOpBase : public MklPoolingOpBase<T> { 188 public: 189 explicit MklPoolingForwardOpBase<T>(OpKernelConstruction* context) 190 : MklPoolingOpBase<T>(context) {} 191 void Compute(OpKernelContext* context) override = 0; 192 193 protected: 194 void ConfigureInput(OpKernelContext* context, 195 const MklDnnShape& input_mkl_shape, 196 const Tensor& input_tensor, 197 MklPoolParameters* pool_params, 198 MklDnnData<T>* dnn_data_input) { 199 CHECK_NOTNULL(pool_params); 200 CHECK_NOTNULL(dnn_data_input); 201 TensorShape input_tensor_shape = input_tensor.shape(); 202 memory::desc input_md = 203 input_mkl_shape.IsMklTensor() 204 ? input_mkl_shape.GetMklLayout() 205 : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape, 206 this->data_format_tf_), 207 MklDnnType<T>(), this->data_format_mkldnn_); 208 dnn_data_input->SetUsrMem(input_md, &input_tensor); 209 this->InitMklPoolParameters(context, pool_params, input_mkl_shape, 210 input_tensor_shape); 211 } 212 213 void AllocateOutputTensor( 214 OpKernelContext* context, 215 const pooling_forward::primitive_desc& pool_fwd_prim_desc, 216 const memory::dims output_dims_mkl_order, 217 const memory::format& output_tf_format, Tensor** output_tensor) { 218 CHECK_NOTNULL(output_tensor); 219 memory::primitive_desc dst_pd = pool_fwd_prim_desc.dst_primitive_desc(); 220 221 MklDnnShape output_mkl_shape; 222 output_mkl_shape.SetMklTensor(true); 223 output_mkl_shape.SetMklLayout(&dst_pd); 224 output_mkl_shape.SetElemType(MklDnnType<T>()); 225 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 226 output_dims_mkl_order, output_tf_format); 227 TensorShape output_tf_shape; 228 229 // only allocate enough space for the elements we need. 230 output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); 231 AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, 232 output_tf_shape, output_mkl_shape); 233 CHECK_NOTNULL(*output_tensor); 234 } 235 236 void PrepareAndExecuteNet( 237 const pooling_forward::primitive_desc& pool_fwd_desc, 238 const MklDnnData<T>* src, MklDnnData<T>* dst, 239 MklDnnData<uint8>* wksp = nullptr) { 240 std::vector<primitive> net; 241 242 // Create pooling primitive and add it to net 243 if (wksp != nullptr) { 244 net.push_back(pooling_forward(pool_fwd_desc, src->GetOpMem(), 245 dst->GetOpMem(), wksp->GetOpMem())); 246 } else { 247 net.push_back( 248 pooling_forward(pool_fwd_desc, src->GetOpMem(), dst->GetOpMem())); 249 } 250 stream(stream::kind::eager).submit(net).wait(); 251 } 252 253 void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor, 254 const MklDnnShape& input_mkl_shape) { 255 if (!input_mkl_shape.IsMklTensor()) { 256 OP_REQUIRES(context, input_tensor.dims() == 4, 257 errors::InvalidArgument("Input must be 4-dimensional")); 258 } else { 259 OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4, 260 errors::InvalidArgument("Input shape must be " 261 "4-dimensional")); 262 } 263 } 264 // .Input("value: T") 265 // .Output("output: T") 266 const int kInputTensorIndexInput = 0; 267 const int kOutputTensorIndexOutput = 0; 268 }; // MklPoolingForwardBaseOp 269 270 template <class T> 271 class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> { 272 public: 273 explicit MklPoolingBackwardOpBase<T>(OpKernelConstruction* context) 274 : MklPoolingOpBase<T>(context) {} 275 void Compute(OpKernelContext* context) override = 0; 276 277 protected: 278 const int kOutputTensorIndexOutput = 0; 279 280 void AllocateOutputTensor( 281 OpKernelContext* context, 282 const pooling_backward::primitive_desc& pool_bkwd_prim_desc, 283 const memory::dims output_dims_mkl_order, 284 const memory::format& output_tf_format, Tensor** output_tensor) { 285 CHECK_NOTNULL(output_tensor); 286 memory::primitive_desc dst_pd = 287 pool_bkwd_prim_desc.diff_src_primitive_desc(); 288 MklDnnShape output_mkl_shape; 289 output_mkl_shape.SetMklTensor(true); 290 output_mkl_shape.SetMklLayout(&dst_pd); 291 output_mkl_shape.SetElemType(MklDnnType<T>()); 292 output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), 293 output_dims_mkl_order, output_tf_format); 294 295 TensorShape output_tf_shape; 296 output_tf_shape.AddDim(this->GetNumTElements(dst_pd)); 297 AllocateOutputSetMklShape(context, kOutputTensorIndexOutput, output_tensor, 298 output_tf_shape, output_mkl_shape); 299 CHECK_NOTNULL(*output_tensor); 300 } 301 302 void PrepareAndExecuteNet( 303 const pooling_backward::primitive_desc& pool_bkwd_desc, 304 MklDnnData<T>* input_gradient_diff_dst, MklDnnData<T>* output_diff_src, 305 const memory::primitive_desc& target_diff_dst_pd, 306 const MklDnnData<uint8>* workspace = nullptr) { 307 std::vector<primitive> net; 308 309 // If the input gradient isn't in the same format as the output 310 // reorder it to the same format as the output 311 input_gradient_diff_dst->CheckReorderToOpMem(target_diff_dst_pd, &net); 312 313 // Create pooling primitive and add it to net 314 if (nullptr == workspace) { 315 net.push_back(pooling_backward(pool_bkwd_desc, 316 input_gradient_diff_dst->GetOpMem(), 317 output_diff_src->GetOpMem())); 318 } else { 319 net.push_back( 320 pooling_backward(pool_bkwd_desc, input_gradient_diff_dst->GetOpMem(), 321 workspace->GetOpMem(), output_diff_src->GetOpMem())); 322 } 323 stream(stream::kind::eager).submit(net).wait(); 324 } 325 326 // Max Pooling and Avg Pooling have slightly different implementations 327 // Takes the Tensor containing original input data and the original 328 // mkl Dnn Shape and populates other data 329 memory::desc ConfigureOriginalInput( 330 OpKernelContext* context, const Tensor& tensor_original_input_shape, 331 const MklDnnShape& original_input_mkl_shape, 332 memory::dims* original_input_dims_nchw, MklPoolParameters* pool_params, 333 const TensorShape& input_tensor_shape) { 334 CHECK_NOTNULL(original_input_dims_nchw); 335 CHECK_NOTNULL(pool_params); 336 this->InitMklPoolParameters(context, pool_params, original_input_mkl_shape, 337 input_tensor_shape); 338 339 *original_input_dims_nchw = 340 original_input_mkl_shape.IsMklTensor() 341 ? original_input_mkl_shape.GetSizesAsMklDnnDims() 342 : TFShapeToMklDnnDimsInNCHW(input_tensor_shape, 343 this->data_format_tf_); 344 345 return original_input_mkl_shape.IsMklTensor() 346 ? original_input_mkl_shape.GetMklLayout() 347 : memory::desc(*original_input_dims_nchw, MklDnnType<T>(), 348 this->data_format_mkldnn_); 349 } 350 351 memory::desc ConfigureOriginalOutput( 352 const MklPoolParameters& pool_params, 353 const MklDnnShape& original_output_mkl_shape, 354 memory::dims output_dims_mkl_order) { 355 this->GetOutputDims(pool_params, &output_dims_mkl_order); 356 357 return original_output_mkl_shape.IsMklTensor() 358 ? original_output_mkl_shape.GetMklLayout() 359 : memory::desc(output_dims_mkl_order, MklDnnType<T>(), 360 this->data_format_mkldnn_); 361 } 362 363 memory::desc ConfigureInputGradient( 364 const MklDnnShape& input_gradient_mkl_shape, 365 const Tensor& input_gradient_tensor, 366 MklDnnData<T>* input_gradient_dnn_data, 367 const memory::desc& original_output_md) { 368 // Configure the gradient as is 369 memory::desc original_input_grad_md = 370 input_gradient_mkl_shape.IsMklTensor() 371 ? input_gradient_mkl_shape.GetMklLayout() 372 : memory::desc( 373 TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), 374 this->data_format_tf_), 375 MklDnnType<T>(), this->data_format_mkldnn_); 376 377 input_gradient_dnn_data->SetUsrMem(original_input_grad_md, 378 &input_gradient_tensor); 379 380 // Check to see if input grad diff dst is in the right format 381 // Create a new memory descriptor with the same shape as the 382 // original, but the format of the other tensors. 383 memory::format original_output_format = 384 static_cast<memory::format>(original_output_md.data.format); 385 bool grad_reorder_needed = 386 input_gradient_dnn_data->IsReorderNeeded(original_output_format); 387 memory::dims diff_dst_dims = 388 input_gradient_mkl_shape.IsMklTensor() 389 ? input_gradient_mkl_shape.GetSizesAsMklDnnDims() 390 : TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(), 391 this->data_format_tf_); 392 memory::desc target_diff_dst_md = 393 memory::desc(diff_dst_dims, MklDnnType<T>(), original_output_format); 394 395 return grad_reorder_needed ? target_diff_dst_md : original_input_grad_md; 396 } 397 }; 398 #endif // INTEL_MKL_ML 399 400 //------------------------------------------------------------------- 401 // Utility functions 402 403 typedef struct { 404 size_t in_dim; 405 size_t in_sizes[4]; 406 size_t in_strides[4]; 407 size_t out_sizes[4]; 408 size_t out_strides[4]; 409 int in_offset[4]; 410 size_t kernel_stride[2]; 411 size_t kernel_size[2]; 412 } MklPoolingOpParams; 413 414 // Transfers the right parameters for pooling to the op parameters 415 // Updates context->status if there is an invalid input. 416 void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format, 417 const MklPoolParameters& params, 418 MklPoolingOpParams* mkl_params); 419 } // namespace tensorflow 420 421 #endif // INTEL_MKL 422 #endif // TENSORFLOW_CORE_KERNELS_MKL_POOLING_OPS_COMMON_H_ 423