1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #define EIGEN_USE_THREADS 17 18 #if GOOGLE_CUDA 19 #define EIGEN_USE_GPU 20 #include "tensorflow/core/kernels/conv_2d.h" 21 #include "tensorflow/core/kernels/conv_ops_gpu.h" 22 #include "tensorflow/core/util/stream_executor_util.h" 23 #endif 24 25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_types.h" 30 #include "tensorflow/core/kernels/fill_functor.h" 31 #include "tensorflow/core/kernels/fused_batch_norm_op.h" 32 #include "tensorflow/core/util/tensor_format.h" 33 34 namespace tensorflow { 35 using CPUDevice = Eigen::ThreadPoolDevice; 36 using GPUDevice = Eigen::GpuDevice; 37 38 namespace functor { 39 40 // Functor used by FusedBatchNormOp to do the computations. 41 template <typename Device, typename T, typename U> 42 struct FusedBatchNorm; 43 // Functor used by FusedBatchNormGradOp to do the computations when 44 // is_training=True. 45 template <typename Device, typename T, typename U> 46 struct FusedBatchNormGrad; 47 48 template <typename T, typename U> 49 struct FusedBatchNorm<CPUDevice, T, U> { 50 void operator()(OpKernelContext* context, const Tensor& x_input, 51 const Tensor& scale_input, const Tensor& offset_input, 52 const Tensor& estimated_mean_input, 53 const Tensor& estimated_variance_input, U epsilon, 54 Tensor* y_output, Tensor* batch_mean_output, 55 Tensor* batch_var_output, Tensor* saved_mean_output, 56 Tensor* saved_var_output, TensorFormat tensor_format, 57 bool is_training) { 58 OP_REQUIRES(context, tensor_format == FORMAT_NHWC, 59 errors::Internal("The CPU implementation of FusedBatchNorm " 60 "only supports NHWC tensor format for now.")); 61 typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>()); 62 typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); 63 typename TTypes<U>::ConstVec offset(offset_input.vec<U>()); 64 typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>()); 65 typename TTypes<U>::ConstVec estimated_variance( 66 estimated_variance_input.vec<U>()); 67 typename TTypes<T, 4>::Tensor y(y_output->tensor<T, 4>()); 68 typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>()); 69 typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>()); 70 typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>()); 71 typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>()); 72 73 const CPUDevice& d = context->eigen_device<CPUDevice>(); 74 75 const int depth = x.dimension(3); 76 const int size = x.size(); 77 const int rest_size = size / depth; 78 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); 79 80 #if !defined(EIGEN_HAS_INDEX_LIST) 81 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth); 82 Eigen::array<int, 1> reduce_dims({0}); 83 Eigen::array<int, 2> bcast_spec({rest_size, 1}); 84 #else 85 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; 86 one_by_depth.set(1, depth); 87 Eigen::IndexList<Eigen::type2index<0> > reduce_dims; 88 Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec; 89 bcast_spec.set(0, rest_size); 90 #endif 91 92 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); 93 const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1; 94 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); 95 // This adjustment is for Bessel's correction 96 U rest_size_adjust = 97 static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one); 98 99 Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth); 100 Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth); 101 if (is_training) { 102 mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv); 103 batch_mean.device(d) = mean; 104 saved_mean.device(d) = mean; 105 } else { 106 mean.device(d) = estimated_mean; 107 } 108 109 auto x_centered = 110 x_rest_by_depth - mean.reshape(one_by_depth).broadcast(bcast_spec); 111 112 if (is_training) { 113 variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv; 114 batch_var.device(d) = variance * rest_size_adjust; 115 saved_var.device(d) = variance; 116 } else { 117 variance.device(d) = estimated_variance; 118 } 119 120 auto scaling_factor = ((variance + epsilon).rsqrt() * scale) 121 .eval() 122 .reshape(one_by_depth) 123 .broadcast(bcast_spec); 124 auto x_scaled = x_centered * scaling_factor; 125 auto x_shifted = 126 x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec); 127 128 y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>(); 129 } 130 }; 131 132 template <typename T, typename U> 133 struct FusedBatchNormGrad<CPUDevice, T, U> { 134 void operator()(OpKernelContext* context, const Tensor& y_backprop_input, 135 const Tensor& x_input, const Tensor& scale_input, 136 const Tensor& mean_input, const Tensor& variance_input, 137 U epsilon, Tensor* x_backprop_output, 138 Tensor* scale_backprop_output, Tensor* offset_backprop_output, 139 TensorFormat tensor_format) { 140 OP_REQUIRES(context, tensor_format == FORMAT_NHWC, 141 errors::Internal("The CPU implementation of FusedBatchNormGrad " 142 "only supports NHWC tensor format for now.")); 143 typename TTypes<T, 4>::ConstTensor y_backprop( 144 y_backprop_input.tensor<T, 4>()); 145 typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>()); 146 typename TTypes<U>::ConstVec scale(scale_input.vec<U>()); 147 typename TTypes<U>::ConstVec mean(mean_input.vec<U>()); 148 typename TTypes<U>::ConstVec variance(variance_input.vec<U>()); 149 typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>()); 150 typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>()); 151 typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>()); 152 153 // Note: the following formulas are used to compute the gradients for 154 // back propagation. 155 // x_backprop = scale * rsqrt(variance + epsilon) * 156 // [y_backprop - mean(y_backprop) - (x - mean(x)) * 157 // mean(y_backprop * (x - mean(x))) / (variance + epsilon)] 158 // scale_backprop = sum(y_backprop * 159 // (x - mean(x)) * rsqrt(variance + epsilon)) 160 // offset_backprop = sum(y_backprop) 161 162 const CPUDevice& d = context->eigen_device<CPUDevice>(); 163 const int depth = x.dimension(3); 164 const int size = x.size(); 165 const int rest_size = size / depth; 166 Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth); 167 168 #if !defined(EIGEN_HAS_INDEX_LIST) 169 Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth); 170 Eigen::array<int, 1> reduce_dims({0}); 171 Eigen::array<int, 2> bcast_spec({rest_size, 1}); 172 #else 173 Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth; 174 one_by_depth.set(1, depth); 175 Eigen::IndexList<Eigen::type2index<0> > reduce_dims; 176 Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec; 177 bcast_spec.set(0, rest_size); 178 #endif 179 180 auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>(); 181 U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size)); 182 183 auto x_mean_rest_by_depth = 184 mean.reshape(one_by_depth).broadcast(bcast_spec); 185 auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth).eval(); 186 auto coef0 = (variance + epsilon).rsqrt(); 187 auto coef0_rest_by_depth = 188 coef0.eval().reshape(one_by_depth).broadcast(bcast_spec); 189 auto x_scaled = x_centered * coef0_rest_by_depth; 190 191 auto y_backprop_rest_by_depth = 192 y_backprop.eval().reshape(rest_by_depth).template cast<U>(); 193 scale_backprop.device(d) = 194 (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims); 195 auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims); 196 offset_backprop.device(d) = y_backprop_sum; 197 198 auto y_backprop_sum_one_by_depth = 199 y_backprop_sum.eval().reshape(one_by_depth); 200 auto y_backprop_mean_one_by_depth = 201 y_backprop_sum_one_by_depth * rest_size_inv; 202 auto y_backprop_mean_rest_by_depth = 203 y_backprop_mean_one_by_depth.broadcast(bcast_spec); 204 auto y_backprop_centered = 205 y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth; 206 auto coef1 = 207 (scale * coef0).eval().reshape(one_by_depth).broadcast(bcast_spec); 208 auto coef2 = (coef0.square() * 209 (y_backprop_rest_by_depth * x_centered).mean(reduce_dims)) 210 .eval() 211 .reshape(one_by_depth) 212 .broadcast(bcast_spec); 213 x_backprop.reshape(rest_by_depth).device(d) = 214 (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>(); 215 } 216 }; 217 218 #if GOOGLE_CUDA 219 template <typename T, typename U> 220 struct FusedBatchNorm<GPUDevice, T, U> { 221 void operator()(OpKernelContext* context, const Tensor& x, 222 const Tensor& scale, const Tensor& offset, 223 const Tensor& estimated_mean, 224 const Tensor& estimated_variance, U epsilon, Tensor* y, 225 Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean, 226 Tensor* saved_inv_var, TensorFormat tensor_format, 227 bool is_training) { 228 auto* stream = context->op_device_context()->stream(); 229 OP_REQUIRES(context, stream, errors::Internal("No GPU stream avalible")); 230 231 const int64 batch_size = GetTensorDim(x, tensor_format, 'N'); 232 const int64 channels = GetTensorDim(x, tensor_format, 'C'); 233 const int64 height = GetTensorDim(x, tensor_format, 'H'); 234 const int64 width = GetTensorDim(x, tensor_format, 'W'); 235 VLOG(2) << "FusedBatchNorm:" 236 << " batch_size: " << batch_size << " channels: " << channels 237 << " height: " << height << " width:" << width 238 << " x shape: " << x.shape().DebugString() 239 << " scale shape: " << scale.shape().DebugString() 240 << " offset shape: " << offset.shape().DebugString() 241 << " tensor format: " << tensor_format; 242 243 // If input is empty, return NaN mean/variance 244 if (x.shape().num_elements() == 0) { 245 functor::SetNanFunctor<U> f; 246 f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>()); 247 f(context->eigen_device<GPUDevice>(), batch_var->flat<U>()); 248 return; 249 } 250 251 Tensor x_maybe_transformed = x; 252 Tensor x_transformed; 253 Tensor y_transformed; 254 perftools::gputools::DeviceMemory<T> y_ptr; 255 256 if (tensor_format == FORMAT_NCHW) { 257 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y); 258 } else if (tensor_format == FORMAT_NHWC) { 259 OP_REQUIRES_OK(context, context->allocate_temp( 260 DataTypeToEnum<T>::value, 261 ShapeFromFormat(FORMAT_NCHW, batch_size, 262 height, width, channels), 263 &x_transformed)); 264 functor::NHWCToNCHW<GPUDevice, T, 4>()( 265 context->eigen_device<GPUDevice>(), 266 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(), 267 x_transformed.tensor<T, 4>()); 268 x_maybe_transformed = x_transformed; 269 270 OP_REQUIRES_OK(context, context->allocate_temp( 271 DataTypeToEnum<T>::value, 272 ShapeFromFormat(FORMAT_NCHW, batch_size, 273 height, width, channels), 274 &y_transformed)); 275 y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed); 276 } else { 277 context->SetStatus( 278 errors::Internal("Unsupported tensor format: ", tensor_format)); 279 return; 280 } 281 282 perftools::gputools::dnn::BatchDescriptor x_desc; 283 x_desc.set_count(batch_size) 284 .set_feature_map_count(channels) 285 .set_height(height) 286 .set_width(width) 287 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 288 289 perftools::gputools::dnn::BatchDescriptor scale_offset_desc; 290 scale_offset_desc.set_count(1) 291 .set_feature_map_count(channels) 292 .set_height(1) 293 .set_width(1) 294 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 295 296 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); 297 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); 298 auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset); 299 auto estimated_mean_ptr = 300 StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean); 301 auto estimated_variance_ptr = 302 StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance); 303 auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean); 304 305 auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var); 306 auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean); 307 auto saved_inv_var_ptr = 308 StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var); 309 310 GPUDevice d = context->eigen_device<GPUDevice>(); 311 using perftools::gputools::DeviceMemory; 312 Tensor inv_var; 313 OP_REQUIRES_OK( 314 context, context->allocate_temp(DataTypeToEnum<U>::value, 315 estimated_variance.shape(), &inv_var)); 316 auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_var); 317 std::function<const DeviceMemory<U>&()> var_to_inv_var = 318 [d, epsilon, estimated_variance, 319 &inv_var_ptr]() -> const DeviceMemory<U>& { 320 auto estimated_variance_ptr = 321 StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance); 322 const U* variance = 323 static_cast<const U*>(estimated_variance_ptr.opaque()); 324 U* inv_variance = static_cast<U*>(inv_var_ptr.opaque()); 325 int channels = inv_var_ptr.ElementCount(); 326 VarianceToInvVariance<U>()(d, variance, epsilon, channels, inv_variance); 327 return inv_var_ptr; 328 }; 329 const int64 sample_size = batch_size * height * width; 330 std::function<void()> inv_var_to_var = [d, &batch_var_ptr, epsilon, 331 sample_size]() { 332 U* variance = static_cast<U*>(batch_var_ptr.opaque()); 333 int channels = batch_var_ptr.ElementCount(); 334 InvVarianceToVariance<U>()(d, epsilon, sample_size, channels, variance); 335 }; 336 337 bool cudnn_launch_status = 338 stream 339 ->ThenBatchNormalizationForward( 340 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr, 341 estimated_variance_ptr, x_desc, scale_offset_desc, 342 static_cast<double>(epsilon), &y_ptr, &batch_mean_ptr, 343 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr, 344 is_training, std::move(var_to_inv_var), 345 std::move(inv_var_to_var)) 346 .ok(); 347 348 if (!cudnn_launch_status) { 349 context->SetStatus( 350 errors::Internal("cuDNN launch failure : input shape (", 351 x.shape().DebugString(), ")")); 352 } 353 if (tensor_format == FORMAT_NHWC) { 354 functor::NCHWToNHWC<GPUDevice, T, 4>()( 355 context->eigen_device<GPUDevice>(), 356 const_cast<const Tensor&>(y_transformed).tensor<T, 4>(), 357 y->tensor<T, 4>()); 358 } 359 } 360 }; 361 362 template <typename T, typename U> 363 struct FusedBatchNormGrad<GPUDevice, T, U> { 364 void operator()(OpKernelContext* context, const Tensor& y_backprop, 365 const Tensor& x, const Tensor& scale, const Tensor& mean, 366 const Tensor& inv_variance, U epsilon, Tensor* x_backprop, 367 Tensor* scale_backprop, Tensor* offset_backprop, 368 TensorFormat tensor_format) { 369 auto* stream = context->op_device_context()->stream(); 370 OP_REQUIRES(context, stream, errors::Internal("No GPU stream avalible")); 371 372 const int64 batch_size = GetTensorDim(x, tensor_format, 'N'); 373 const int64 channels = GetTensorDim(x, tensor_format, 'C'); 374 const int64 height = GetTensorDim(x, tensor_format, 'H'); 375 const int64 width = GetTensorDim(x, tensor_format, 'W'); 376 377 VLOG(2) << "FusedBatchNormGrad:" 378 << " batch_size: " << batch_size << " channels: " << channels 379 << " height: " << height << " width: " << width 380 << " y_backprop shape: " << y_backprop.shape().DebugString() 381 << " x shape: " << x.shape().DebugString() 382 << " scale shape: " << scale.shape().DebugString() 383 << " tensor format: " << tensor_format; 384 385 // Inputs 386 Tensor y_backprop_maybe_transformed = y_backprop; 387 Tensor x_maybe_transformed = x; 388 Tensor y_backprop_transformed; 389 Tensor x_transformed; 390 391 // Outputs 392 Tensor x_backprop_transformed; 393 perftools::gputools::DeviceMemory<T> x_backprop_ptr; 394 395 if (tensor_format == FORMAT_NCHW) { 396 x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop); 397 } else if (tensor_format == FORMAT_NHWC) { 398 // Transform inputs from 'NHWC' to 'NCHW' 399 OP_REQUIRES_OK(context, context->allocate_temp( 400 DataTypeToEnum<T>::value, 401 ShapeFromFormat(FORMAT_NCHW, batch_size, 402 height, width, channels), 403 &y_backprop_transformed)); 404 functor::NHWCToNCHW<GPUDevice, T, 4>()( 405 context->eigen_device<GPUDevice>(), 406 const_cast<const Tensor&>(y_backprop_maybe_transformed) 407 .tensor<T, 4>(), 408 y_backprop_transformed.tensor<T, 4>()); 409 y_backprop_maybe_transformed = y_backprop_transformed; 410 411 OP_REQUIRES_OK(context, context->allocate_temp( 412 DataTypeToEnum<T>::value, 413 ShapeFromFormat(FORMAT_NCHW, batch_size, 414 height, width, channels), 415 &x_transformed)); 416 functor::NHWCToNCHW<GPUDevice, T, 4>()( 417 context->eigen_device<GPUDevice>(), 418 const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(), 419 x_transformed.tensor<T, 4>()); 420 x_maybe_transformed = x_transformed; 421 422 // Allocate memory for transformed outputs in 'NCHW' 423 OP_REQUIRES_OK(context, context->allocate_temp( 424 DataTypeToEnum<T>::value, 425 ShapeFromFormat(FORMAT_NCHW, batch_size, 426 height, width, channels), 427 &x_backprop_transformed)); 428 x_backprop_ptr = 429 StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed); 430 } else { 431 context->SetStatus( 432 errors::Internal("Unsupported tensor format: ", tensor_format)); 433 return; 434 } 435 436 perftools::gputools::dnn::BatchDescriptor x_desc; 437 x_desc.set_count(batch_size) 438 .set_feature_map_count(channels) 439 .set_height(height) 440 .set_width(width) 441 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 442 443 perftools::gputools::dnn::BatchDescriptor scale_offset_desc; 444 scale_offset_desc.set_count(1) 445 .set_feature_map_count(channels) 446 .set_height(1) 447 .set_width(1) 448 .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX); 449 450 auto y_backprop_ptr = 451 StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed); 452 auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed); 453 auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale); 454 auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean); 455 auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance); 456 auto scale_backprop_ptr = 457 StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop); 458 auto offset_backprop_ptr = 459 StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop); 460 461 // the cudnn kernel outputs inverse variance in forward and reuse it in 462 // backward 463 bool cudnn_launch_status = 464 stream 465 ->ThenBatchNormalizationBackward( 466 y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr, 467 x_desc, scale_offset_desc, static_cast<double>(epsilon), 468 &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr) 469 .ok(); 470 471 if (!cudnn_launch_status) { 472 context->SetStatus( 473 errors::Internal("cuDNN launch failure : input shape (", 474 x.shape().DebugString(), ")")); 475 } 476 if (tensor_format == FORMAT_NHWC) { 477 functor::NCHWToNHWC<GPUDevice, T, 4>()( 478 context->eigen_device<GPUDevice>(), 479 const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(), 480 x_backprop->tensor<T, 4>()); 481 } 482 } 483 }; 484 485 // Forward declarations of the functor specializations for GPU. 486 #define DECLARE_GPU_SPEC(T, U) \ 487 template <> \ 488 void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()( \ 489 const GPUDevice& d, const Tensor& y_backprop_input, \ 490 const Tensor& x_input, const Tensor& scale_input, \ 491 const Tensor& mean_input, const Tensor& variance_input, U epsilon, \ 492 Tensor* x_backprop_output, Tensor* scale_backprop_output, \ 493 Tensor* offset_backprop_output, typename TTypes<U>::Vec scratch1, \ 494 typename TTypes<U>::Vec scratch2); \ 495 extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>; 496 DECLARE_GPU_SPEC(float, float); 497 DECLARE_GPU_SPEC(Eigen::half, float); 498 499 #endif // GOOGLE_CUDA 500 } // namespace functor 501 502 template <typename Device, typename T, typename U> 503 class FusedBatchNormOp : public OpKernel { 504 public: 505 explicit FusedBatchNormOp(OpKernelConstruction* context) : OpKernel(context) { 506 float epsilon; 507 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 508 epsilon_ = U(epsilon); 509 string tensor_format; 510 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 511 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 512 errors::InvalidArgument("Invalid data format")); 513 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 514 } 515 516 void Compute(OpKernelContext* context) override { 517 const Tensor& x = context->input(0); 518 const Tensor& scale = context->input(1); 519 const Tensor& offset = context->input(2); 520 const Tensor& estimated_mean = context->input(3); 521 const Tensor& estimated_variance = context->input(4); 522 523 OP_REQUIRES(context, x.dims() == 4, 524 errors::InvalidArgument("input must be 4-dimensional", 525 x.shape().DebugString())); 526 OP_REQUIRES(context, scale.dims() == 1, 527 errors::InvalidArgument("scale must be 1-dimensional", 528 scale.shape().DebugString())); 529 OP_REQUIRES(context, offset.dims() == 1, 530 errors::InvalidArgument("offset must be 1-dimensional", 531 offset.shape().DebugString())); 532 OP_REQUIRES(context, estimated_mean.dims() == 1, 533 errors::InvalidArgument("estimated_mean must be 1-dimensional", 534 estimated_mean.shape().DebugString())); 535 OP_REQUIRES( 536 context, estimated_variance.dims() == 1, 537 errors::InvalidArgument("estimated_variance must be 1-dimensional", 538 estimated_variance.shape().DebugString())); 539 if (is_training_) { 540 OP_REQUIRES( 541 context, estimated_mean.dim_size(0) == 0, 542 errors::InvalidArgument("estimated_mean must be empty for training", 543 estimated_mean.shape().DebugString())); 544 OP_REQUIRES(context, estimated_variance.dim_size(0) == 0, 545 errors::InvalidArgument( 546 "estimated_variance must be empty for training", 547 estimated_variance.shape().DebugString())); 548 } 549 550 Tensor* y = nullptr; 551 OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( 552 {0}, 0, x.shape(), &y)); 553 Tensor* batch_mean = nullptr; 554 OP_REQUIRES_OK(context, 555 context->allocate_output(1, scale.shape(), &batch_mean)); 556 Tensor* batch_var = nullptr; 557 OP_REQUIRES_OK(context, 558 context->allocate_output(2, scale.shape(), &batch_var)); 559 Tensor* saved_mean = nullptr; 560 OP_REQUIRES_OK(context, 561 context->allocate_output(3, scale.shape(), &saved_mean)); 562 Tensor* saved_maybe_inv_var = nullptr; 563 OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(), 564 &saved_maybe_inv_var)); 565 566 functor::FusedBatchNorm<Device, T, U>()( 567 context, x, scale, offset, estimated_mean, estimated_variance, epsilon_, 568 y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var, 569 tensor_format_, is_training_); 570 } 571 572 private: 573 U epsilon_; 574 TensorFormat tensor_format_; 575 bool is_training_; 576 }; 577 578 template <typename Device, typename T, typename U> 579 class FusedBatchNormGradOp : public OpKernel { 580 public: 581 explicit FusedBatchNormGradOp(OpKernelConstruction* context) 582 : OpKernel(context) { 583 float epsilon; 584 OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon)); 585 epsilon_ = U(epsilon); 586 string tensor_format; 587 OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format)); 588 OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_), 589 errors::InvalidArgument("Invalid data format")); 590 OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_)); 591 } 592 593 void Compute(OpKernelContext* context) override { 594 const Tensor& y_backprop = context->input(0); 595 const Tensor& x = context->input(1); 596 const Tensor& scale = context->input(2); 597 // When is_training=True, batch mean and variance/inverted variance are 598 // saved in the forward pass to be reused here. When is_training=False, 599 // population mean and variance need to be forwarded here to compute the 600 // gradients. 601 const Tensor& saved_mean_or_pop_mean = context->input(3); 602 // The Eigen implementation saves variance in the forward pass, while cuDNN 603 // saves inverted variance. 604 const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4); 605 606 OP_REQUIRES(context, y_backprop.dims() == 4, 607 errors::InvalidArgument("input must be 4-dimensional", 608 y_backprop.shape().DebugString())); 609 OP_REQUIRES(context, x.dims() == 4, 610 errors::InvalidArgument("input must be 4-dimensional", 611 x.shape().DebugString())); 612 OP_REQUIRES(context, scale.dims() == 1, 613 errors::InvalidArgument("scale must be 1-dimensional", 614 scale.shape().DebugString())); 615 OP_REQUIRES( 616 context, saved_mean_or_pop_mean.dims() == 1, 617 errors::InvalidArgument("saved mean must be 1-dimensional", 618 saved_mean_or_pop_mean.shape().DebugString())); 619 OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1, 620 errors::InvalidArgument( 621 "saved variance must be 1-dimensional", 622 saved_maybe_inv_var_or_pop_var.shape().DebugString())); 623 624 Tensor* x_backprop = nullptr; 625 OP_REQUIRES_OK(context, 626 context->allocate_output(0, x.shape(), &x_backprop)); 627 628 const TensorShape& scale_offset_shape = scale.shape(); 629 Tensor* scale_backprop = nullptr; 630 OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape, 631 &scale_backprop)); 632 Tensor* offset_backprop = nullptr; 633 OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape, 634 &offset_backprop)); 635 // Two placeholders for estimated_mean and estimated_variance, which are 636 // used for inference and thus not needed here for gradient computation. 637 // They are filled with zeros so as to avoid NaN outputs. 638 Tensor* placeholder_1 = nullptr; 639 OP_REQUIRES_OK( 640 context, context->allocate_output(3, TensorShape({}), &placeholder_1)); 641 functor::SetZeroFunctor<Device, float> f; 642 f(context->eigen_device<Device>(), placeholder_1->flat<U>()); 643 Tensor* placeholder_2 = nullptr; 644 OP_REQUIRES_OK( 645 context, context->allocate_output(4, TensorShape({}), &placeholder_2)); 646 f(context->eigen_device<Device>(), placeholder_2->flat<U>()); 647 648 // If input is empty, set gradients w.r.t scale/offset to zero. 649 if (x.shape().num_elements() == 0) { 650 functor::SetZeroFunctor<Device, U> f; 651 f(context->eigen_device<Device>(), scale_backprop->flat<U>()); 652 f(context->eigen_device<Device>(), offset_backprop->flat<U>()); 653 return; 654 } 655 656 if (is_training_) { 657 functor::FusedBatchNormGrad<Device, T, U>()( 658 context, y_backprop, x, scale, saved_mean_or_pop_mean, 659 saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop, 660 offset_backprop, tensor_format_); 661 662 } else { 663 // Necessary layout conversion is currently done in python. 664 CHECK(tensor_format_ == FORMAT_NHWC) 665 << "The implementation of FusedBatchNormGrad with is_training=False " 666 "only support " 667 << "NHWC tensor format for now."; 668 Tensor scratch1, scratch2; 669 OP_REQUIRES_OK(context, 670 context->allocate_temp(DataTypeToEnum<U>::value, 671 scale_offset_shape, &scratch1)); 672 OP_REQUIRES_OK(context, 673 context->allocate_temp(DataTypeToEnum<U>::value, 674 scale_offset_shape, &scratch2)); 675 functor::FusedBatchNormFreezeGrad<Device, T, U>()( 676 context->eigen_device<Device>(), y_backprop, x, scale, 677 saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_, 678 x_backprop, scale_backprop, offset_backprop, scratch1.vec<U>(), 679 scratch2.vec<U>()); 680 } 681 } 682 683 private: 684 U epsilon_; 685 TensorFormat tensor_format_; 686 bool is_training_; 687 }; 688 689 REGISTER_KERNEL_BUILDER( 690 Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"), 691 FusedBatchNormOp<CPUDevice, float, float>); 692 693 REGISTER_KERNEL_BUILDER( 694 Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"), 695 FusedBatchNormGradOp<CPUDevice, float, float>); 696 697 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 698 .Device(DEVICE_CPU) 699 .TypeConstraint<float>("T") 700 .TypeConstraint<float>("U"), 701 FusedBatchNormOp<CPUDevice, float, float>); 702 703 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 704 .Device(DEVICE_CPU) 705 .TypeConstraint<float>("T") 706 .TypeConstraint<float>("U"), 707 FusedBatchNormGradOp<CPUDevice, float, float>); 708 709 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 710 .Device(DEVICE_CPU) 711 .TypeConstraint<Eigen::half>("T") 712 .TypeConstraint<float>("U"), 713 FusedBatchNormOp<CPUDevice, Eigen::half, float>); 714 715 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 716 .Device(DEVICE_CPU) 717 .TypeConstraint<Eigen::half>("T") 718 .TypeConstraint<float>("U"), 719 FusedBatchNormGradOp<CPUDevice, Eigen::half, float>); 720 721 #if GOOGLE_CUDA 722 723 REGISTER_KERNEL_BUILDER( 724 Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"), 725 FusedBatchNormOp<GPUDevice, float, float>); 726 727 REGISTER_KERNEL_BUILDER( 728 Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), 729 FusedBatchNormGradOp<GPUDevice, float, float>); 730 731 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 732 .Device(DEVICE_GPU) 733 .TypeConstraint<float>("T") 734 .TypeConstraint<float>("U"), 735 FusedBatchNormOp<GPUDevice, float, float>); 736 737 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 738 .Device(DEVICE_GPU) 739 .TypeConstraint<float>("T") 740 .TypeConstraint<float>("U"), 741 FusedBatchNormGradOp<GPUDevice, float, float>); 742 743 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2") 744 .Device(DEVICE_GPU) 745 .TypeConstraint<Eigen::half>("T") 746 .TypeConstraint<float>("U"), 747 FusedBatchNormOp<GPUDevice, Eigen::half, float>); 748 749 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2") 750 .Device(DEVICE_GPU) 751 .TypeConstraint<Eigen::half>("T") 752 .TypeConstraint<float>("U"), 753 FusedBatchNormGradOp<GPUDevice, Eigen::half, float>); 754 755 #endif 756 757 } // namespace tensorflow 758