Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
     17 #define TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/tensor.h"
     21 #include "tensorflow/core/framework/tensor_types.h"
     22 
     23 namespace tensorflow {
     24 namespace functor {
     25 
     26 #if GOOGLE_CUDA
     27 
     28 // There is a behavior difference between cuDNN v4 and v5 with regard to the
     29 // scaling factor for function cudnnBatchNormalizationForwardInference.
     30 // This function corrects the scaling factor if cuDNN v4 is used, so that
     31 // this behavior inconsistency is hidden from TensorFlow users.
     32 // Details: in cuDNN v4, y = bnScale * (x - mean) * variance + bnBias;
     33 // in v5, y = bnScale * (x - mean) / sqrt(variance + epsilon) + bnBias
     34 // The template is instantiated with T as float in batch_norm_ops.cu.cc; for
     35 // other types, the instantiation needs to be added accordingly.
     36 template <class T>
     37 struct VarianceToInvVariance {
     38   void operator()(const Eigen::GpuDevice& d, const T* variance, double epsilon,
     39                   int channels, T* inv_variance);
     40 };
     41 
     42 // This function converts the inverted variance of the cuDNN forward training
     43 // output to variance for TensorFlow to calculate the running variance.
     44 // The template is instantiated with T as float in batch_norm_ops.cu.cc; for
     45 // other types, the instantiation needs to be added accordingly.
     46 template <class T>
     47 struct InvVarianceToVariance {
     48   void operator()(const Eigen::GpuDevice& d, double epsilon, int sample_size,
     49                   int channels, T* variance);
     50 };
     51 
     52 // This function sets a GPU tensor to NaNs.
     53 template <class T>
     54 struct SetNanFunctor {
     55   void operator()(const Eigen::GpuDevice& d, typename TTypes<T>::Flat out);
     56 };
     57 
     58 #endif  // GOOGLE_CUDA
     59 
     60 // Functor used by FusedBatchNormGradOp to do the computations when
     61 // is_training=False. Both CPU and GPU will use this functor.
     62 template <typename Device, typename T, typename U>
     63 struct FusedBatchNormFreezeGrad {
     64   void operator()(const Device& d, const Tensor& y_backprop_input,
     65                   const Tensor& x_input, const Tensor& scale_input,
     66                   const Tensor& pop_mean_input,
     67                   const Tensor& pop_variance_input, U epsilon,
     68                   Tensor* x_backprop_output, Tensor* scale_backprop_output,
     69                   Tensor* offset_backprop_output,
     70                   typename TTypes<U>::Vec scratch1,
     71                   typename TTypes<U>::Vec scratch2) {
     72     typename TTypes<T, 4>::ConstTensor y_backprop(
     73         y_backprop_input.tensor<T, 4>());
     74     typename TTypes<T, 4>::ConstTensor input(x_input.tensor<T, 4>());
     75     typename TTypes<U>::ConstVec scale(scale_input.vec<U>());
     76     typename TTypes<U>::ConstVec pop_mean(pop_mean_input.vec<U>());
     77     typename TTypes<U>::ConstVec pop_var(pop_variance_input.vec<U>());
     78     typename TTypes<T, 4>::Tensor x_backprop(x_backprop_output->tensor<T, 4>());
     79     typename TTypes<U>::Vec scale_backprop(scale_backprop_output->vec<U>());
     80     typename TTypes<U>::Vec offset_backprop(offset_backprop_output->vec<U>());
     81 
     82     const int depth = pop_mean.dimension(0);
     83     const int rest_size = input.size() / depth;
     84 
     85     Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
     86 #if !defined(EIGEN_HAS_INDEX_LIST)
     87     Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
     88     Eigen::array<int, 1> reduction_axis{0};
     89     Eigen::array<int, 2> rest_by_one({rest_size, 1});
     90 #else
     91     Eigen::IndexList<Eigen::type2index<1>, Eigen::Index> one_by_depth;
     92     one_by_depth.set(1, depth);
     93     Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
     94     Eigen::IndexList<Eigen::Index, Eigen::type2index<1> > rest_by_one;
     95     rest_by_one.set(0, rest_size);
     96 #endif
     97 
     98     // offset_backprop  = sum(y_backprop)
     99     // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + epsilon))
    100     // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon))
    101 
    102     auto y_backprop_rest_by_depth =
    103         y_backprop.reshape(rest_by_depth).template cast<U>();
    104     auto input_rest_by_depth = input.reshape(rest_by_depth).template cast<U>();
    105 
    106     offset_backprop.device(d) = y_backprop_rest_by_depth.sum(reduction_axis);
    107 
    108     // scratch1 = rsqrt(pop_var + epsilon)
    109     scratch1.device(d) = (pop_var + pop_var.constant(epsilon)).rsqrt();
    110 
    111     // scratch2 = sum(y_backprop * (x - mean))
    112     scratch2.device(d) =
    113         (y_backprop_rest_by_depth *
    114          (input_rest_by_depth -
    115           pop_mean.reshape(one_by_depth).broadcast(rest_by_one)))
    116             .sum(reduction_axis);
    117 
    118     x_backprop.reshape(rest_by_depth).device(d) =
    119         (y_backprop_rest_by_depth * ((scratch1 * scale)
    120                                          .eval()
    121                                          .reshape(one_by_depth)
    122                                          .broadcast(rest_by_one)))
    123             .template cast<T>();
    124     scale_backprop.device(d) = scratch2 * scratch1;
    125   }
    126 };
    127 
    128 }  // namespace functor
    129 }  // namespace tensorflow
    130 
    131 #endif  // TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
    132