Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #if GOOGLE_CUDA
     17 
     18 #define EIGEN_USE_GPU
     19 
     20 #include "tensorflow/core/framework/register_types.h"
     21 #include "tensorflow/core/kernels/training_ops.h"
     22 
     23 namespace tensorflow {
     24 
     25 typedef Eigen::GpuDevice GPUDevice;
     26 
     27 namespace functor {
     28 template <typename T>
     29 struct ApplyGradientDescent<GPUDevice, T> {
     30   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
     31                   typename TTypes<T>::ConstScalar lr,
     32                   typename TTypes<T>::ConstFlat grad) {
     33     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
     34     bcast[0] = grad.dimension(0);
     35     Eigen::Sizes<1> single;
     36     var.device(d) -= lr.reshape(single).broadcast(bcast) * grad;
     37   }
     38 };
     39 
     40 template <typename T>
     41 struct ApplyAdagrad<GPUDevice, T> {
     42   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
     43                   typename TTypes<T>::Flat accum,
     44                   typename TTypes<T>::ConstScalar lr,
     45                   typename TTypes<T>::ConstFlat grad) {
     46     accum.device(d) += grad.square();
     47     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
     48     bcast[0] = grad.dimension(0);
     49     Eigen::Sizes<1> single;
     50     var.device(d) -= lr.reshape(single).broadcast(bcast) * grad * accum.rsqrt();
     51   }
     52 };
     53 
     54 template <typename T>
     55 struct ApplyAdadelta<GPUDevice, T> {
     56   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
     57                   typename TTypes<T>::Flat accum,
     58                   typename TTypes<T>::Flat accum_update,
     59                   typename TTypes<T>::ConstScalar lr,
     60                   typename TTypes<T>::ConstScalar rho,
     61                   typename TTypes<T>::ConstScalar epsilon,
     62                   typename TTypes<T>::ConstFlat grad) {
     63     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
     64     bcast[0] = grad.dimension(0);
     65     Eigen::Sizes<1> single;
     66 
     67     accum.device(d) = accum * rho.reshape(single).broadcast(bcast) +
     68                       grad.square() * (grad.constant(T(1)) -
     69                                        rho.reshape(single).broadcast(bcast));
     70     const auto update =
     71         (accum_update + epsilon.reshape(single).broadcast(bcast)).sqrt() *
     72         (accum + epsilon.reshape(single).broadcast(bcast)).rsqrt() * grad;
     73     var.device(d) -= update * lr.reshape(single).broadcast(bcast);
     74     accum_update.device(d) =
     75         accum_update * rho.reshape(single).broadcast(bcast) +
     76         update.square() *
     77             (grad.constant(T(1)) - rho.reshape(single).broadcast(bcast));
     78   }
     79 };
     80 
     81 template <typename T>
     82 struct ApplyMomentum<GPUDevice, T> {
     83   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
     84                   typename TTypes<T>::Flat accum,
     85                   typename TTypes<T>::ConstScalar lr,
     86                   typename TTypes<T>::ConstFlat grad,
     87                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov) {
     88     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
     89     bcast[0] = grad.dimension(0);
     90     Eigen::Sizes<1> single;
     91     accum.device(d) = accum * momentum.reshape(single).broadcast(bcast) + grad;
     92     if (use_nesterov) {
     93       var.device(d) -= grad * lr.reshape(single).broadcast(bcast) +
     94                        accum * momentum.reshape(single).broadcast(bcast) *
     95                            lr.reshape(single).broadcast(bcast);
     96     } else {
     97       var.device(d) -= lr.reshape(single).broadcast(bcast) * accum;
     98     }
     99   }
    100 };
    101 
    102 template <typename T>
    103 struct ApplyAdam<GPUDevice, T> {
    104   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
    105                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
    106                   typename TTypes<T>::ConstScalar beta1_power,
    107                   typename TTypes<T>::ConstScalar beta2_power,
    108                   typename TTypes<T>::ConstScalar lr,
    109                   typename TTypes<T>::ConstScalar beta1,
    110                   typename TTypes<T>::ConstScalar beta2,
    111                   typename TTypes<T>::ConstScalar epsilon,
    112                   typename TTypes<T>::ConstFlat grad, bool use_nesterov) {
    113     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
    114     bcast[0] = grad.dimension(0);
    115     Eigen::Sizes<1> single;
    116     const auto one = static_cast<T>(1.0);
    117     m.device(d) =
    118         m + (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
    119                 (grad - m);
    120     v.device(d) =
    121         v + (beta2.constant(one) - beta2).reshape(single).broadcast(bcast) *
    122                 (grad.square() - v);
    123 
    124     if (use_nesterov) {
    125       var.device(d) -=
    126           (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
    127            (beta1_power.constant(one) - beta1_power))
    128               .reshape(single)
    129               .broadcast(bcast) *
    130           (m * beta1.reshape(single).broadcast(bcast) +
    131            (beta1.constant(one) - beta1).reshape(single).broadcast(bcast) *
    132                grad) /
    133           (epsilon.reshape(single).broadcast(bcast) + v.sqrt());
    134     } else {
    135       var.device(d) -= (lr * (beta2_power.constant(one) - beta2_power).sqrt() /
    136                         (beta1_power.constant(one) - beta1_power))
    137                            .reshape(single)
    138                            .broadcast(bcast) *
    139                        m /
    140                        (epsilon.reshape(single).broadcast(bcast) + v.sqrt());
    141     }
    142   }
    143 };
    144 
    145 template <typename T>
    146 struct ApplyRMSProp<GPUDevice, T> {
    147   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
    148                   typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
    149                   typename TTypes<T>::ConstScalar lr,
    150                   typename TTypes<T>::ConstScalar rho,
    151                   typename TTypes<T>::ConstScalar momentum,
    152                   typename TTypes<T>::ConstScalar epsilon,
    153                   typename TTypes<T>::ConstFlat grad) {
    154     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
    155     bcast[0] = grad.dimension(0);
    156     Eigen::Sizes<1> single;
    157     const auto one = static_cast<T>(1.0);
    158     ms.device(d) =
    159         ms + (rho.constant(one) - rho).reshape(single).broadcast(bcast) *
    160                  (grad.square() - ms);
    161     mom.device(d) =
    162         mom * momentum.reshape(single).broadcast(bcast) +
    163         lr.reshape(single).broadcast(bcast) * grad /
    164             ((epsilon.reshape(single).broadcast(bcast) + ms).sqrt());
    165     var.device(d) -= mom;
    166   }
    167 };
    168 
    169 template <typename T>
    170 struct ApplyCenteredRMSProp<GPUDevice, T> {
    171   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
    172                   typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
    173                   typename TTypes<T>::Flat mom,
    174                   typename TTypes<T>::ConstScalar lr,
    175                   typename TTypes<T>::ConstScalar rho,
    176                   typename TTypes<T>::ConstScalar momentum,
    177                   typename TTypes<T>::ConstScalar epsilon,
    178                   typename TTypes<T>::ConstFlat grad) {
    179     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
    180     bcast[0] = grad.dimension(0);
    181     Eigen::Sizes<1> single;
    182     const auto one = static_cast<T>(1.0);
    183     const auto one_minus_rho =
    184         (rho.constant(one) - rho).reshape(single).broadcast(bcast);
    185     ms.device(d) = ms + one_minus_rho * (grad.square() - ms);
    186     mg.device(d) = mg + one_minus_rho * (grad - mg);
    187     auto denom = (ms - mg.square()) + epsilon.reshape(single).broadcast(bcast);
    188     mom.device(d) = mom * momentum.reshape(single).broadcast(bcast) +
    189                     lr.reshape(single).broadcast(bcast) * grad / denom.sqrt();
    190     var.device(d) -= mom;
    191   }
    192 };
    193 
    194 template <typename T>
    195 struct ApplyAddSign<GPUDevice, T> {
    196   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
    197                   typename TTypes<T>::Flat m,
    198                   typename TTypes<T>::ConstScalar lr,
    199                   typename TTypes<T>::ConstScalar alpha,
    200                   typename TTypes<T>::ConstScalar sign_decay,
    201                   typename TTypes<T>::ConstScalar beta,
    202                   typename TTypes<T>::ConstFlat grad) {
    203     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
    204     bcast[0] = grad.dimension(0);
    205     Eigen::Sizes<1> single;
    206 
    207     // The following is the GPU equivalent of the CPU version:
    208     // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
    209     const auto one = static_cast<T>(1.0);
    210     auto beta_bcast = beta.reshape(single).broadcast(bcast);
    211     auto one_minus_beta =
    212         (beta.constant(one) - beta).reshape(single).broadcast(bcast);
    213     m.device(d) = m * beta_bcast + grad * one_minus_beta;
    214 
    215     // The following is the GPU equivalent of the CPU version:
    216     // var.device(d) -= lr() * (alpha() + sign_decay() * sign_gm) * grad;
    217     auto sign_gm = grad.sign() * m.sign();
    218     auto lr_bcast = lr.reshape(single).broadcast(bcast);
    219     auto alpha_bcast = alpha.reshape(single).broadcast(bcast);
    220     auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
    221     var.device(d) -=
    222         lr_bcast * (alpha_bcast + sign_decay_bcast * sign_gm) * grad;
    223   }
    224 };
    225 
    226 template <typename T>
    227 struct ApplyPowerSign<GPUDevice, T> {
    228   void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
    229                   typename TTypes<T>::Flat m,
    230                   typename TTypes<T>::ConstScalar lr,
    231                   typename TTypes<T>::ConstScalar logbase,
    232                   typename TTypes<T>::ConstScalar sign_decay,
    233                   typename TTypes<T>::ConstScalar beta,
    234                   typename TTypes<T>::ConstFlat grad) {
    235     Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
    236     bcast[0] = grad.dimension(0);
    237     Eigen::Sizes<1> single;
    238 
    239     // The following is the GPU equivalent of the CPU version:
    240     // m.device(d) = m * beta() + grad * (static_cast<T>(1) - beta());
    241     const auto one = static_cast<T>(1.0);
    242     auto beta_bcast = beta.reshape(single).broadcast(bcast);
    243     auto one_minus_beta =
    244         (beta.constant(one) - beta).reshape(single).broadcast(bcast);
    245     m.device(d) = m * beta_bcast + grad * one_minus_beta;
    246 
    247     // The following is the GPU equivalent of the CPU version:
    248     // auto grad_scale = (logbase() * sign_decay() * sign_gm).exp();
    249     // var.device(d) -= lr() * grad_scale * grad;
    250     auto sign_gm = grad.sign() * m.sign();
    251     auto lr_bcast = lr.reshape(single).broadcast(bcast);
    252     auto logbase_bcast = logbase.reshape(single).broadcast(bcast);
    253     auto sign_decay_bcast = sign_decay.reshape(single).broadcast(bcast);
    254     auto grad_scale = (logbase_bcast * sign_decay_bcast * sign_gm).exp();
    255     var.device(d) -= lr_bcast * grad_scale * grad;
    256   }
    257 };
    258 
    259 }  // namespace functor
    260 
    261 template struct functor::ApplyGradientDescent<GPUDevice, Eigen::half>;
    262 template struct functor::ApplyGradientDescent<GPUDevice, float>;
    263 template struct functor::ApplyGradientDescent<GPUDevice, double>;
    264 
    265 template struct functor::ApplyAdagrad<GPUDevice, Eigen::half>;
    266 template struct functor::ApplyAdagrad<GPUDevice, float>;
    267 template struct functor::ApplyAdagrad<GPUDevice, double>;
    268 
    269 template struct functor::ApplyAdadelta<GPUDevice, Eigen::half>;
    270 template struct functor::ApplyAdadelta<GPUDevice, float>;
    271 template struct functor::ApplyAdadelta<GPUDevice, double>;
    272 
    273 template struct functor::ApplyMomentum<GPUDevice, Eigen::half>;
    274 template struct functor::ApplyMomentum<GPUDevice, float>;
    275 template struct functor::ApplyMomentum<GPUDevice, double>;
    276 
    277 template struct functor::ApplyAdam<GPUDevice, Eigen::half>;
    278 template struct functor::ApplyAdam<GPUDevice, float>;
    279 template struct functor::ApplyAdam<GPUDevice, double>;
    280 
    281 template struct functor::ApplyRMSProp<GPUDevice, Eigen::half>;
    282 template struct functor::ApplyRMSProp<GPUDevice, float>;
    283 template struct functor::ApplyRMSProp<GPUDevice, double>;
    284 
    285 template struct functor::ApplyCenteredRMSProp<GPUDevice, Eigen::half>;
    286 template struct functor::ApplyCenteredRMSProp<GPUDevice, float>;
    287 template struct functor::ApplyCenteredRMSProp<GPUDevice, double>;
    288 
    289 template struct functor::ApplyAddSign<GPUDevice, Eigen::half>;
    290 template struct functor::ApplyAddSign<GPUDevice, float>;
    291 template struct functor::ApplyAddSign<GPUDevice, double>;
    292 
    293 template struct functor::ApplyPowerSign<GPUDevice, Eigen::half>;
    294 template struct functor::ApplyPowerSign<GPUDevice, float>;
    295 template struct functor::ApplyPowerSign<GPUDevice, double>;
    296 
    297 }  // end namespace tensorflow
    298 
    299 #endif  // GOOGLE_CUDA
    300