Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
     17 #define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
     18 // Functor definition for BatchNormOp, must be compilable by nvcc.
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/tensor_types.h"
     21 
     22 namespace tensorflow {
     23 namespace functor {
     24 
     25 // Functor used by BatchNormOp to do the computations.
     26 template <typename Device, typename T>
     27 struct BatchNorm {
     28   void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
     29                   typename TTypes<T>::ConstVec mean,
     30                   typename TTypes<T>::ConstVec var,
     31                   typename TTypes<T>::ConstVec beta,
     32                   typename TTypes<T>::ConstVec gamma, T variance_epsilon,
     33                   bool scale_after_normalization,
     34                   typename TTypes<T, 4>::Tensor output) {
     35     const int depth = mean.dimension(0);
     36     const int rest_size = input.size() / depth;
     37 
     38     Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth);
     39 #if !defined(EIGEN_HAS_INDEX_LIST)
     40     Eigen::DSizes<int, 2> rest_by_one(rest_size, 1);
     41     Eigen::DSizes<int, 2> one_by_depth(1, depth);
     42     Eigen::DSizes<int, 2> depth_by_one(depth, 1);
     43 #else
     44     Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one;
     45     rest_by_one.set(0, rest_size);
     46     Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth;
     47     one_by_depth.set(1, depth);
     48     Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one;
     49     depth_by_one.set(0, depth);
     50 #endif
     51     if (scale_after_normalization) {
     52       output.reshape(rest_by_depth).device(d) =
     53           (input.reshape(rest_by_depth) -
     54            mean.reshape(one_by_depth).broadcast(rest_by_one)) *
     55               ((var + var.constant(variance_epsilon)).rsqrt() * gamma)
     56                   .eval()
     57                   .reshape(one_by_depth)
     58                   .broadcast(rest_by_one) +
     59           beta.reshape(one_by_depth).broadcast(rest_by_one);
     60     } else {
     61       output.reshape(rest_by_depth).device(d) =
     62           (input.reshape(rest_by_depth) -
     63            mean.reshape(one_by_depth).broadcast(rest_by_one)) *
     64               ((var + var.constant(variance_epsilon)).rsqrt())
     65                   .eval()
     66                   .reshape(one_by_depth)
     67                   .broadcast(rest_by_one) +
     68           beta.reshape(one_by_depth).broadcast(rest_by_one);
     69     }
     70   }
     71 };
     72 
     73 template <typename Device, typename T>
     74 struct BatchNormGrad {
     75   void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input,
     76                   typename TTypes<T>::ConstVec mean,
     77                   typename TTypes<T>::ConstVec var,
     78                   typename TTypes<T>::ConstVec gamma,
     79                   typename TTypes<T, 4>::ConstTensor out_backprop,
     80                   T variance_epsilon, bool scale_after_normalization,
     81                   typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm,
     82                   typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db,
     83                   typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1,
     84                   typename TTypes<T>::Vec scratch2) {
     85     const int depth = mean.dimension(0);
     86     const int rest_size = input.size() / depth;
     87 
     88     typedef typename TTypes<T>::ConstVec::Index Index;
     89 
     90     Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth);
     91 #if !defined(EIGEN_HAS_INDEX_LIST)
     92     Eigen::DSizes<Index, 2> rest_by_one(rest_size, 1);
     93     Eigen::DSizes<Index, 2> one_by_depth(1, depth);
     94     Eigen::array<Index, 1> reduction_axis;
     95     reduction_axis[0] = 0;  // Reduces on first dimension.
     96 #else
     97     Eigen::IndexList<Index, Eigen::type2index<1> > rest_by_one;
     98     rest_by_one.set(0, rest_size);
     99     Eigen::IndexList<Eigen::type2index<1>, Index> one_by_depth;
    100     one_by_depth.set(1, depth);
    101     Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
    102 #endif
    103 
    104     // db = out_backprop
    105     //
    106     // dg = out_backprop * ((x - m) * rsqrt(v + epsilon))
    107     //
    108     // dv = sum_over_rest(out_backprop * gamma * (x - m)) *
    109     //      (-1/2) * (v + epsilon) ^ (-3/2)
    110     //
    111     // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon))
    112     //
    113     // dx = out_backprop * (gamma * rsqrt(v + epsilon))
    114     db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis);
    115 
    116     // scratch1 = rsqrt(v + epsilon)
    117     scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt();
    118 
    119     // scratch2 = sum_over_rest(out_backprop * (x - m))
    120     scratch2.device(d) = (out_backprop.reshape(rest_by_depth) *
    121                           (input.reshape(rest_by_depth) -
    122                            mean.reshape(one_by_depth).broadcast(rest_by_one)))
    123                              .sum(reduction_axis);
    124 
    125     if (scale_after_normalization) {
    126       dx.reshape(rest_by_depth).device(d) =
    127           out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma)
    128                                                      .eval()
    129                                                      .reshape(one_by_depth)
    130                                                      .broadcast(rest_by_one));
    131       dm.device(d) = -db * (scratch1 * gamma).eval();
    132       dg.device(d) = scratch2 * scratch1;
    133     } else {
    134       dx.reshape(rest_by_depth).device(d) =
    135           out_backprop.reshape(rest_by_depth) *
    136           scratch1.reshape(one_by_depth).broadcast(rest_by_one);
    137       dm.device(d) = -db * scratch1;
    138       dg.device(d) = dg.constant(static_cast<T>(0.0));  // Gamma is not learned.
    139     }
    140 
    141     // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2)
    142     scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) /
    143                          (var + var.constant(variance_epsilon));
    144 
    145     if (scale_after_normalization) {
    146       dv.device(d) = scratch2 * (scratch1 * gamma).eval();
    147     } else {
    148       dv.device(d) = scratch2 * scratch1;
    149     }
    150   }
    151 };
    152 
    153 }  // namespace functor
    154 }  // namespace tensorflow
    155 
    156 #endif  // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
    157