Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """An optimizer module for stochastic gradient Langevin dynamics."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 from tensorflow.python.framework import ops
     20 from tensorflow.python.ops import array_ops
     21 from tensorflow.python.ops import check_ops
     22 from tensorflow.python.ops import control_flow_ops
     23 from tensorflow.python.ops import init_ops
     24 from tensorflow.python.ops import math_ops
     25 from tensorflow.python.ops import random_ops
     26 from tensorflow.python.ops import variable_scope as varscope_ops
     27 from tensorflow.python.training import optimizer
     28 from tensorflow.python.training import training_ops
     29 
     30 
     31 class SGLDOptimizer(optimizer.Optimizer):
     32   """An optimizer module for stochastic gradient Langevin dynamics.
     33 
     34   This implements the preconditioned Stochastic Gradient Langevin Dynamics
     35   optimizer [1]. The optimization variable is regarded as a sample from the
     36   posterior under Stochastic Gradient Langevin Dynamics with noise rescaled in
     37   each dimension according to RMSProp [2].
     38 
     39   Note: If a prior is included in the loss, it should be scaled by
     40   `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches
     41   in the data.  I.e., it should be divided by the `num_pseudo_batches` term
     42   described below.
     43 
     44   [1]: "Preconditioned Stochastic Gradient Langevin Dynamics for Deep Neural
     45        Networks." Chunyuan Li, Changyou Chen, David Carlson, Lawrence Carin.
     46        ArXiv:1512.07666, 2015. https://arxiv.org/abs/1512.07666
     47   [2]: http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
     48 
     49   Args:
     50     learning_rate: Scalar `float`-like `Tensor`. The base learning rate for the
     51       optimizer. Must be tuned to the specific function being minimized.
     52     preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential
     53       decay rate of the rescaling of the preconditioner (RMSprop). (This is
     54       "alpha" in [1]). Should be smaller than but nearly `1` to approximate
     55       sampling from the posterior. (Default: `0.95`)
     56     num_pseudo_batches: Scalar `int`-like `Tensor`. The effective number of
     57       minibatches in the data set.  Trades off noise and prior with the SGD
     58       likelihood term. Note: Assumes the loss is taken as the mean over a
     59       minibatch. Otherwise if the sum was taken, divide this number by the
     60       batch size.  (Default: `1`)
     61     burnin: Scalar `int`-like `Tensor`. The number of iterations to collect
     62       gradient statistics to update the preconditioner before starting to draw
     63       noisy samples. (Default: `25`)
     64     diagonal_bias: Scalar `float`-like `Tensor`. Term added to the diagonal of
     65       the preconditioner to prevent the preconditioner from degenerating.
     66       (Default: `1e-8`)
     67     name: Python `str` describing ops managed by this function.
     68       (Default: `"SGLDOptimizer"`)
     69     variable_scope: Variable scope used for calls to `tf.get_variable`.
     70       If `None`, a new variable scope is created using name
     71       `ops.get_default_graph().unique_name(name or default_name)`.
     72 
     73   Raises:
     74     InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in
     75       `(0,1]`.
     76   """
     77 
     78   def __init__(self,
     79                learning_rate,
     80                preconditioner_decay_rate=0.95,
     81                num_pseudo_batches=1,
     82                burnin=25,
     83                diagonal_bias=1e-8,
     84                name=None,
     85                variable_scope=None):
     86     default_name = 'SGLDOptimizer'
     87     with ops.name_scope(name, default_name, [
     88         learning_rate, preconditioner_decay_rate, num_pseudo_batches, burnin,
     89         diagonal_bias
     90     ]):
     91       if variable_scope is None:
     92         var_scope_name = ops.get_default_graph().unique_name(
     93             name or default_name)
     94         with varscope_ops.variable_scope(var_scope_name) as scope:
     95           self._variable_scope = scope
     96       else:
     97         self._variable_scope = variable_scope
     98 
     99       self._preconditioner_decay_rate = ops.convert_to_tensor(
    100           preconditioner_decay_rate, name='preconditioner_decay_rate')
    101       self._num_pseudo_batches = ops.convert_to_tensor(
    102           num_pseudo_batches, name='num_pseudo_batches')
    103       self._burnin = ops.convert_to_tensor(burnin, name='burnin')
    104       self._diagonal_bias = ops.convert_to_tensor(
    105           diagonal_bias, name='diagonal_bias')
    106       self._learning_rate = ops.convert_to_tensor(
    107           learning_rate, name='learning_rate')
    108 
    109       with varscope_ops.variable_scope(self._variable_scope):
    110         self._counter = varscope_ops.get_variable(
    111             'counter', initializer=0, trainable=False)
    112 
    113       self._preconditioner_decay_rate = control_flow_ops.with_dependencies([
    114           check_ops.assert_non_negative(
    115               self._preconditioner_decay_rate,
    116               message='`preconditioner_decay_rate` must be non-negative'),
    117           check_ops.assert_less_equal(
    118               self._preconditioner_decay_rate,
    119               1.,
    120               message='`preconditioner_decay_rate` must be at most 1.'),
    121       ], self._preconditioner_decay_rate)
    122 
    123       self._num_pseudo_batches = control_flow_ops.with_dependencies([
    124           check_ops.assert_greater(
    125               self._num_pseudo_batches,
    126               0,
    127               message='`num_pseudo_batches` must be greater than zero')
    128       ], self._num_pseudo_batches)
    129 
    130       self._burnin = control_flow_ops.with_dependencies([
    131           check_ops.assert_non_negative(
    132               self._burnin, message='`burnin` must be non-negative'),
    133           check_ops.assert_integer(
    134               self._burnin, message='`burnin` must be an integer')
    135       ], self._burnin)
    136 
    137       self._diagonal_bias = control_flow_ops.with_dependencies([
    138           check_ops.assert_non_negative(
    139               self._diagonal_bias,
    140               message='`diagonal_bias` must be non-negative')
    141       ], self._diagonal_bias)
    142 
    143       super(SGLDOptimizer, self).__init__(use_locking=False,
    144                                           name=name or default_name)
    145 
    146   def _create_slots(self, var_list):
    147     for v in var_list:
    148       init_rms = init_ops.ones_initializer(dtype=v.dtype)
    149       self._get_or_make_slot_with_initializer(v, init_rms, v.get_shape(),
    150                                               v.dtype, 'rms', self._name)
    151 
    152   def _prepare(self):
    153     # We need to put the conversion and check here because a user will likely
    154     # want to decay the learning rate dynamically.
    155     self._learning_rate_tensor = control_flow_ops.with_dependencies([
    156         check_ops.assert_non_negative(
    157             self._learning_rate, message='`learning_rate` must be non-negative')
    158     ], ops.convert_to_tensor(self._learning_rate, name='learning_rate_tensor'))
    159     self._decay_tensor = ops.convert_to_tensor(
    160         self._preconditioner_decay_rate, name='preconditioner_decay_rate')
    161 
    162     super(SGLDOptimizer, self)._prepare()
    163 
    164   def _apply_dense(self, grad, var):
    165     rms = self.get_slot(var, 'rms')
    166 
    167     with ops.control_dependencies([
    168         self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor,
    169                                                        var.dtype.base_dtype))]):
    170       new_grad = self._apply_noisy_update(rms, grad)
    171 
    172     return training_ops.apply_gradient_descent(
    173         var,
    174         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    175         new_grad,
    176         use_locking=self._use_locking).op
    177 
    178   def _apply_sparse(self, grad, var):
    179     rms = self.get_slot(var, 'rms')
    180 
    181     with ops.control_dependencies([
    182         self._update_momentum(rms, grad, math_ops.cast(self._decay_tensor,
    183                                                        var.dtype.base_dtype))]):
    184       new_grad = self._apply_noisy_update(rms, grad)
    185 
    186     return training_ops.apply_gradient_descent(
    187         var,
    188         math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
    189         new_grad,
    190         use_locking=self._use_locking).op
    191 
    192   def _finish(self, update_ops, name_scope):
    193     update_ops.append([self._counter.assign_add(1)])
    194     return control_flow_ops.group(*update_ops, name=name_scope)
    195 
    196   @property
    197   def variable_scope(self):
    198     """Variable scope of all calls to `tf.get_variable`."""
    199     return self._variable_scope
    200 
    201   def _apply_noisy_update(self, mom, grad):
    202     # Compute and apply the gradient update following
    203     # preconditioned Langevin dynamics
    204     stddev = array_ops.where(
    205         array_ops.squeeze(self._counter > self._burnin),
    206         math_ops.cast(math_ops.rsqrt(self._learning_rate), grad.dtype),
    207         array_ops.zeros([], grad.dtype))
    208 
    209     preconditioner = math_ops.rsqrt(
    210         mom + math_ops.cast(self._diagonal_bias, grad.dtype))
    211     return (
    212         0.5 * preconditioner * grad * math_ops.cast(self._num_pseudo_batches,
    213                                                     grad.dtype) +
    214         random_ops.random_normal(array_ops.shape(grad), 1.0, dtype=grad.dtype) *
    215         stddev * math_ops.sqrt(preconditioner))
    216 
    217   def _update_momentum(self, mom, grad, decay):
    218     # Keep an exponentially weighted moving average of squared gradients.
    219     # Not thread safe
    220     return mom.assign_add((1.0 - decay) * (math_ops.square(grad) - mom))
    221