1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Implementation of PowerSign.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import math 22 23 from tensorflow.python.framework import ops 24 from tensorflow.python.ops import array_ops 25 from tensorflow.python.ops import control_flow_ops 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops import state_ops 28 from tensorflow.python.training import optimizer 29 from tensorflow.python.training import training_ops 30 31 32 class PowerSignOptimizer(optimizer.Optimizer): 33 """Optimizer that implements the PowerSign update. 34 35 See [Bello et al., ICML2017], 36 [Neural Optimizer Search with RL](https://arxiv.org/abs/1709.07417). 37 """ 38 39 def __init__(self, 40 learning_rate=0.1, 41 base=math.e, 42 beta=0.9, 43 sign_decay_fn=None, 44 use_locking=False, 45 name='PowerSignOptimizer'): 46 """Constructs a new PowerSignOptimizer object. 47 48 Initialization: 49 50 ``` 51 m_0 <- 0 (Initialize initial 1st moment vector) 52 t <- 0 (Initialize timestep) 53 ``` 54 55 Update: 56 57 ``` 58 t <- t + 1 59 m_t <- beta1 * m_{t-1} + (1 - beta1) * g 60 sign_decay <- sign_decay_fn(t) 61 update <- base ** (sign_decay * sign(g) * sign(m)) * g 62 variable <- variable - lr_t * update 63 ``` 64 65 Example usage for PowerSign-cd (PowerSign with cosine sign decay) 66 ``` 67 decay_steps = 1000 68 linear_decay_fn = sign_decays.get_linear_decay_fn(decay_steps) 69 opt = PowerSignOptimizer(learning_rate=0.1, sign_decay_fn=linear_decay_fn) 70 ``` 71 72 Args: 73 learning_rate: learning_rate used when taking a step. 74 base: base used in optimizer. 75 beta: decay used for computing the moving average m. 76 sign_decay_fn: decay function applied to the sign(g) sign(m) quantity. 77 Takes global_step as an argument. See sign_decay.py for some examples. 78 use_locking: If True, use locks for update operations. 79 name: Optional name for the operations created iwhen applying gradients. 80 Defaults to "PowerSignOptimizer". 81 """ 82 super(PowerSignOptimizer, self).__init__(use_locking, name) 83 self._lr = learning_rate 84 self._beta = beta 85 self._logbase = math.log(base) 86 87 self._sign_decay_fn = sign_decay_fn 88 89 # Tensor versions of the constructor arguments, created in _prepare(). 90 self._lr_t = None 91 self._beta_t = None 92 self._logbase_t = None 93 94 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 95 if self._sign_decay_fn is not None: 96 self._sign_decay_t = ops.convert_to_tensor( 97 self._sign_decay_fn(global_step), name='sign_decay') 98 return super(PowerSignOptimizer, self).apply_gradients( 99 grads_and_vars, global_step=global_step, name=name) 100 101 def _create_slots(self, var_list): 102 # Create slots for the first moment. 103 for v in var_list: 104 self._zeros_slot(v, 'm', self._name) 105 106 def _prepare(self): 107 self._lr_t = ops.convert_to_tensor(self._lr, name='learning_rate') 108 self._beta_t = ops.convert_to_tensor(self._beta, name='beta') 109 self._logbase_t = ops.convert_to_tensor(self._logbase, name='logbase') 110 if self._sign_decay_fn is None: 111 self._sign_decay_t = ops.convert_to_tensor(1.0, name='sign_decay') 112 113 def _apply_dense(self, grad, var): 114 m = self.get_slot(var, 'm') 115 return training_ops.apply_power_sign( 116 var, 117 m, 118 math_ops.cast(self._lr_t, var.dtype.base_dtype), 119 math_ops.cast(self._logbase_t, var.dtype.base_dtype), 120 math_ops.cast(self._sign_decay_t, var.dtype.base_dtype), 121 math_ops.cast(self._beta_t, var.dtype.base_dtype), 122 grad, 123 use_locking=self._use_locking).op 124 125 def _resource_apply_dense(self, grad, var): 126 m = self.get_slot(var, 'm') 127 return training_ops.resource_apply_power_sign( 128 var.handle, 129 m.handle, 130 math_ops.cast(self._lr_t, var.dtype.base_dtype), 131 math_ops.cast(self._logbase_t, var.dtype.base_dtype), 132 math_ops.cast(self._sign_decay_t, var.dtype.base_dtype), 133 math_ops.cast(self._beta_t, var.dtype.base_dtype), 134 grad, 135 use_locking=self._use_locking) 136 137 def _apply_sparse(self, grad, var): 138 lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 139 beta_t = math_ops.cast(self._beta_t, var.dtype.base_dtype) 140 logbase_t = math_ops.cast(self._logbase_t, var.dtype.base_dtype) 141 e_t = math_ops.cast(math.e, var.dtype.base_dtype) 142 143 m = self.get_slot(var, 'm') 144 m_t = state_ops.assign( 145 m, (m * beta_t) + (grad * (1 - beta_t)), use_locking=self._use_locking) 146 147 sign_g = ops.IndexedSlices( 148 math_ops.sign(grad.values), grad.indices, dense_shape=grad.dense_shape) 149 sign_gm = ops.IndexedSlices( 150 array_ops.gather(math_ops.sign(m_t), sign_g.indices) * sign_g.values, 151 sign_g.indices, 152 dense_shape=sign_g.dense_shape) 153 154 sign_decayed = math_ops.cast( 155 self._sign_decay_t, var.dtype.base_dtype) 156 multiplier_values = math_ops.pow( 157 e_t, logbase_t * sign_decayed * sign_gm.values) 158 multiplier = ops.IndexedSlices( 159 multiplier_values, sign_gm.indices, dense_shape=sign_gm.dense_shape) 160 161 final_update = ops.IndexedSlices( 162 lr_t * multiplier.values * grad.values, 163 multiplier.indices, 164 dense_shape=multiplier.dense_shape) 165 166 var_update = state_ops.scatter_sub( 167 var, 168 final_update.indices, 169 final_update.values, 170 use_locking=self._use_locking) 171 172 return control_flow_ops.group(* [var_update, m_t]) 173