Home | History | Annotate | Download | only in kernels
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_
     17 #define TENSORFLOW_KERNELS_TRAINING_OPS_H_
     18 
     19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
     20 #include "tensorflow/core/framework/tensor_types.h"
     21 #include "tensorflow/core/platform/types.h"
     22 
     23 namespace tensorflow {
     24 namespace functor {
     25 
     26 // Each training algorithm has a ApplyXYZ functor struct declared in
     27 // this header file. They are specialized for different devices
     28 // (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc).
     29 
     30 template <typename Device, typename T>
     31 struct ApplyGradientDescent {
     32   void operator()(const Device& d, typename TTypes<T>::Flat var,
     33                   typename TTypes<T>::ConstScalar alpha,
     34                   typename TTypes<T>::ConstFlat delta);
     35 };
     36 
     37 template <typename Device, typename T>
     38 struct ApplyAdadelta {
     39   void operator()(const Device& d, typename TTypes<T>::Flat var,
     40                   typename TTypes<T>::Flat accum,
     41                   typename TTypes<T>::Flat accum_update,
     42                   typename TTypes<T>::ConstScalar lr,
     43                   typename TTypes<T>::ConstScalar rho,
     44                   typename TTypes<T>::ConstScalar epsilon,
     45                   typename TTypes<T>::ConstFlat grad);
     46 };
     47 
     48 template <typename Device, typename T>
     49 struct FobosElasticNet {
     50   void operator()(const Device& d, typename TTypes<T>::Flat var,
     51                   typename TTypes<T>::ConstScalar lr,
     52                   typename TTypes<T>::ConstScalar l1,
     53                   typename TTypes<T>::ConstScalar l2,
     54                   typename TTypes<T>::ConstFlat grad);
     55 };
     56 
     57 template <typename Device, typename T>
     58 struct ApplyProximalGradientDescent {
     59   void operator()(const Device& d, typename TTypes<T>::Flat var,
     60                   typename TTypes<T>::ConstScalar lr,
     61                   typename TTypes<T>::ConstScalar l1,
     62                   typename TTypes<T>::ConstScalar l2,
     63                   typename TTypes<T>::ConstFlat grad);
     64 };
     65 
     66 template <typename Device, typename T>
     67 struct ApplyAdagrad {
     68   void operator()(const Device& d, typename TTypes<T>::Flat var,
     69                   typename TTypes<T>::Flat accum,
     70                   typename TTypes<T>::ConstScalar lr,
     71                   typename TTypes<T>::ConstFlat grad);
     72 };
     73 
     74 template <typename Device, typename T>
     75 struct ApplyAdagradDA {
     76   void operator()(const Device& d, typename TTypes<T>::Flat var,
     77                   typename TTypes<T>::Flat gradient_accum,
     78                   typename TTypes<T>::Flat gradient_squared_accum,
     79                   typename TTypes<T>::ConstScalar lr, int64 global_step,
     80                   typename TTypes<T>::ConstScalar l1,
     81                   typename TTypes<T>::ConstScalar l2,
     82                   typename TTypes<T>::ConstFlat grad);
     83 };
     84 
     85 template <typename Device, typename T>
     86 struct ApplyProximalAdagrad {
     87   void operator()(const Device& d, typename TTypes<T>::Flat var,
     88                   typename TTypes<T>::Flat accum,
     89                   typename TTypes<T>::ConstScalar lr,
     90                   typename TTypes<T>::ConstScalar l1,
     91                   typename TTypes<T>::ConstScalar l2,
     92                   typename TTypes<T>::ConstFlat grad);
     93 };
     94 
     95 template <typename Device, typename T>
     96 struct ApplyFtrl {
     97   void operator()(const Device& d, typename TTypes<T>::Flat var,
     98                   typename TTypes<T>::Flat accum,
     99                   typename TTypes<T>::Flat linear,
    100                   typename TTypes<T>::ConstFlat grad,
    101                   typename TTypes<T>::ConstScalar lr,
    102                   typename TTypes<T>::ConstScalar l1,
    103                   typename TTypes<T>::ConstScalar l2,
    104                   typename TTypes<T>::ConstScalar lr_power);
    105 };
    106 
    107 template <typename Device, typename T>
    108 struct ApplyFtrlV2 {
    109   void operator()(const Device& d, typename TTypes<T>::Flat var,
    110                   typename TTypes<T>::Flat accum,
    111                   typename TTypes<T>::Flat linear,
    112                   typename TTypes<T>::ConstFlat grad,
    113                   typename TTypes<T>::ConstScalar lr,
    114                   typename TTypes<T>::ConstScalar l1,
    115                   typename TTypes<T>::ConstScalar l2,
    116                   typename TTypes<T>::ConstScalar l2_shrinkage,
    117                   typename TTypes<T>::ConstScalar lr_power);
    118 };
    119 
    120 template <typename Device, typename T>
    121 struct ApplyMomentum {
    122   void operator()(const Device& d, typename TTypes<T>::Flat var,
    123                   typename TTypes<T>::Flat accum,
    124                   typename TTypes<T>::ConstScalar lr,
    125                   typename TTypes<T>::ConstFlat grad,
    126                   typename TTypes<T>::ConstScalar momentum, bool use_nesterov);
    127 };
    128 
    129 template <typename Device, typename T>
    130 struct ApplyAdam {
    131   void operator()(const Device& d, typename TTypes<T>::Flat var,
    132                   typename TTypes<T>::Flat m, typename TTypes<T>::Flat v,
    133                   typename TTypes<T>::ConstScalar beta1_power,
    134                   typename TTypes<T>::ConstScalar beta2_power,
    135                   typename TTypes<T>::ConstScalar lr,
    136                   typename TTypes<T>::ConstScalar beta1,
    137                   typename TTypes<T>::ConstScalar beta2,
    138                   typename TTypes<T>::ConstScalar epsilon,
    139                   typename TTypes<T>::ConstFlat grad, bool use_nesterov);
    140 };
    141 
    142 template <typename Device, typename T>
    143 struct ApplyRMSProp {
    144   void operator()(const Device& d, typename TTypes<T>::Flat var,
    145                   typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom,
    146                   typename TTypes<T>::ConstScalar lr,
    147                   typename TTypes<T>::ConstScalar rho,
    148                   typename TTypes<T>::ConstScalar momentum,
    149                   typename TTypes<T>::ConstScalar epsilon,
    150                   typename TTypes<T>::ConstFlat grad);
    151 };
    152 
    153 template <typename Device, typename T>
    154 struct ApplyCenteredRMSProp {
    155   void operator()(const Device& d, typename TTypes<T>::Flat var,
    156                   typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms,
    157                   typename TTypes<T>::Flat mom,
    158                   typename TTypes<T>::ConstScalar lr,
    159                   typename TTypes<T>::ConstScalar rho,
    160                   typename TTypes<T>::ConstScalar momentum,
    161                   typename TTypes<T>::ConstScalar epsilon,
    162                   typename TTypes<T>::ConstFlat grad);
    163 };
    164 
    165 template <typename Device, typename T>
    166 struct ApplyAddSign {
    167   void operator()(const Device& d, typename TTypes<T>::Flat var,
    168                   typename TTypes<T>::Flat m,
    169                   typename TTypes<T>::ConstScalar lr,
    170                   typename TTypes<T>::ConstScalar alpha,
    171                   typename TTypes<T>::ConstScalar sign_decay,
    172                   typename TTypes<T>::ConstScalar beta,
    173                   typename TTypes<T>::ConstFlat grad);
    174 };
    175 
    176 template <typename Device, typename T>
    177 struct ApplyPowerSign {
    178   void operator()(const Device& d, typename TTypes<T>::Flat var,
    179                   typename TTypes<T>::Flat m,
    180                   typename TTypes<T>::ConstScalar lr,
    181                   typename TTypes<T>::ConstScalar logbase,
    182                   typename TTypes<T>::ConstScalar sign_decay,
    183                   typename TTypes<T>::ConstScalar beta,
    184                   typename TTypes<T>::ConstFlat grad);
    185 };
    186 
    187 }  // end namespace functor
    188 }  // end namespace tensorflow
    189 
    190 #endif  // TENSORFLOW_KERNELS_TRAINING_OPS_H_
    191