1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ 17 #define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ 18 // Functor definition for BatchNormOp, must be compilable by nvcc. 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/tensor_types.h" 21 22 namespace tensorflow { 23 namespace functor { 24 25 // Functor used by BatchNormOp to do the computations. 26 template <typename Device, typename T> 27 struct BatchNorm { 28 void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input, 29 typename TTypes<T>::ConstVec mean, 30 typename TTypes<T>::ConstVec var, 31 typename TTypes<T>::ConstVec beta, 32 typename TTypes<T>::ConstVec gamma, T variance_epsilon, 33 bool scale_after_normalization, 34 typename TTypes<T, 4>::Tensor output) { 35 const int depth = mean.dimension(0); 36 const int rest_size = input.size() / depth; 37 38 Eigen::DSizes<int, 2> rest_by_depth(rest_size, depth); 39 #if !defined(EIGEN_HAS_INDEX_LIST) 40 Eigen::DSizes<int, 2> rest_by_one(rest_size, 1); 41 Eigen::DSizes<int, 2> one_by_depth(1, depth); 42 Eigen::DSizes<int, 2> depth_by_one(depth, 1); 43 #else 44 Eigen::IndexList<int, Eigen::type2index<1> > rest_by_one; 45 rest_by_one.set(0, rest_size); 46 Eigen::IndexList<Eigen::type2index<1>, int> one_by_depth; 47 one_by_depth.set(1, depth); 48 Eigen::IndexList<int, Eigen::type2index<1> > depth_by_one; 49 depth_by_one.set(0, depth); 50 #endif 51 if (scale_after_normalization) { 52 output.reshape(rest_by_depth).device(d) = 53 (input.reshape(rest_by_depth) - 54 mean.reshape(one_by_depth).broadcast(rest_by_one)) * 55 ((var + var.constant(variance_epsilon)).rsqrt() * gamma) 56 .eval() 57 .reshape(one_by_depth) 58 .broadcast(rest_by_one) + 59 beta.reshape(one_by_depth).broadcast(rest_by_one); 60 } else { 61 output.reshape(rest_by_depth).device(d) = 62 (input.reshape(rest_by_depth) - 63 mean.reshape(one_by_depth).broadcast(rest_by_one)) * 64 ((var + var.constant(variance_epsilon)).rsqrt()) 65 .eval() 66 .reshape(one_by_depth) 67 .broadcast(rest_by_one) + 68 beta.reshape(one_by_depth).broadcast(rest_by_one); 69 } 70 } 71 }; 72 73 template <typename Device, typename T> 74 struct BatchNormGrad { 75 void operator()(const Device& d, typename TTypes<T, 4>::ConstTensor input, 76 typename TTypes<T>::ConstVec mean, 77 typename TTypes<T>::ConstVec var, 78 typename TTypes<T>::ConstVec gamma, 79 typename TTypes<T, 4>::ConstTensor out_backprop, 80 T variance_epsilon, bool scale_after_normalization, 81 typename TTypes<T, 4>::Tensor dx, typename TTypes<T>::Vec dm, 82 typename TTypes<T>::Vec dv, typename TTypes<T>::Vec db, 83 typename TTypes<T>::Vec dg, typename TTypes<T>::Vec scratch1, 84 typename TTypes<T>::Vec scratch2) { 85 const int depth = mean.dimension(0); 86 const int rest_size = input.size() / depth; 87 88 typedef typename TTypes<T>::ConstVec::Index Index; 89 90 Eigen::DSizes<Index, 2> rest_by_depth(rest_size, depth); 91 #if !defined(EIGEN_HAS_INDEX_LIST) 92 Eigen::DSizes<Index, 2> rest_by_one(rest_size, 1); 93 Eigen::DSizes<Index, 2> one_by_depth(1, depth); 94 Eigen::array<Index, 1> reduction_axis; 95 reduction_axis[0] = 0; // Reduces on first dimension. 96 #else 97 Eigen::IndexList<Index, Eigen::type2index<1> > rest_by_one; 98 rest_by_one.set(0, rest_size); 99 Eigen::IndexList<Eigen::type2index<1>, Index> one_by_depth; 100 one_by_depth.set(1, depth); 101 Eigen::IndexList<Eigen::type2index<0> > reduction_axis; 102 #endif 103 104 // db = out_backprop 105 // 106 // dg = out_backprop * ((x - m) * rsqrt(v + epsilon)) 107 // 108 // dv = sum_over_rest(out_backprop * gamma * (x - m)) * 109 // (-1/2) * (v + epsilon) ^ (-3/2) 110 // 111 // dm = sum_over_rest(out_backprop * gamma) * (-1 / rsqrt(v + epsilon)) 112 // 113 // dx = out_backprop * (gamma * rsqrt(v + epsilon)) 114 db.device(d) = out_backprop.reshape(rest_by_depth).sum(reduction_axis); 115 116 // scratch1 = rsqrt(v + epsilon) 117 scratch1.device(d) = (var + var.constant(variance_epsilon)).rsqrt(); 118 119 // scratch2 = sum_over_rest(out_backprop * (x - m)) 120 scratch2.device(d) = (out_backprop.reshape(rest_by_depth) * 121 (input.reshape(rest_by_depth) - 122 mean.reshape(one_by_depth).broadcast(rest_by_one))) 123 .sum(reduction_axis); 124 125 if (scale_after_normalization) { 126 dx.reshape(rest_by_depth).device(d) = 127 out_backprop.reshape(rest_by_depth) * ((scratch1 * gamma) 128 .eval() 129 .reshape(one_by_depth) 130 .broadcast(rest_by_one)); 131 dm.device(d) = -db * (scratch1 * gamma).eval(); 132 dg.device(d) = scratch2 * scratch1; 133 } else { 134 dx.reshape(rest_by_depth).device(d) = 135 out_backprop.reshape(rest_by_depth) * 136 scratch1.reshape(one_by_depth).broadcast(rest_by_one); 137 dm.device(d) = -db * scratch1; 138 dg.device(d) = dg.constant(static_cast<T>(0.0)); // Gamma is not learned. 139 } 140 141 // scratch1 = - 1/2 * (var + epsilon) ^ (-3/2) 142 scratch1.device(d) = scratch1 * scratch1.constant(static_cast<T>(-0.5f)) / 143 (var + var.constant(variance_epsilon)); 144 145 if (scale_after_normalization) { 146 dv.device(d) = scratch2 * (scratch1 * gamma).eval(); 147 } else { 148 dv.device(d) = scratch2 * scratch1; 149 } 150 } 151 }; 152 153 } // namespace functor 154 } // namespace tensorflow 155 156 #endif // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_ 157