Home | History | Annotate | Download | only in training
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Implementation of PowerSign."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import math
     22 
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import control_flow_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops import state_ops
     28 from tensorflow.python.training import optimizer
     29 from tensorflow.python.training import training_ops
     30 
     31 
     32 class PowerSignOptimizer(optimizer.Optimizer):
     33   """Optimizer that implements the PowerSign update.
     34 
     35   See [Bello et al., ICML2017],
     36   [Neural Optimizer Search with RL](https://arxiv.org/abs/1709.07417).
     37   """
     38 
     39   def __init__(self,
     40                learning_rate=0.1,
     41                base=math.e,
     42                beta=0.9,
     43                sign_decay_fn=None,
     44                use_locking=False,
     45                name='PowerSignOptimizer'):
     46     """Constructs a new PowerSignOptimizer object.
     47 
     48     Initialization:
     49 
     50     ```
     51     m_0 <- 0 (Initialize initial 1st moment vector)
     52     t <- 0 (Initialize timestep)
     53     ```
     54 
     55     Update:
     56 
     57     ```
     58     t <- t + 1
     59     m_t <- beta1 * m_{t-1} + (1 - beta1) * g
     60     sign_decay <- sign_decay_fn(t)
     61     update <- base ** (sign_decay * sign(g) * sign(m)) * g
     62     variable <- variable - lr_t * update
     63     ```
     64 
     65     Example usage for PowerSign-cd (PowerSign with cosine sign decay)
     66     ```
     67     decay_steps = 1000
     68     linear_decay_fn = sign_decays.get_linear_decay_fn(decay_steps)
     69     opt = PowerSignOptimizer(learning_rate=0.1, sign_decay_fn=linear_decay_fn)
     70     ```
     71 
     72     Args:
     73       learning_rate: learning_rate used when taking a step.
     74       base: base used in optimizer.
     75       beta: decay used for computing the moving average m.
     76       sign_decay_fn: decay function applied to the sign(g) sign(m) quantity.
     77           Takes global_step as an argument. See sign_decay.py for some examples.
     78       use_locking: If True, use locks for update operations.
     79       name: Optional name for the operations created iwhen applying gradients.
     80         Defaults to "PowerSignOptimizer".
     81     """
     82     super(PowerSignOptimizer, self).__init__(use_locking, name)
     83     self._lr = learning_rate
     84     self._beta = beta
     85     self._logbase = math.log(base)
     86 
     87     self._sign_decay_fn = sign_decay_fn
     88 
     89     # Tensor versions of the constructor arguments, created in _prepare().
     90     self._lr_t = None
     91     self._beta_t = None
     92     self._logbase_t = None
     93 
     94   def apply_gradients(self, grads_and_vars, global_step=None, name=None):
     95     if self._sign_decay_fn is not None:
     96       self._sign_decay_t = ops.convert_to_tensor(
     97           self._sign_decay_fn(global_step), name='sign_decay')
     98     return super(PowerSignOptimizer, self).apply_gradients(
     99         grads_and_vars, global_step=global_step, name=name)
    100 
    101   def _create_slots(self, var_list):
    102     # Create slots for the first moment.
    103     for v in var_list:
    104       self._zeros_slot(v, 'm', self._name)
    105 
    106   def _prepare(self):
    107     self._lr_t = ops.convert_to_tensor(self._lr, name='learning_rate')
    108     self._beta_t = ops.convert_to_tensor(self._beta, name='beta')
    109     self._logbase_t = ops.convert_to_tensor(self._logbase, name='logbase')
    110     if self._sign_decay_fn is None:
    111       self._sign_decay_t = ops.convert_to_tensor(1.0, name='sign_decay')
    112 
    113   def _apply_dense(self, grad, var):
    114     m = self.get_slot(var, 'm')
    115     return training_ops.apply_power_sign(
    116         var,
    117         m,
    118         math_ops.cast(self._lr_t, var.dtype.base_dtype),
    119         math_ops.cast(self._logbase_t, var.dtype.base_dtype),
    120         math_ops.cast(self._sign_decay_t, var.dtype.base_dtype),
    121         math_ops.cast(self._beta_t, var.dtype.base_dtype),
    122         grad,
    123         use_locking=self._use_locking).op
    124 
    125   def _resource_apply_dense(self, grad, var):
    126     m = self.get_slot(var, 'm')
    127     return training_ops.resource_apply_power_sign(
    128         var.handle,
    129         m.handle,
    130         math_ops.cast(self._lr_t, var.dtype.base_dtype),
    131         math_ops.cast(self._logbase_t, var.dtype.base_dtype),
    132         math_ops.cast(self._sign_decay_t, var.dtype.base_dtype),
    133         math_ops.cast(self._beta_t, var.dtype.base_dtype),
    134         grad,
    135         use_locking=self._use_locking)
    136 
    137   def _apply_sparse(self, grad, var):
    138     lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
    139     beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype)
    140     logbase_t = math_ops.cast(self._logbase_t, var.dtype.base_dtype)
    141     e_t = math_ops.cast(math.e, var.dtype.base_dtype)
    142 
    143     m = self.get_slot(var, 'm')
    144     m_t = state_ops.assign(
    145         m, (m * beta_t) + (grad * (1 - beta_t)), use_locking=self._use_locking)
    146 
    147     sign_g = ops.IndexedSlices(
    148         math_ops.sign(grad.values), grad.indices, dense_shape=grad.dense_shape)
    149     sign_gm = ops.IndexedSlices(
    150         array_ops.gather(math_ops.sign(m_t), sign_g.indices) * sign_g.values,
    151         sign_g.indices,
    152         dense_shape=sign_g.dense_shape)
    153 
    154     sign_decayed = math_ops.cast(
    155         self._sign_decay_t, var.dtype.base_dtype)
    156     multiplier_values = math_ops.pow(
    157         e_t, logbase_t * sign_decayed * sign_gm.values)
    158     multiplier = ops.IndexedSlices(
    159         multiplier_values, sign_gm.indices, dense_shape=sign_gm.dense_shape)
    160 
    161     final_update = ops.IndexedSlices(
    162         lr_t * multiplier.values * grad.values,
    163         multiplier.indices,
    164         dense_shape=multiplier.dense_shape)
    165 
    166     var_update = state_ops.scatter_sub(
    167         var,
    168         final_update.indices,
    169         final_update.values,
    170         use_locking=self._use_locking)
    171 
    172     return control_flow_ops.group(* [var_update, m_t])
    173