1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifdef INTEL_MKL 16 17 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 18 #include "tensorflow/core/framework/op_kernel.h" 19 #include "tensorflow/core/framework/register_types.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/util/tensor_format.h" 23 24 #include "mkl_dnn.h" 25 #include "mkl_dnn_types.h" 26 #include "tensorflow/core/util/mkl_util.h" 27 28 #ifndef INTEL_MKL_ML 29 #include "mkldnn.hpp" 30 31 using mkldnn::batch_normalization_backward; 32 using mkldnn::batch_normalization_forward; 33 using mkldnn::prop_kind; 34 using mkldnn::stream; 35 using mkldnn::use_global_stats; 36 using mkldnn::use_scale_shift; 37 #endif 38 39 // TODO(inteltf) Address comments from PR 8968. 40 41 namespace tensorflow { 42 using CPUDevice = Eigen::ThreadPoolDevice; 43 44 #ifdef INTEL_MKL_ML 45 46 template <typename Device, typename T> 47 class MklFusedBatchNormOp : public OpKernel { 48 public: 49 explicit MklFusedBatchNormOp(OpKernelConstruction* context) 50 : OpKernel(context) { 51 float epsilon; 52 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 53 epsilon_ = T(epsilon); 54 string tensor_format; 55 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 56 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 57 errors::InvalidArgument("Invalid data format")); 58 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 59 } 60 61 void Compute(OpKernelContext* context) override { 62 MklFusedBatchNormOpContext mkl_context; 63 const Tensor& input = MklGetInput(context, 0); 64 const Tensor& scale = MklGetInput(context, 1); 65 const Tensor& shift = MklGetInput(context, 2); 66 const Tensor& est_mean = MklGetInput(context, 3); 67 const Tensor& est_variance = MklGetInput(context, 4); 68 69 GetMklShape(context, 0, &(mkl_context.mkl_shape_input_shape)); 70 bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor(); 71 72 if (!input_in_mkl_format) { 73 OP_REQUIRES(context, input.dims() == 4, 74 errors::InvalidArgument("input must be 4-dimensional", 75 input.shape().DebugString())); 76 } 77 OP_REQUIRES(context, scale.dims() == 1, 78 errors::InvalidArgument("scale must be 1-dimensional", 79 scale.shape().DebugString())); 80 OP_REQUIRES(context, shift.dims() == 1, 81 errors::InvalidArgument("offset must be 1-dimensional", 82 shift.shape().DebugString())); 83 OP_REQUIRES(context, est_mean.dims() == 1, 84 errors::InvalidArgument("estimated_mean must be 1-dimensional", 85 est_mean.shape().DebugString())); 86 87 OP_REQUIRES( 88 context, est_variance.dims() == 1, 89 errors::InvalidArgument("estimated_variance must be 1-dimensional", 90 est_variance.shape().DebugString())); 91 92 if (is_training_) { 93 OP_REQUIRES(context, est_mean.dim_size(0) == 0, 94 errors::InvalidArgument("estimated_mean empty for training", 95 est_mean.shape().DebugString())); 96 OP_REQUIRES(context, est_variance.dim_size(0) == 0, 97 errors::InvalidArgument( 98 "estimated_variance must be empty for training", 99 est_variance.shape().DebugString())); 100 } 101 102 unsigned int flag_batch_norm = 103 is_training_ ? dnnUseScaleShift 104 : (dnnUseInputMeanVariance | dnnUseScaleShift); 105 106 mkl_context.MklExtractParams(context, tensor_format_); 107 108 // Create layout only for input data as it is used in Op primitive. 109 mkl_context.MklCreateInputLayout(context); 110 111 // Create Op primitive. 112 CHECK_EQ(dnnBatchNormalizationCreateForward_v2_F32( 113 &(mkl_context.mkl_prim_batchnorm), nullptr, 114 mkl_context.mkl_lt_input, static_cast<float>(epsilon_), 115 flag_batch_norm), 116 E_SUCCESS); 117 118 // Temporary tensors with buffers for the context inputs, if 119 // conversion to MKL-Op specific layouts are required. It is assumed here 120 // that TF's 1D tensors (scale, shift, est_mean, and est_variance) won't 121 // require any conversion. 122 // Since scale-shift is combined in MKL, a buffer is required. 123 Tensor mkl_tmp_input_buf_tensor, mkl_tmp_scale_shift_buf_tensor; 124 mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor, 125 &mkl_tmp_scale_shift_buf_tensor); 126 127 // Output data in MKL layout 128 Tensor* output = nullptr; 129 TensorShape tf_shape_output; 130 MklShape mkl_shape_output; 131 mkl_shape_output.SetMklTensor(true); 132 mkl_shape_output.SetMklLayout(mkl_context.mkl_prim_batchnorm, 133 dnnResourceDst); 134 mkl_shape_output.SetTfLayout(mkl_context.mkl_params.in_dim, 135 mkl_context.mkl_params.in_sizes, 136 mkl_context.mkl_params.in_strides); 137 mkl_shape_output.SetTfDimOrder(mkl_context.mkl_params.in_dim, 138 tensor_format_); 139 tf_shape_output.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>( 140 mkl_shape_output.GetMklLayout())) / 141 sizeof(T)); 142 AllocateOutputSetMklShape(context, 0, &output, tf_shape_output, 143 mkl_shape_output); 144 mkl_context.mkl_res_batchnorm[dnnResourceDst] = 145 static_cast<void*>(output->flat<T>().data()); 146 147 // Batch mean in TF layout 148 Tensor* batch_mean = nullptr; 149 MklShape mkl_shape_batch_mean; 150 mkl_shape_batch_mean.SetMklTensor(false); 151 AllocateOutputSetMklShape(context, 1, &batch_mean, scale.shape(), 152 mkl_shape_batch_mean); 153 // Batch variance in TF layout 154 Tensor* batch_variance = nullptr; 155 MklShape mkl_shape_batch_variance; 156 mkl_shape_batch_variance.SetMklTensor(false); 157 AllocateOutputSetMklShape(context, 2, &batch_variance, scale.shape(), 158 mkl_shape_batch_variance); 159 // If training mode, set dnnResourceMean and dnnResourceVariance to 160 // output tensors for batch mean and variance. 161 // Otherwise, set dnnResourceMean and dnnResourceVariance to 162 // estimated mean and variance. 163 if (is_training_) 164 mkl_context.MklSetMeanVariance(*batch_mean, *batch_variance); 165 else 166 mkl_context.MklSetMeanVariance(est_mean, est_variance); 167 168 // Now that all resources are set, it is ready for dnnExecute 169 CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm, 170 mkl_context.mkl_res_batchnorm), 171 E_SUCCESS); 172 173 // Mean and variance (without Bessel's correction) saved for backward 174 // computation to serve as pre-computed mean and variance. 175 Tensor* saved_mean = nullptr; 176 MklShape mkl_shape_saved_mean; 177 mkl_shape_saved_mean.SetMklTensor(false); 178 AllocateOutputSetMklShape(context, 3, &saved_mean, scale.shape(), 179 mkl_shape_saved_mean); 180 std::memcpy( 181 reinterpret_cast<char*>(saved_mean->flat<float>().data()), 182 reinterpret_cast<char*>(mkl_context.mkl_res_batchnorm[dnnResourceMean]), 183 scale.NumElements() * sizeof(float)); 184 Tensor* saved_variance = nullptr; 185 MklShape mkl_shape_saved_variance; 186 mkl_shape_saved_variance.SetMklTensor(false); 187 AllocateOutputSetMklShape(context, 4, &saved_variance, scale.shape(), 188 mkl_shape_saved_variance); 189 std::memcpy(reinterpret_cast<char*>(saved_variance->flat<float>().data()), 190 reinterpret_cast<char*>( 191 mkl_context.mkl_res_batchnorm[dnnResourceVariance]), 192 scale.NumElements() * sizeof(float)); 193 194 // Bessel's correction on variance, if training mode is on 195 if (is_training_) { 196 float* p_var = static_cast<float*>(batch_variance->flat<T>().data()); 197 auto depth = mkl_context.mkl_params.depth; 198 size_t orig_size = mkl_context.mkl_params.in_sizes[0] * 199 mkl_context.mkl_params.in_sizes[1] * 200 mkl_context.mkl_params.in_sizes[3]; 201 size_t adjust_size = orig_size - 1; 202 float adjust_factor = (static_cast<float>(orig_size)) / adjust_size; 203 for (int i = 0; i < depth; i++) p_var[i] = adjust_factor * p_var[i]; 204 } 205 206 mkl_context.MklCleanup(); 207 } 208 209 private: 210 T epsilon_; 211 TensorFormat tensor_format_; 212 bool is_training_; 213 214 // Structure containing all info for MklOp 215 typedef struct { 216 // Parameters used for input and output layouts 217 struct MklBatchNormParams { 218 // BatchNormOp src and 219 size_t in_dim; 220 size_t in_sizes[4]; 221 size_t in_strides[4]; 222 size_t depth; // Batch normalization is done for per channel. 223 } mkl_params; 224 225 MklShape mkl_shape_input_shape; 226 227 // MKL primitive and resources for BatchNormOp 228 dnnPrimitive_t mkl_prim_batchnorm = nullptr; 229 void* mkl_res_batchnorm[dnnResourceNumber]; 230 231 // MKL layouts for inputs in the context 232 dnnLayout_t mkl_lt_input = nullptr; 233 234 void MklCleanup() { 235 bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); 236 if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input); 237 if (mkl_prim_batchnorm != nullptr) dnnDelete_F32(mkl_prim_batchnorm); 238 } 239 240 void MklExtractParams(OpKernelContext* context, 241 const TensorFormat& tensor_format) { 242 const Tensor& input = MklGetInput(context, 0); 243 bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); 244 mkl_params.in_dim = input_in_mkl_format 245 ? mkl_shape_input_shape.GetDimension() 246 : input.dims(); 247 mkl_params.in_sizes[0] = static_cast<size_t>( 248 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0] 249 : GetTensorDim(input, tensor_format, 'W')); 250 mkl_params.in_sizes[1] = static_cast<size_t>( 251 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1] 252 : GetTensorDim(input, tensor_format, 'H')); 253 mkl_params.in_sizes[2] = static_cast<size_t>( 254 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2] 255 : GetTensorDim(input, tensor_format, 'C')); 256 mkl_params.in_sizes[3] = static_cast<size_t>( 257 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3] 258 : GetTensorDim(input, tensor_format, 'N')); 259 mkl_params.depth = mkl_params.in_sizes[2]; 260 GetStridesFromSizes(tensor_format, mkl_params.in_strides, 261 mkl_params.in_sizes); 262 } 263 264 void MklCreateInputLayout(OpKernelContext* context) { 265 const Tensor& input = MklGetInput(context, 0); 266 bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); 267 if (input_in_mkl_format) { 268 mkl_lt_input = 269 static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout()); 270 } else { 271 CHECK_EQ( 272 dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dim, 273 mkl_params.in_sizes, mkl_params.in_strides), 274 E_SUCCESS); 275 } 276 } 277 void MklPrepareContextInputs(OpKernelContext* context, 278 Tensor* mkl_tmp_input_buf_tensor, 279 Tensor* mkl_tmp_scale_shift_buf_tensor) { 280 bool mkl_convert_input; 281 dnnPrimitive_t mkl_prim_convert_input = nullptr; 282 dnnLayout_t mkl_lt_internal_input = nullptr; 283 void* mkl_buf_converted_input = nullptr; 284 // Compare with internal layouts and convert if needed 285 const Tensor& input = MklGetInput(context, 0); 286 void* mkl_buf_input = 287 const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); 288 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32( 289 &mkl_lt_internal_input, mkl_prim_batchnorm, dnnResourceSrc), 290 E_SUCCESS); 291 mkl_convert_input = 292 !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input); 293 if (mkl_convert_input) { 294 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input, 295 mkl_lt_internal_input), 296 E_SUCCESS); 297 AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, 298 &mkl_buf_converted_input); 299 CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, 300 mkl_buf_converted_input), 301 E_SUCCESS); 302 dnnDelete_F32(mkl_prim_convert_input); 303 } 304 dnnLayoutDelete_F32(mkl_lt_internal_input); 305 mkl_res_batchnorm[dnnResourceSrc] = 306 (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input; 307 308 // scale-shift layout is created from primitive. So no conversion 309 // is needed, however, a buffer has to be allocated. 310 dnnLayout_t mkl_lt_scale_shift = nullptr; 311 void* mkl_buf_scale_shift = nullptr; 312 CHECK_EQ( 313 dnnLayoutCreateFromPrimitive_F32( 314 &mkl_lt_scale_shift, mkl_prim_batchnorm, dnnResourceScaleShift), 315 E_SUCCESS); 316 AllocTmpBuffer(context, mkl_tmp_scale_shift_buf_tensor, 317 mkl_lt_scale_shift, &mkl_buf_scale_shift); 318 // Fill the scale-shift buffer with data, presumably buffer is 2D array 319 const Tensor& scale = MklGetInput(context, 1); 320 const Tensor& shift = MklGetInput(context, 2); 321 float* buf_scale_shift = static_cast<float*>(mkl_buf_scale_shift); 322 float* buf_scale = const_cast<float*>( 323 static_cast<const float*>(scale.flat<float>().data())); 324 float* buf_shift = const_cast<float*>( 325 static_cast<const float*>(shift.flat<float>().data())); 326 auto depth = mkl_params.depth; 327 for (int i = 0; i < depth; i++) { 328 buf_scale_shift[i] = buf_scale[i]; 329 buf_scale_shift[i + depth] = buf_shift[i]; 330 } 331 mkl_res_batchnorm[dnnResourceScaleShift] = mkl_buf_scale_shift; 332 } 333 334 inline void MklSetMeanVariance(const Tensor& mean, const Tensor& variance) { 335 mkl_res_batchnorm[dnnResourceMean] = const_cast<void*>( 336 static_cast<const void*>(mean.flat<float>().data())); 337 mkl_res_batchnorm[dnnResourceVariance] = const_cast<void*>( 338 static_cast<const void*>(variance.flat<float>().data())); 339 } 340 } MklFusedBatchNormOpContext; 341 }; 342 343 template <typename Device, typename T> 344 class MklFusedBatchNormGradOp : public OpKernel { 345 public: 346 explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) 347 : OpKernel(context) { 348 float epsilon; 349 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 350 epsilon_ = T(epsilon); 351 string tensor_format; 352 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 353 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 354 errors::InvalidArgument("Invalid data format")); 355 } 356 357 void Compute(OpKernelContext* context) override { 358 MklFusedBatchNormGradOpContext mkl_context; 359 360 const Tensor& out_backprop = MklGetInput(context, 0); 361 const Tensor& input = MklGetInput(context, 1); 362 const Tensor& scale = MklGetInput(context, 2); 363 const Tensor& saved_mean = MklGetInput(context, 3); 364 const Tensor& saved_var = MklGetInput(context, 4); 365 366 // Here scale, mean, and variance are 1D and considered 367 // those having same layout in MKL and TF 368 GetMklShape(context, 0, &(mkl_context.mkl_shape_out_backprop)); 369 GetMklShape(context, 1, &(mkl_context.mkl_shape_input_shape)); 370 371 bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor(); 372 bool out_backprop_in_mkl_format = 373 mkl_context.mkl_shape_out_backprop.IsMklTensor(); 374 if (!out_backprop_in_mkl_format) { 375 OP_REQUIRES(context, out_backprop.dims() == 4, 376 errors::InvalidArgument("input must be 4-dimensional", 377 out_backprop.shape().DebugString())); 378 } 379 if (!input_in_mkl_format) { 380 OP_REQUIRES(context, input.dims() == 4, 381 errors::InvalidArgument("input must be 4-dimensional", 382 input.shape().DebugString())); 383 } 384 OP_REQUIRES(context, scale.dims() == 1, 385 errors::InvalidArgument("scale must be 1-dimensional", 386 scale.shape().DebugString())); 387 OP_REQUIRES(context, saved_mean.dims() == 1, 388 errors::InvalidArgument("saved mean must be 1-dimensional", 389 saved_mean.shape().DebugString())); 390 OP_REQUIRES(context, saved_var.dims() == 1, 391 errors::InvalidArgument("saved variance must be 1-dimensional", 392 saved_var.shape().DebugString())); 393 394 mkl_context.MklExtractParams(context, tensor_format_); 395 396 mkl_context.MklCreateInputLayout(context); 397 398 unsigned int flag_batch_norm_grad = dnnUseScaleShift; 399 400 // Create Backward Op primitive. 401 CHECK_EQ(dnnBatchNormalizationCreateBackward_v2_F32( 402 &(mkl_context.mkl_prim_batchnorm_bwd), nullptr, 403 mkl_context.mkl_lt_input, static_cast<float>(epsilon_), 404 flag_batch_norm_grad), 405 E_SUCCESS); 406 407 // Temporary tensors and their buffers if conversion is required 408 Tensor mkl_tmp_input_buf_tensor, mkl_tmp_outbackprop_buf_tensor, 409 mkl_tmp_scaleshift_buf_tensor; 410 mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor, 411 &mkl_tmp_outbackprop_buf_tensor, 412 &mkl_tmp_scaleshift_buf_tensor); 413 414 // Allocate tensor for grad w.r.t. input(x) 415 Tensor* in_backprop = nullptr; 416 TensorShape tf_shape_in_backprop; 417 MklShape mkl_shape_in_backprop; 418 mkl_shape_in_backprop.SetMklTensor(true); 419 mkl_shape_in_backprop.SetMklLayout(mkl_context.mkl_prim_batchnorm_bwd, 420 dnnResourceDiffSrc); 421 mkl_shape_in_backprop.SetTfLayout(mkl_context.mkl_params.in_dims, 422 mkl_context.mkl_params.in_sizes, 423 mkl_context.mkl_params.in_strides); 424 mkl_shape_in_backprop.SetTfDimOrder(mkl_context.mkl_params.in_dims, 425 tensor_format_); 426 tf_shape_in_backprop.AddDim( 427 dnnLayoutGetMemorySize_F32( 428 static_cast<dnnLayout_t>(mkl_shape_in_backprop.GetMklLayout())) / 429 sizeof(T)); 430 AllocateOutputSetMklShape(context, 0, &in_backprop, tf_shape_in_backprop, 431 mkl_shape_in_backprop); 432 mkl_context.mkl_res_batchnorm_bwd[dnnResourceDiffSrc] = 433 static_cast<void*>(in_backprop->flat<T>().data()); 434 435 // grad_scale and grad_shift are combined together in MKL 436 // So create a single temporary buffer for those. 437 // Also set dnnResourceDiffScaleShift to the temporary buffer 438 Tensor mkl_tmp_grad_scale_shift_buf_tensor; 439 mkl_context.MklPrepareGradScaleShift(context, 440 &mkl_tmp_grad_scale_shift_buf_tensor); 441 442 // All dnn resources are set now, ready to execute 443 CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm_bwd, 444 mkl_context.mkl_res_batchnorm_bwd), 445 E_SUCCESS); 446 447 // Now separate out scale and shift grad and copy to individual tensors 448 const TensorShape& tf_shape_scale_shift = scale.shape(); 449 // Allocate tensor for grad w.r.t. scale (beta) 450 Tensor* scale_backprop = nullptr; 451 MklShape mkl_shape_scale_backprop; 452 AllocateOutputSetMklShape(context, 1, &scale_backprop, tf_shape_scale_shift, 453 mkl_shape_scale_backprop); 454 455 // Allocate tensor for grad w.r.t. shift(gamma) 456 Tensor* shift_backprop = nullptr; 457 MklShape mkl_shape_shift_backprop; 458 AllocateOutputSetMklShape(context, 2, &shift_backprop, tf_shape_scale_shift, 459 mkl_shape_shift_backprop); 460 461 // copy scale and shift grads to tensors 462 float* mkl_buf_scale_shift = const_cast<float*>(static_cast<const float*>( 463 mkl_tmp_grad_scale_shift_buf_tensor.flat<T>().data())); 464 float* tf_buf_scale = const_cast<float*>( 465 static_cast<const float*>(scale_backprop->flat<T>().data())); 466 float* tf_buf_shift = const_cast<float*>( 467 static_cast<const float*>(shift_backprop->flat<T>().data())); 468 auto depth = mkl_context.mkl_params.depth; 469 for (int i = 0; i < depth; i++) { 470 tf_buf_scale[i] = mkl_buf_scale_shift[i]; 471 tf_buf_shift[i] = mkl_buf_scale_shift[i + depth]; 472 } 473 474 // Two placeholders for estimated_mean and estimated_variance, which are 475 // used for inference and thus not needed here for gradient computation. 476 Tensor* placeholder_1 = nullptr; 477 MklShape mkl_shape_placeholder_1; 478 AllocateOutputSetMklShape(context, 3, &placeholder_1, TensorShape({}), 479 mkl_shape_placeholder_1); 480 Tensor* placeholder_2 = nullptr; 481 MklShape mkl_shape_placeholder_2; 482 AllocateOutputSetMklShape(context, 4, &placeholder_2, TensorShape({}), 483 mkl_shape_placeholder_2); 484 485 mkl_context.MklCleanup(); 486 } 487 488 private: 489 T epsilon_; 490 TensorFormat tensor_format_; 491 492 // Structure containing all info for MklOp 493 typedef struct { 494 // Parameters used for input and output layouts 495 struct MklBatchNormParams { 496 // BatchNormOp src and 497 size_t in_dims; 498 size_t in_sizes[4]; 499 size_t in_strides[4]; 500 size_t depth; // Batch normalization is done for per channel. 501 } mkl_params; 502 503 MklShape mkl_shape_out_backprop; 504 MklShape mkl_shape_input_shape; 505 506 // MKL primitive and resources for BatchNormOp 507 dnnPrimitive_t mkl_prim_batchnorm_bwd = nullptr; 508 void* mkl_res_batchnorm_bwd[dnnResourceNumber]; 509 510 // MKL layouts for inputs in the context 511 dnnLayout_t mkl_lt_out_backprop = nullptr; 512 dnnLayout_t mkl_lt_input = nullptr; 513 514 void MklCleanup() { 515 bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); 516 bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor(); 517 if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input); 518 if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_out_backprop); 519 520 dnnDelete_F32(mkl_prim_batchnorm_bwd); 521 } 522 523 void MklExtractParams(OpKernelContext* context, 524 const TensorFormat& tensor_format) { 525 const Tensor& input = MklGetInput(context, 1); 526 bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); 527 mkl_params.in_dims = input_in_mkl_format 528 ? mkl_shape_input_shape.GetDimension() 529 : input.dims(); 530 mkl_params.in_sizes[0] = static_cast<size_t>( 531 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0] 532 : GetTensorDim(input, tensor_format, 'W')); 533 mkl_params.in_sizes[1] = static_cast<size_t>( 534 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1] 535 : GetTensorDim(input, tensor_format, 'H')); 536 mkl_params.in_sizes[2] = static_cast<size_t>( 537 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2] 538 : GetTensorDim(input, tensor_format, 'C')); 539 mkl_params.in_sizes[3] = static_cast<size_t>( 540 input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3] 541 : GetTensorDim(input, tensor_format, 'N')); 542 mkl_params.depth = mkl_params.in_sizes[2]; 543 GetStridesFromSizes(tensor_format, mkl_params.in_strides, 544 mkl_params.in_sizes); 545 } 546 547 void MklCreateInputLayout(OpKernelContext* context) { 548 bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor(); 549 if (input_in_mkl_format) { 550 mkl_lt_input = 551 static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout()); 552 } else { 553 CHECK_EQ( 554 dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dims, 555 mkl_params.in_sizes, mkl_params.in_strides), 556 E_SUCCESS); 557 } 558 559 bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor(); 560 if (out_backprop_in_mkl_format) { 561 mkl_lt_out_backprop = 562 static_cast<dnnLayout_t>(mkl_shape_out_backprop.GetCurLayout()); 563 } else { 564 CHECK_EQ( 565 dnnLayoutCreate_F32(&mkl_lt_out_backprop, mkl_params.in_dims, 566 mkl_params.in_sizes, mkl_params.in_strides), 567 E_SUCCESS); 568 } 569 } 570 571 void MklPrepareContextInputs(OpKernelContext* context, 572 Tensor* mkl_tmp_input_buf_tensor, 573 Tensor* mkl_tmp_outbackprop_buf_tensor, 574 Tensor* mkl_tmp_scaleshift_buf_tensor) { 575 bool mkl_convert_input; 576 dnnPrimitive_t mkl_prim_convert_input = nullptr; 577 dnnLayout_t mkl_lt_internal_input = nullptr; 578 void* mkl_buf_converted_input = nullptr; 579 // Compare with internal layouts and convert if needed 580 const Tensor& input = MklGetInput(context, 1); 581 void* mkl_buf_input = 582 const_cast<void*>(static_cast<const void*>(input.flat<T>().data())); 583 CHECK_EQ( 584 dnnLayoutCreateFromPrimitive_F32( 585 &mkl_lt_internal_input, mkl_prim_batchnorm_bwd, dnnResourceSrc), 586 E_SUCCESS); 587 mkl_convert_input = 588 !dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input); 589 if (mkl_convert_input) { 590 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input, 591 mkl_lt_internal_input), 592 E_SUCCESS); 593 AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input, 594 &mkl_buf_converted_input); 595 CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input, 596 mkl_buf_converted_input), 597 E_SUCCESS); 598 dnnDelete_F32(mkl_prim_convert_input); 599 } 600 dnnLayoutDelete_F32(mkl_lt_internal_input); 601 mkl_res_batchnorm_bwd[dnnResourceSrc] = 602 (mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input; 603 604 bool mkl_convert_out_backprop; 605 dnnPrimitive_t mkl_prim_convert_out_backprop = nullptr; 606 dnnLayout_t mkl_lt_internal_out_backprop = nullptr; 607 void* mkl_buf_converted_out_backprop = nullptr; 608 // Compare with internal layouts and convert if needed 609 const Tensor& out_backprop = MklGetInput(context, 0); 610 void* mkl_buf_out_backprop = const_cast<void*>( 611 static_cast<const void*>(out_backprop.flat<T>().data())); 612 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop, 613 mkl_prim_batchnorm_bwd, 614 dnnResourceDiffDst), 615 E_SUCCESS); 616 mkl_convert_out_backprop = !dnnLayoutCompare_F32( 617 mkl_lt_internal_out_backprop, mkl_lt_out_backprop); 618 if (mkl_convert_out_backprop) { 619 CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop, 620 mkl_lt_out_backprop, 621 mkl_lt_internal_out_backprop), 622 E_SUCCESS); 623 AllocTmpBuffer(context, mkl_tmp_outbackprop_buf_tensor, 624 mkl_lt_internal_out_backprop, 625 &mkl_buf_converted_out_backprop); 626 CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop, 627 mkl_buf_out_backprop, 628 mkl_buf_converted_out_backprop), 629 E_SUCCESS); 630 dnnDelete_F32(mkl_prim_convert_out_backprop); 631 } 632 dnnLayoutDelete_F32(mkl_lt_internal_out_backprop); 633 mkl_res_batchnorm_bwd[dnnResourceDiffDst] = 634 (mkl_convert_out_backprop) ? mkl_buf_converted_out_backprop 635 : mkl_buf_out_backprop; 636 637 // Set dnnResourceMean and dnnResourceVariance 638 const Tensor& saved_mean = MklGetInput(context, 3); 639 const Tensor& saved_var = MklGetInput(context, 4); 640 void* mkl_buf_saved_mean = const_cast<void*>( 641 static_cast<const void*>(saved_mean.flat<T>().data())); 642 void* mkl_buf_saved_var = const_cast<void*>( 643 static_cast<const void*>(saved_var.flat<T>().data())); 644 mkl_res_batchnorm_bwd[dnnResourceMean] = mkl_buf_saved_mean; 645 mkl_res_batchnorm_bwd[dnnResourceVariance] = mkl_buf_saved_var; 646 647 // Set dnnResourceScaleShift 648 // Note backward Op needs only current values of scale parameters, 649 // shift parameters could be garbage and won't be used 650 const Tensor& scale = MklGetInput(context, 2); 651 dnnLayout_t mkl_lt_scale_shift = nullptr; 652 void* mkl_buf_scale_shift = nullptr; 653 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_scale_shift, 654 mkl_prim_batchnorm_bwd, 655 dnnResourceScaleShift), 656 E_SUCCESS); 657 AllocTmpBuffer(context, mkl_tmp_scaleshift_buf_tensor, mkl_lt_scale_shift, 658 &mkl_buf_scale_shift); 659 float* pscale = 660 const_cast<float*>(static_cast<const float*>(scale.flat<T>().data())); 661 float* pscale_shift = static_cast<float*>(mkl_buf_scale_shift); 662 auto depth = mkl_params.depth; 663 for (int i = 0; i < depth; i++) pscale_shift[i] = pscale[i]; 664 mkl_res_batchnorm_bwd[dnnResourceScaleShift] = mkl_buf_scale_shift; 665 dnnLayoutDelete_F32(mkl_lt_scale_shift); 666 } 667 668 void MklPrepareGradScaleShift(OpKernelContext* context, 669 Tensor* mkl_tmp_grad_scale_shift_buf_tensor) { 670 dnnLayout_t mkl_lt_grad_scaleshift = nullptr; 671 void* mkl_buf_grad_scaleshift = nullptr; 672 CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_grad_scaleshift, 673 mkl_prim_batchnorm_bwd, 674 dnnResourceDiffScaleShift), 675 E_SUCCESS); 676 AllocTmpBuffer(context, mkl_tmp_grad_scale_shift_buf_tensor, 677 mkl_lt_grad_scaleshift, &mkl_buf_grad_scaleshift); 678 mkl_res_batchnorm_bwd[dnnResourceDiffScaleShift] = 679 mkl_buf_grad_scaleshift; 680 dnnLayoutDelete_F32(mkl_lt_grad_scaleshift); 681 } 682 } MklFusedBatchNormGradOpContext; 683 }; 684 #endif 685 686 #ifndef INTEL_MKL_ML 687 688 template <typename Device, typename T> 689 class MklFusedBatchNormOp : public OpKernel { 690 public: 691 explicit MklFusedBatchNormOp(OpKernelConstruction* context) 692 : OpKernel(context) { 693 float epsilon; 694 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 695 epsilon_ = T(epsilon); 696 string tensor_format; 697 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 698 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 699 errors::InvalidArgument("Invalid data format")); 700 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 701 } 702 703 void Compute(OpKernelContext* context) override { 704 try { 705 auto cpu_engine = engine(engine::cpu, 0); 706 const size_t kSrcIndex = 0; // index of src input tensor 707 const size_t kScaleIndex = 1; // index of scale tensor 708 const size_t kShiftIndex = 2; // index of shift tensor 709 const size_t kMeanIndex = 3; // index of est_mean tensor 710 const size_t kVarianceIndex = 4; // index of est_variance tensor 711 712 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 713 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 714 const Tensor& shift_tensor = MklGetInput(context, kShiftIndex); 715 const Tensor& est_mean_tensor = MklGetInput(context, kMeanIndex); 716 const Tensor& est_variance_tensor = MklGetInput(context, kVarianceIndex); 717 718 TensorShape tf_shape_src; 719 MklDnnShape dnn_shape_src; 720 GetMklShape(context, kSrcIndex, &dnn_shape_src); 721 722 if (dnn_shape_src.IsMklTensor()) { 723 tf_shape_src = dnn_shape_src.GetTfShape(); 724 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 725 errors::InvalidArgument("input must be 4-dimensional", 726 src_tensor.shape().DebugString())); 727 } else { 728 tf_shape_src = src_tensor.shape(); 729 OP_REQUIRES(context, src_tensor.dims() == 4, 730 errors::InvalidArgument("input must be 4-dimensional", 731 src_tensor.shape().DebugString())); 732 } 733 OP_REQUIRES(context, scale_tensor.dims() == 1, 734 errors::InvalidArgument("scale must be 1-dimensional", 735 scale_tensor.shape().DebugString())); 736 OP_REQUIRES(context, shift_tensor.dims() == 1, 737 errors::InvalidArgument("offset must be 1-dimensional", 738 shift_tensor.shape().DebugString())); 739 OP_REQUIRES( 740 context, est_mean_tensor.dims() == 1, 741 errors::InvalidArgument("estimated_mean must be 1-dimensional", 742 est_mean_tensor.shape().DebugString())); 743 OP_REQUIRES( 744 context, est_variance_tensor.dims() == 1, 745 errors::InvalidArgument("estimated_variance must be 1-dimensional", 746 est_variance_tensor.shape().DebugString())); 747 748 if (is_training_) { 749 OP_REQUIRES( 750 context, est_mean_tensor.dim_size(0) == 0, 751 errors::InvalidArgument("estimated_mean must be empty for training", 752 est_mean_tensor.shape().DebugString())); 753 OP_REQUIRES(context, est_variance_tensor.dim_size(0) == 0, 754 errors::InvalidArgument( 755 "estimated_variance must be empty for training", 756 est_variance_tensor.shape().DebugString())); 757 } 758 759 // special case: input with 0 element and 0 batch size 760 Tensor* dst_tensor = nullptr; 761 if (tf_shape_src.num_elements() == 0) { 762 HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), 763 &dst_tensor); 764 return; 765 } 766 767 if (dnn_shape_src.IsMklTensor()) 768 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 769 else 770 ExtractParams(context); 771 772 // Indices of output tensors 773 const size_t kDstIndex = 0; 774 775 // allocate 4 output TF tensors 776 Tensor* batch_mean_tensor = nullptr; 777 Tensor* batch_variance_tensor = nullptr; 778 Tensor* saved_mean_tensor = nullptr; 779 Tensor* saved_variance_tensor = nullptr; 780 AllocateTFOutputs(context, scale_tensor.shape(), &batch_mean_tensor, 781 &batch_variance_tensor, &saved_mean_tensor, 782 &saved_variance_tensor); 783 784 if (is_training_) 785 SetMeanVariance(*batch_mean_tensor, *batch_variance_tensor); 786 else 787 SetMeanVariance(est_mean_tensor, est_variance_tensor); 788 789 MklDnnData<T> src(&cpu_engine); 790 MklDnnData<T> dst(&cpu_engine); 791 792 memory::format format_m; 793 if (dnn_shape_src.IsMklTensor()) { 794 if (dnn_shape_src.IsTensorInNCHWFormat()) { 795 format_m = memory::format::nchw; 796 } else { 797 format_m = memory::format::nhwc; 798 } 799 } else { 800 format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); 801 } 802 803 // set src primitive 804 memory::dims src_dims; 805 if (dnn_shape_src.IsMklTensor()) { 806 src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), 807 tensor_format_); 808 } else { 809 src_dims = 810 TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 811 } 812 813 auto src_md = dnn_shape_src.IsMklTensor() 814 ? dnn_shape_src.GetMklLayout() 815 : memory::desc(src_dims, MklDnnType<T>(), format_m); 816 src.SetUsrMem(src_md, &src_tensor); 817 818 // set weights primitive 819 // MKL-DNN packs scale & shift as "weights": 820 // <scale>...<scale><shift>...<shift> 821 auto weights_desc = 822 memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc); 823 auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); 824 auto weights_m = memory(weights_pd); 825 T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle()); 826 T* scale_tf = 827 reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data())); 828 T* shift_tf = 829 reinterpret_cast<T*>(const_cast<T*>(shift_tensor.flat<T>().data())); 830 831 for (int k = 0; k < depth_; k++) { 832 weights_data[k] = scale_tf[k]; 833 weights_data[k + depth_] = shift_tf[k]; 834 } 835 836 // set mean primitive 837 auto mean_desc = 838 memory::desc({1, depth_}, MklDnnType<T>(), memory::format::nc); 839 auto mean_pd = memory::primitive_desc(mean_desc, cpu_engine); 840 char* saved_mean_data_tf = 841 reinterpret_cast<char*>(saved_mean_tensor->flat<T>().data()); 842 std::memcpy(saved_mean_data_tf, reinterpret_cast<char*>(mean_values_), 843 depth_ * sizeof(T)); 844 auto mean_m = 845 memory(mean_pd, reinterpret_cast<void*>(saved_mean_data_tf)); 846 847 // set variance primitive 848 auto variance_desc = 849 memory::desc({1, depth_}, MklDnnType<T>(), memory::format::nc); 850 auto variance_pd = memory::primitive_desc(variance_desc, cpu_engine); 851 char* saved_variance_data_tf = 852 reinterpret_cast<char*>(saved_variance_tensor->flat<T>().data()); 853 std::memcpy(saved_variance_data_tf, 854 reinterpret_cast<char*>(variance_values_), 855 depth_ * sizeof(T)); 856 auto variance_m = memory(variance_pd, saved_variance_data_tf); 857 858 prop_kind pk = (is_training_) ? prop_kind::forward_training 859 : prop_kind::forward_scoring; 860 auto bnrm_fwd_desc = batch_normalization_forward::desc( 861 pk, src.GetUsrMemDesc(), epsilon_, 862 is_training_ ? use_scale_shift 863 : (use_scale_shift | use_global_stats)); 864 auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( 865 bnrm_fwd_desc, cpu_engine); 866 867 // allocate dst tensor 868 MklDnnShape dnn_shape_dst; 869 TensorShape tf_shape_dst; 870 if (dnn_shape_src.IsMklTensor()) { 871 dnn_shape_dst.SetMklTensor(true); 872 auto dst_pd = bnrm_fwd_pd.dst_primitive_desc(); 873 dnn_shape_dst.SetMklLayout(&dst_pd); 874 dnn_shape_dst.SetElemType(MklDnnType<T>()); 875 dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, 876 format_m); 877 tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); 878 } else { 879 dnn_shape_dst.SetMklTensor(false); 880 tf_shape_dst = src_tensor.shape(); 881 } 882 AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, 883 dnn_shape_dst); 884 885 // Output of batchnorm has same shape as input. 886 dst.SetUsrMem(src_md, dst_tensor); 887 888 primitive bnrm_fwd_op; 889 if (is_training_) { 890 bnrm_fwd_op = 891 batch_normalization_forward(bnrm_fwd_pd, src.GetOpMem(), weights_m, 892 dst.GetOpMem(), mean_m, variance_m); 893 } else { 894 bnrm_fwd_op = batch_normalization_forward( 895 bnrm_fwd_pd, src.GetOpMem(), mean_m, variance_m, 896 (const primitive::at)weights_m, dst.GetOpMem()); 897 } 898 std::vector<primitive> net; 899 net.push_back(bnrm_fwd_op); 900 stream(stream::kind::eager).submit(net).wait(); 901 902 // copy batch_mean data 903 T* batch_mean_data_tf = 904 reinterpret_cast<T*>(batch_mean_tensor->flat<T>().data()); 905 std::memcpy(reinterpret_cast<char*>(batch_mean_data_tf), 906 reinterpret_cast<char*>(mean_m.get_data_handle()), 907 depth_ * sizeof(T)); 908 909 // copy batch_variance data with Bessel's correction 910 // if training mode is on 911 float adjust_factor = 1.0; 912 if (is_training_) { 913 size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; 914 size_t adjust_size = orig_size - 1; 915 adjust_factor = (static_cast<float>(orig_size)) / adjust_size; 916 } 917 for (int k = 0; k < depth_; k++) 918 batch_variance_tensor->flat<T>().data()[k] = 919 (reinterpret_cast<T*>(variance_m.get_data_handle()))[k] * 920 adjust_factor; 921 } catch (mkldnn::error& e) { 922 string error_msg = "Status: " + std::to_string(e.status) + 923 ", message: " + string(e.message) + ", in file " + 924 string(__FILE__) + ":" + std::to_string(__LINE__); 925 OP_REQUIRES_OK( 926 context, 927 errors::Aborted("Operation received an exception:", error_msg)); 928 } 929 } 930 931 private: 932 T epsilon_; 933 TensorFormat tensor_format_; 934 bool is_training_; 935 T* mean_values_; 936 T* variance_values_; 937 size_t depth_; // batch normalization is done for per channel. 938 939 void ExtractParams(OpKernelContext* context) { 940 const Tensor& input = MklGetInput(context, 0); 941 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 942 } 943 944 void SetMeanVariance(const Tensor& mean, const Tensor& variance) { 945 mean_values_ = reinterpret_cast<T*>(const_cast<T*>(mean.flat<T>().data())); 946 variance_values_ = 947 reinterpret_cast<T*>(const_cast<T*>(variance.flat<T>().data())); 948 } 949 950 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 951 TensorShape tf_shape_scale, Tensor** dst_tensor) { 952 CHECK_NOTNULL(dst_tensor); 953 954 const size_t kDstIndex = 0; 955 MklDnnShape dnn_shape_dst; 956 dnn_shape_dst.SetMklTensor(false); 957 AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, 958 dnn_shape_dst); 959 CHECK_NOTNULL(*dst_tensor); 960 memset(const_cast<char*>((*dst_tensor)->tensor_data().data()), 0, 961 (*dst_tensor)->tensor_data().size()); 962 963 Tensor* batch_mean_tensor = nullptr; 964 Tensor* batch_variance_tensor = nullptr; 965 Tensor* saved_mean_tensor = nullptr; 966 Tensor* saved_variance_tensor = nullptr; 967 AllocateTFOutputs(context, tf_shape_scale, &batch_mean_tensor, 968 &batch_variance_tensor, &saved_mean_tensor, 969 &saved_variance_tensor); 970 } 971 972 void AllocateTFOutputs(OpKernelContext* context, TensorShape tf_shape_scale, 973 Tensor** batch_mean_tensor, 974 Tensor** batch_variance_tensor, 975 Tensor** saved_mean_tensor, 976 Tensor** saved_variance_tensor) { 977 CHECK_NOTNULL(batch_mean_tensor); 978 CHECK_NOTNULL(batch_variance_tensor); 979 CHECK_NOTNULL(saved_mean_tensor); 980 CHECK_NOTNULL(saved_variance_tensor); 981 982 const size_t kBatchMeanIndex = 1; 983 const size_t kBatchVarianceIndex = 2; 984 const size_t kSavedMeanIndex = 3; 985 const size_t kSavedVarianceIndex = 4; 986 987 // allocate batch mean output tensor 988 MklDnnShape mkl_shape_batch_mean; 989 mkl_shape_batch_mean.SetMklTensor(false); 990 AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, 991 tf_shape_scale, mkl_shape_batch_mean); 992 CHECK_NOTNULL(*batch_mean_tensor); 993 // set NAN mean value in case of empty input tensor 994 for (int k = 0; k < tf_shape_scale.num_elements(); k++) 995 (*batch_mean_tensor)->flat<T>().data()[k] = NAN; 996 997 // allocate batch variance output tensor 998 MklDnnShape mkl_shape_batch_variance; 999 mkl_shape_batch_variance.SetMklTensor(false); 1000 AllocateOutputSetMklShape(context, kBatchVarianceIndex, 1001 batch_variance_tensor, tf_shape_scale, 1002 mkl_shape_batch_variance); 1003 CHECK_NOTNULL(*batch_variance_tensor); 1004 // set NAN variance value in case of empty input tensor 1005 for (int k = 0; k < tf_shape_scale.num_elements(); k++) 1006 (*batch_variance_tensor)->flat<T>().data()[k] = NAN; 1007 1008 // Mean and variance (without Bessel's correction) saved for backward 1009 // computation to serve as pre-computed mean and variance. 1010 MklDnnShape mkl_shape_saved_mean; 1011 mkl_shape_saved_mean.SetMklTensor(false); 1012 AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, 1013 tf_shape_scale, mkl_shape_saved_mean); 1014 CHECK_NOTNULL(*saved_mean_tensor); 1015 // set NAN mean value in case of empty input tensor 1016 for (int k = 0; k < tf_shape_scale.num_elements(); k++) 1017 (*saved_mean_tensor)->flat<T>().data()[k] = NAN; 1018 1019 MklDnnShape mkl_shape_saved_variance; 1020 mkl_shape_saved_variance.SetMklTensor(false); 1021 AllocateOutputSetMklShape(context, kSavedVarianceIndex, 1022 saved_variance_tensor, tf_shape_scale, 1023 mkl_shape_saved_variance); 1024 CHECK_NOTNULL(*saved_variance_tensor); 1025 // set NAN variance value in case of empty input tensor 1026 for (int k = 0; k < tf_shape_scale.num_elements(); k++) 1027 (*saved_variance_tensor)->flat<T>().data()[k] = NAN; 1028 } 1029 }; 1030 1031 template <typename Device, typename T> 1032 class MklFusedBatchNormGradOp : public OpKernel { 1033 public: 1034 explicit MklFusedBatchNormGradOp(OpKernelConstruction* context) 1035 : OpKernel(context) { 1036 float epsilon; 1037 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 1038 epsilon_ = T(epsilon); 1039 string tensor_format; 1040 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 1041 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 1042 errors::InvalidArgument("Invalid data format")); 1043 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 1044 } 1045 1046 void Compute(OpKernelContext* context) override { 1047 try { 1048 auto cpu_engine = engine(engine::cpu, 0); 1049 const size_t kDiffDstIndex = 0; // index of diff_dst tensor 1050 const size_t kSrcIndex = 1; // index of src input tensor 1051 const size_t kScaleIndex = 2; // index of scale tensor 1052 const size_t kMeanIndex = 3; // index of saved_mean tensor 1053 const size_t kVarianceIndex = 4; // index of saved_variance tensor 1054 const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIndex); 1055 const Tensor& src_tensor = MklGetInput(context, kSrcIndex); 1056 const Tensor& scale_tensor = MklGetInput(context, kScaleIndex); 1057 const Tensor& saved_mean_tensor = MklGetInput(context, kMeanIndex); 1058 const Tensor& saved_variance_tensor = 1059 MklGetInput(context, kVarianceIndex); 1060 1061 MklDnnShape dnn_shape_src, dnn_shape_diff_dst; 1062 GetMklShape(context, kSrcIndex, &dnn_shape_src); 1063 GetMklShape(context, kDiffDstIndex, &dnn_shape_diff_dst); 1064 TensorShape tf_shape_src, tf_shape_diff_dst; 1065 1066 if (dnn_shape_diff_dst.IsMklTensor()) { 1067 tf_shape_diff_dst = dnn_shape_diff_dst.GetTfShape(); 1068 OP_REQUIRES( 1069 context, dnn_shape_diff_dst.GetDimension() == 4, 1070 errors::InvalidArgument("input must be 4-dimensional", 1071 diff_dst_tensor.shape().DebugString())); 1072 } else { 1073 tf_shape_diff_dst = diff_dst_tensor.shape(); 1074 OP_REQUIRES( 1075 context, diff_dst_tensor.dims() == 4, 1076 errors::InvalidArgument("input must be 4-dimensional", 1077 diff_dst_tensor.shape().DebugString())); 1078 } 1079 1080 if (dnn_shape_src.IsMklTensor()) { 1081 tf_shape_src = dnn_shape_src.GetTfShape(); 1082 OP_REQUIRES(context, dnn_shape_src.GetDimension() == 4, 1083 errors::InvalidArgument("input must be 4-dimensional", 1084 src_tensor.shape().DebugString())); 1085 } else { 1086 tf_shape_src = src_tensor.shape(); 1087 OP_REQUIRES(context, src_tensor.dims() == 4, 1088 errors::InvalidArgument("input must be 4-dimensional", 1089 src_tensor.shape().DebugString())); 1090 } 1091 1092 OP_REQUIRES(context, scale_tensor.dims() == 1, 1093 errors::InvalidArgument("scale must be 1-dimensional", 1094 scale_tensor.shape().DebugString())); 1095 OP_REQUIRES( 1096 context, saved_mean_tensor.dims() == 1, 1097 errors::InvalidArgument("saved mean must be 1-dimensional", 1098 saved_mean_tensor.shape().DebugString())); 1099 1100 OP_REQUIRES( 1101 context, saved_variance_tensor.dims() == 1, 1102 errors::InvalidArgument("saved variance must be 1-dimensional", 1103 saved_variance_tensor.shape().DebugString())); 1104 1105 Tensor* diff_src_tensor = nullptr; 1106 if (tf_shape_src.num_elements() == 0 || 1107 tf_shape_diff_dst.num_elements() == 0) { 1108 HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), 1109 &diff_src_tensor); 1110 return; 1111 } 1112 1113 if (dnn_shape_src.IsMklTensor()) 1114 depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); 1115 else 1116 ExtractParams(context); 1117 1118 memory::format format_m; 1119 if (dnn_shape_src.IsMklTensor()) { 1120 if (dnn_shape_src.IsTensorInNCHWFormat()) 1121 format_m = memory::format::nchw; 1122 else 1123 format_m = memory::format::nhwc; 1124 } else { 1125 format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); 1126 } 1127 1128 MklDnnData<T> src(&cpu_engine); 1129 MklDnnData<T> mean(&cpu_engine); 1130 MklDnnData<T> variance(&cpu_engine); 1131 MklDnnData<T> diff_dst(&cpu_engine); 1132 MklDnnData<T> diff_src(&cpu_engine); 1133 1134 memory::dims src_dims, diff_dst_dims; 1135 if (dnn_shape_src.IsMklTensor()) 1136 src_dims = TFShapeToMklDnnDimsInNCHW(dnn_shape_src.GetTfShape(), 1137 tensor_format_); 1138 else 1139 src_dims = 1140 TFShapeToMklDnnDimsInNCHW(src_tensor.shape(), tensor_format_); 1141 1142 if (dnn_shape_diff_dst.IsMklTensor()) 1143 diff_dst_dims = TFShapeToMklDnnDimsInNCHW( 1144 dnn_shape_diff_dst.GetTfShape(), tensor_format_); 1145 else 1146 diff_dst_dims = 1147 TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); 1148 1149 // set src and diff_dst primitives 1150 memory::desc src_md({}, memory::data_undef, memory::format_undef); 1151 memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); 1152 if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { 1153 if (dnn_shape_src.IsMklTensor()) { 1154 src_md = dnn_shape_src.GetMklLayout(); 1155 diff_dst_md = src_md; 1156 } else { 1157 diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); 1158 src_md = diff_dst_md; 1159 } 1160 } else { 1161 src_md = memory::desc(src_dims, MklDnnType<T>(), format_m); 1162 diff_dst_md = src_md; 1163 } 1164 src.SetUsrMem(src_md, &src_tensor); 1165 diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); 1166 1167 // weights -- DNN packs scales/shifts as weights in order of 1168 // scale, ..., scale, shift, ..., shift 1169 auto weights_desc = 1170 memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc); 1171 auto weights_pd = memory::primitive_desc(weights_desc, cpu_engine); 1172 auto weights_m = memory(weights_pd); 1173 T* weights_data = reinterpret_cast<T*>(weights_m.get_data_handle()); 1174 T* scale_tf = 1175 reinterpret_cast<T*>(const_cast<T*>(scale_tensor.flat<T>().data())); 1176 for (int k = 0; k < depth_; k++) { 1177 weights_data[k] = scale_tf[k]; 1178 weights_data[k + depth_] = 0; 1179 } 1180 1181 // set mean primitive 1182 memory::dims mv_dims = GetMeanVarianceDims(); 1183 mean.SetUsrMem(mv_dims, memory::format::nc, 1184 const_cast<void*>(static_cast<const void*>( 1185 saved_mean_tensor.flat<T>().data()))); 1186 mean.SetOpMemDesc(mv_dims, memory::format::nc); 1187 1188 // set variance primitive 1189 variance.SetUsrMem(mv_dims, memory::format::nc, 1190 const_cast<void*>(static_cast<const void*>( 1191 saved_variance_tensor.flat<T>().data()))); 1192 variance.SetOpMemDesc(mv_dims, memory::format::nc); 1193 1194 // set diff_weight primitive 1195 auto diff_weights_desc = 1196 memory::desc({2, depth_}, MklDnnType<T>(), memory::format::nc); 1197 auto diff_weights_pd = 1198 memory::primitive_desc(diff_weights_desc, cpu_engine); 1199 auto diff_weights_m = memory(diff_weights_pd); 1200 1201 auto bnrm_fwd_desc = batch_normalization_forward::desc( 1202 prop_kind::forward_training, src.GetUsrMemDesc(), epsilon_, 1203 is_training_ ? use_scale_shift 1204 : (use_scale_shift | use_global_stats)); 1205 auto bnrm_fwd_pd = batch_normalization_forward::primitive_desc( 1206 bnrm_fwd_desc, cpu_engine); 1207 1208 // Indices of output tensors 1209 const size_t kDiffSrcIndex = 0; // index of diff_src tensor 1210 1211 // allocate diff_src tensor 1212 MklDnnShape dnn_shape_diff_src; 1213 TensorShape tf_shape_diff_src; 1214 if (dnn_shape_src.IsMklTensor()) { 1215 dnn_shape_diff_src.SetMklTensor(true); 1216 auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc(); 1217 dnn_shape_diff_src.SetMklLayout(&diff_src_pd); 1218 dnn_shape_diff_src.SetElemType(MklDnnType<T>()); 1219 dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, 1220 format_m); 1221 dnn_shape_diff_src.SetTfDimOrder(dnn_shape_src.GetDimension(), 1222 tensor_format_); 1223 tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); 1224 } else { 1225 dnn_shape_diff_src.SetMklTensor(false); 1226 tf_shape_diff_src = src_tensor.shape(); 1227 } 1228 AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, 1229 tf_shape_diff_src, dnn_shape_diff_src); 1230 1231 diff_src.SetUsrMem(src_md, diff_src_tensor); 1232 1233 prop_kind pk = prop_kind::backward; 1234 auto bnrm_bwd_desc = batch_normalization_backward::desc( 1235 pk, diff_src.GetUsrMemDesc(), src.GetUsrMemDesc(), epsilon_, 1236 /* for inference, specify use_global_stats 1237 1. on fwd prop, use mean and variance 1238 provided as inputs 1239 2. on bwd prop, mean and variance are 1240 considered as constants. Thus, 1241 reduce the amout of MKL computations 1242 */ 1243 is_training_ ? use_scale_shift 1244 : (use_scale_shift | use_global_stats)); 1245 auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( 1246 bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd); 1247 1248 auto bnrm_bwd_op = batch_normalization_backward( 1249 bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(), 1250 diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m); 1251 1252 std::vector<primitive> net; 1253 net.push_back(bnrm_bwd_op); 1254 stream(stream::kind::eager).submit(net).wait(); 1255 1256 // allocate 4 output TF tensors 1257 Tensor* diff_scale_tensor = nullptr; 1258 Tensor* diff_shift_tensor = nullptr; 1259 AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, 1260 &diff_shift_tensor); 1261 1262 // copy data: diff_scale and diff_shift 1263 T* diff_weights_data_dnn = 1264 reinterpret_cast<T*>(diff_weights_m.get_data_handle()); 1265 for (int i = 0; i < depth_; i++) { 1266 diff_scale_tensor->flat<T>().data()[i] = diff_weights_data_dnn[i]; 1267 diff_shift_tensor->flat<T>().data()[i] = 1268 diff_weights_data_dnn[i + depth_]; 1269 } 1270 } catch (mkldnn::error& e) { 1271 string error_msg = "Status: " + std::to_string(e.status) + 1272 ", message: " + string(e.message) + ", in file " + 1273 string(__FILE__) + ":" + std::to_string(__LINE__); 1274 OP_REQUIRES_OK( 1275 context, 1276 errors::Aborted("Operation received an exception:", error_msg)); 1277 } 1278 } 1279 1280 private: 1281 T epsilon_; 1282 TensorFormat tensor_format_; 1283 int depth_; // batch normalization is done for per channel. 1284 bool is_training_; 1285 1286 void ExtractParams(OpKernelContext* context) { 1287 const Tensor& input = MklGetInput(context, 0); 1288 depth_ = static_cast<int>(GetTensorDim(input, tensor_format_, 'C')); 1289 } 1290 1291 void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, 1292 TensorShape tf_shape_scale_shift, 1293 Tensor** diff_src_tensor) { 1294 const size_t kDiffSrcIndex = 0; 1295 1296 MklDnnShape dnn_shape_diff_src; 1297 dnn_shape_diff_src.SetMklTensor(false); 1298 AllocateOutputSetMklShape(context, kDiffSrcIndex, diff_src_tensor, 1299 tf_shape_src, dnn_shape_diff_src); 1300 for (size_t i = 0; i < (*diff_src_tensor)->shape().num_elements(); i++) 1301 (*diff_src_tensor)->flat<T>().data()[i] = 0; 1302 1303 Tensor* diff_scale_tensor = nullptr; 1304 Tensor* diff_shift_tensor = nullptr; 1305 AllocateTFOutputs(context, tf_shape_scale_shift, &diff_scale_tensor, 1306 &diff_shift_tensor); 1307 } 1308 1309 void AllocateTFOutputs(OpKernelContext* context, 1310 TensorShape tf_shape_scale_shift, 1311 Tensor** diff_scale_tensor, 1312 Tensor** diff_shift_tensor) { 1313 CHECK_NOTNULL(diff_scale_tensor); 1314 CHECK_NOTNULL(diff_shift_tensor); 1315 1316 const size_t kDiffScaleIndex = 1; 1317 const size_t kDiffShiftIndex = 2; 1318 const size_t kP1Index = 3; 1319 const size_t kP2Index = 4; 1320 1321 // separate out scale and shift grad and copy to individual tensors 1322 MklDnnShape mkl_shape_diff_scale; 1323 mkl_shape_diff_scale.SetMklTensor(false); 1324 AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, 1325 tf_shape_scale_shift, mkl_shape_diff_scale); 1326 CHECK_NOTNULL(*diff_scale_tensor); 1327 for (size_t i = 0; i < (*diff_scale_tensor)->shape().num_elements(); i++) 1328 (*diff_scale_tensor)->flat<T>().data()[i] = 0; 1329 1330 MklDnnShape mkl_shape_diff_shift; 1331 mkl_shape_diff_shift.SetMklTensor(false); 1332 AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, 1333 tf_shape_scale_shift, mkl_shape_diff_shift); 1334 CHECK_NOTNULL(*diff_shift_tensor); 1335 for (size_t i = 0; i < (*diff_shift_tensor)->shape().num_elements(); i++) 1336 (*diff_shift_tensor)->flat<T>().data()[i] = 0; 1337 1338 // Placeholders for estimated_mean and estimated_variance, which are 1339 // used for inference and thus not needed here for gradient computation. 1340 Tensor *p1_tensor = nullptr, *p2_tensor = nullptr; 1341 MklDnnShape mkl_shape_p; 1342 mkl_shape_p.SetMklTensor(false); 1343 AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), 1344 mkl_shape_p); 1345 AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), 1346 mkl_shape_p); 1347 } 1348 1349 memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } 1350 }; 1351 1352 #endif 1353 1354 #define REGISTER_MKL_CPU(T) \ 1355 REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm") \ 1356 .Device(DEVICE_CPU) \ 1357 .TypeConstraint<T>("T") \ 1358 .Label(mkl_op_registry::kMklOpLabel), \ 1359 MklFusedBatchNormOp<CPUDevice, T>); 1360 TF_CALL_float(REGISTER_MKL_CPU); 1361 #undef REGISTER_MKL_CPU 1362 1363 #define REGISTER_MKL_CPU(T) \ 1364 REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad") \ 1365 .Device(DEVICE_CPU) \ 1366 .TypeConstraint<T>("T") \ 1367 .Label(mkl_op_registry::kMklOpLabel), \ 1368 MklFusedBatchNormGradOp<CPUDevice, T>); 1369 TF_CALL_float(REGISTER_MKL_CPU); 1370 #undef REGISTER_MKL_CPU 1371 } // namespace tensorflow 1372 1373 #endif // INTEL_MKL 1374