Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #define EIGEN_USE_THREADS
     17 
     18 #if GOOGLE_CUDA
     19 #define EIGEN_USE_GPU
     20 #include "tensorflow/core/kernels/conv_2d.h"
     21 #include "tensorflow/core/util/stream_executor_util.h"
     22 #endif
     23 
     24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     25 #include "tensorflow/core/framework/op_kernel.h"
     26 #include "tensorflow/core/framework/register_types.h"
     27 #include "tensorflow/core/framework/tensor.h"
     28 #include "tensorflow/core/framework/tensor_types.h"
     29 #include "tensorflow/core/kernels/fill_functor.h"
     30 #include "tensorflow/core/kernels/fused_batch_norm_op.h"
     31 #include "tensorflow/core/util/tensor_format.h"
     32 
     33 namespace tensorflow {
     34 using CPUDevice = Eigen::ThreadPoolDevice;
     35 using GPUDevice = Eigen::GpuDevice;
     36 
     37 namespace functor {
     38 
     39 // Functor used by FusedBatchNormOp to do the computations.
     40 template <typename Device, typename T, typename U>
     41 struct FusedBatchNorm;
     42 // Functor used by FusedBatchNormGradOp to do the computations when
     43 // is_training=True.
     44 template <typename Device, typename T, typename U>
     45 struct FusedBatchNormGrad;
     46 
     47 template <bool IsSame, typename Y, typename X, typename T>
     48 struct CastIfNecessary {
     49   static inline void process(
     50       Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth,
     51       const CPUDevice& d) {
     52     y.reshape(rest_by_depth).device(d) = x_shifted.template cast<T>();
     53   }
     54 };
     55 
     56 template <typename Y, typename X, typename T>
     57 struct CastIfNecessary<true, Y, X, T> {
     58   static inline void process(
     59       Y& y, X& x_shifted, const Eigen::DSizes<Eigen::Index, 2>& rest_by_depth,
     60       const CPUDevice& d) {
     61     y.reshape(rest_by_depth).device(d) = x_shifted;
     62   }
     63 };
     64 
     65 template <typename T, typename U>
     66 struct FusedBatchNorm<CPUDevice, T, U> {
     67   void operator()(OpKernelContext* context, const Tensor& x_input,
     68                   const Tensor& scale_input, const Tensor& offset_input,
     69                   const Tensor& estimated_mean_input,
     70                   const Tensor& estimated_variance_input, U epsilon,
     71                   Tensor* y_output, Tensor* batch_mean_output,
     72                   Tensor* batch_var_output, Tensor* saved_mean_output,
     73                   Tensor* saved_var_output, TensorFormat tensor_format,
     74                   bool is_training) {
     75     OP_REQUIRES(context, tensor_format == FORMAT_NHWC,
     76                 errors::Internal("The CPU implementation of FusedBatchNorm "
     77                                  "only supports NHWC tensor format for now."));
     78     typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>());
     79     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
     80     typename TTypes<U>::ConstVec offset(offset_input.vec<U>());
     81     typename TTypes<U>::ConstVec estimated_mean(estimated_mean_input.vec<U>());
     82     typename TTypes<U>::ConstVec estimated_variance(
     83         estimated_variance_input.vec<U>());
     84     typename TTypes<T, 4>::Tensor y(y_output->tensor<T, 4>());
     85     typename TTypes<U>::Vec batch_mean(batch_mean_output->vec<U>());
     86     typename TTypes<U>::Vec batch_var(batch_var_output->vec<U>());
     87     typename TTypes<U>::Vec saved_mean(saved_mean_output->vec<U>());
     88     typename TTypes<U>::Vec saved_var(saved_var_output->vec<U>());
     89 
     90     const CPUDevice& d = context->eigen_device<CPUDevice>();
     91 
     92     const int depth = x.dimension(3);
     93     const int size = x.size();
     94     const int rest_size = size / depth;
     95     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
     96 
     97 #if !defined(EIGEN_HAS_INDEX_LIST)
     98     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
     99     Eigen::array<int, 1> reduce_dims({0});
    100     Eigen::array<int, 2> bcast_spec({rest_size, 1});
    101 #else
    102     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
    103     one_by_depth.set(1, depth);
    104     Eigen::IndexList<Eigen::type2index<0> > reduce_dims;
    105     Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec;
    106     bcast_spec.set(0, rest_size);
    107 #endif
    108 
    109     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
    110     const int rest_size_minus_one = (rest_size > 1) ? (rest_size - 1) : 1;
    111     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
    112     // This adjustment is for Bessel's correction
    113     U rest_size_adjust =
    114         static_cast<U>(rest_size) / static_cast<U>(rest_size_minus_one);
    115 
    116     Eigen::Tensor<U, 1, Eigen::RowMajor> mean(depth);
    117     Eigen::Tensor<U, 1, Eigen::RowMajor> variance(depth);
    118     if (is_training) {
    119       mean.device(d) = (x_rest_by_depth.sum(reduce_dims) * rest_size_inv);
    120       batch_mean.device(d) = mean;
    121       saved_mean.device(d) = mean;
    122     } else {
    123       mean.device(d) = estimated_mean;
    124     }
    125 
    126     auto x_centered =
    127         x_rest_by_depth - mean.reshape(one_by_depth).broadcast(bcast_spec);
    128 
    129     if (is_training) {
    130       variance.device(d) = x_centered.square().sum(reduce_dims) * rest_size_inv;
    131       batch_var.device(d) = variance * rest_size_adjust;
    132       saved_var.device(d) = variance;
    133     } else {
    134       variance.device(d) = estimated_variance;
    135     }
    136 
    137     auto scaling_factor = ((variance + epsilon).rsqrt() * scale)
    138                               .eval()
    139                               .reshape(one_by_depth)
    140                               .broadcast(bcast_spec);
    141     auto x_scaled = x_centered * scaling_factor;
    142     auto x_shifted =
    143         x_scaled + offset.reshape(one_by_depth).broadcast(bcast_spec);
    144 
    145     // Explicitly checks the types of T and U and only casts x_shifted when
    146     // T != U. (Not doing so caused a 35-50% performance slowdown for
    147     // some compiler flags.)
    148     CastIfNecessary<std::is_same<T, U>::value, decltype(y), decltype(x_shifted),
    149                     T>::process(y, x_shifted, rest_by_depth, d);
    150   }
    151 };
    152 
    153 template <typename T, typename U>
    154 struct FusedBatchNormGrad<CPUDevice, T, U> {
    155   void operator()(OpKernelContext* context, const Tensor& y_backprop_input,
    156                   const Tensor& x_input, const Tensor& scale_input,
    157                   const Tensor& mean_input, const Tensor& variance_input,
    158                   U epsilon, Tensor* x_backprop_output,
    159                   Tensor* scale_backprop_output, Tensor* offset_backprop_output,
    160                   TensorFormat tensor_format) {
    161     OP_REQUIRES(context, tensor_format == FORMAT_NHWC,
    162                 errors::Internal("The CPU implementation of FusedBatchNormGrad "
    163                                  "only supports NHWC tensor format for now."));
    164     typename TTypes<T, 4>::ConstTensor y_backprop(
    165         y_backprop_input.tensor<T, 4>());
    166     typename TTypes<T, 4>::ConstTensor x(x_input.tensor<T, 4>());
    167     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
    168     typename TTypes<U>::ConstVec mean(mean_input.vec<U>());
    169     typename TTypes<U>::ConstVec variance(variance_input.vec<U>());
    170     typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
    171     typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
    172     typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
    173 
    174     // Note: the following formulas are used to compute the gradients for
    175     // back propagation.
    176     // x_backprop = scale * rsqrt(variance + epsilon) *
    177     //              [y_backprop - mean(y_backprop) - (x - mean(x)) *
    178     //              mean(y_backprop * (x - mean(x))) / (variance + epsilon)]
    179     // scale_backprop = sum(y_backprop *
    180     //                  (x - mean(x)) * rsqrt(variance + epsilon))
    181     // offset_backprop = sum(y_backprop)
    182 
    183     const CPUDevice& d = context->eigen_device<CPUDevice>();
    184     const int depth = x.dimension(3);
    185     const int size = x.size();
    186     const int rest_size = size / depth;
    187     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
    188 
    189 #if !defined(EIGEN_HAS_INDEX_LIST)
    190     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
    191     Eigen::array<int, 1> reduce_dims({0});
    192     Eigen::array<int, 2> bcast_spec({rest_size, 1});
    193 #else
    194     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
    195     one_by_depth.set(1, depth);
    196     Eigen::IndexList<Eigen::type2index<0> > reduce_dims;
    197     Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > bcast_spec;
    198     bcast_spec.set(0, rest_size);
    199 #endif
    200 
    201     auto x_rest_by_depth = x.reshape(rest_by_depth).template cast<U>();
    202     U rest_size_inv = static_cast<U>(1.0f / static_cast<U>(rest_size));
    203 
    204     auto x_mean_rest_by_depth =
    205         mean.reshape(one_by_depth).broadcast(bcast_spec);
    206     auto x_centered = (x_rest_by_depth - x_mean_rest_by_depth).eval();
    207     auto coef0 = (variance + epsilon).rsqrt();
    208     auto coef0_rest_by_depth =
    209         coef0.eval().reshape(one_by_depth).broadcast(bcast_spec);
    210     auto x_scaled = x_centered * coef0_rest_by_depth;
    211 
    212     auto y_backprop_rest_by_depth =
    213         y_backprop.eval().reshape(rest_by_depth).template cast<U>();
    214     scale_backprop.device(d) =
    215         (y_backprop_rest_by_depth * x_scaled).sum(reduce_dims);
    216     auto y_backprop_sum = y_backprop_rest_by_depth.sum(reduce_dims);
    217     offset_backprop.device(d) = y_backprop_sum;
    218 
    219     auto y_backprop_sum_one_by_depth =
    220         y_backprop_sum.eval().reshape(one_by_depth);
    221     auto y_backprop_mean_one_by_depth =
    222         y_backprop_sum_one_by_depth * rest_size_inv;
    223     auto y_backprop_mean_rest_by_depth =
    224         y_backprop_mean_one_by_depth.broadcast(bcast_spec);
    225     auto y_backprop_centered =
    226         y_backprop_rest_by_depth - y_backprop_mean_rest_by_depth;
    227     auto coef1 =
    228         (scale * coef0).eval().reshape(one_by_depth).broadcast(bcast_spec);
    229     auto coef2 = (coef0.square() *
    230                   (y_backprop_rest_by_depth * x_centered).mean(reduce_dims))
    231                      .eval()
    232                      .reshape(one_by_depth)
    233                      .broadcast(bcast_spec);
    234     x_backprop.reshape(rest_by_depth).device(d) =
    235         (coef1 * (y_backprop_centered - x_centered * coef2)).template cast<T>();
    236   }
    237 };
    238 
    239 #if GOOGLE_CUDA
    240 template <typename T, typename U>
    241 struct FusedBatchNorm<GPUDevice, T, U> {
    242   void operator()(OpKernelContext* context, const Tensor& x,
    243                   const Tensor& scale, const Tensor& offset,
    244                   const Tensor& estimated_mean,
    245                   const Tensor& estimated_variance, U epsilon, Tensor* y,
    246                   Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean,
    247                   Tensor* saved_inv_var, TensorFormat tensor_format,
    248                   bool is_training) {
    249     auto* stream = context->op_device_context()->stream();
    250     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
    251 
    252     const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
    253     const int64 channels = GetTensorDim(x, tensor_format, 'C');
    254     const int64 height = GetTensorDim(x, tensor_format, 'H');
    255     const int64 width = GetTensorDim(x, tensor_format, 'W');
    256     VLOG(2) << "FusedBatchNorm:"
    257             << " batch_size: " << batch_size << " channels: " << channels
    258             << " height: " << height << " width:" << width
    259             << " x shape: " << x.shape().DebugString()
    260             << " scale shape: " << scale.shape().DebugString()
    261             << " offset shape: " << offset.shape().DebugString()
    262             << " tensor format: " << tensor_format;
    263 
    264     // If input is empty, return NaN mean/variance
    265     if (x.shape().num_elements() == 0) {
    266       functor::SetNanFunctor<U> f;
    267       f(context->eigen_device<GPUDevice>(), batch_mean->flat<U>());
    268       f(context->eigen_device<GPUDevice>(), batch_var->flat<U>());
    269       return;
    270     }
    271 
    272     Tensor x_maybe_transformed = x;
    273     Tensor x_transformed;
    274     Tensor y_transformed;
    275     se::DeviceMemory<T> y_ptr;
    276 
    277     if (tensor_format == FORMAT_NCHW) {
    278       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y);
    279     } else if (tensor_format == FORMAT_NHWC) {
    280       OP_REQUIRES_OK(context, context->allocate_temp(
    281                                   DataTypeToEnum<T>::value,
    282                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    283                                                   height, width, channels),
    284                                   &x_transformed));
    285       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    286           context->eigen_device<GPUDevice>(),
    287           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
    288           x_transformed.tensor<T, 4>());
    289       x_maybe_transformed = x_transformed;
    290 
    291       OP_REQUIRES_OK(context, context->allocate_temp(
    292                                   DataTypeToEnum<T>::value,
    293                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    294                                                   height, width, channels),
    295                                   &y_transformed));
    296       y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(y_transformed);
    297     } else {
    298       context->SetStatus(
    299           errors::Internal("Unsupported tensor format: ", tensor_format));
    300       return;
    301     }
    302 
    303     se::dnn::BatchDescriptor x_desc;
    304     x_desc.set_count(batch_size)
    305         .set_feature_map_count(channels)
    306         .set_height(height)
    307         .set_width(width)
    308         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    309 
    310     se::dnn::BatchDescriptor scale_offset_desc;
    311     scale_offset_desc.set_count(1)
    312         .set_feature_map_count(channels)
    313         .set_height(1)
    314         .set_width(1)
    315         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    316 
    317     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
    318     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
    319     auto offset_ptr = StreamExecutorUtil::AsDeviceMemory<U>(offset);
    320     auto estimated_mean_ptr =
    321         StreamExecutorUtil::AsDeviceMemory<U>(estimated_mean);
    322     auto estimated_variance_ptr =
    323         StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
    324     auto batch_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_mean);
    325 
    326     auto batch_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*batch_var);
    327     auto saved_mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(*saved_mean);
    328     auto saved_inv_var_ptr =
    329         StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
    330 
    331     GPUDevice d = context->eigen_device<GPUDevice>();
    332     using se::DeviceMemory;
    333     Tensor inv_var;
    334     OP_REQUIRES_OK(
    335         context, context->allocate_temp(DataTypeToEnum<U>::value,
    336                                         estimated_variance.shape(), &inv_var));
    337     auto inv_var_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_var);
    338     std::function<const DeviceMemory<U>&()> var_to_inv_var =
    339         [d, epsilon, estimated_variance,
    340          &inv_var_ptr]() -> const DeviceMemory<U>& {
    341       auto estimated_variance_ptr =
    342           StreamExecutorUtil::AsDeviceMemory<U>(estimated_variance);
    343       const U* variance =
    344           static_cast<const U*>(estimated_variance_ptr.opaque());
    345       U* inv_variance = static_cast<U*>(inv_var_ptr.opaque());
    346       int channels = inv_var_ptr.ElementCount();
    347       VarianceToInvVariance<U>()(d, variance, epsilon, channels, inv_variance);
    348       return inv_var_ptr;
    349     };
    350     const int64 sample_size = batch_size * height * width;
    351     std::function<void()> inv_var_to_var = [d, &batch_var_ptr, epsilon,
    352                                             sample_size]() {
    353       U* variance = static_cast<U*>(batch_var_ptr.opaque());
    354       int channels = batch_var_ptr.ElementCount();
    355       InvVarianceToVariance<U>()(d, epsilon, sample_size, channels, variance);
    356     };
    357 
    358     bool cudnn_launch_status =
    359         stream
    360             ->ThenBatchNormalizationForward(
    361                 x_ptr, scale_ptr, offset_ptr, estimated_mean_ptr,
    362                 estimated_variance_ptr, x_desc, scale_offset_desc,
    363                 static_cast<double>(epsilon), &y_ptr, &batch_mean_ptr,
    364                 &batch_var_ptr, &saved_mean_ptr, &saved_inv_var_ptr,
    365                 is_training, std::move(var_to_inv_var),
    366                 std::move(inv_var_to_var))
    367             .ok();
    368 
    369     if (!cudnn_launch_status) {
    370       context->SetStatus(
    371           errors::Internal("cuDNN launch failure : input shape (",
    372                            x.shape().DebugString(), ")"));
    373     }
    374     if (tensor_format == FORMAT_NHWC) {
    375       functor::NCHWToNHWC<GPUDevice, T, 4>()(
    376           context->eigen_device<GPUDevice>(),
    377           const_cast<const Tensor&>(y_transformed).tensor<T, 4>(),
    378           y->tensor<T, 4>());
    379     }
    380   }
    381 };
    382 
    383 template <typename T, typename U>
    384 struct FusedBatchNormGrad<GPUDevice, T, U> {
    385   void operator()(OpKernelContext* context, const Tensor& y_backprop,
    386                   const Tensor& x, const Tensor& scale, const Tensor& mean,
    387                   const Tensor& inv_variance, U epsilon, Tensor* x_backprop,
    388                   Tensor* scale_backprop, Tensor* offset_backprop,
    389                   TensorFormat tensor_format) {
    390     auto* stream = context->op_device_context()->stream();
    391     OP_REQUIRES(context, stream, errors::Internal("No GPU stream available"));
    392 
    393     const int64 batch_size = GetTensorDim(x, tensor_format, 'N');
    394     const int64 channels = GetTensorDim(x, tensor_format, 'C');
    395     const int64 height = GetTensorDim(x, tensor_format, 'H');
    396     const int64 width = GetTensorDim(x, tensor_format, 'W');
    397 
    398     VLOG(2) << "FusedBatchNormGrad:"
    399             << " batch_size: " << batch_size << " channels: " << channels
    400             << " height: " << height << " width: " << width
    401             << " y_backprop shape: " << y_backprop.shape().DebugString()
    402             << " x shape: " << x.shape().DebugString()
    403             << " scale shape: " << scale.shape().DebugString()
    404             << " tensor format: " << tensor_format;
    405 
    406     // Inputs
    407     Tensor y_backprop_maybe_transformed = y_backprop;
    408     Tensor x_maybe_transformed = x;
    409     Tensor y_backprop_transformed;
    410     Tensor x_transformed;
    411 
    412     // Outputs
    413     Tensor x_backprop_transformed;
    414     se::DeviceMemory<T> x_backprop_ptr;
    415 
    416     if (tensor_format == FORMAT_NCHW) {
    417       x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
    418     } else if (tensor_format == FORMAT_NHWC) {
    419       // Transform inputs from 'NHWC' to 'NCHW'
    420       OP_REQUIRES_OK(context, context->allocate_temp(
    421                                   DataTypeToEnum<T>::value,
    422                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    423                                                   height, width, channels),
    424                                   &y_backprop_transformed));
    425       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    426           context->eigen_device<GPUDevice>(),
    427           const_cast<const Tensor&>(y_backprop_maybe_transformed)
    428               .tensor<T, 4>(),
    429           y_backprop_transformed.tensor<T, 4>());
    430       y_backprop_maybe_transformed = y_backprop_transformed;
    431 
    432       OP_REQUIRES_OK(context, context->allocate_temp(
    433                                   DataTypeToEnum<T>::value,
    434                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    435                                                   height, width, channels),
    436                                   &x_transformed));
    437       functor::NHWCToNCHW<GPUDevice, T, 4>()(
    438           context->eigen_device<GPUDevice>(),
    439           const_cast<const Tensor&>(x_maybe_transformed).tensor<T, 4>(),
    440           x_transformed.tensor<T, 4>());
    441       x_maybe_transformed = x_transformed;
    442 
    443       // Allocate memory for transformed outputs in 'NCHW'
    444       OP_REQUIRES_OK(context, context->allocate_temp(
    445                                   DataTypeToEnum<T>::value,
    446                                   ShapeFromFormat(FORMAT_NCHW, batch_size,
    447                                                   height, width, channels),
    448                                   &x_backprop_transformed));
    449       x_backprop_ptr =
    450           StreamExecutorUtil::AsDeviceMemory<T>(x_backprop_transformed);
    451     } else {
    452       context->SetStatus(
    453           errors::Internal("Unsupported tensor format: ", tensor_format));
    454       return;
    455     }
    456 
    457     se::dnn::BatchDescriptor x_desc;
    458     x_desc.set_count(batch_size)
    459         .set_feature_map_count(channels)
    460         .set_height(height)
    461         .set_width(width)
    462         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    463 
    464     se::dnn::BatchDescriptor scale_offset_desc;
    465     scale_offset_desc.set_count(1)
    466         .set_feature_map_count(channels)
    467         .set_height(1)
    468         .set_width(1)
    469         .set_layout(se::dnn::DataLayout::kBatchDepthYX);
    470 
    471     auto y_backprop_ptr =
    472         StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
    473     auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
    474     auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
    475     auto mean_ptr = StreamExecutorUtil::AsDeviceMemory<U>(mean);
    476     auto inv_variance_ptr = StreamExecutorUtil::AsDeviceMemory<U>(inv_variance);
    477     auto scale_backprop_ptr =
    478         StreamExecutorUtil::AsDeviceMemory<U>(*scale_backprop);
    479     auto offset_backprop_ptr =
    480         StreamExecutorUtil::AsDeviceMemory<U>(*offset_backprop);
    481 
    482     // the cudnn kernel outputs inverse variance in forward and reuse it in
    483     // backward
    484     bool cudnn_launch_status =
    485         stream
    486             ->ThenBatchNormalizationBackward(
    487                 y_backprop_ptr, x_ptr, scale_ptr, mean_ptr, inv_variance_ptr,
    488                 x_desc, scale_offset_desc, static_cast<double>(epsilon),
    489                 &x_backprop_ptr, &scale_backprop_ptr, &offset_backprop_ptr)
    490             .ok();
    491 
    492     if (!cudnn_launch_status) {
    493       context->SetStatus(
    494           errors::Internal("cuDNN launch failure : input shape (",
    495                            x.shape().DebugString(), ")"));
    496     }
    497     if (tensor_format == FORMAT_NHWC) {
    498       functor::NCHWToNHWC<GPUDevice, T, 4>()(
    499           context->eigen_device<GPUDevice>(),
    500           const_cast<const Tensor&>(x_backprop_transformed).tensor<T, 4>(),
    501           x_backprop->tensor<T, 4>());
    502     }
    503   }
    504 };
    505 
    506 // Forward declarations of the functor specializations for GPU.
    507 #define DECLARE_GPU_SPEC(T, U)                                           \
    508   template <>                                                            \
    509   void FusedBatchNormFreezeGrad<GPUDevice, T, U>::operator()(            \
    510       const GPUDevice& d, const Tensor& y_backprop_input,                \
    511       const Tensor& x_input, const Tensor& scale_input,                  \
    512       const Tensor& mean_input, const Tensor& variance_input, U epsilon, \
    513       Tensor* x_backprop_output, Tensor* scale_backprop_output,          \
    514       Tensor* offset_backprop_output, typename TTypes<U>::Vec scratch1,  \
    515       typename TTypes<U>::Vec scratch2);                                 \
    516   extern template struct FusedBatchNormFreezeGrad<GPUDevice, T, U>;
    517 DECLARE_GPU_SPEC(float, float);
    518 DECLARE_GPU_SPEC(Eigen::half, float);
    519 
    520 #endif  // GOOGLE_CUDA
    521 }  // namespace functor
    522 
    523 template <typename Device, typename T, typename U>
    524 class FusedBatchNormOp : public OpKernel {
    525  public:
    526   explicit FusedBatchNormOp(OpKernelConstruction* context) : OpKernel(context) {
    527     float epsilon;
    528     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    529     epsilon_ = U(epsilon);
    530     string tensor_format;
    531     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    532     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
    533                 errors::InvalidArgument("Invalid data format"));
    534     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
    535   }
    536 
    537   void Compute(OpKernelContext* context) override {
    538     const Tensor& x = context->input(0);
    539     const Tensor& scale = context->input(1);
    540     const Tensor& offset = context->input(2);
    541     const Tensor& estimated_mean = context->input(3);
    542     const Tensor& estimated_variance = context->input(4);
    543 
    544     OP_REQUIRES(context, x.dims() == 4,
    545                 errors::InvalidArgument("input must be 4-dimensional",
    546                                         x.shape().DebugString()));
    547     OP_REQUIRES(context, scale.dims() == 1,
    548                 errors::InvalidArgument("scale must be 1-dimensional",
    549                                         scale.shape().DebugString()));
    550     OP_REQUIRES(context, offset.dims() == 1,
    551                 errors::InvalidArgument("offset must be 1-dimensional",
    552                                         offset.shape().DebugString()));
    553     OP_REQUIRES(context, estimated_mean.dims() == 1,
    554                 errors::InvalidArgument("estimated_mean must be 1-dimensional",
    555                                         estimated_mean.shape().DebugString()));
    556     OP_REQUIRES(
    557         context, estimated_variance.dims() == 1,
    558         errors::InvalidArgument("estimated_variance must be 1-dimensional",
    559                                 estimated_variance.shape().DebugString()));
    560     if (is_training_) {
    561       OP_REQUIRES(
    562           context, estimated_mean.dim_size(0) == 0,
    563           errors::InvalidArgument("estimated_mean must be empty for training",
    564                                   estimated_mean.shape().DebugString()));
    565       OP_REQUIRES(context, estimated_variance.dim_size(0) == 0,
    566                   errors::InvalidArgument(
    567                       "estimated_variance must be empty for training",
    568                       estimated_variance.shape().DebugString()));
    569     }
    570 
    571     Tensor* y = nullptr;
    572     OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
    573                                 {0}, 0, x.shape(), &y));
    574     Tensor* batch_mean = nullptr;
    575     OP_REQUIRES_OK(context,
    576                    context->allocate_output(1, scale.shape(), &batch_mean));
    577     Tensor* batch_var = nullptr;
    578     OP_REQUIRES_OK(context,
    579                    context->allocate_output(2, scale.shape(), &batch_var));
    580     Tensor* saved_mean = nullptr;
    581     OP_REQUIRES_OK(context,
    582                    context->allocate_output(3, scale.shape(), &saved_mean));
    583     Tensor* saved_maybe_inv_var = nullptr;
    584     OP_REQUIRES_OK(context, context->allocate_output(4, scale.shape(),
    585                                                      &saved_maybe_inv_var));
    586 
    587     functor::FusedBatchNorm<Device, T, U>()(
    588         context, x, scale, offset, estimated_mean, estimated_variance, epsilon_,
    589         y, batch_mean, batch_var, saved_mean, saved_maybe_inv_var,
    590         tensor_format_, is_training_);
    591   }
    592 
    593  private:
    594   U epsilon_;
    595   TensorFormat tensor_format_;
    596   bool is_training_;
    597 };
    598 
    599 template <typename Device, typename T, typename U>
    600 class FusedBatchNormGradOp : public OpKernel {
    601  public:
    602   explicit FusedBatchNormGradOp(OpKernelConstruction* context)
    603       : OpKernel(context) {
    604     float epsilon;
    605     OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
    606     epsilon_ = U(epsilon);
    607     string tensor_format;
    608     OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
    609     OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
    610                 errors::InvalidArgument("Invalid data format"));
    611     OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
    612   }
    613 
    614   void Compute(OpKernelContext* context) override {
    615     const Tensor& y_backprop = context->input(0);
    616     const Tensor& x = context->input(1);
    617     const Tensor& scale = context->input(2);
    618     // When is_training=True, batch mean and variance/inverted variance are
    619     // saved in the forward pass to be reused here. When is_training=False,
    620     // population mean and variance need to be forwarded here to compute the
    621     // gradients.
    622     const Tensor& saved_mean_or_pop_mean = context->input(3);
    623     // The Eigen implementation saves variance in the forward pass, while cuDNN
    624     // saves inverted variance.
    625     const Tensor& saved_maybe_inv_var_or_pop_var = context->input(4);
    626 
    627     OP_REQUIRES(context, y_backprop.dims() == 4,
    628                 errors::InvalidArgument("input must be 4-dimensional",
    629                                         y_backprop.shape().DebugString()));
    630     OP_REQUIRES(context, x.dims() == 4,
    631                 errors::InvalidArgument("input must be 4-dimensional",
    632                                         x.shape().DebugString()));
    633     OP_REQUIRES(context, scale.dims() == 1,
    634                 errors::InvalidArgument("scale must be 1-dimensional",
    635                                         scale.shape().DebugString()));
    636     OP_REQUIRES(
    637         context, saved_mean_or_pop_mean.dims() == 1,
    638         errors::InvalidArgument("saved mean must be 1-dimensional",
    639                                 saved_mean_or_pop_mean.shape().DebugString()));
    640     OP_REQUIRES(context, saved_maybe_inv_var_or_pop_var.dims() == 1,
    641                 errors::InvalidArgument(
    642                     "saved variance must be 1-dimensional",
    643                     saved_maybe_inv_var_or_pop_var.shape().DebugString()));
    644 
    645     Tensor* x_backprop = nullptr;
    646     OP_REQUIRES_OK(context,
    647                    context->allocate_output(0, x.shape(), &x_backprop));
    648 
    649     const TensorShape& scale_offset_shape = scale.shape();
    650     Tensor* scale_backprop = nullptr;
    651     OP_REQUIRES_OK(context, context->allocate_output(1, scale_offset_shape,
    652                                                      &scale_backprop));
    653     Tensor* offset_backprop = nullptr;
    654     OP_REQUIRES_OK(context, context->allocate_output(2, scale_offset_shape,
    655                                                      &offset_backprop));
    656     // Two placeholders for estimated_mean and estimated_variance, which are
    657     // used for inference and thus not needed here for gradient computation.
    658     // They are filled with zeros so as to avoid NaN outputs.
    659     Tensor* placeholder_1 = nullptr;
    660     OP_REQUIRES_OK(
    661         context, context->allocate_output(3, TensorShape({}), &placeholder_1));
    662     functor::SetZeroFunctor<Device, float> f;
    663     f(context->eigen_device<Device>(), placeholder_1->flat<U>());
    664     Tensor* placeholder_2 = nullptr;
    665     OP_REQUIRES_OK(
    666         context, context->allocate_output(4, TensorShape({}), &placeholder_2));
    667     f(context->eigen_device<Device>(), placeholder_2->flat<U>());
    668 
    669     // If input is empty, set gradients w.r.t scale/offset to zero.
    670     if (x.shape().num_elements() == 0) {
    671       functor::SetZeroFunctor<Device, U> f;
    672       f(context->eigen_device<Device>(), scale_backprop->flat<U>());
    673       f(context->eigen_device<Device>(), offset_backprop->flat<U>());
    674       return;
    675     }
    676 
    677     if (is_training_) {
    678       functor::FusedBatchNormGrad<Device, T, U>()(
    679           context, y_backprop, x, scale, saved_mean_or_pop_mean,
    680           saved_maybe_inv_var_or_pop_var, epsilon_, x_backprop, scale_backprop,
    681           offset_backprop, tensor_format_);
    682 
    683     } else {
    684       // Necessary layout conversion is currently done in python.
    685       CHECK(tensor_format_ == FORMAT_NHWC)
    686           << "The implementation of FusedBatchNormGrad with is_training=False "
    687              "only support "
    688           << "NHWC tensor format for now.";
    689       Tensor scratch1, scratch2;
    690       OP_REQUIRES_OK(context,
    691                      context->allocate_temp(DataTypeToEnum<U>::value,
    692                                             scale_offset_shape, &scratch1));
    693       OP_REQUIRES_OK(context,
    694                      context->allocate_temp(DataTypeToEnum<U>::value,
    695                                             scale_offset_shape, &scratch2));
    696       functor::FusedBatchNormFreezeGrad<Device, T, U>()(
    697           context->eigen_device<Device>(), y_backprop, x, scale,
    698           saved_mean_or_pop_mean, saved_maybe_inv_var_or_pop_var, epsilon_,
    699           x_backprop, scale_backprop, offset_backprop, scratch1.vec<U>(),
    700           scratch2.vec<U>());
    701     }
    702   }
    703 
    704  private:
    705   U epsilon_;
    706   TensorFormat tensor_format_;
    707   bool is_training_;
    708 };
    709 
    710 REGISTER_KERNEL_BUILDER(
    711     Name("FusedBatchNorm").Device(DEVICE_CPU).TypeConstraint<float>("T"),
    712     FusedBatchNormOp<CPUDevice, float, float>);
    713 
    714 REGISTER_KERNEL_BUILDER(
    715     Name("FusedBatchNormGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
    716     FusedBatchNormGradOp<CPUDevice, float, float>);
    717 
    718 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    719                             .Device(DEVICE_CPU)
    720                             .TypeConstraint<float>("T")
    721                             .TypeConstraint<float>("U"),
    722                         FusedBatchNormOp<CPUDevice, float, float>);
    723 
    724 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    725                             .Device(DEVICE_CPU)
    726                             .TypeConstraint<float>("T")
    727                             .TypeConstraint<float>("U"),
    728                         FusedBatchNormGradOp<CPUDevice, float, float>);
    729 
    730 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    731                             .Device(DEVICE_CPU)
    732                             .TypeConstraint<Eigen::half>("T")
    733                             .TypeConstraint<float>("U"),
    734                         FusedBatchNormOp<CPUDevice, Eigen::half, float>);
    735 
    736 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    737                             .Device(DEVICE_CPU)
    738                             .TypeConstraint<Eigen::half>("T")
    739                             .TypeConstraint<float>("U"),
    740                         FusedBatchNormGradOp<CPUDevice, Eigen::half, float>);
    741 
    742 #if GOOGLE_CUDA
    743 
    744 REGISTER_KERNEL_BUILDER(
    745     Name("FusedBatchNorm").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    746     FusedBatchNormOp<GPUDevice, float, float>);
    747 
    748 REGISTER_KERNEL_BUILDER(
    749     Name("FusedBatchNormGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
    750     FusedBatchNormGradOp<GPUDevice, float, float>);
    751 
    752 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    753                             .Device(DEVICE_GPU)
    754                             .TypeConstraint<float>("T")
    755                             .TypeConstraint<float>("U"),
    756                         FusedBatchNormOp<GPUDevice, float, float>);
    757 
    758 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    759                             .Device(DEVICE_GPU)
    760                             .TypeConstraint<float>("T")
    761                             .TypeConstraint<float>("U"),
    762                         FusedBatchNormGradOp<GPUDevice, float, float>);
    763 
    764 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormV2")
    765                             .Device(DEVICE_GPU)
    766                             .TypeConstraint<Eigen::half>("T")
    767                             .TypeConstraint<float>("U"),
    768                         FusedBatchNormOp<GPUDevice, Eigen::half, float>);
    769 
    770 REGISTER_KERNEL_BUILDER(Name("FusedBatchNormGradV2")
    771                             .Device(DEVICE_GPU)
    772                             .TypeConstraint<Eigen::half>("T")
    773                             .TypeConstraint<float>("U"),
    774                         FusedBatchNormGradOp<GPUDevice, Eigen::half, float>);
    775 
    776 #endif
    777 
    778 }  // namespace tensorflow
    779