Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #include "tensorflow/core/kernels/conv_2d.h"
     21 #include "tensorflow/core/kernels/conv_ops_gpu.h"
     22 #include "tensorflow/core/util/stream_executor_util.h"
     23 #endif
     24 
     25 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     26 #include "tensorflow/core/framework/op_kernel.h"
     27 #include "tensorflow/core/framework/register_types.h"
     28 #include "tensorflow/core/framework/tensor.h"
     29 #include "tensorflow/core/framework/tensor_types.h"
     30 #include "tensorflow/core/kernels/fill_functor.h"
     31 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
     32 #include "tensorflow/core/util/tensor_format.h"
     33 
     34 namespace tensorflow {
     35 using CPUDevice = Eigen::ThreadPoolDevice;
     36 using GPUDevice = Eigen::GpuDevice;
     37 
     38 namespace functor {
     39 
     40 // Functor used by FusedBatchNormOp to do the computations.
     41 template <typename Device, typename T, typename U>
     42 struct FusedBatchNorm;
     43 // Functor used by FusedBatchNormGradOp to do the computations when
     44 // is_training=True.
     45 template <typename Device, typename T, typename U>
     46 struct FusedBatchNormGrad;
     47 
     48 template <typename T, typename U>
     49 struct FusedBatchNorm<CPUDevice, T, U> {
     50   void operator()(OpKernelContext* context, const Tensor& x_input,
     51                   const Tensor& scale_input, const Tensor& offset_input,
     52                   const Tensor& estimated_mean_input,
     53                   const Tensor& estimated_variance_input, U epsilon,
     54                   Tensor* y_output, Tensor* batch_mean_output,
     55                   Tensor* batch_var_output, Tensor* saved_mean_output,
     56                   Tensor* saved_var_output, TensorFormat tensor_format,
     57                   bool is_training) {
     58     OP_REQUIRES(context, tensor_format == FORMAT_NHWC,
     59                 errors::Internal("The CPU implementation of FusedBatchNorm "
     60                                  "only supports NHWC tensor format for now."));
     61     typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>());
     62     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
     63     typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
     64     typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>());
     65     typename TTypes<U>::ConstVec estimated_variance(
     66         estimated_variance_input.vec<U>());
     67     typename TTypes<T, 4>::Tensor y(y_output->tensor<T, 4>());
     68     typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
     69     typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>());
     70     typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>());
     71     typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>());
     72 
     73     const CPUDevice& d = context->eigen_device<CPUDevice>();
     74 
     75     const int depth = x.dimension(3);
     76     const int size = x.size();
     77     const int rest_size = size / depth;
     78     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
     79 
     80 #if !defined(EIGEN_HAS_INDEX_LIST)
     81     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
     82     Eigen::array<int, 1> reduce_dims({0});
     83     Eigen::array<int, 2> bcast_spec({rest_size, 1});
     84 #else
     85     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
     86     one_by_depth.set(1, depth);
     87     Eigen::IndexList<Eigen::type2index<0> > reduce_dims;
     88     Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec;
     89     bcast_spec.set(0, rest_size);
     90 #endif
     91 
     92     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
     93     const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
     94     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
     95     // This adjustment is for Bessel's correction
     96     U rest_size_adjust =
     97         static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
     98 
     99     Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth);
    100     Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth);
    101     if (is_training) {
    102       mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
    103       batch_mean.device(d) = mean;
    104       saved_mean.device(d) = mean;
    105     } else {
    106       mean.device(d) = estimated_mean;
    107     }
    108 
    109     auto x_centered =
    110         x_rest_by_depth - mean.reshape(one_by_depth).broadcast(bcast_spec);
    111 
    112     if (is_training) {
    113       variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv;
    114       batch_var.device(d) = variance * rest_size_adjust;
    115       saved_var.device(d) = variance;
    116     } else {
    117       variance.device(d) = estimated_variance;
    118     }
    119 
    120     auto scaling_factor = ((variance + epsilon).rsqrt() * scale)
    121                               .eval()
    122                               .reshape(one_by_depth)
    123                               .broadcast(bcast_spec);
    124     auto x_scaled = x_centered * scaling_factor;
    125     auto x_shifted =
    126         x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec);
    127 
    128     y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>();
    129   }
    130 };
    131 
    132 template <typename T, typename U>
    133 struct FusedBatchNormGrad<CPUDevice, T, U> {
    134   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
    135                   const Tensor& x_input, const Tensor& scale_input,
    136                   const Tensor& mean_input, const Tensor& variance_input,
    137                   U epsilon, Tensor* x_backprop_output,
    138                   Tensor* scale_backprop_output, Tensor* offset_backprop_output,
    139                   TensorFormat tensor_format) {
    140     OP_REQUIRES(context, tensor_format == FORMAT_NHWC,
    141                 errors::Internal("The CPU implementation of FusedBatchNormGrad "
    142                                  "only supports NHWC tensor format for now."));
    143     typename TTypes<T, 4>::ConstTensor y_backprop(
    144         y_backprop_input.tensor<T, 4>());
    145     typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>());
    146     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
    147     typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
    148     typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
    149     typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
    150     typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
    151     typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
    152 
    153     // Note: the following formulas are used to compute the gradients for
    154     // back propagation.
    155     // x_backprop = scale * rsqrt(variance + epsilon) *
    156     //              [y_backprop - mean(y_backprop) - (x - mean(x)) *
    157     //              mean(y_backprop * (x - mean(x))) / (variance + epsilon)]
    158     // scale_backprop = sum(y_backprop *
    159     //                  (x - mean(x)) * rsqrt(variance + epsilon))
    160     // offset_backprop = sum(y_backprop)
    161 
    162     const CPUDevice& d = context->eigen_device<CPUDevice>();
    163     const int depth = x.dimension(3);
    164     const int size = x.size();
    165     const int rest_size = size / depth;
    166     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
    167 
    168 #if !defined(EIGEN_HAS_INDEX_LIST)
    169     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
    170     Eigen::array<int, 1> reduce_dims({0});
    171     Eigen::array<int, 2> bcast_spec({rest_size, 1});
    172 #else
    173     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
    174     one_by_depth.set(1, depth);
    175     Eigen::IndexList<Eigen::type2index<0> > reduce_dims;
    176     Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec;
    177     bcast_spec.set(0, rest_size);
    178 #endif
    179 
    180     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
    181     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
    182 
    183     auto x_mean_rest_by_depth =
    184         mean.reshape(one_by_depth).broadcast(bcast_spec);
    185     auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth).eval();
    186     auto coef0 = (variance + epsilon).rsqrt();
    187     auto coef0_rest_by_depth =
    188         coef0.eval().reshape(one_by_depth).broadcast(bcast_spec);
    189     auto x_scaled = x_centered * coef0_rest_by_depth;
    190 
    191     auto y_backprop_rest_by_depth =
    192         y_backprop.eval().reshape(rest_by_depth).template cast<U>();
    193     scale_backprop.device(d) =
    194         (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims);
    195     auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims);
    196     offset_backprop.device(d) = y_backprop_sum;
    197 
    198     auto y_backprop_sum_one_by_depth =
    199         y_backprop_sum.eval().reshape(one_by_depth);
    200     auto y_backprop_mean_one_by_depth =
    201         y_backprop_sum_one_by_depth * rest_size_inv;
    202     auto y_backprop_mean_rest_by_depth =
    203         y_backprop_mean_one_by_depth.broadcast(bcast_spec);
    204     auto y_backprop_centered =
    205         y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth;
    206     auto coef1 =
    207         (scale * coef0).eval().reshape(one_by_depth).broadcast(bcast_spec);
    208     auto coef2 = (coef0.square() *
    209                   (y_backprop_rest_by_depth * x_centered).mean(reduce_dims))
    210                      .eval()
    211                      .reshape(one_by_depth)
    212                      .broadcast(bcast_spec);
    213     x_backprop.reshape(rest_by_depth).device(d) =
    214         (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
    215   }
    216 };
    217 
    218 #if GOOGLE_CUDA
    219 template <typename T, typename U>
    220 struct FusedBatchNorm<GPUDevice, T, U> {
    221   void operator()(OpKernelContext* context, const Tensor& x,
    222                   const Tensor& scale, const Tensor& offset,
    223                   const Tensor& estimated_mean,
    224                   const Tensor& estimated_variance, U epsilon, Tensor* y,
    225                   Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
    226                   Tensor* saved_inv_var, TensorFormat tensor_format,
    227                   bool is_training) {
    228     auto* stream = context->op_device_context()->stream();
    229     OP_REQUIRES(context, stream, errors::Internal("No GPU stream avalible"));
    230 
    231     const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
    232     const int64 channels = GetTensorDim(x, tensor_format, 'C');
    233     const int64 height = GetTensorDim(x, tensor_format, 'H');
    234     const int64 width = GetTensorDim(x, tensor_format, 'W');
    235     VLOG(2) << "FusedBatchNorm:"
    236             << " batch_size: " << batch_size << " channels: " << channels
    237             << " height: " << height << " width:" << width
    238             << " x shape: " << x.shape().DebugString()
    239             << " scale shape: " << scale.shape().DebugString()
    240             << " offset shape: " << offset.shape().DebugString()
    241             << " tensor format: " << tensor_format;
    242 
    243     // If input is empty, return NaN mean/variance
    244     if (x.shape().num_elements() == 0) {
    245       functor::SetNanFunctor<U> f;
    246       f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
    247       f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
    248       return;
    249     }
    250 
    251     Tensor x_maybe_transformed = x;
    252     Tensor x_transformed;
    253     Tensor y_transformed;
    254     perftools::gputools::DeviceMemory<T> y_ptr;
    255 
    256     if (tensor_format == FORMAT_NCHW) {
    257       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y);
    258     } else if (tensor_format == FORMAT_NHWC) {
    259       OP_REQUIRES_OK(context, context->allocate_temp(
    260                                   DataTypeToEnum<T>::value,
    261                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    262                                                   height, width, channels),
    263                                   &x_transformed));
    264       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    265           context->eigen_device<GPUDevice>(),
    266           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
    267           x_transformed.tensor<T, 4>());
    268       x_maybe_transformed = x_transformed;
    269 
    270       OP_REQUIRES_OK(context, context->allocate_temp(
    271                                   DataTypeToEnum<T>::value,
    272                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    273                                                   height, width, channels),
    274                                   &y_transformed));
    275       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed);
    276     } else {
    277       context->SetStatus(
    278           errors::Internal("Unsupported tensor format: ", tensor_format));
    279       return;
    280     }
    281 
    282     perftools::gputools::dnn::BatchDescriptor x_desc;
    283     x_desc.set_count(batch_size)
    284         .set_feature_map_count(channels)
    285         .set_height(height)
    286         .set_width(width)
    287         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    288 
    289     perftools::gputools::dnn::BatchDescriptor scale_offset_desc;
    290     scale_offset_desc.set_count(1)
    291         .set_feature_map_count(channels)
    292         .set_height(1)
    293         .set_width(1)
    294         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    295 
    296     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
    297     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
    298     auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset);
    299     auto estimated_mean_ptr =
    300         StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean);
    301     auto estimated_variance_ptr =
    302         StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
    303     auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean);
    304 
    305     auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var);
    306     auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean);
    307     auto saved_inv_var_ptr =
    308         StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
    309 
    310     GPUDevice d = context->eigen_device<GPUDevice>();
    311     using perftools::gputools::DeviceMemory;
    312     Tensor inv_var;
    313     OP_REQUIRES_OK(
    314         context, context->allocate_temp(DataTypeToEnum<U>::value,
    315                                         estimated_variance.shape(), &inv_var));
    316     auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_var);
    317     std::function<const DeviceMemory<U>&()> var_to_inv_var =
    318         [d, epsilon, estimated_variance,
    319          &inv_var_ptr]() -> const DeviceMemory<U>& {
    320       auto estimated_variance_ptr =
    321           StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
    322       const U* variance =
    323           static_cast<const U*>(estimated_variance_ptr.opaque());
    324       U* inv_variance = static_cast<U*>(inv_var_ptr.opaque());
    325       int channels = inv_var_ptr.ElementCount();
    326       VarianceToInvVariance<U>()(d, variance, epsilon, channels, inv_variance);
    327       return inv_var_ptr;
    328     };
    329     const int64 sample_size = batch_size * height * width;
    330     std::function<void()> inv_var_to_var = [d, &batch_var_ptr, epsilon,
    331                                             sample_size]() {
    332       U* variance = static_cast<U*>(batch_var_ptr.opaque());
    333       int channels = batch_var_ptr.ElementCount();
    334       InvVarianceToVariance<U>()(d, epsilon, sample_size, channels, variance);
    335     };
    336 
    337     bool cudnn_launch_status =
    338         stream
    339             ->ThenBatchNormalizationForward(
    340                 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
    341                 estimated_variance_ptr, x_desc, scale_offset_desc,
    342                 static_cast<double>(epsilon), &y_ptr, &batch_mean_ptr,
    343                 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
    344                 is_training, std::move(var_to_inv_var),
    345                 std::move(inv_var_to_var))
    346             .ok();
    347 
    348     if (!cudnn_launch_status) {
    349       context->SetStatus(
    350           errors::Internal("cuDNN launch failure : input shape (",
    351                            x.shape().DebugString(), ")"));
    352     }
    353     if (tensor_format == FORMAT_NHWC) {
    354       functor::NCHWToNHWC<GPUDevice, T, 4>()(
    355           context->eigen_device<GPUDevice>(),
    356           const_cast<const Tensor&>(y_transformed).tensor<T, 4>(),
    357           y->tensor<T, 4>());
    358     }
    359   }
    360 };
    361 
    362 template <typename T, typename U>
    363 struct FusedBatchNormGrad<GPUDevice, T, U> {
    364   void operator()(OpKernelContext* context, const Tensor& y_backprop,
    365                   const Tensor& x, const Tensor& scale, const Tensor& mean,
    366                   const Tensor& inv_variance, U epsilon, Tensor* x_backprop,
    367                   Tensor* scale_backprop, Tensor* offset_backprop,
    368                   TensorFormat tensor_format) {
    369     auto* stream = context->op_device_context()->stream();
    370     OP_REQUIRES(context, stream, errors::Internal("No GPU stream avalible"));
    371 
    372     const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
    373     const int64 channels = GetTensorDim(x, tensor_format, 'C');
    374     const int64 height = GetTensorDim(x, tensor_format, 'H');
    375     const int64 width = GetTensorDim(x, tensor_format, 'W');
    376 
    377     VLOG(2) << "FusedBatchNormGrad:"
    378             << " batch_size: " << batch_size << " channels: " << channels
    379             << " height: " << height << " width: " << width
    380             << " y_backprop shape: " << y_backprop.shape().DebugString()
    381             << " x shape: " << x.shape().DebugString()
    382             << " scale shape: " << scale.shape().DebugString()
    383             << " tensor format: " << tensor_format;
    384 
    385     // Inputs
    386     Tensor y_backprop_maybe_transformed = y_backprop;
    387     Tensor x_maybe_transformed = x;
    388     Tensor y_backprop_transformed;
    389     Tensor x_transformed;
    390 
    391     // Outputs
    392     Tensor x_backprop_transformed;
    393     perftools::gputools::DeviceMemory<T> x_backprop_ptr;
    394 
    395     if (tensor_format == FORMAT_NCHW) {
    396       x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
    397     } else if (tensor_format == FORMAT_NHWC) {
    398       // Transform inputs from 'NHWC' to 'NCHW'
    399       OP_REQUIRES_OK(context, context->allocate_temp(
    400                                   DataTypeToEnum<T>::value,
    401                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    402                                                   height, width, channels),
    403                                   &y_backprop_transformed));
    404       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    405           context->eigen_device<GPUDevice>(),
    406           const_cast<const Tensor&>(y_backprop_maybe_transformed)
    407               .tensor<T, 4>(),
    408           y_backprop_transformed.tensor<T, 4>());
    409       y_backprop_maybe_transformed = y_backprop_transformed;
    410 
    411       OP_REQUIRES_OK(context, context->allocate_temp(
    412                                   DataTypeToEnum<T>::value,
    413                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    414                                                   height, width, channels),
    415                                   &x_transformed));
    416       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    417           context->eigen_device<GPUDevice>(),
    418           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
    419           x_transformed.tensor<T, 4>());
    420       x_maybe_transformed = x_transformed;
    421 
    422       // Allocate memory for transformed outputs in 'NCHW'
    423       OP_REQUIRES_OK(context, context->allocate_temp(
    424                                   DataTypeToEnum<T>::value,
    425                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    426                                                   height, width, channels),
    427                                   &x_backprop_transformed));
    428       x_backprop_ptr =
    429           StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed);
    430     } else {
    431       context->SetStatus(
    432           errors::Internal("Unsupported tensor format: ", tensor_format));
    433       return;
    434     }
    435 
    436     perftools::gputools::dnn::BatchDescriptor x_desc;
    437     x_desc.set_count(batch_size)
    438         .set_feature_map_count(channels)
    439         .set_height(height)
    440         .set_width(width)
    441         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    442 
    443     perftools::gputools::dnn::BatchDescriptor scale_offset_desc;
    444     scale_offset_desc.set_count(1)
    445         .set_feature_map_count(channels)
    446         .set_height(1)
    447         .set_width(1)
    448         .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
    449 
    450     auto y_backprop_ptr =
    451         StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
    452     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
    453     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
    454     auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean);
    455     auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance);
    456     auto scale_backprop_ptr =
    457         StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop);
    458     auto offset_backprop_ptr =
    459         StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop);
    460 
    461     // the cudnn kernel outputs inverse variance in forward and reuse it in
    462     // backward
    463     bool cudnn_launch_status =
    464         stream
    465             ->ThenBatchNormalizationBackward(
    466                 y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr,
    467                 x_desc, scale_offset_desc, static_cast<double>(epsilon),
    468                 &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr)
    469             .ok();
    470 
    471     if (!cudnn_launch_status) {
    472       context->SetStatus(
    473           errors::Internal("cuDNN launch failure : input shape (",
    474                            x.shape().DebugString(), ")"));
    475     }
    476     if (tensor_format == FORMAT_NHWC) {
    477       functor::NCHWToNHWC<GPUDevice, T, 4>()(
    478           context->eigen_device<GPUDevice>(),
    479           const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(),
    480           x_backprop->tensor<T, 4>());
    481     }
    482   }
    483 };
    484 
    485 // Forward declarations of the functor specializations for GPU.
    486 #define DECLARE_GPU_SPEC(T, U)                                           \
    487   template <>                                                            \
    488   void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()(            \
    489       const GPUDevice& d, const Tensor& y_backprop_input,                \
    490       const Tensor& x_input, const Tensor& scale_input,                  \
    491       const Tensor& mean_input, const Tensor& variance_input, U epsilon, \
    492       Tensor* x_backprop_output, Tensor* scale_backprop_output,          \
    493       Tensor* offset_backprop_output, typename TTypes<U>::Vec scratch1,  \
    494       typename TTypes<U>::Vec scratch2);                                 \
    495   extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>;
    496 DECLARE_GPU_SPEC(float, float);
    497 DECLARE_GPU_SPEC(Eigen::half, float);
    498 
    499 #endif  // GOOGLE_CUDA
    500 }  // namespace functor
    501 
    502 template <typename Device, typename T, typename U>
    503 class FusedBatchNormOp : public OpKernel {
    504  public:
    505   explicit FusedBatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
    506     float epsilon;
    507     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    508     epsilon_ = U(epsilon);
    509     string tensor_format;
    510     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    511     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
    512                 errors::InvalidArgument("Invalid data format"));
    513     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
    514   }
    515 
    516   void Compute(OpKernelContext* context) override {
    517     const Tensor& x = context->input(0);
    518     const Tensor& scale = context->input(1);
    519     const Tensor& offset = context->input(2);
    520     const Tensor& estimated_mean = context->input(3);
    521     const Tensor& estimated_variance = context->input(4);
    522 
    523     OP_REQUIRES(context, x.dims() == 4,
    524                 errors::InvalidArgument("input must be 4-dimensional",
    525                                         x.shape().DebugString()));
    526     OP_REQUIRES(context, scale.dims() == 1,
    527                 errors::InvalidArgument("scale must be 1-dimensional",
    528                                         scale.shape().DebugString()));
    529     OP_REQUIRES(context, offset.dims() == 1,
    530                 errors::InvalidArgument("offset must be 1-dimensional",
    531                                         offset.shape().DebugString()));
    532     OP_REQUIRES(context, estimated_mean.dims() == 1,
    533                 errors::InvalidArgument("estimated_mean must be 1-dimensional",
    534                                         estimated_mean.shape().DebugString()));
    535     OP_REQUIRES(
    536         context, estimated_variance.dims() == 1,
    537         errors::InvalidArgument("estimated_variance must be 1-dimensional",
    538                                 estimated_variance.shape().DebugString()));
    539     if (is_training_) {
    540       OP_REQUIRES(
    541           context, estimated_mean.dim_size(0) == 0,
    542           errors::InvalidArgument("estimated_mean must be empty for training",
    543                                   estimated_mean.shape().DebugString()));
    544       OP_REQUIRES(context, estimated_variance.dim_size(0) == 0,
    545                   errors::InvalidArgument(
    546                       "estimated_variance must be empty for training",
    547                       estimated_variance.shape().DebugString()));
    548     }
    549 
    550     Tensor* y = nullptr;
    551     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    552                                 {0}, 0, x.shape(), &y));
    553     Tensor* batch_mean = nullptr;
    554     OP_REQUIRES_OK(context,
    555                    context->allocate_output(1, scale.shape(), &batch_mean));
    556     Tensor* batch_var = nullptr;
    557     OP_REQUIRES_OK(context,
    558                    context->allocate_output(2, scale.shape(), &batch_var));
    559     Tensor* saved_mean = nullptr;
    560     OP_REQUIRES_OK(context,
    561                    context->allocate_output(3, scale.shape(), &saved_mean));
    562     Tensor* saved_maybe_inv_var = nullptr;
    563     OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
    564                                                      &saved_maybe_inv_var));
    565 
    566     functor::FusedBatchNorm<Device, T, U>()(
    567         context, x, scale, offset, estimated_mean, estimated_variance, epsilon_,
    568         y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
    569         tensor_format_, is_training_);
    570   }
    571 
    572  private:
    573   U epsilon_;
    574   TensorFormat tensor_format_;
    575   bool is_training_;
    576 };
    577 
    578 template <typename Device, typename T, typename U>
    579 class FusedBatchNormGradOp : public OpKernel {
    580  public:
    581   explicit FusedBatchNormGradOp(OpKernelConstruction* context)
    582       : OpKernel(context) {
    583     float epsilon;
    584     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    585     epsilon_ = U(epsilon);
    586     string tensor_format;
    587     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    588     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
    589                 errors::InvalidArgument("Invalid data format"));
    590     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
    591   }
    592 
    593   void Compute(OpKernelContext* context) override {
    594     const Tensor& y_backprop = context->input(0);
    595     const Tensor& x = context->input(1);
    596     const Tensor& scale = context->input(2);
    597     // When is_training=True, batch mean and variance/inverted variance are
    598     // saved in the forward pass to be reused here. When is_training=False,
    599     // population mean and variance need to be forwarded here to compute the
    600     // gradients.
    601     const Tensor& saved_mean_or_pop_mean = context->input(3);
    602     // The Eigen implementation saves variance in the forward pass, while cuDNN
    603     // saves inverted variance.
    604     const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
    605 
    606     OP_REQUIRES(context, y_backprop.dims() == 4,
    607                 errors::InvalidArgument("input must be 4-dimensional",
    608                                         y_backprop.shape().DebugString()));
    609     OP_REQUIRES(context, x.dims() == 4,
    610                 errors::InvalidArgument("input must be 4-dimensional",
    611                                         x.shape().DebugString()));
    612     OP_REQUIRES(context, scale.dims() == 1,
    613                 errors::InvalidArgument("scale must be 1-dimensional",
    614                                         scale.shape().DebugString()));
    615     OP_REQUIRES(
    616         context, saved_mean_or_pop_mean.dims() == 1,
    617         errors::InvalidArgument("saved mean must be 1-dimensional",
    618                                 saved_mean_or_pop_mean.shape().DebugString()));
    619     OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1,
    620                 errors::InvalidArgument(
    621                     "saved variance must be 1-dimensional",
    622                     saved_maybe_inv_var_or_pop_var.shape().DebugString()));
    623 
    624     Tensor* x_backprop = nullptr;
    625     OP_REQUIRES_OK(context,
    626                    context->allocate_output(0, x.shape(), &x_backprop));
    627 
    628     const TensorShape& scale_offset_shape = scale.shape();
    629     Tensor* scale_backprop = nullptr;
    630     OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape,
    631                                                      &scale_backprop));
    632     Tensor* offset_backprop = nullptr;
    633     OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
    634                                                      &offset_backprop));
    635     // Two placeholders for estimated_mean and estimated_variance, which are
    636     // used for inference and thus not needed here for gradient computation.
    637     // They are filled with zeros so as to avoid NaN outputs.
    638     Tensor* placeholder_1 = nullptr;
    639     OP_REQUIRES_OK(
    640         context, context->allocate_output(3, TensorShape({}), &placeholder_1));
    641     functor::SetZeroFunctor<Device, float> f;
    642     f(context->eigen_device<Device>(), placeholder_1->flat<U>());
    643     Tensor* placeholder_2 = nullptr;
    644     OP_REQUIRES_OK(
    645         context, context->allocate_output(4, TensorShape({}), &placeholder_2));
    646     f(context->eigen_device<Device>(), placeholder_2->flat<U>());
    647 
    648     // If input is empty, set gradients w.r.t scale/offset to zero.
    649     if (x.shape().num_elements() == 0) {
    650       functor::SetZeroFunctor<Device, U> f;
    651       f(context->eigen_device<Device>(), scale_backprop->flat<U>());
    652       f(context->eigen_device<Device>(), offset_backprop->flat<U>());
    653       return;
    654     }
    655 
    656     if (is_training_) {
    657       functor::FusedBatchNormGrad<Device, T, U>()(
    658           context, y_backprop, x, scale, saved_mean_or_pop_mean,
    659           saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
    660           offset_backprop, tensor_format_);
    661 
    662     } else {
    663       // Necessary layout conversion is currently done in python.
    664       CHECK(tensor_format_ == FORMAT_NHWC)
    665           << "The implementation of FusedBatchNormGrad with is_training=False "
    666              "only support "
    667           << "NHWC tensor format for now.";
    668       Tensor scratch1, scratch2;
    669       OP_REQUIRES_OK(context,
    670                      context->allocate_temp(DataTypeToEnum<U>::value,
    671                                             scale_offset_shape, &scratch1));
    672       OP_REQUIRES_OK(context,
    673                      context->allocate_temp(DataTypeToEnum<U>::value,
    674                                             scale_offset_shape, &scratch2));
    675       functor::FusedBatchNormFreezeGrad<Device, T, U>()(
    676           context->eigen_device<Device>(), y_backprop, x, scale,
    677           saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_,
    678           x_backprop, scale_backprop, offset_backprop, scratch1.vec<U>(),
    679           scratch2.vec<U>());
    680     }
    681   }
    682 
    683  private:
    684   U epsilon_;
    685   TensorFormat tensor_format_;
    686   bool is_training_;
    687 };
    688 
    689 REGISTER_KERNEL_BUILDER(
    690     Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
    691     FusedBatchNormOp<CPUDevice, float, float>);
    692 
    693 REGISTER_KERNEL_BUILDER(
    694     Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
    695     FusedBatchNormGradOp<CPUDevice, float, float>);
    696 
    697 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    698                             .Device(DEVICE_CPU)
    699                             .TypeConstraint<float>("T")
    700                             .TypeConstraint<float>("U"),
    701                         FusedBatchNormOp<CPUDevice, float, float>);
    702 
    703 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    704                             .Device(DEVICE_CPU)
    705                             .TypeConstraint<float>("T")
    706                             .TypeConstraint<float>("U"),
    707                         FusedBatchNormGradOp<CPUDevice, float, float>);
    708 
    709 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    710                             .Device(DEVICE_CPU)
    711                             .TypeConstraint<Eigen::half>("T")
    712                             .TypeConstraint<float>("U"),
    713                         FusedBatchNormOp<CPUDevice, Eigen::half, float>);
    714 
    715 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    716                             .Device(DEVICE_CPU)
    717                             .TypeConstraint<Eigen::half>("T")
    718                             .TypeConstraint<float>("U"),
    719                         FusedBatchNormGradOp<CPUDevice, Eigen::half, float>);
    720 
    721 #if GOOGLE_CUDA
    722 
    723 REGISTER_KERNEL_BUILDER(
    724     Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    725     FusedBatchNormOp<GPUDevice, float, float>);
    726 
    727 REGISTER_KERNEL_BUILDER(
    728     Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    729     FusedBatchNormGradOp<GPUDevice, float, float>);
    730 
    731 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    732                             .Device(DEVICE_GPU)
    733                             .TypeConstraint<float>("T")
    734                             .TypeConstraint<float>("U"),
    735                         FusedBatchNormOp<GPUDevice, float, float>);
    736 
    737 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    738                             .Device(DEVICE_GPU)
    739                             .TypeConstraint<float>("T")
    740                             .TypeConstraint<float>("U"),
    741                         FusedBatchNormGradOp<GPUDevice, float, float>);
    742 
    743 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    744                             .Device(DEVICE_GPU)
    745                             .TypeConstraint<Eigen::half>("T")
    746                             .TypeConstraint<float>("U"),
    747                         FusedBatchNormOp<GPUDevice, Eigen::half, float>);
    748 
    749 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    750                             .Device(DEVICE_GPU)
    751                             .TypeConstraint<Eigen::half>("T")
    752                             .TypeConstraint<float>("U"),
    753                         FusedBatchNormGradOp<GPUDevice, Eigen::half, float>);
    754 
    755 #endif
    756 
    757 }  // namespace tensorflow
    758