1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """An optimizer module for stochastic gradient Langevin dynamics.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 from tensorflow.python.framework import ops 20 from tensorflow.python.ops import array_ops 21 from tensorflow.python.ops import check_ops 22 from tensorflow.python.ops import control_flow_ops 23 from tensorflow.python.ops import init_ops 24 from tensorflow.python.ops import math_ops 25 from tensorflow.python.ops import random_ops 26 from tensorflow.python.ops import variable_scope as varscope_ops 27 from tensorflow.python.training import optimizer 28 from tensorflow.python.training import training_ops 29 30 31 class SGLDOptimizer(optimizer.Optimizer): 32 """An optimizer module for stochastic gradient Langevin dynamics. 33 34 This implements the preconditioned Stochastic Gradient Langevin Dynamics 35 optimizer [1]. The optimization variable is regarded as a sample from the 36 posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in 37 each dimension according to RMSProp [2]. 38 39 Note: If a prior is included in the loss, it should be scaled by 40 `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches 41 in the data. I.e., it should be divided by the `num_pseudo_batches` term 42 described below. 43 44 [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural 45 Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin. 46 ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666 47 [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf 48 49 Args: 50 learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the 51 optimizer. Must be tuned to the specific function being minimized. 52 preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential 53 decay rate of the rescaling of the preconditioner (RMSprop). (This is 54 "alpha" in [1]). Should be smaller than but nearly `1` to approximate 55 sampling from the posterior. (Default: `0.95`) 56 num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of 57 minibatches in the data set. Trades off noise and prior with the SGD 58 likelihood term. Note: Assumes the loss is taken as the mean over a 59 minibatch. Otherwise if the sum was taken, divide this number by the 60 batch size. (Default: `1`) 61 burnin: Scalar `int`-like `Tensor`. The number of iterations to collect 62 gradient statistics to update the preconditioner before starting to draw 63 noisy samples. (Default: `25`) 64 diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of 65 the preconditioner to prevent the preconditioner from degenerating. 66 (Default: `1e-8`) 67 name: Python `str` describing ops managed by this function. 68 (Default: `"SGLDOptimizer"`) 69 variable_scope: Variable scope used for calls to `tf.get_variable`. 70 If `None`, a new variable scope is created using name 71 `ops.get_default_graph().unique_name(name or default_name)`. 72 73 Raises: 74 InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in 75 `(0,1]`. 76 """ 77 78 def __init__(self, 79 learning_rate, 80 preconditioner_decay_rate=0.95, 81 num_pseudo_batches=1, 82 burnin=25, 83 diagonal_bias=1e-8, 84 name=None, 85 variable_scope=None): 86 default_name = 'SGLDOptimizer' 87 with ops.name_scope(name, default_name, [ 88 learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin, 89 diagonal_bias 90 ]): 91 if variable_scope is None: 92 var_scope_name = ops.get_default_graph().unique_name( 93 name or default_name) 94 with varscope_ops.variable_scope(var_scope_name) as scope: 95 self._variable_scope = scope 96 else: 97 self._variable_scope = variable_scope 98 99 self._preconditioner_decay_rate = ops.convert_to_tensor( 100 preconditioner_decay_rate, name='preconditioner_decay_rate') 101 self._num_pseudo_batches = ops.convert_to_tensor( 102 num_pseudo_batches, name='num_pseudo_batches') 103 self._burnin = ops.convert_to_tensor(burnin, name='burnin') 104 self._diagonal_bias = ops.convert_to_tensor( 105 diagonal_bias, name='diagonal_bias') 106 self._learning_rate = ops.convert_to_tensor( 107 learning_rate, name='learning_rate') 108 109 with varscope_ops.variable_scope(self._variable_scope): 110 self._counter = varscope_ops.get_variable( 111 'counter', initializer=0, trainable=False) 112 113 self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ 114 check_ops.assert_non_negative( 115 self._preconditioner_decay_rate, 116 message='`preconditioner_decay_rate` must be non-negative'), 117 check_ops.assert_less_equal( 118 self._preconditioner_decay_rate, 119 1., 120 message='`preconditioner_decay_rate` must be at most 1.'), 121 ], self._preconditioner_decay_rate) 122 123 self._num_pseudo_batches = control_flow_ops.with_dependencies([ 124 check_ops.assert_greater( 125 self._num_pseudo_batches, 126 0, 127 message='`num_pseudo_batches` must be greater than zero') 128 ], self._num_pseudo_batches) 129 130 self._burnin = control_flow_ops.with_dependencies([ 131 check_ops.assert_non_negative( 132 self._burnin, message='`burnin` must be non-negative'), 133 check_ops.assert_integer( 134 self._burnin, message='`burnin` must be an integer') 135 ], self._burnin) 136 137 self._diagonal_bias = control_flow_ops.with_dependencies([ 138 check_ops.assert_non_negative( 139 self._diagonal_bias, 140 message='`diagonal_bias` must be non-negative') 141 ], self._diagonal_bias) 142 143 super(SGLDOptimizer, self).__init__(use_locking=False, 144 name=name or default_name) 145 146 def _create_slots(self, var_list): 147 for v in var_list: 148 init_rms = init_ops.ones_initializer(dtype=v.dtype) 149 self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(), 150 v.dtype, 'rms', self._name) 151 152 def _prepare(self): 153 # We need to put the conversion and check here because a user will likely 154 # want to decay the learning rate dynamically. 155 self._learning_rate_tensor = control_flow_ops.with_dependencies([ 156 check_ops.assert_non_negative( 157 self._learning_rate, message='`learning_rate` must be non-negative') 158 ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor')) 159 self._decay_tensor = ops.convert_to_tensor( 160 self._preconditioner_decay_rate, name='preconditioner_decay_rate') 161 162 super(SGLDOptimizer, self)._prepare() 163 164 def _apply_dense(self, grad, var): 165 rms = self.get_slot(var, 'rms') 166 167 with ops.control_dependencies([ 168 self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, 169 var.dtype.base_dtype))]): 170 new_grad = self._apply_noisy_update(rms, grad) 171 172 return training_ops.apply_gradient_descent( 173 var, 174 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 175 new_grad, 176 use_locking=self._use_locking).op 177 178 def _apply_sparse(self, grad, var): 179 rms = self.get_slot(var, 'rms') 180 181 with ops.control_dependencies([ 182 self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor, 183 var.dtype.base_dtype))]): 184 new_grad = self._apply_noisy_update(rms, grad) 185 186 return training_ops.apply_gradient_descent( 187 var, 188 math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype), 189 new_grad, 190 use_locking=self._use_locking).op 191 192 def _finish(self, update_ops, name_scope): 193 update_ops.append([self._counter.assign_add(1)]) 194 return control_flow_ops.group(*update_ops, name=name_scope) 195 196 @property 197 def variable_scope(self): 198 """Variable scope of all calls to `tf.get_variable`.""" 199 return self._variable_scope 200 201 def _apply_noisy_update(self, mom, grad): 202 # Compute and apply the gradient update following 203 # preconditioned Langevin dynamics 204 stddev = array_ops.where( 205 array_ops.squeeze(self._counter > self._burnin), 206 math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype), 207 array_ops.zeros([], grad.dtype)) 208 209 preconditioner = math_ops.rsqrt( 210 mom + math_ops.cast(self._diagonal_bias, grad.dtype)) 211 return ( 212 0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches, 213 grad.dtype) + 214 random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) * 215 stddev * math_ops.sqrt(preconditioner)) 216 217 def _update_momentum(self, mom, grad, decay): 218 # Keep an exponentially weighted moving average of squared gradients. 219 # Not thread safe 220 return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom)) 221