1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/nn_ops.cc. 17 18 #ifdef INTEL_MKL 19 20 #include <algorithm> 21 #include <vector> 22 23 #include "tensorflow/core/framework/numeric_op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/register_types.h" 26 #include "tensorflow/core/framework/tensor.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 #include "tensorflow/core/framework/tensor_slice.h" 29 #include "tensorflow/core/kernels/conv_grad_ops.h" 30 #include "tensorflow/core/kernels/mkl_conv_ops.h" 31 #include "tensorflow/core/kernels/ops_util.h" 32 #include "tensorflow/core/lib/core/errors.h" 33 #include "tensorflow/core/lib/gtl/array_slice.h" 34 #include "tensorflow/core/platform/logging.h" 35 #include "tensorflow/core/platform/macros.h" 36 #include "tensorflow/core/util/padding.h" 37 #include "tensorflow/core/util/tensor_format.h" 38 #include "tensorflow/core/util/use_cudnn.h" 39 #include "tensorflow/core/util/work_sharder.h" 40 41 #include "mkl_dnn.h" 42 #include "mkl_dnn_types.h" 43 #include "tensorflow/core/util/mkl_util.h" 44 45 #ifndef INTEL_MKL_ML 46 #include "mkldnn.hpp" 47 48 using mkldnn::convolution_backward_weights; 49 using mkldnn::memory; 50 using mkldnn::prop_kind; 51 using mkldnn::stream; 52 #endif 53 54 namespace tensorflow { 55 56 typedef Eigen::ThreadPoolDevice CPUDevice; 57 58 #ifdef INTEL_MKL_ML 59 60 template <typename Device, class T> 61 class MklConv2DCustomBackpropFilterOp : public OpKernel { 62 public: 63 explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context) 64 : OpKernel(context) { 65 string data_format; 66 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); 67 OP_REQUIRES(context, FormatFromString(data_format, &data_format_), 68 errors::InvalidArgument("Invalid data format")); 69 70 OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); 71 int stride_n = GetTensorDim(strides_, data_format_, 'N'); 72 int stride_c = GetTensorDim(strides_, data_format_, 'C'); 73 OP_REQUIRES( 74 context, (stride_n == 1 && stride_c == 1), 75 errors::InvalidArgument("Current implementation does not yet support " 76 "strides in the batch and depth dimensions.")); 77 OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); 78 } 79 80 void Compute(OpKernelContext* context) override { 81 MklConv2DGradFilterOpContext mkl_context; 82 const Tensor& input = MklGetInput(context, 0); 83 GetMklShape(context, 0, &(mkl_context.input_shape)); 84 bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor(); 85 86 const Tensor& filter_sizes = MklGetInput(context, 1); 87 88 const Tensor& out_backprop = MklGetInput(context, 2); 89 GetMklShape(context, 2, &(mkl_context.out_backprop_shape)); 90 bool out_backprop_in_mkl_format = 91 mkl_context.out_backprop_shape.IsMklTensor(); 92 93 TensorShape input_shape, filter_shape, out_backprop_shape; 94 95 OP_REQUIRES( 96 context, TensorShapeUtils::IsVector(filter_sizes.shape()), 97 errors::InvalidArgument( 98 "Conv2DCustomBackpropFilter: filter_sizes input must be 1-dim, " 99 "not ", 100 filter_sizes.dims())); 101 OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( 102 filter_sizes.vec<int32>(), &filter_shape)); 103 104 ConvBackpropDimensions backprop_dims; 105 106 // Generate shape for input if input is in MKL format. 107 if (input_in_mkl_format) { 108 OP_REQUIRES(context, mkl_context.input_shape.GetDimension() == 4, 109 errors::InvalidArgument( 110 "Conv2DCustomBackpropFilter: input size must be 4-dim")); 111 112 MklSizesToTFSizes(context, data_format_, mkl_context.input_shape, 113 &input_shape); 114 } else { 115 input_shape = input.shape(); 116 } 117 118 // Generate shape for outback prop if input is in MKL format. 119 if (out_backprop_in_mkl_format) { 120 OP_REQUIRES( 121 context, mkl_context.out_backprop_shape.GetDimension() == 4, 122 errors::InvalidArgument( 123 "Conv2DCustomBackpropFilter: outbackprop size must be 4-dim")); 124 125 MklSizesToTFSizes(context, data_format_, mkl_context.out_backprop_shape, 126 &out_backprop_shape); 127 } else { 128 out_backprop_shape = out_backprop.shape(); 129 } 130 131 OP_REQUIRES_OK(context, 132 ConvBackpropComputeDimensions( 133 "Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, 134 input_shape, filter_shape, out_backprop_shape, strides_, 135 padding_, data_format_, &backprop_dims)); 136 137 int64 pad_top, pad_bottom; 138 int64 pad_left, pad_right; 139 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 140 backprop_dims.spatial_dims[0].input_size, 141 backprop_dims.spatial_dims[0].filter_size, 142 backprop_dims.spatial_dims[0].stride, padding_, 143 &backprop_dims.spatial_dims[0].output_size, 144 &pad_top, &pad_bottom)); 145 OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( 146 backprop_dims.spatial_dims[1].input_size, 147 backprop_dims.spatial_dims[1].filter_size, 148 backprop_dims.spatial_dims[1].stride, padding_, 149 &backprop_dims.spatial_dims[1].output_size, 150 &pad_left, &pad_right)); 151 152 // Create MKL primitives for convolution filter grad 153 mkl_context.in_dims = input_in_mkl_format 154 ? mkl_context.input_shape.GetDimension() 155 : input.dims(); 156 mkl_context.out_dims = out_backprop_in_mkl_format 157 ? mkl_context.out_backprop_shape.GetDimension() 158 : out_backprop.dims(); 159 mkl_context.in_sizes[0] = 160 static_cast<size_t>(backprop_dims.spatial_dims[1].input_size); 161 mkl_context.in_sizes[1] = 162 static_cast<size_t>(backprop_dims.spatial_dims[0].input_size); 163 mkl_context.in_sizes[2] = static_cast<size_t>(backprop_dims.in_depth); 164 mkl_context.in_sizes[3] = static_cast<size_t>(backprop_dims.batch_size); 165 mkl_context.out_sizes[0] = 166 static_cast<size_t>(backprop_dims.spatial_dims[1].output_size); 167 mkl_context.out_sizes[1] = 168 static_cast<size_t>(backprop_dims.spatial_dims[0].output_size); 169 mkl_context.out_sizes[2] = static_cast<size_t>(backprop_dims.out_depth); 170 mkl_context.out_sizes[3] = static_cast<size_t>(backprop_dims.batch_size); 171 mkl_context.input_offsets[0] = static_cast<int>(-pad_left); 172 mkl_context.input_offsets[1] = static_cast<int>(-pad_top); 173 mkl_context.conv_strides[0] = 174 static_cast<size_t>(backprop_dims.spatial_dims[1].stride); 175 mkl_context.conv_strides[1] = 176 static_cast<size_t>(backprop_dims.spatial_dims[0].stride); 177 178 GetStridesFromSizes(data_format_, mkl_context.in_strides, 179 mkl_context.in_sizes); 180 GetStridesFromSizes(data_format_, mkl_context.out_strides, 181 mkl_context.out_sizes); 182 183 // MKL understands dimensions in 0, 1, 2, and 3 indices denotes 184 // filter cols, rows, input channels, and output depth/channels. 185 mkl_context.filter_dims = 4; 186 mkl_context.filter_sizes[0] = backprop_dims.spatial_dims[1].filter_size; 187 mkl_context.filter_sizes[1] = backprop_dims.spatial_dims[0].filter_size; 188 mkl_context.filter_sizes[2] = backprop_dims.in_depth; 189 mkl_context.filter_sizes[3] = backprop_dims.out_depth; 190 191 // We want filter grad to be in TF format, so 192 // make the strides accordingly to reflect this fact. 193 // Note TF filter layout : (rows, cols, in_depth, out_depth), 194 // while row is the innermost dimension. 195 mkl_context.filter_strides[0] = 196 backprop_dims.out_depth * backprop_dims.in_depth; 197 mkl_context.filter_strides[1] = backprop_dims.out_depth * 198 backprop_dims.in_depth * 199 backprop_dims.spatial_dims[1].filter_size; 200 mkl_context.filter_strides[2] = backprop_dims.out_depth; 201 mkl_context.filter_strides[3] = 1; 202 203 mkl_context.conv_strides[0] = backprop_dims.spatial_dims[1].stride; 204 mkl_context.conv_strides[1] = backprop_dims.spatial_dims[0].stride; 205 206 // Create convolution-grad-filter primitive 207 CHECK_EQ(dnnConvolutionCreateBackwardFilter_F32( 208 &mkl_context.prim_conv_bwdfilter, nullptr, 209 dnnAlgorithmConvolutionDirect, mkl_context.in_dims, 210 mkl_context.in_sizes, mkl_context.out_sizes, 211 mkl_context.filter_sizes, mkl_context.conv_strides, 212 mkl_context.input_offsets, dnnBorderZeros), 213 E_SUCCESS); 214 215 // Create the layouts for entities in received context. 216 mkl_context.MklCreateInputLayouts(context); 217 218 // Mkl needs the entities in its native format. 219 // So create temporary tensors along with buffers to 220 // convert the received entities. 221 Tensor mkl_tmp_input_buf_tensor, mkl_tmp_out_backprop_buf_tensor; 222 // This preparation sets (1) dnnResourceSrc (2) dnnResourceDiffDst 223 mkl_context.MklPrepareInputs(context, &mkl_tmp_input_buf_tensor, 224 &mkl_tmp_out_backprop_buf_tensor); 225 226 // Final conv-grad-filter should be in TF layout. 227 Tensor* grad_filter; 228 mkl_context.grad_filter_shape.SetMklTensor(false); 229 mkl_context.grad_filter_shape.SetTfLayout(mkl_context.filter_dims, 230 mkl_context.filter_sizes, 231 mkl_context.filter_strides); 232 AllocateOutputSetMklShape(context, 0, &grad_filter, filter_shape, 233 mkl_context.grad_filter_shape); 234 235 // Need to set member variable for TF layout 236 mkl_context.lt_grad_filter = mkl_context.grad_filter_shape.GetTfLayout(); 237 238 // MKL conv-grad-filter might produce grad in its internal layout 239 Tensor mkl_tmp_grad_filter_buf_tensor; 240 // This preparation sets conversion primitive if required 241 // and allocates temporary tensor and its buffer without doing conversions. 242 // Also sets (3) dnnResourceDiffFilter accordingly 243 mkl_context.MklPrepareGradFilter(context, grad_filter, 244 &mkl_tmp_grad_filter_buf_tensor); 245 246 // After setting all the required dnnResources, ready for execution! 247 CHECK_EQ( 248 dnnExecute_F32(mkl_context.prim_conv_bwdfilter, mkl_context.conv_res), 249 E_SUCCESS); 250 251 // Convert grad-filter to TF layout 252 if (mkl_context.convert_bwdfilter != nullptr) { 253 void* mkl_buf_convert_grad_filter = 254 const_cast<void*>(static_cast<const void*>( 255 mkl_tmp_grad_filter_buf_tensor.flat<T>().data())); 256 void* mkl_buf_grad_filter = const_cast<void*>( 257 static_cast<const void*>(grad_filter->flat<T>().data())); 258 CHECK_EQ(dnnConversionExecute_F32(mkl_context.convert_bwdfilter, 259 mkl_buf_convert_grad_filter, 260 mkl_buf_grad_filter), 261 E_SUCCESS); 262 } 263 264 mkl_context.MklCleanup(); 265 } 266 267 private: 268 typedef struct { 269 int in_dims; 270 size_t in_sizes[4]; 271 size_t in_strides[4]; 272 int out_dims; 273 size_t out_sizes[4]; 274 size_t out_strides[4]; 275 int filter_dims; 276 size_t filter_sizes[4]; 277 size_t filter_strides[4]; 278 int input_offsets[2]; 279 size_t conv_strides[2]; 280 MklShape input_shape, grad_filter_shape, out_backprop_shape; 281 dnnPrimitive_t prim_conv_bwdfilter = nullptr; 282 dnnPrimitive_t convert_bwdfilter = nullptr; 283 dnnLayout_t lt_input = nullptr; 284 dnnLayout_t lt_grad_filter = nullptr; 285 dnnLayout_t lt_out_backprop = nullptr; 286 void* conv_res[dnnResourceNumber]; 287 288 void MklCleanup() { 289 // Cleanup member layouts and primitives except "lt_grad_filter_" 290 // which points to MklShape's TFLayout 291 bool input_in_mkl_format = input_shape.IsMklTensor(); 292 bool out_backprop_in_mkl_format = out_backprop_shape.IsMklTensor(); 293 if (!input_in_mkl_format) dnnLayoutDelete_F32(lt_input); 294 if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(lt_out_backprop); 295 if (convert_bwdfilter != nullptr) dnnDelete_F32(convert_bwdfilter); 296 dnnDelete_F32(prim_conv_bwdfilter); 297 } 298 299 // Create MKL dnnLayout_t objects for tensors coming into the layer 300 void MklCreateInputLayouts(OpKernelContext* context) { 301 bool input_in_mkl_format = input_shape.IsMklTensor(); 302 if (input_in_mkl_format) { 303 lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout()); 304 } else { 305 CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides), 306 E_SUCCESS); 307 } 308 309 bool out_backprop_in_mkl_format = out_backprop_shape.IsMklTensor(); 310 if (out_backprop_in_mkl_format) { 311 lt_out_backprop = 312 static_cast<dnnLayout_t>(out_backprop_shape.GetCurLayout()); 313 } else { 314 CHECK_EQ(dnnLayoutCreate_F32(<_out_backprop, out_dims, out_sizes, 315 out_strides), 316 E_SUCCESS); 317 } 318 } 319 320 // Compare incoming tensor layouts with MKL preferred layouts and convert 321 // data to the preferred layout if necessary 322 void MklPrepareInputs(OpKernelContext* context, 323 Tensor* mkl_tmp_input_buf_tensor, 324 Tensor* mkl_tmp_out_backprop_buf_tensor) { 325 bool mkl_convert_input, mkl_convert_out_backprop; 326 dnnPrimitive_t mkl_prim_convert_input, mkl_prim_convert_out_backprop; 327 dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_out_backprop; 328 void *mkl_buf_convert_input, *mkl_buf_convert_out_backprop; 329 330 mkl_prim_convert_input = nullptr; 331 mkl_prim_convert_out_backprop = nullptr; 332 mkl_lt_internal_input = nullptr; 333 mkl_lt_internal_out_backprop = nullptr; 334 mkl_buf_convert_input = nullptr; 335 mkl_buf_convert_out_backprop = nullptr; 336 337 // Compare with internal layouts and convert if needed 338 const Tensor& input = MklGetInput(context, 0); 339 void* mkl_buf_input = 340 const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); 341 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 342 &mkl_lt_internal_input, prim_conv_bwdfilter, dnnResourceSrc), 343 E_SUCCESS); 344 mkl_convert_input = 345 !dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input); 346 if (mkl_convert_input) { 347 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, lt_input, 348 mkl_lt_internal_input), 349 E_SUCCESS); 350 AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, 351 &mkl_buf_convert_input); 352 CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, 353 mkl_buf_convert_input), 354 E_SUCCESS); 355 dnnDelete_F32(mkl_prim_convert_input); 356 } 357 dnnLayoutDelete_F32(mkl_lt_internal_input); 358 359 conv_res[dnnResourceSrc] = 360 (mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input; 361 362 const Tensor& out_backprop = MklGetInput(context, 2); 363 void* mkl_buf_out_backprop = const_cast<void*>( 364 static_cast<const void*>(out_backprop.flat<T>().data())); 365 366 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop, 367 prim_conv_bwdfilter, 368 dnnResourceDiffDst), 369 E_SUCCESS); 370 mkl_convert_out_backprop = 371 !dnnLayoutCompare_F32(mkl_lt_internal_out_backprop, lt_out_backprop); 372 if (mkl_convert_out_backprop) { 373 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop, 374 lt_out_backprop, 375 mkl_lt_internal_out_backprop), 376 E_SUCCESS); 377 AllocTmpBuffer(context, mkl_tmp_out_backprop_buf_tensor, 378 lt_out_backprop, &mkl_buf_convert_out_backprop); 379 CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop, 380 mkl_buf_out_backprop, 381 mkl_buf_convert_out_backprop), 382 E_SUCCESS); 383 dnnDelete_F32(mkl_prim_convert_out_backprop); 384 } 385 dnnLayoutDelete_F32(mkl_lt_internal_out_backprop); 386 387 conv_res[dnnResourceDiffDst] = (mkl_convert_out_backprop) 388 ? mkl_buf_convert_out_backprop 389 : mkl_buf_out_backprop; 390 } 391 392 void MklPrepareGradFilter(OpKernelContext* context, Tensor* grad_filter, 393 Tensor* mkl_tmp_grad_filter_buf_tensor) { 394 bool mkl_convert_grad_filter; 395 dnnLayout_t mkl_lt_internal_grad_filter = nullptr; 396 void* mkl_buf_convert_grad_filter = nullptr; 397 void* mkl_buf_grad_filter = const_cast<void*>( 398 static_cast<const void*>(grad_filter->flat<T>().data())); 399 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_grad_filter, 400 prim_conv_bwdfilter, 401 dnnResourceDiffFilter), 402 E_SUCCESS); 403 mkl_convert_grad_filter = 404 !dnnLayoutCompare_F32(mkl_lt_internal_grad_filter, lt_grad_filter); 405 if (mkl_convert_grad_filter) { 406 CHECK_EQ(dnnConversionCreate_F32(&convert_bwdfilter, 407 mkl_lt_internal_grad_filter, 408 lt_grad_filter), 409 E_SUCCESS); 410 AllocTmpBuffer(context, mkl_tmp_grad_filter_buf_tensor, 411 mkl_lt_internal_grad_filter, 412 &mkl_buf_convert_grad_filter); 413 } 414 dnnLayoutDelete_F32(mkl_lt_internal_grad_filter); 415 416 conv_res[dnnResourceDiffFilter] = (mkl_convert_grad_filter) 417 ? mkl_buf_convert_grad_filter 418 : mkl_buf_grad_filter; 419 } 420 } MklConv2DGradFilterOpContext; 421 422 std::vector<int32> strides_; 423 Padding padding_; 424 TensorFormat data_format_; 425 }; 426 427 #define REGISTER_MKL_FILTER_KERNELS(T) \ 428 REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \ 429 .Device(DEVICE_CPU) \ 430 .TypeConstraint<T>("T") \ 431 .Label(mkl_op_registry::kMklOpLabel), \ 432 MklConv2DCustomBackpropFilterOp<CPUDevice, T>); 433 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); 434 #undef REGISTER_MKL_FILTER_KERNELS 435 436 #else 437 438 template <typename Device, class T, bool biasEnabled> 439 class MklConv2DCustomBackpropFilterOp 440 : public MklConv2DBackpropCommonOp<Device, T> { 441 public: 442 explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context) 443 : MklConv2DBackpropCommonOp<Device, T>(context) {} 444 ~MklConv2DCustomBackpropFilterOp() {} 445 446 private: 447 void ValidateMklShapes(const MklDnnShape& input_mkl_shape, 448 const MklDnnShape& filter_mkl_shape, 449 const MklDnnShape& obp_mkl_shape) { 450 CHECK(!filter_mkl_shape.IsMklTensor()) 451 << "Conv2DBackpropFilter: filter should not be in MKL Layout"; 452 } 453 454 size_t GetInputTensorIndexWithSizes() { return 1; /* filter index */ } 455 456 TensorShape MakeInputTfShape(OpKernelContext* context, 457 const Tensor& input_tensor) { 458 size_t input_idx = 0; 459 return GetTfShape(context, input_idx); 460 } 461 462 TensorShape MakeFilterTfShape(OpKernelContext* context, 463 const Tensor& filter_tensor) { 464 TensorShape filter_tf_shape; 465 CHECK_EQ(TensorShapeUtils::IsVector(filter_tensor.shape()), true); 466 CHECK_EQ(TensorShapeUtils::MakeShape(filter_tensor.vec<int32>(), 467 &filter_tf_shape) 468 .ok(), 469 true); 470 return filter_tf_shape; 471 } 472 473 TensorShape GetOutputTfShape(const TensorShape& input_shape, 474 const TensorShape& filter_shape, 475 const TensorShape& outbprop_shape) { 476 // Shape of output of Conv2DBackpropFilter is same as shape of filter. 477 return filter_shape; 478 } 479 480 const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, 481 const memory::dims& fwd_filter_dims) { 482 // Shape of output of Conv2DBackpropFilter is same as shape of filter. 483 return fwd_filter_dims; 484 } 485 486 memory::format GetOutputFormat(const memory::format data_format) { 487 // Output layout is Tensorflow's filter layout (HWIO). 488 return memory::format::hwio; 489 } 490 491 void CreatePrimitive(OpKernelContext* context, const engine& cpu_engine, 492 const convolution_forward::primitive_desc& conv_fwd_pd, 493 MklDnnData<T>* input, MklDnnData<T>* filter, 494 MklDnnData<T>* outbackprop, MklDnnData<T>* output, 495 Tensor** output_tensor, const memory::dims& strides, 496 const memory::dims& padding_l, 497 const memory::dims& padding_r, padding_kind padding, 498 const memory::dims& bwd_output_dims, 499 memory::format bwd_output_format) { 500 CHECK_NOTNULL(context); 501 CHECK_NOTNULL(input); 502 CHECK_NOTNULL(filter); 503 CHECK_NOTNULL(outbackprop); 504 CHECK_NOTNULL(output); 505 CHECK_NOTNULL(output_tensor); 506 507 MklDnnData<T>* bias_grad = nullptr; 508 int depth = 0; 509 if (biasEnabled) { 510 // Data structure for bias_grad 511 bias_grad = new MklDnnData<T>(&cpu_engine); 512 TensorShape obp_tf_shape = GetTfShape(context, 2); 513 depth = (MklConv2DBackpropCommonOp<Device, T>::GetTFDataFormat() == 514 FORMAT_NCHW) 515 ? obp_tf_shape.dim_size(1) 516 : obp_tf_shape.dim_size(3); 517 memory::dims bias_grad_dims = {depth}; 518 bias_grad->SetOpMemDesc(bias_grad_dims, memory::format::x); 519 } 520 521 // Create convolution backward weights primitive. 522 auto bwd_desc = 523 (biasEnabled && (bias_grad != nullptr)) 524 ? convolution_backward_weights::desc( 525 convolution_direct, input->GetOpMemDesc(), 526 output->GetOpMemDesc(), bias_grad->GetOpMemDesc(), 527 outbackprop->GetOpMemDesc(), strides, padding_l, padding_r, 528 padding) 529 : convolution_backward_weights::desc( 530 convolution_direct, input->GetOpMemDesc(), 531 output->GetOpMemDesc(), outbackprop->GetOpMemDesc(), strides, 532 padding_l, padding_r, padding); 533 534 auto bwd_pd = convolution_backward_weights::primitive_desc( 535 bwd_desc, cpu_engine, conv_fwd_pd); 536 537 // Allocate output tensor. 538 AllocateOutputTensor(context, bwd_pd, bwd_output_dims, bwd_output_format, 539 output_tensor); 540 541 CHECK_NOTNULL(*output_tensor); 542 // Set buffer handle using allocated output tensor. 543 output->SetUsrMemDataHandle(*output_tensor); 544 545 if (biasEnabled && (bias_grad != nullptr)) { 546 // Allocate bias_grad tensor 547 TensorShape bias_grad_shape({depth}); 548 Tensor* bias_grad_tensor = nullptr; 549 AllocateBiasGradTensor(context, bias_grad_shape, &bias_grad_tensor); 550 memory::dims bias_grad_dims = {depth}; 551 // Since Bias is 1D, we use format::x from MKLDNN to represent it. 552 auto bias_grad_md = 553 memory::desc({bias_grad_dims}, MklDnnType<T>(), memory::format::x); 554 bias_grad->SetUsrMem(bias_grad_md, bias_grad_tensor); 555 bias_grad->SetUsrMemDataHandle(bias_grad_tensor); 556 } 557 558 if (biasEnabled && (bias_grad != nullptr)) { 559 PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output, bias_grad); 560 } else { 561 PrepareAndExecutePrimitive(bwd_pd, input, outbackprop, output); 562 } 563 } 564 565 // Allocate output tensor. 566 void AllocateOutputTensor( 567 OpKernelContext* context, 568 const convolution_backward_weights::primitive_desc& conv_pd, 569 const memory::dims& output_dims_mkl_order, 570 memory::format output_tf_format, Tensor** output_tensor) { 571 CHECK_NOTNULL(output_tensor); 572 573 // For BackpropFilter, we convert the output tensor back in Tensorflow 574 // layout. Because typically, BackpropFilter is the last operator in the 575 // graph that emit filter gradient that is provided to ApplyGradient 576 // method to update the filter. But it may be possible to eliminate this 577 // by forwarding filter in MKL layout if we support ApplyGradient method 578 // for MKL layout propagation. 579 MklDnnShape output_mkl_shape; 580 output_mkl_shape.SetMklTensor(false); 581 // output_dims_mkl_order is in OIHW format. 582 // Allocate shape of TF tensor in HWIO format. 583 TensorShape output_tf_shape({output_dims_mkl_order[MklDnnDims::Dim_H], 584 output_dims_mkl_order[MklDnnDims::Dim_W], 585 output_dims_mkl_order[MklDnnDims::Dim_I], 586 output_dims_mkl_order[MklDnnDims::Dim_O]}); 587 AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, 588 output_mkl_shape); 589 } 590 591 // Allocate tensor for bias grad 592 void AllocateBiasGradTensor(OpKernelContext* context, 593 const TensorShape& bias_grad_shape, 594 Tensor** bias_grad_tensor) { 595 CHECK_NOTNULL(bias_grad_tensor); 596 597 MklDnnShape bias_grad_mkl_shape; 598 bias_grad_mkl_shape.SetMklTensor(false); 599 AllocateOutputSetMklShape(context, 1, bias_grad_tensor, bias_grad_shape, 600 bias_grad_mkl_shape); 601 } 602 603 // Prepare and execute net - checks for input and output reorders. 604 void PrepareAndExecutePrimitive( 605 const convolution_backward_weights::primitive_desc& conv_pd, 606 MklDnnData<T>* input, MklDnnData<T>* obp, MklDnnData<T>* output, 607 MklDnnData<T>* bias_grad = nullptr) { 608 // Create reorders between user layout and MKL layout if it is needed and 609 // add it to the net before convolution. 610 std::vector<primitive> net; 611 input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net); 612 obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net); 613 614 // For BackpropFilter, we convert the output tensor back in Tensorflow 615 // layout. 616 bool output_reorder_required = output->PrepareReorderToUserMemIfReq( 617 conv_pd.diff_weights_primitive_desc()); 618 619 if (biasEnabled && (bias_grad != nullptr)) { 620 net.push_back(convolution_backward_weights( 621 conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem(), 622 bias_grad->GetOpMem())); 623 } else { 624 net.push_back(convolution_backward_weights( 625 conv_pd, input->GetOpMem(), obp->GetOpMem(), output->GetOpMem())); 626 } 627 628 if (output_reorder_required) { 629 output->InsertReorderToUserMem(&net); 630 } 631 632 stream(stream::kind::eager).submit(net).wait(); 633 } 634 }; 635 636 #define REGISTER_MKL_FILTER_KERNELS(T) \ 637 REGISTER_KERNEL_BUILDER( \ 638 Name("_MklConv2DBackpropFilter") \ 639 .Device(DEVICE_CPU) \ 640 .TypeConstraint<T>("T") \ 641 .Label(mkl_op_registry::kMklOpLabel), \ 642 MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>); \ 643 REGISTER_KERNEL_BUILDER( \ 644 Name("_MklConv2DBackpropFilterWithBias") \ 645 .Device(DEVICE_CPU) \ 646 .TypeConstraint<T>("T") \ 647 .Label(mkl_op_registry::kMklOpLabel), \ 648 MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>); \ 649 REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \ 650 .Device(DEVICE_CPU) \ 651 .TypeConstraint<T>("T") \ 652 .Label(mkl_op_registry::kMklOpLabel), \ 653 MklDummyOp<CPUDevice, T>); 654 655 TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); 656 #undef REGISTER_MKL_FILTER_KERNELS 657 658 #endif // INTEL_MKL_ML 659 660 } // namespace tensorflow 661 662 #endif // INTEL_MKL 663