1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_XENT_OP_H_ 17 #define TENSORFLOW_KERNELS_XENT_OP_H_ 18 // Functor definition for SparseXentOp, must be compilable by nvcc. 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/tensor_types.h" 22 #include "tensorflow/core/kernels/bounds_check.h" 23 #include "tensorflow/core/platform/macros.h" 24 #include "tensorflow/core/platform/types.h" 25 26 namespace tensorflow { 27 28 namespace sparse_xent_helpers { 29 30 template <typename T> 31 typename TTypes<const T, 1>::Tensor32Bit To32BitConst( 32 typename TTypes<T>::Vec in) { 33 return To32Bit(typename TTypes<T>::ConstVec(in.data(), in.dimensions())); 34 } 35 36 template <typename T> 37 typename TTypes<const T, 2>::Tensor32Bit To32BitConst( 38 typename TTypes<T>::Matrix in) { 39 return To32Bit(typename TTypes<T>::ConstMatrix(in.data(), in.dimensions())); 40 } 41 42 } // namespace sparse_xent_helpers 43 44 namespace generator { 45 46 // Generator for calculation of the sparse Xent loss. 47 // This generator takes the logits, the sum of the exponentiated 48 // logits, and the label indices. For each minibatch entry, ignoring 49 // the batch index b, it calculates: 50 // 51 // loss[j] = (log(sum_exp_logits) - logits[j]) * 1{ j == label } 52 // 53 // for j = 0 .. num_classes. This value must be summed over all j for 54 // the final loss. 55 template <typename T, typename Index> 56 class SparseXentLossGenerator { 57 public: 58 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentLossGenerator( 59 typename TTypes<const T, 2>::Tensor32Bit logits, 60 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits, 61 typename TTypes<const Index, 1>::Tensor32Bit labels, 62 const Index max_depth) 63 : logits_(logits), 64 sum_exp_logits_(sum_exp_logits), 65 labels_(labels), 66 max_depth_(max_depth) {} 67 68 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T 69 operator()(const Eigen::array<int, 2>& coords) const { 70 const int batch = coords[0]; 71 const int depth = coords[1]; 72 const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch)); 73 if (!FastBoundsCheck(label, max_depth_)) { 74 return Eigen::NumTraits<T>::quiet_NaN(); 75 } 76 return TF_PREDICT_FALSE(label == depth) 77 ? (Eigen::numext::log(sum_exp_logits_(batch)) - logits_(coords)) 78 : T(0.0); 79 }; 80 81 private: 82 typename TTypes<const T, 2>::Tensor32Bit logits_; 83 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_; 84 typename TTypes<const Index, 1>::Tensor32Bit labels_; 85 const Index max_depth_; 86 }; 87 88 // Generator for calculation of the sparse Xent gradient. 89 // This generator takes the exponentiated logits, their sums, and the label 90 // indices. For each minibatch entry, ignoring the batch index b, it calculates: 91 // 92 // exp_logits[j] / sum_exp_logits - 1{ j == label } 93 // 94 // for j = 0 .. num_classes. 95 template <typename T, typename Index> 96 class SparseXentGradGenerator { 97 public: 98 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseXentGradGenerator( 99 typename TTypes<const T, 2>::Tensor32Bit exp_logits, 100 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits, 101 typename TTypes<const Index, 1>::Tensor32Bit labels, 102 const Index max_depth) 103 : exp_logits_(exp_logits), 104 sum_exp_logits_(sum_exp_logits), 105 labels_(labels), 106 max_depth_(max_depth) {} 107 108 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T 109 operator()(const Eigen::array<int, 2>& coords) const { 110 const int batch = coords[0]; 111 const int depth = coords[1]; 112 const Index label = tensorflow::internal::SubtleMustCopy(labels_(batch)); 113 if (!FastBoundsCheck(label, max_depth_)) { 114 return Eigen::NumTraits<T>::quiet_NaN(); 115 } 116 T subtract = TF_PREDICT_FALSE(depth == label) ? T(1.0) : T(0.0); 117 return exp_logits_(coords) / sum_exp_logits_(batch) - subtract; 118 }; 119 120 private: 121 typename TTypes<const T, 2>::Tensor32Bit exp_logits_; 122 typename TTypes<const T, 1>::Tensor32Bit sum_exp_logits_; 123 typename TTypes<const Index, 1>::Tensor32Bit labels_; 124 const Index max_depth_; 125 }; 126 127 } // namespace generator 128 129 namespace functor { 130 131 // Functor used by SparseXentOp to do the computations. 132 template <typename Device, typename T, typename Index> 133 struct SparseXentFunctor { 134 // Computes Cross Entropy loss and backprop. 135 // 136 // logits: batch_size, num_classes. 137 // labels: num_classes. 138 // scratch: temporary tensor, dims: batch_size, 1 139 // loss: output tensor for the loss, dims: batch_size. 140 // backprop: output tensor for the backprop, dims: batch_size, num_classes. 141 void operator()(const Device& d, typename TTypes<T>::ConstMatrix logits, 142 typename TTypes<Index>::ConstVec labels, 143 typename TTypes<T>::Vec scratch, typename TTypes<T>::Vec loss, 144 typename TTypes<T>::Matrix backprop); 145 }; 146 147 // Eigen code implementing SparseXentFunctor::operator(). 148 // This code works for both CPU and GPU and is used by the functor 149 // specializations for both device types. 150 template <typename Device, typename T, typename Index> 151 struct SparseXentEigenImpl { 152 static void Compute(const Device& d, typename TTypes<T>::ConstMatrix logits, 153 typename TTypes<Index>::ConstVec labels, 154 typename TTypes<T>::Vec scratch, 155 typename TTypes<T>::Vec loss, 156 typename TTypes<T>::Matrix backprop) { 157 // NOTE(touts): This duplicates some of the computations in softmax_op 158 // because we need the intermediate (logits -max(logits)) values to 159 // avoid a log(exp()) in the computation of the loss. 160 161 const int kBatchDim = 0; 162 const int kClassDim = 1; 163 164 const int batch_size = logits.dimension(kBatchDim); 165 const int num_classes = logits.dimension(kClassDim); 166 167 // These arrays are used to reduce along the class dimension, and broadcast 168 // the resulting value to all classes. 169 #if !defined(EIGEN_HAS_INDEX_LIST) 170 Eigen::array<int, 1> along_class; 171 along_class[0] = kClassDim; 172 Eigen::array<int, 1> batch_only; 173 batch_only[0] = batch_size; 174 Eigen::array<int, 2> batch_by_one; 175 batch_by_one[0] = batch_size; 176 batch_by_one[1] = 1; 177 Eigen::array<int, 2> one_by_class; 178 one_by_class[0] = 1; 179 one_by_class[1] = num_classes; 180 #else 181 Eigen::IndexList<Eigen::type2index<kClassDim> > along_class; 182 Eigen::IndexList<int, Eigen::type2index<1> > batch_by_one; 183 batch_by_one.set(0, batch_size); 184 Eigen::IndexList<int> batch_only; 185 batch_only.set(0, batch_size); 186 Eigen::IndexList<Eigen::type2index<1>, int> one_by_class; 187 one_by_class.set(1, num_classes); 188 #endif 189 190 // scratch = max_logits along classes. 191 To32Bit(scratch).device(d) = To32Bit(logits).maximum(along_class); 192 193 // backprop = logits - max_logits. 194 To32Bit(backprop).device(d) = 195 To32Bit(logits) - 196 To32Bit(scratch).reshape(batch_by_one).broadcast(one_by_class); 197 198 // scratch = sum(exp(logits - max_logits)) along classes. 199 To32Bit(scratch).device(d) = To32Bit(backprop).exp().sum(along_class); 200 201 // sum(-labels * 202 // ((logits - max_logits) - log(sum(exp(logits - max_logits))))) 203 // along classes 204 generator::SparseXentLossGenerator<T, Index> sparse_xent_loss_gen( 205 sparse_xent_helpers::To32BitConst<T>(backprop), 206 sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels), 207 backprop.dimension(1) /* max_depth */); 208 To32Bit(loss).device(d) = 209 To32Bit(backprop).generate(sparse_xent_loss_gen).sum(along_class); 210 211 // backprop: prob - labels, where 212 // prob = exp(logits - max_logits) / sum(exp(logits - max_logits)) 213 To32Bit(backprop).device(d) = To32Bit(backprop).exp(); 214 generator::SparseXentGradGenerator<T, Index> sparse_xent_grad_gen( 215 sparse_xent_helpers::To32BitConst<T>(backprop), 216 sparse_xent_helpers::To32BitConst<T>(scratch), To32Bit(labels), 217 backprop.dimension(1) /* max_depth */); 218 To32Bit(backprop).device(d) = 219 To32Bit(backprop).generate(sparse_xent_grad_gen); 220 } 221 }; 222 223 } // namespace functor 224 225 } // namespace tensorflow 226 227 #endif // TENSORFLOW_KERNELS_XENT_OP_H_ 228