1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Adam for TensorFlow.""" 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.eager import context 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import control_flow_ops 24 from tensorflow.python.ops import math_ops 25 from tensorflow.python.ops import resource_variable_ops 26 from tensorflow.python.ops import state_ops 27 from tensorflow.python.training import optimizer 28 from tensorflow.python.training import training_ops 29 from tensorflow.python.util.tf_export import tf_export 30 31 32 @tf_export("train.AdamOptimizer") 33 class AdamOptimizer(optimizer.Optimizer): 34 """Optimizer that implements the Adam algorithm. 35 36 See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) 37 ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). 38 """ 39 40 def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, 41 use_locking=False, name="Adam"): 42 """Construct a new Adam optimizer. 43 44 Initialization: 45 46 ``` 47 m_0 <- 0 (Initialize initial 1st moment vector) 48 v_0 <- 0 (Initialize initial 2nd moment vector) 49 t <- 0 (Initialize timestep) 50 ``` 51 52 The update rule for `variable` with gradient `g` uses an optimization 53 described at the end of section2 of the paper: 54 55 ``` 56 t <- t + 1 57 lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) 58 59 m_t <- beta1 * m_{t-1} + (1 - beta1) * g 60 v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g 61 variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) 62 ``` 63 64 The default value of 1e-8 for epsilon might not be a good default in 65 general. For example, when training an Inception network on ImageNet a 66 current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the 67 formulation just before Section 2.1 of the Kingma and Ba paper rather than 68 the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon 69 hat" in the paper. 70 71 The sparse implementation of this algorithm (used when the gradient is an 72 IndexedSlices object, typically because of `tf.gather` or an embedding 73 lookup in the forward pass) does apply momentum to variable slices even if 74 they were not used in the forward pass (meaning they have a gradient equal 75 to zero). Momentum decay (beta1) is also applied to the entire momentum 76 accumulator. This means that the sparse behavior is equivalent to the dense 77 behavior (in contrast to some momentum implementations which ignore momentum 78 unless a variable slice was actually used). 79 80 Args: 81 learning_rate: A Tensor or a floating point value. The learning rate. 82 beta1: A float value or a constant float tensor. 83 The exponential decay rate for the 1st moment estimates. 84 beta2: A float value or a constant float tensor. 85 The exponential decay rate for the 2nd moment estimates. 86 epsilon: A small constant for numerical stability. This epsilon is 87 "epsilon hat" in the Kingma and Ba paper (in the formula just before 88 Section 2.1), not the epsilon in Algorithm 1 of the paper. 89 use_locking: If True use locks for update operations. 90 name: Optional name for the operations created when applying gradients. 91 Defaults to "Adam". 92 """ 93 super(AdamOptimizer, self).__init__(use_locking, name) 94 self._lr = learning_rate 95 self._beta1 = beta1 96 self._beta2 = beta2 97 self._epsilon = epsilon 98 99 # Tensor versions of the constructor arguments, created in _prepare(). 100 self._lr_t = None 101 self._beta1_t = None 102 self._beta2_t = None 103 self._epsilon_t = None 104 105 # Created in SparseApply if needed. 106 self._updated_lr = None 107 108 def _get_beta_accumulators(self): 109 if context.in_graph_mode(): 110 graph = ops.get_default_graph() 111 else: 112 graph = None 113 return (self._get_non_slot_variable("beta1_power", graph=graph), 114 self._get_non_slot_variable("beta2_power", graph=graph)) 115 116 def _create_slots(self, var_list): 117 # Create the beta1 and beta2 accumulators on the same device as the first 118 # variable. Sort the var_list to make sure this device is consistent across 119 # workers (these need to go on the same PS, otherwise some updates are 120 # silently ignored). 121 first_var = min(var_list, key=lambda x: x.name) 122 self._create_non_slot_variable(initial_value=self._beta1, 123 name="beta1_power", 124 colocate_with=first_var) 125 self._create_non_slot_variable(initial_value=self._beta2, 126 name="beta2_power", 127 colocate_with=first_var) 128 129 # Create slots for the first and second moments. 130 for v in var_list: 131 self._zeros_slot(v, "m", self._name) 132 self._zeros_slot(v, "v", self._name) 133 134 def _prepare(self): 135 self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate") 136 self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1") 137 self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2") 138 self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon") 139 140 def _apply_dense(self, grad, var): 141 m = self.get_slot(var, "m") 142 v = self.get_slot(var, "v") 143 beta1_power, beta2_power = self._get_beta_accumulators() 144 return training_ops.apply_adam( 145 var, m, v, 146 math_ops.cast(beta1_power, var.dtype.base_dtype), 147 math_ops.cast(beta2_power, var.dtype.base_dtype), 148 math_ops.cast(self._lr_t, var.dtype.base_dtype), 149 math_ops.cast(self._beta1_t, var.dtype.base_dtype), 150 math_ops.cast(self._beta2_t, var.dtype.base_dtype), 151 math_ops.cast(self._epsilon_t, var.dtype.base_dtype), 152 grad, use_locking=self._use_locking).op 153 154 def _resource_apply_dense(self, grad, var): 155 m = self.get_slot(var, "m") 156 v = self.get_slot(var, "v") 157 beta1_power, beta2_power = self._get_beta_accumulators() 158 return training_ops.resource_apply_adam( 159 var.handle, m.handle, v.handle, 160 math_ops.cast(beta1_power, grad.dtype.base_dtype), 161 math_ops.cast(beta2_power, grad.dtype.base_dtype), 162 math_ops.cast(self._lr_t, grad.dtype.base_dtype), 163 math_ops.cast(self._beta1_t, grad.dtype.base_dtype), 164 math_ops.cast(self._beta2_t, grad.dtype.base_dtype), 165 math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), 166 grad, use_locking=self._use_locking) 167 168 def _apply_sparse_shared(self, grad, var, indices, scatter_add): 169 beta1_power, beta2_power = self._get_beta_accumulators() 170 beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) 171 beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) 172 lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 173 beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 174 beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 175 epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 176 lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 177 # m_t = beta1 * m + (1 - beta1) * g_t 178 m = self.get_slot(var, "m") 179 m_scaled_g_values = grad * (1 - beta1_t) 180 m_t = state_ops.assign(m, m * beta1_t, 181 use_locking=self._use_locking) 182 with ops.control_dependencies([m_t]): 183 m_t = scatter_add(m, indices, m_scaled_g_values) 184 # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 185 v = self.get_slot(var, "v") 186 v_scaled_g_values = (grad * grad) * (1 - beta2_t) 187 v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 188 with ops.control_dependencies([v_t]): 189 v_t = scatter_add(v, indices, v_scaled_g_values) 190 v_sqrt = math_ops.sqrt(v_t) 191 var_update = state_ops.assign_sub(var, 192 lr * m_t / (v_sqrt + epsilon_t), 193 use_locking=self._use_locking) 194 return control_flow_ops.group(*[var_update, m_t, v_t]) 195 196 def _apply_sparse(self, grad, var): 197 return self._apply_sparse_shared( 198 grad.values, var, grad.indices, 199 lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 200 x, i, v, use_locking=self._use_locking)) 201 202 def _resource_scatter_add(self, x, i, v): 203 with ops.control_dependencies( 204 [resource_variable_ops.resource_scatter_add( 205 x.handle, i, v)]): 206 return x.value() 207 208 def _resource_apply_sparse(self, grad, var, indices): 209 return self._apply_sparse_shared( 210 grad, var, indices, self._resource_scatter_add) 211 212 def _finish(self, update_ops, name_scope): 213 # Update the power accumulators. 214 with ops.control_dependencies(update_ops): 215 beta1_power, beta2_power = self._get_beta_accumulators() 216 with ops.colocate_with(beta1_power): 217 update_beta1 = beta1_power.assign( 218 beta1_power * self._beta1_t, use_locking=self._use_locking) 219 update_beta2 = beta2_power.assign( 220 beta2_power * self._beta2_t, use_locking=self._use_locking) 221 return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], 222 name=name_scope) 223