1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #include "tensorflow/core/kernels/conv_2d.h" 21 #include "tensorflow/core/util/stream_executor_util.h" 22 #endif 23 24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 25 #include "tensorflow/core/framework/op_kernel.h" 26 #include "tensorflow/core/framework/register_types.h" 27 #include "tensorflow/core/framework/tensor.h" 28 #include "tensorflow/core/framework/tensor_types.h" 29 #include "tensorflow/core/kernels/fill_functor.h" 30 #include "tensorflow/core/kernels/fused_batch_norm_op.h" 31 #include "tensorflow/core/util/tensor_format.h" 32 33 namespace tensorflow { 34 using CPUDevice = Eigen::ThreadPoolDevice; 35 using GPUDevice = Eigen::GpuDevice; 36 37 namespace functor { 38 39 // Functor used by FusedBatchNormOp to do the computations. 40 template <typename Device, typename T, typename U> 41 struct FusedBatchNorm; 42 // Functor used by FusedBatchNormGradOp to do the computations when 43 // is_training=True. 44 template <typename Device, typename T, typename U> 45 struct FusedBatchNormGrad; 46 47 template <bool IsSame, typename Y, typename X, typename T> 48 struct CastIfNecessary { 49 static inline void process( 50 Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth, 51 const CPUDevice& d) { 52 y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>(); 53 } 54 }; 55 56 template <typename Y, typename X, typename T> 57 struct CastIfNecessary<true, Y, X, T> { 58 static inline void process( 59 Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth, 60 const CPUDevice& d) { 61 y.reshape(rest_by_depth).device(d) = x_shifted; 62 } 63 }; 64 65 template <typename T, typename U> 66 struct FusedBatchNorm<CPUDevice, T, U> { 67 void operator()(OpKernelContext* context, const Tensor& x_input, 68 const Tensor& scale_input, const Tensor& offset_input, 69 const Tensor& estimated_mean_input, 70 const Tensor& estimated_variance_input, U epsilon, 71 Tensor* y_output, Tensor* batch_mean_output, 72 Tensor* batch_var_output, Tensor* saved_mean_output, 73 Tensor* saved_var_output, TensorFormat tensor_format, 74 bool is_training) { 75 OP_REQUIRES(context, tensor_format == FORMAT_NHWC, 76 errors::Internal("The CPU implementation of FusedBatchNorm " 77 "only supports NHWC tensor format for now.")); 78 typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>()); 79 typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); 80 typename TTypes<U>::ConstVec offset(offset_input.vec<U>()); 81 typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>()); 82 typename TTypes<U>::ConstVec estimated_variance( 83 estimated_variance_input.vec<U>()); 84 typename TTypes<T, 4>::Tensor y(y_output->tensor<T, 4>()); 85 typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>()); 86 typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>()); 87 typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>()); 88 typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>()); 89 90 const CPUDevice& d = context->eigen_device<CPUDevice>(); 91 92 const int depth = x.dimension(3); 93 const int size = x.size(); 94 const int rest_size = size / depth; 95 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); 96 97 #if !defined(EIGEN_HAS_INDEX_LIST) 98 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth); 99 Eigen::array<int, 1> reduce_dims({0}); 100 Eigen::array<int, 2> bcast_spec({rest_size, 1}); 101 #else 102 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; 103 one_by_depth.set(1, depth); 104 Eigen::IndexList<Eigen::type2index<0> > reduce_dims; 105 Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec; 106 bcast_spec.set(0, rest_size); 107 #endif 108 109 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); 110 const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; 111 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); 112 // This adjustment is for Bessel's correction 113 U rest_size_adjust = 114 static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one); 115 116 Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth); 117 Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth); 118 if (is_training) { 119 mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); 120 batch_mean.device(d) = mean; 121 saved_mean.device(d) = mean; 122 } else { 123 mean.device(d) = estimated_mean; 124 } 125 126 auto x_centered = 127 x_rest_by_depth - mean.reshape(one_by_depth).broadcast(bcast_spec); 128 129 if (is_training) { 130 variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv; 131 batch_var.device(d) = variance * rest_size_adjust; 132 saved_var.device(d) = variance; 133 } else { 134 variance.device(d) = estimated_variance; 135 } 136 137 auto scaling_factor = ((variance + epsilon).rsqrt() * scale) 138 .eval() 139 .reshape(one_by_depth) 140 .broadcast(bcast_spec); 141 auto x_scaled = x_centered * scaling_factor; 142 auto x_shifted = 143 x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec); 144 145 // Explicitly checks the types of T and U and only casts x_shifted when 146 // T != U. (Not doing so caused a 35-50% performance slowdown for 147 // some compiler flags.) 148 CastIfNecessary<std::is_same<T, U>::value, decltype(y), decltype(x_shifted), 149 T>::process(y, x_shifted, rest_by_depth, d); 150 } 151 }; 152 153 template <typename T, typename U> 154 struct FusedBatchNormGrad<CPUDevice, T, U> { 155 void operator()(OpKernelContext* context, const Tensor& y_backprop_input, 156 const Tensor& x_input, const Tensor& scale_input, 157 const Tensor& mean_input, const Tensor& variance_input, 158 U epsilon, Tensor* x_backprop_output, 159 Tensor* scale_backprop_output, Tensor* offset_backprop_output, 160 TensorFormat tensor_format) { 161 OP_REQUIRES(context, tensor_format == FORMAT_NHWC, 162 errors::Internal("The CPU implementation of FusedBatchNormGrad " 163 "only supports NHWC tensor format for now.")); 164 typename TTypes<T, 4>::ConstTensor y_backprop( 165 y_backprop_input.tensor<T, 4>()); 166 typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>()); 167 typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); 168 typename TTypes<U>::ConstVec mean(mean_input.vec<U>()); 169 typename TTypes<U>::ConstVec variance(variance_input.vec<U>()); 170 typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>()); 171 typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>()); 172 typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>()); 173 174 // Note: the following formulas are used to compute the gradients for 175 // back propagation. 176 // x_backprop = scale * rsqrt(variance + epsilon) * 177 // [y_backprop - mean(y_backprop) - (x - mean(x)) * 178 // mean(y_backprop * (x - mean(x))) / (variance + epsilon)] 179 // scale_backprop = sum(y_backprop * 180 // (x - mean(x)) * rsqrt(variance + epsilon)) 181 // offset_backprop = sum(y_backprop) 182 183 const CPUDevice& d = context->eigen_device<CPUDevice>(); 184 const int depth = x.dimension(3); 185 const int size = x.size(); 186 const int rest_size = size / depth; 187 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); 188 189 #if !defined(EIGEN_HAS_INDEX_LIST) 190 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth); 191 Eigen::array<int, 1> reduce_dims({0}); 192 Eigen::array<int, 2> bcast_spec({rest_size, 1}); 193 #else 194 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; 195 one_by_depth.set(1, depth); 196 Eigen::IndexList<Eigen::type2index<0> > reduce_dims; 197 Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec; 198 bcast_spec.set(0, rest_size); 199 #endif 200 201 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); 202 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); 203 204 auto x_mean_rest_by_depth = 205 mean.reshape(one_by_depth).broadcast(bcast_spec); 206 auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth).eval(); 207 auto coef0 = (variance + epsilon).rsqrt(); 208 auto coef0_rest_by_depth = 209 coef0.eval().reshape(one_by_depth).broadcast(bcast_spec); 210 auto x_scaled = x_centered * coef0_rest_by_depth; 211 212 auto y_backprop_rest_by_depth = 213 y_backprop.eval().reshape(rest_by_depth).template cast<U>(); 214 scale_backprop.device(d) = 215 (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims); 216 auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims); 217 offset_backprop.device(d) = y_backprop_sum; 218 219 auto y_backprop_sum_one_by_depth = 220 y_backprop_sum.eval().reshape(one_by_depth); 221 auto y_backprop_mean_one_by_depth = 222 y_backprop_sum_one_by_depth * rest_size_inv; 223 auto y_backprop_mean_rest_by_depth = 224 y_backprop_mean_one_by_depth.broadcast(bcast_spec); 225 auto y_backprop_centered = 226 y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth; 227 auto coef1 = 228 (scale * coef0).eval().reshape(one_by_depth).broadcast(bcast_spec); 229 auto coef2 = (coef0.square() * 230 (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)) 231 .eval() 232 .reshape(one_by_depth) 233 .broadcast(bcast_spec); 234 x_backprop.reshape(rest_by_depth).device(d) = 235 (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>(); 236 } 237 }; 238 239 #if GOOGLE_CUDA 240 template <typename T, typename U> 241 struct FusedBatchNorm<GPUDevice, T, U> { 242 void operator()(OpKernelContext* context, const Tensor& x, 243 const Tensor& scale, const Tensor& offset, 244 const Tensor& estimated_mean, 245 const Tensor& estimated_variance, U epsilon, Tensor* y, 246 Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean, 247 Tensor* saved_inv_var, TensorFormat tensor_format, 248 bool is_training) { 249 auto* stream = context->op_device_context()->stream(); 250 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available")); 251 252 const int64 batch_size = GetTensorDim(x, tensor_format, 'N'); 253 const int64 channels = GetTensorDim(x, tensor_format, 'C'); 254 const int64 height = GetTensorDim(x, tensor_format, 'H'); 255 const int64 width = GetTensorDim(x, tensor_format, 'W'); 256 VLOG(2) << "FusedBatchNorm:" 257 << " batch_size: " << batch_size << " channels: " << channels 258 << " height: " << height << " width:" << width 259 << " x shape: " << x.shape().DebugString() 260 << " scale shape: " << scale.shape().DebugString() 261 << " offset shape: " << offset.shape().DebugString() 262 << " tensor format: " << tensor_format; 263 264 // If input is empty, return NaN mean/variance 265 if (x.shape().num_elements() == 0) { 266 functor::SetNanFunctor<U> f; 267 f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>()); 268 f(context->eigen_device<GPUDevice>(), batch_var->flat<U>()); 269 return; 270 } 271 272 Tensor x_maybe_transformed = x; 273 Tensor x_transformed; 274 Tensor y_transformed; 275 se::DeviceMemory<T> y_ptr; 276 277 if (tensor_format == FORMAT_NCHW) { 278 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y); 279 } else if (tensor_format == FORMAT_NHWC) { 280 OP_REQUIRES_OK(context, context->allocate_temp( 281 DataTypeToEnum<T>::value, 282 ShapeFromFormat(FORMAT_NCHW, batch_size, 283 height, width, channels), 284 &x_transformed)); 285 functor::NHWCToNCHW<GPUDevice, T, 4>()( 286 context->eigen_device<GPUDevice>(), 287 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(), 288 x_transformed.tensor<T, 4>()); 289 x_maybe_transformed = x_transformed; 290 291 OP_REQUIRES_OK(context, context->allocate_temp( 292 DataTypeToEnum<T>::value, 293 ShapeFromFormat(FORMAT_NCHW, batch_size, 294 height, width, channels), 295 &y_transformed)); 296 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed); 297 } else { 298 context->SetStatus( 299 errors::Internal("Unsupported tensor format: ", tensor_format)); 300 return; 301 } 302 303 se::dnn::BatchDescriptor x_desc; 304 x_desc.set_count(batch_size) 305 .set_feature_map_count(channels) 306 .set_height(height) 307 .set_width(width) 308 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 309 310 se::dnn::BatchDescriptor scale_offset_desc; 311 scale_offset_desc.set_count(1) 312 .set_feature_map_count(channels) 313 .set_height(1) 314 .set_width(1) 315 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 316 317 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); 318 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); 319 auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset); 320 auto estimated_mean_ptr = 321 StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean); 322 auto estimated_variance_ptr = 323 StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance); 324 auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean); 325 326 auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var); 327 auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean); 328 auto saved_inv_var_ptr = 329 StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var); 330 331 GPUDevice d = context->eigen_device<GPUDevice>(); 332 using se::DeviceMemory; 333 Tensor inv_var; 334 OP_REQUIRES_OK( 335 context, context->allocate_temp(DataTypeToEnum<U>::value, 336 estimated_variance.shape(), &inv_var)); 337 auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_var); 338 std::function<const DeviceMemory<U>&()> var_to_inv_var = 339 [d, epsilon, estimated_variance, 340 &inv_var_ptr]() -> const DeviceMemory<U>& { 341 auto estimated_variance_ptr = 342 StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance); 343 const U* variance = 344 static_cast<const U*>(estimated_variance_ptr.opaque()); 345 U* inv_variance = static_cast<U*>(inv_var_ptr.opaque()); 346 int channels = inv_var_ptr.ElementCount(); 347 VarianceToInvVariance<U>()(d, variance, epsilon, channels, inv_variance); 348 return inv_var_ptr; 349 }; 350 const int64 sample_size = batch_size * height * width; 351 std::function<void()> inv_var_to_var = [d, &batch_var_ptr, epsilon, 352 sample_size]() { 353 U* variance = static_cast<U*>(batch_var_ptr.opaque()); 354 int channels = batch_var_ptr.ElementCount(); 355 InvVarianceToVariance<U>()(d, epsilon, sample_size, channels, variance); 356 }; 357 358 bool cudnn_launch_status = 359 stream 360 ->ThenBatchNormalizationForward( 361 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr, 362 estimated_variance_ptr, x_desc, scale_offset_desc, 363 static_cast<double>(epsilon), &y_ptr, &batch_mean_ptr, 364 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr, 365 is_training, std::move(var_to_inv_var), 366 std::move(inv_var_to_var)) 367 .ok(); 368 369 if (!cudnn_launch_status) { 370 context->SetStatus( 371 errors::Internal("cuDNN launch failure : input shape (", 372 x.shape().DebugString(), ")")); 373 } 374 if (tensor_format == FORMAT_NHWC) { 375 functor::NCHWToNHWC<GPUDevice, T, 4>()( 376 context->eigen_device<GPUDevice>(), 377 const_cast<const Tensor&>(y_transformed).tensor<T, 4>(), 378 y->tensor<T, 4>()); 379 } 380 } 381 }; 382 383 template <typename T, typename U> 384 struct FusedBatchNormGrad<GPUDevice, T, U> { 385 void operator()(OpKernelContext* context, const Tensor& y_backprop, 386 const Tensor& x, const Tensor& scale, const Tensor& mean, 387 const Tensor& inv_variance, U epsilon, Tensor* x_backprop, 388 Tensor* scale_backprop, Tensor* offset_backprop, 389 TensorFormat tensor_format) { 390 auto* stream = context->op_device_context()->stream(); 391 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available")); 392 393 const int64 batch_size = GetTensorDim(x, tensor_format, 'N'); 394 const int64 channels = GetTensorDim(x, tensor_format, 'C'); 395 const int64 height = GetTensorDim(x, tensor_format, 'H'); 396 const int64 width = GetTensorDim(x, tensor_format, 'W'); 397 398 VLOG(2) << "FusedBatchNormGrad:" 399 << " batch_size: " << batch_size << " channels: " << channels 400 << " height: " << height << " width: " << width 401 << " y_backprop shape: " << y_backprop.shape().DebugString() 402 << " x shape: " << x.shape().DebugString() 403 << " scale shape: " << scale.shape().DebugString() 404 << " tensor format: " << tensor_format; 405 406 // Inputs 407 Tensor y_backprop_maybe_transformed = y_backprop; 408 Tensor x_maybe_transformed = x; 409 Tensor y_backprop_transformed; 410 Tensor x_transformed; 411 412 // Outputs 413 Tensor x_backprop_transformed; 414 se::DeviceMemory<T> x_backprop_ptr; 415 416 if (tensor_format == FORMAT_NCHW) { 417 x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop); 418 } else if (tensor_format == FORMAT_NHWC) { 419 // Transform inputs from 'NHWC' to 'NCHW' 420 OP_REQUIRES_OK(context, context->allocate_temp( 421 DataTypeToEnum<T>::value, 422 ShapeFromFormat(FORMAT_NCHW, batch_size, 423 height, width, channels), 424 &y_backprop_transformed)); 425 functor::NHWCToNCHW<GPUDevice, T, 4>()( 426 context->eigen_device<GPUDevice>(), 427 const_cast<const Tensor&>(y_backprop_maybe_transformed) 428 .tensor<T, 4>(), 429 y_backprop_transformed.tensor<T, 4>()); 430 y_backprop_maybe_transformed = y_backprop_transformed; 431 432 OP_REQUIRES_OK(context, context->allocate_temp( 433 DataTypeToEnum<T>::value, 434 ShapeFromFormat(FORMAT_NCHW, batch_size, 435 height, width, channels), 436 &x_transformed)); 437 functor::NHWCToNCHW<GPUDevice, T, 4>()( 438 context->eigen_device<GPUDevice>(), 439 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(), 440 x_transformed.tensor<T, 4>()); 441 x_maybe_transformed = x_transformed; 442 443 // Allocate memory for transformed outputs in 'NCHW' 444 OP_REQUIRES_OK(context, context->allocate_temp( 445 DataTypeToEnum<T>::value, 446 ShapeFromFormat(FORMAT_NCHW, batch_size, 447 height, width, channels), 448 &x_backprop_transformed)); 449 x_backprop_ptr = 450 StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed); 451 } else { 452 context->SetStatus( 453 errors::Internal("Unsupported tensor format: ", tensor_format)); 454 return; 455 } 456 457 se::dnn::BatchDescriptor x_desc; 458 x_desc.set_count(batch_size) 459 .set_feature_map_count(channels) 460 .set_height(height) 461 .set_width(width) 462 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 463 464 se::dnn::BatchDescriptor scale_offset_desc; 465 scale_offset_desc.set_count(1) 466 .set_feature_map_count(channels) 467 .set_height(1) 468 .set_width(1) 469 .set_layout(se::dnn::DataLayout::kBatchDepthYX); 470 471 auto y_backprop_ptr = 472 StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed); 473 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); 474 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); 475 auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean); 476 auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance); 477 auto scale_backprop_ptr = 478 StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop); 479 auto offset_backprop_ptr = 480 StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop); 481 482 // the cudnn kernel outputs inverse variance in forward and reuse it in 483 // backward 484 bool cudnn_launch_status = 485 stream 486 ->ThenBatchNormalizationBackward( 487 y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr, 488 x_desc, scale_offset_desc, static_cast<double>(epsilon), 489 &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr) 490 .ok(); 491 492 if (!cudnn_launch_status) { 493 context->SetStatus( 494 errors::Internal("cuDNN launch failure : input shape (", 495 x.shape().DebugString(), ")")); 496 } 497 if (tensor_format == FORMAT_NHWC) { 498 functor::NCHWToNHWC<GPUDevice, T, 4>()( 499 context->eigen_device<GPUDevice>(), 500 const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(), 501 x_backprop->tensor<T, 4>()); 502 } 503 } 504 }; 505 506 // Forward declarations of the functor specializations for GPU. 507 #define DECLARE_GPU_SPEC(T, U) \ 508 template <> \ 509 void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \ 510 const GPUDevice& d, const Tensor& y_backprop_input, \ 511 const Tensor& x_input, const Tensor& scale_input, \ 512 const Tensor& mean_input, const Tensor& variance_input, U epsilon, \ 513 Tensor* x_backprop_output, Tensor* scale_backprop_output, \ 514 Tensor* offset_backprop_output, typename TTypes<U>::Vec scratch1, \ 515 typename TTypes<U>::Vec scratch2); \ 516 extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>; 517 DECLARE_GPU_SPEC(float, float); 518 DECLARE_GPU_SPEC(Eigen::half, float); 519 520 #endif // GOOGLE_CUDA 521 } // namespace functor 522 523 template <typename Device, typename T, typename U> 524 class FusedBatchNormOp : public OpKernel { 525 public: 526 explicit FusedBatchNormOp(OpKernelConstruction* context) : OpKernel(context) { 527 float epsilon; 528 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 529 epsilon_ = U(epsilon); 530 string tensor_format; 531 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 532 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 533 errors::InvalidArgument("Invalid data format")); 534 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 535 } 536 537 void Compute(OpKernelContext* context) override { 538 const Tensor& x = context->input(0); 539 const Tensor& scale = context->input(1); 540 const Tensor& offset = context->input(2); 541 const Tensor& estimated_mean = context->input(3); 542 const Tensor& estimated_variance = context->input(4); 543 544 OP_REQUIRES(context, x.dims() == 4, 545 errors::InvalidArgument("input must be 4-dimensional", 546 x.shape().DebugString())); 547 OP_REQUIRES(context, scale.dims() == 1, 548 errors::InvalidArgument("scale must be 1-dimensional", 549 scale.shape().DebugString())); 550 OP_REQUIRES(context, offset.dims() == 1, 551 errors::InvalidArgument("offset must be 1-dimensional", 552 offset.shape().DebugString())); 553 OP_REQUIRES(context, estimated_mean.dims() == 1, 554 errors::InvalidArgument("estimated_mean must be 1-dimensional", 555 estimated_mean.shape().DebugString())); 556 OP_REQUIRES( 557 context, estimated_variance.dims() == 1, 558 errors::InvalidArgument("estimated_variance must be 1-dimensional", 559 estimated_variance.shape().DebugString())); 560 if (is_training_) { 561 OP_REQUIRES( 562 context, estimated_mean.dim_size(0) == 0, 563 errors::InvalidArgument("estimated_mean must be empty for training", 564 estimated_mean.shape().DebugString())); 565 OP_REQUIRES(context, estimated_variance.dim_size(0) == 0, 566 errors::InvalidArgument( 567 "estimated_variance must be empty for training", 568 estimated_variance.shape().DebugString())); 569 } 570 571 Tensor* y = nullptr; 572 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 573 {0}, 0, x.shape(), &y)); 574 Tensor* batch_mean = nullptr; 575 OP_REQUIRES_OK(context, 576 context->allocate_output(1, scale.shape(), &batch_mean)); 577 Tensor* batch_var = nullptr; 578 OP_REQUIRES_OK(context, 579 context->allocate_output(2, scale.shape(), &batch_var)); 580 Tensor* saved_mean = nullptr; 581 OP_REQUIRES_OK(context, 582 context->allocate_output(3, scale.shape(), &saved_mean)); 583 Tensor* saved_maybe_inv_var = nullptr; 584 OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(), 585 &saved_maybe_inv_var)); 586 587 functor::FusedBatchNorm<Device, T, U>()( 588 context, x, scale, offset, estimated_mean, estimated_variance, epsilon_, 589 y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var, 590 tensor_format_, is_training_); 591 } 592 593 private: 594 U epsilon_; 595 TensorFormat tensor_format_; 596 bool is_training_; 597 }; 598 599 template <typename Device, typename T, typename U> 600 class FusedBatchNormGradOp : public OpKernel { 601 public: 602 explicit FusedBatchNormGradOp(OpKernelConstruction* context) 603 : OpKernel(context) { 604 float epsilon; 605 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 606 epsilon_ = U(epsilon); 607 string tensor_format; 608 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 609 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 610 errors::InvalidArgument("Invalid data format")); 611 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 612 } 613 614 void Compute(OpKernelContext* context) override { 615 const Tensor& y_backprop = context->input(0); 616 const Tensor& x = context->input(1); 617 const Tensor& scale = context->input(2); 618 // When is_training=True, batch mean and variance/inverted variance are 619 // saved in the forward pass to be reused here. When is_training=False, 620 // population mean and variance need to be forwarded here to compute the 621 // gradients. 622 const Tensor& saved_mean_or_pop_mean = context->input(3); 623 // The Eigen implementation saves variance in the forward pass, while cuDNN 624 // saves inverted variance. 625 const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4); 626 627 OP_REQUIRES(context, y_backprop.dims() == 4, 628 errors::InvalidArgument("input must be 4-dimensional", 629 y_backprop.shape().DebugString())); 630 OP_REQUIRES(context, x.dims() == 4, 631 errors::InvalidArgument("input must be 4-dimensional", 632 x.shape().DebugString())); 633 OP_REQUIRES(context, scale.dims() == 1, 634 errors::InvalidArgument("scale must be 1-dimensional", 635 scale.shape().DebugString())); 636 OP_REQUIRES( 637 context, saved_mean_or_pop_mean.dims() == 1, 638 errors::InvalidArgument("saved mean must be 1-dimensional", 639 saved_mean_or_pop_mean.shape().DebugString())); 640 OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1, 641 errors::InvalidArgument( 642 "saved variance must be 1-dimensional", 643 saved_maybe_inv_var_or_pop_var.shape().DebugString())); 644 645 Tensor* x_backprop = nullptr; 646 OP_REQUIRES_OK(context, 647 context->allocate_output(0, x.shape(), &x_backprop)); 648 649 const TensorShape& scale_offset_shape = scale.shape(); 650 Tensor* scale_backprop = nullptr; 651 OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape, 652 &scale_backprop)); 653 Tensor* offset_backprop = nullptr; 654 OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape, 655 &offset_backprop)); 656 // Two placeholders for estimated_mean and estimated_variance, which are 657 // used for inference and thus not needed here for gradient computation. 658 // They are filled with zeros so as to avoid NaN outputs. 659 Tensor* placeholder_1 = nullptr; 660 OP_REQUIRES_OK( 661 context, context->allocate_output(3, TensorShape({}), &placeholder_1)); 662 functor::SetZeroFunctor<Device, float> f; 663 f(context->eigen_device<Device>(), placeholder_1->flat<U>()); 664 Tensor* placeholder_2 = nullptr; 665 OP_REQUIRES_OK( 666 context, context->allocate_output(4, TensorShape({}), &placeholder_2)); 667 f(context->eigen_device<Device>(), placeholder_2->flat<U>()); 668 669 // If input is empty, set gradients w.r.t scale/offset to zero. 670 if (x.shape().num_elements() == 0) { 671 functor::SetZeroFunctor<Device, U> f; 672 f(context->eigen_device<Device>(), scale_backprop->flat<U>()); 673 f(context->eigen_device<Device>(), offset_backprop->flat<U>()); 674 return; 675 } 676 677 if (is_training_) { 678 functor::FusedBatchNormGrad<Device, T, U>()( 679 context, y_backprop, x, scale, saved_mean_or_pop_mean, 680 saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, 681 offset_backprop, tensor_format_); 682 683 } else { 684 // Necessary layout conversion is currently done in python. 685 CHECK(tensor_format_ == FORMAT_NHWC) 686 << "The implementation of FusedBatchNormGrad with is_training=False " 687 "only support " 688 << "NHWC tensor format for now."; 689 Tensor scratch1, scratch2; 690 OP_REQUIRES_OK(context, 691 context->allocate_temp(DataTypeToEnum<U>::value, 692 scale_offset_shape, &scratch1)); 693 OP_REQUIRES_OK(context, 694 context->allocate_temp(DataTypeToEnum<U>::value, 695 scale_offset_shape, &scratch2)); 696 functor::FusedBatchNormFreezeGrad<Device, T, U>()( 697 context->eigen_device<Device>(), y_backprop, x, scale, 698 saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, 699 x_backprop, scale_backprop, offset_backprop, scratch1.vec<U>(), 700 scratch2.vec<U>()); 701 } 702 } 703 704 private: 705 U epsilon_; 706 TensorFormat tensor_format_; 707 bool is_training_; 708 }; 709 710 REGISTER_KERNEL_BUILDER( 711 Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"), 712 FusedBatchNormOp<CPUDevice, float, float>); 713 714 REGISTER_KERNEL_BUILDER( 715 Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"), 716 FusedBatchNormGradOp<CPUDevice, float, float>); 717 718 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 719 .Device(DEVICE_CPU) 720 .TypeConstraint<float>("T") 721 .TypeConstraint<float>("U"), 722 FusedBatchNormOp<CPUDevice, float, float>); 723 724 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 725 .Device(DEVICE_CPU) 726 .TypeConstraint<float>("T") 727 .TypeConstraint<float>("U"), 728 FusedBatchNormGradOp<CPUDevice, float, float>); 729 730 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 731 .Device(DEVICE_CPU) 732 .TypeConstraint<Eigen::half>("T") 733 .TypeConstraint<float>("U"), 734 FusedBatchNormOp<CPUDevice, Eigen::half, float>); 735 736 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 737 .Device(DEVICE_CPU) 738 .TypeConstraint<Eigen::half>("T") 739 .TypeConstraint<float>("U"), 740 FusedBatchNormGradOp<CPUDevice, Eigen::half, float>); 741 742 #if GOOGLE_CUDA 743 744 REGISTER_KERNEL_BUILDER( 745 Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"), 746 FusedBatchNormOp<GPUDevice, float, float>); 747 748 REGISTER_KERNEL_BUILDER( 749 Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), 750 FusedBatchNormGradOp<GPUDevice, float, float>); 751 752 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 753 .Device(DEVICE_GPU) 754 .TypeConstraint<float>("T") 755 .TypeConstraint<float>("U"), 756 FusedBatchNormOp<GPUDevice, float, float>); 757 758 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 759 .Device(DEVICE_GPU) 760 .TypeConstraint<float>("T") 761 .TypeConstraint<float>("U"), 762 FusedBatchNormGradOp<GPUDevice, float, float>); 763 764 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 765 .Device(DEVICE_GPU) 766 .TypeConstraint<Eigen::half>("T") 767 .TypeConstraint<float>("U"), 768 FusedBatchNormOp<GPUDevice, Eigen::half, float>); 769 770 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 771 .Device(DEVICE_GPU) 772 .TypeConstraint<Eigen::half>("T") 773 .TypeConstraint<float>("U"), 774 FusedBatchNormGradOp<GPUDevice, Eigen::half, float>); 775 776 #endif 777 778 } // namespace tensorflow 779