1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_ 17 #define TENSORFLOW_KERNELS_TRAINING_OPS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/tensor_types.h" 21 #include "tensorflow/core/platform/types.h" 22 23 namespace tensorflow { 24 namespace functor { 25 26 // Each training algorithm has a ApplyXYZ functor struct declared in 27 // this header file. They are specialized for different devices 28 // (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc). 29 30 template <typename Device, typename T> 31 struct ApplyGradientDescent { 32 void operator()(const Device& d, typename TTypes<T>::Flat var, 33 typename TTypes<T>::ConstScalar alpha, 34 typename TTypes<T>::ConstFlat delta); 35 }; 36 37 template <typename Device, typename T> 38 struct ApplyAdadelta { 39 void operator()(const Device& d, typename TTypes<T>::Flat var, 40 typename TTypes<T>::Flat accum, 41 typename TTypes<T>::Flat accum_update, 42 typename TTypes<T>::ConstScalar lr, 43 typename TTypes<T>::ConstScalar rho, 44 typename TTypes<T>::ConstScalar epsilon, 45 typename TTypes<T>::ConstFlat grad); 46 }; 47 48 template <typename Device, typename T> 49 struct FobosElasticNet { 50 void operator()(const Device& d, typename TTypes<T>::Flat var, 51 typename TTypes<T>::ConstScalar lr, 52 typename TTypes<T>::ConstScalar l1, 53 typename TTypes<T>::ConstScalar l2, 54 typename TTypes<T>::ConstFlat grad); 55 }; 56 57 template <typename Device, typename T> 58 struct ApplyProximalGradientDescent { 59 void operator()(const Device& d, typename TTypes<T>::Flat var, 60 typename TTypes<T>::ConstScalar lr, 61 typename TTypes<T>::ConstScalar l1, 62 typename TTypes<T>::ConstScalar l2, 63 typename TTypes<T>::ConstFlat grad); 64 }; 65 66 template <typename Device, typename T> 67 struct ApplyAdagrad { 68 void operator()(const Device& d, typename TTypes<T>::Flat var, 69 typename TTypes<T>::Flat accum, 70 typename TTypes<T>::ConstScalar lr, 71 typename TTypes<T>::ConstFlat grad); 72 }; 73 74 template <typename Device, typename T> 75 struct ApplyAdagradDA { 76 void operator()(const Device& d, typename TTypes<T>::Flat var, 77 typename TTypes<T>::Flat gradient_accum, 78 typename TTypes<T>::Flat gradient_squared_accum, 79 typename TTypes<T>::ConstScalar lr, int64 global_step, 80 typename TTypes<T>::ConstScalar l1, 81 typename TTypes<T>::ConstScalar l2, 82 typename TTypes<T>::ConstFlat grad); 83 }; 84 85 template <typename Device, typename T> 86 struct ApplyProximalAdagrad { 87 void operator()(const Device& d, typename TTypes<T>::Flat var, 88 typename TTypes<T>::Flat accum, 89 typename TTypes<T>::ConstScalar lr, 90 typename TTypes<T>::ConstScalar l1, 91 typename TTypes<T>::ConstScalar l2, 92 typename TTypes<T>::ConstFlat grad); 93 }; 94 95 template <typename Device, typename T> 96 struct ApplyFtrl { 97 void operator()(const Device& d, typename TTypes<T>::Flat var, 98 typename TTypes<T>::Flat accum, 99 typename TTypes<T>::Flat linear, 100 typename TTypes<T>::ConstFlat grad, 101 typename TTypes<T>::ConstScalar lr, 102 typename TTypes<T>::ConstScalar l1, 103 typename TTypes<T>::ConstScalar l2, 104 typename TTypes<T>::ConstScalar lr_power); 105 }; 106 107 template <typename Device, typename T> 108 struct ApplyFtrlV2 { 109 void operator()(const Device& d, typename TTypes<T>::Flat var, 110 typename TTypes<T>::Flat accum, 111 typename TTypes<T>::Flat linear, 112 typename TTypes<T>::ConstFlat grad, 113 typename TTypes<T>::ConstScalar lr, 114 typename TTypes<T>::ConstScalar l1, 115 typename TTypes<T>::ConstScalar l2, 116 typename TTypes<T>::ConstScalar l2_shrinkage, 117 typename TTypes<T>::ConstScalar lr_power); 118 }; 119 120 template <typename Device, typename T> 121 struct ApplyMomentum { 122 void operator()(const Device& d, typename TTypes<T>::Flat var, 123 typename TTypes<T>::Flat accum, 124 typename TTypes<T>::ConstScalar lr, 125 typename TTypes<T>::ConstFlat grad, 126 typename TTypes<T>::ConstScalar momentum, bool use_nesterov); 127 }; 128 129 template <typename Device, typename T> 130 struct ApplyAdam { 131 void operator()(const Device& d, typename TTypes<T>::Flat var, 132 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, 133 typename TTypes<T>::ConstScalar beta1_power, 134 typename TTypes<T>::ConstScalar beta2_power, 135 typename TTypes<T>::ConstScalar lr, 136 typename TTypes<T>::ConstScalar beta1, 137 typename TTypes<T>::ConstScalar beta2, 138 typename TTypes<T>::ConstScalar epsilon, 139 typename TTypes<T>::ConstFlat grad, bool use_nesterov); 140 }; 141 142 template <typename Device, typename T> 143 struct ApplyRMSProp { 144 void operator()(const Device& d, typename TTypes<T>::Flat var, 145 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, 146 typename TTypes<T>::ConstScalar lr, 147 typename TTypes<T>::ConstScalar rho, 148 typename TTypes<T>::ConstScalar momentum, 149 typename TTypes<T>::ConstScalar epsilon, 150 typename TTypes<T>::ConstFlat grad); 151 }; 152 153 template <typename Device, typename T> 154 struct ApplyCenteredRMSProp { 155 void operator()(const Device& d, typename TTypes<T>::Flat var, 156 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, 157 typename TTypes<T>::Flat mom, 158 typename TTypes<T>::ConstScalar lr, 159 typename TTypes<T>::ConstScalar rho, 160 typename TTypes<T>::ConstScalar momentum, 161 typename TTypes<T>::ConstScalar epsilon, 162 typename TTypes<T>::ConstFlat grad); 163 }; 164 165 template <typename Device, typename T> 166 struct ApplyAddSign { 167 void operator()(const Device& d, typename TTypes<T>::Flat var, 168 typename TTypes<T>::Flat m, 169 typename TTypes<T>::ConstScalar lr, 170 typename TTypes<T>::ConstScalar alpha, 171 typename TTypes<T>::ConstScalar sign_decay, 172 typename TTypes<T>::ConstScalar beta, 173 typename TTypes<T>::ConstFlat grad); 174 }; 175 176 template <typename Device, typename T> 177 struct ApplyPowerSign { 178 void operator()(const Device& d, typename TTypes<T>::Flat var, 179 typename TTypes<T>::Flat m, 180 typename TTypes<T>::ConstScalar lr, 181 typename TTypes<T>::ConstScalar logbase, 182 typename TTypes<T>::ConstScalar sign_decay, 183 typename TTypes<T>::ConstScalar beta, 184 typename TTypes<T>::ConstFlat grad); 185 }; 186 187 } // end namespace functor 188 } // end namespace tensorflow 189 190 #endif // TENSORFLOW_KERNELS_TRAINING_OPS_H_ 191