Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Adam for TensorFlow."""
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.eager import context
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import control_flow_ops
     24 from tensorflow.python.ops import math_ops
     25 from tensorflow.python.ops import resource_variable_ops
     26 from tensorflow.python.ops import state_ops
     27 from tensorflow.python.training import optimizer
     28 from tensorflow.python.training import training_ops
     29 from tensorflow.python.util.tf_export import tf_export
     30 
     31 
     32 @tf_export("train.AdamOptimizer")
     33 class AdamOptimizer(optimizer.Optimizer):
     34   """Optimizer that implements the Adam algorithm.
     35 
     36   See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
     37   ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
     38   """
     39 
     40   def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
     41                use_locking=False, name="Adam"):
     42     """Construct a new Adam optimizer.
     43 
     44     Initialization:
     45 
     46     ```
     47     m_0 <- 0 (Initialize initial 1st moment vector)
     48     v_0 <- 0 (Initialize initial 2nd moment vector)
     49     t <- 0 (Initialize timestep)
     50     ```
     51 
     52     The update rule for `variable` with gradient `g` uses an optimization
     53     described at the end of section2 of the paper:
     54 
     55     ```
     56     t <- t + 1
     57     lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
     58 
     59     m_t <- beta1 * m_{t-1} + (1 - beta1) * g
     60     v_t <- beta2 * v_{t-1} + (1 - beta2) * g * g
     61     variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon)
     62     ```
     63 
     64     The default value of 1e-8 for epsilon might not be a good default in
     65     general. For example, when training an Inception network on ImageNet a
     66     current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
     67     formulation just before Section 2.1 of the Kingma and Ba paper rather than
     68     the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
     69     hat" in the paper.
     70 
     71     The sparse implementation of this algorithm (used when the gradient is an
     72     IndexedSlices object, typically because of `tf.gather` or an embedding
     73     lookup in the forward pass) does apply momentum to variable slices even if
     74     they were not used in the forward pass (meaning they have a gradient equal
     75     to zero). Momentum decay (beta1) is also applied to the entire momentum
     76     accumulator. This means that the sparse behavior is equivalent to the dense
     77     behavior (in contrast to some momentum implementations which ignore momentum
     78     unless a variable slice was actually used).
     79 
     80     Args:
     81       learning_rate: A Tensor or a floating point value.  The learning rate.
     82       beta1: A float value or a constant float tensor.
     83         The exponential decay rate for the 1st moment estimates.
     84       beta2: A float value or a constant float tensor.
     85         The exponential decay rate for the 2nd moment estimates.
     86       epsilon: A small constant for numerical stability. This epsilon is
     87         "epsilon hat" in the Kingma and Ba paper (in the formula just before
     88         Section 2.1), not the epsilon in Algorithm 1 of the paper.
     89       use_locking: If True use locks for update operations.
     90       name: Optional name for the operations created when applying gradients.
     91         Defaults to "Adam".
     92     """
     93     super(AdamOptimizer, self).__init__(use_locking, name)
     94     self._lr = learning_rate
     95     self._beta1 = beta1
     96     self._beta2 = beta2
     97     self._epsilon = epsilon
     98 
     99     # Tensor versions of the constructor arguments, created in _prepare().
    100     self._lr_t = None
    101     self._beta1_t = None
    102     self._beta2_t = None
    103     self._epsilon_t = None
    104 
    105     # Created in SparseApply if needed.
    106     self._updated_lr = None
    107 
    108   def _get_beta_accumulators(self):
    109     if context.in_graph_mode():
    110       graph = ops.get_default_graph()
    111     else:
    112       graph = None
    113     return (self._get_non_slot_variable("beta1_power", graph=graph),
    114             self._get_non_slot_variable("beta2_power", graph=graph))
    115 
    116   def _create_slots(self, var_list):
    117     # Create the beta1 and beta2 accumulators on the same device as the first
    118     # variable. Sort the var_list to make sure this device is consistent across
    119     # workers (these need to go on the same PS, otherwise some updates are
    120     # silently ignored).
    121     first_var = min(var_list, key=lambda x: x.name)
    122     self._create_non_slot_variable(initial_value=self._beta1,
    123                                    name="beta1_power",
    124                                    colocate_with=first_var)
    125     self._create_non_slot_variable(initial_value=self._beta2,
    126                                    name="beta2_power",
    127                                    colocate_with=first_var)
    128 
    129     # Create slots for the first and second moments.
    130     for v in var_list:
    131       self._zeros_slot(v, "m", self._name)
    132       self._zeros_slot(v, "v", self._name)
    133 
    134   def _prepare(self):
    135     self._lr_t = ops.convert_to_tensor(self._lr, name="learning_rate")
    136     self._beta1_t = ops.convert_to_tensor(self._beta1, name="beta1")
    137     self._beta2_t = ops.convert_to_tensor(self._beta2, name="beta2")
    138     self._epsilon_t = ops.convert_to_tensor(self._epsilon, name="epsilon")
    139 
    140   def _apply_dense(self, grad, var):
    141     m = self.get_slot(var, "m")
    142     v = self.get_slot(var, "v")
    143     beta1_power, beta2_power = self._get_beta_accumulators()
    144     return training_ops.apply_adam(
    145         var, m, v,
    146         math_ops.cast(beta1_power, var.dtype.base_dtype),
    147         math_ops.cast(beta2_power, var.dtype.base_dtype),
    148         math_ops.cast(self._lr_t, var.dtype.base_dtype),
    149         math_ops.cast(self._beta1_t, var.dtype.base_dtype),
    150         math_ops.cast(self._beta2_t, var.dtype.base_dtype),
    151         math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
    152         grad, use_locking=self._use_locking).op
    153 
    154   def _resource_apply_dense(self, grad, var):
    155     m = self.get_slot(var, "m")
    156     v = self.get_slot(var, "v")
    157     beta1_power, beta2_power = self._get_beta_accumulators()
    158     return training_ops.resource_apply_adam(
    159         var.handle, m.handle, v.handle,
    160         math_ops.cast(beta1_power, grad.dtype.base_dtype),
    161         math_ops.cast(beta2_power, grad.dtype.base_dtype),
    162         math_ops.cast(self._lr_t, grad.dtype.base_dtype),
    163         math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
    164         math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
    165         math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
    166         grad, use_locking=self._use_locking)
    167 
    168   def _apply_sparse_shared(self, grad, var, indices, scatter_add):
    169     beta1_power, beta2_power = self._get_beta_accumulators()
    170     beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
    171     beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
    172     lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    173     beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
    174     beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
    175     epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
    176     lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
    177     # m_t = beta1 * m + (1 - beta1) * g_t
    178     m = self.get_slot(var, "m")
    179     m_scaled_g_values = grad * (1 - beta1_t)
    180     m_t = state_ops.assign(m, m * beta1_t,
    181                            use_locking=self._use_locking)
    182     with ops.control_dependencies([m_t]):
    183       m_t = scatter_add(m, indices, m_scaled_g_values)
    184     # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
    185     v = self.get_slot(var, "v")
    186     v_scaled_g_values = (grad * grad) * (1 - beta2_t)
    187     v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
    188     with ops.control_dependencies([v_t]):
    189       v_t = scatter_add(v, indices, v_scaled_g_values)
    190     v_sqrt = math_ops.sqrt(v_t)
    191     var_update = state_ops.assign_sub(var,
    192                                       lr * m_t / (v_sqrt + epsilon_t),
    193                                       use_locking=self._use_locking)
    194     return control_flow_ops.group(*[var_update, m_t, v_t])
    195 
    196   def _apply_sparse(self, grad, var):
    197     return self._apply_sparse_shared(
    198         grad.values, var, grad.indices,
    199         lambda x, i, v: state_ops.scatter_add(  # pylint: disable=g-long-lambda
    200             x, i, v, use_locking=self._use_locking))
    201 
    202   def _resource_scatter_add(self, x, i, v):
    203     with ops.control_dependencies(
    204         [resource_variable_ops.resource_scatter_add(
    205             x.handle, i, v)]):
    206       return x.value()
    207 
    208   def _resource_apply_sparse(self, grad, var, indices):
    209     return self._apply_sparse_shared(
    210         grad, var, indices, self._resource_scatter_add)
    211 
    212   def _finish(self, update_ops, name_scope):
    213     # Update the power accumulators.
    214     with ops.control_dependencies(update_ops):
    215       beta1_power, beta2_power = self._get_beta_accumulators()
    216       with ops.colocate_with(beta1_power):
    217         update_beta1 = beta1_power.assign(
    218             beta1_power * self._beta1_t, use_locking=self._use_locking)
    219         update_beta2 = beta2_power.assign(
    220             beta2_power * self._beta2_t, use_locking=self._use_locking)
    221     return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
    222                                   name=name_scope)
    223