1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #if GOOGLE_CUDA 17 18 #define EIGEN_USE_GPU 19 20 #include "tensorflow/core/framework/register_types.h" 21 #include "tensorflow/core/kernels/training_ops.h" 22 23 namespace tensorflow { 24 25 typedef Eigen::GpuDevice GPUDevice; 26 27 namespace functor { 28 template <typename T> 29 struct ApplyGradientDescent<GPUDevice, T> { 30 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 31 typename TTypes<T>::ConstScalar lr, 32 typename TTypes<T>::ConstFlat grad) { 33 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 34 bcast[0] = grad.dimension(0); 35 Eigen::Sizes<1> single; 36 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad; 37 } 38 }; 39 40 template <typename T> 41 struct ApplyAdagrad<GPUDevice, T> { 42 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 43 typename TTypes<T>::Flat accum, 44 typename TTypes<T>::ConstScalar lr, 45 typename TTypes<T>::ConstFlat grad) { 46 accum.device(d) += grad.square(); 47 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 48 bcast[0] = grad.dimension(0); 49 Eigen::Sizes<1> single; 50 var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt(); 51 } 52 }; 53 54 template <typename T> 55 struct ApplyAdadelta<GPUDevice, T> { 56 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 57 typename TTypes<T>::Flat accum, 58 typename TTypes<T>::Flat accum_update, 59 typename TTypes<T>::ConstScalar lr, 60 typename TTypes<T>::ConstScalar rho, 61 typename TTypes<T>::ConstScalar epsilon, 62 typename TTypes<T>::ConstFlat grad) { 63 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 64 bcast[0] = grad.dimension(0); 65 Eigen::Sizes<1> single; 66 67 accum.device(d) = accum * rho.reshape(single).broadcast(bcast) + 68 grad.square() * (grad.constant(T(1)) - 69 rho.reshape(single).broadcast(bcast)); 70 const auto update = 71 (accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() * 72 (accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad; 73 var.device(d) -= update * lr.reshape(single).broadcast(bcast); 74 accum_update.device(d) = 75 accum_update * rho.reshape(single).broadcast(bcast) + 76 update.square() * 77 (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast)); 78 } 79 }; 80 81 template <typename T> 82 struct ApplyMomentum<GPUDevice, T> { 83 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 84 typename TTypes<T>::Flat accum, 85 typename TTypes<T>::ConstScalar lr, 86 typename TTypes<T>::ConstFlat grad, 87 typename TTypes<T>::ConstScalar momentum, bool use_nesterov) { 88 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 89 bcast[0] = grad.dimension(0); 90 Eigen::Sizes<1> single; 91 accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad; 92 if (use_nesterov) { 93 var.device(d) -= grad * lr.reshape(single).broadcast(bcast) + 94 accum * momentum.reshape(single).broadcast(bcast) * 95 lr.reshape(single).broadcast(bcast); 96 } else { 97 var.device(d) -= lr.reshape(single).broadcast(bcast) * accum; 98 } 99 } 100 }; 101 102 template <typename T> 103 struct ApplyAdam<GPUDevice, T> { 104 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 105 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, 106 typename TTypes<T>::ConstScalar beta1_power, 107 typename TTypes<T>::ConstScalar beta2_power, 108 typename TTypes<T>::ConstScalar lr, 109 typename TTypes<T>::ConstScalar beta1, 110 typename TTypes<T>::ConstScalar beta2, 111 typename TTypes<T>::ConstScalar epsilon, 112 typename TTypes<T>::ConstFlat grad, bool use_nesterov) { 113 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 114 bcast[0] = grad.dimension(0); 115 Eigen::Sizes<1> single; 116 const auto one = static_cast<T>(1.0); 117 m.device(d) = 118 m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * 119 (grad - m); 120 v.device(d) = 121 v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) * 122 (grad.square() - v); 123 124 if (use_nesterov) { 125 var.device(d) -= 126 (lr * (beta2_power.constant(one) - beta2_power).sqrt() / 127 (beta1_power.constant(one) - beta1_power)) 128 .reshape(single) 129 .broadcast(bcast) * 130 (m * beta1.reshape(single).broadcast(bcast) + 131 (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) * 132 grad) / 133 (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); 134 } else { 135 var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() / 136 (beta1_power.constant(one) - beta1_power)) 137 .reshape(single) 138 .broadcast(bcast) * 139 m / 140 (epsilon.reshape(single).broadcast(bcast) + v.sqrt()); 141 } 142 } 143 }; 144 145 template <typename T> 146 struct ApplyRMSProp<GPUDevice, T> { 147 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 148 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, 149 typename TTypes<T>::ConstScalar lr, 150 typename TTypes<T>::ConstScalar rho, 151 typename TTypes<T>::ConstScalar momentum, 152 typename TTypes<T>::ConstScalar epsilon, 153 typename TTypes<T>::ConstFlat grad) { 154 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 155 bcast[0] = grad.dimension(0); 156 Eigen::Sizes<1> single; 157 const auto one = static_cast<T>(1.0); 158 ms.device(d) = 159 ms + (rho.constant(one) - rho).reshape(single).broadcast(bcast) * 160 (grad.square() - ms); 161 mom.device(d) = 162 mom * momentum.reshape(single).broadcast(bcast) + 163 lr.reshape(single).broadcast(bcast) * grad / 164 ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt()); 165 var.device(d) -= mom; 166 } 167 }; 168 169 template <typename T> 170 struct ApplyCenteredRMSProp<GPUDevice, T> { 171 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 172 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, 173 typename TTypes<T>::Flat mom, 174 typename TTypes<T>::ConstScalar lr, 175 typename TTypes<T>::ConstScalar rho, 176 typename TTypes<T>::ConstScalar momentum, 177 typename TTypes<T>::ConstScalar epsilon, 178 typename TTypes<T>::ConstFlat grad) { 179 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 180 bcast[0] = grad.dimension(0); 181 Eigen::Sizes<1> single; 182 const auto one = static_cast<T>(1.0); 183 const auto one_minus_rho = 184 (rho.constant(one) - rho).reshape(single).broadcast(bcast); 185 ms.device(d) = ms + one_minus_rho * (grad.square() - ms); 186 mg.device(d) = mg + one_minus_rho * (grad - mg); 187 auto denom = (ms - mg.square()) + epsilon.reshape(single).broadcast(bcast); 188 mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) + 189 lr.reshape(single).broadcast(bcast) * grad / denom.sqrt(); 190 var.device(d) -= mom; 191 } 192 }; 193 194 template <typename T> 195 struct ApplyAddSign<GPUDevice, T> { 196 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 197 typename TTypes<T>::Flat m, 198 typename TTypes<T>::ConstScalar lr, 199 typename TTypes<T>::ConstScalar alpha, 200 typename TTypes<T>::ConstScalar sign_decay, 201 typename TTypes<T>::ConstScalar beta, 202 typename TTypes<T>::ConstFlat grad) { 203 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 204 bcast[0] = grad.dimension(0); 205 Eigen::Sizes<1> single; 206 207 // The following is the GPU equivalent of the CPU version: 208 // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta()); 209 const auto one = static_cast<T>(1.0); 210 auto beta_bcast = beta.reshape(single).broadcast(bcast); 211 auto one_minus_beta = 212 (beta.constant(one) - beta).reshape(single).broadcast(bcast); 213 m.device(d) = m * beta_bcast + grad * one_minus_beta; 214 215 // The following is the GPU equivalent of the CPU version: 216 // var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad; 217 auto sign_gm = grad.sign() * m.sign(); 218 auto lr_bcast = lr.reshape(single).broadcast(bcast); 219 auto alpha_bcast = alpha.reshape(single).broadcast(bcast); 220 auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast); 221 var.device(d) -= 222 lr_bcast * (alpha_bcast + sign_decay_bcast * sign_gm) * grad; 223 } 224 }; 225 226 template <typename T> 227 struct ApplyPowerSign<GPUDevice, T> { 228 void operator()(const GPUDevice& d, typename TTypes<T>::Flat var, 229 typename TTypes<T>::Flat m, 230 typename TTypes<T>::ConstScalar lr, 231 typename TTypes<T>::ConstScalar logbase, 232 typename TTypes<T>::ConstScalar sign_decay, 233 typename TTypes<T>::ConstScalar beta, 234 typename TTypes<T>::ConstFlat grad) { 235 Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast; 236 bcast[0] = grad.dimension(0); 237 Eigen::Sizes<1> single; 238 239 // The following is the GPU equivalent of the CPU version: 240 // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta()); 241 const auto one = static_cast<T>(1.0); 242 auto beta_bcast = beta.reshape(single).broadcast(bcast); 243 auto one_minus_beta = 244 (beta.constant(one) - beta).reshape(single).broadcast(bcast); 245 m.device(d) = m * beta_bcast + grad * one_minus_beta; 246 247 // The following is the GPU equivalent of the CPU version: 248 // auto grad_scale = (logbase() * sign_decay() * sign_gm).exp(); 249 // var.device(d) -= lr() * grad_scale * grad; 250 auto sign_gm = grad.sign() * m.sign(); 251 auto lr_bcast = lr.reshape(single).broadcast(bcast); 252 auto logbase_bcast = logbase.reshape(single).broadcast(bcast); 253 auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast); 254 auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp(); 255 var.device(d) -= lr_bcast * grad_scale * grad; 256 } 257 }; 258 259 } // namespace functor 260 261 template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>; 262 template struct functor::ApplyGradientDescent<GPUDevice, float>; 263 template struct functor::ApplyGradientDescent<GPUDevice, double>; 264 265 template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>; 266 template struct functor::ApplyAdagrad<GPUDevice, float>; 267 template struct functor::ApplyAdagrad<GPUDevice, double>; 268 269 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>; 270 template struct functor::ApplyAdadelta<GPUDevice, float>; 271 template struct functor::ApplyAdadelta<GPUDevice, double>; 272 273 template struct functor::ApplyMomentum<GPUDevice, Eigen::half>; 274 template struct functor::ApplyMomentum<GPUDevice, float>; 275 template struct functor::ApplyMomentum<GPUDevice, double>; 276 277 template struct functor::ApplyAdam<GPUDevice, Eigen::half>; 278 template struct functor::ApplyAdam<GPUDevice, float>; 279 template struct functor::ApplyAdam<GPUDevice, double>; 280 281 template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>; 282 template struct functor::ApplyRMSProp<GPUDevice, float>; 283 template struct functor::ApplyRMSProp<GPUDevice, double>; 284 285 template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>; 286 template struct functor::ApplyCenteredRMSProp<GPUDevice, float>; 287 template struct functor::ApplyCenteredRMSProp<GPUDevice, double>; 288 289 template struct functor::ApplyAddSign<GPUDevice, Eigen::half>; 290 template struct functor::ApplyAddSign<GPUDevice, float>; 291 template struct functor::ApplyAddSign<GPUDevice, double>; 292 293 template struct functor::ApplyPowerSign<GPUDevice, Eigen::half>; 294 template struct functor::ApplyPowerSign<GPUDevice, float>; 295 template struct functor::ApplyPowerSign<GPUDevice, double>; 296 297 } // end namespace tensorflow 298 299 #endif // GOOGLE_CUDA 300