Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Moving average optimizer."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import six
     22 
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.ops import control_flow_ops
     25 from tensorflow.python.ops import variables
     26 from tensorflow.python.training import moving_averages
     27 from tensorflow.python.training import optimizer
     28 from tensorflow.python.training import saver
     29 
     30 
     31 class MovingAverageOptimizer(optimizer.Optimizer):
     32   """Optimizer that computes a moving average of the variables.
     33 
     34   Empirically it has been found that using the moving average of the trained
     35   parameters of a deep network is better than using its trained parameters
     36   directly. This optimizer allows you to compute this moving average and swap
     37   the variables at save time so that any code outside of the training loop will
     38   use by default the averaged values instead of the original ones.
     39 
     40   Example of usage:
     41 
     42   ```python
     43 
     44   // Encapsulate your favorite optimizer (here the momentum one)
     45   // inside the MovingAverageOptimizer.
     46   opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum)
     47   opt = tf.contrib.opt.MovingAverageOptimizer(opt)
     48   // Then create your model and all its variables.
     49   model = build_model()
     50   // Add the training op that optimizes using opt.
     51   // This needs to be called before swapping_saver().
     52   opt.minimize(cost, var_list)
     53   // Then create your saver like this:
     54   saver = opt.swapping_saver()
     55   // Pass it to your training loop.
     56       slim.learning.train(
     57           model,
     58           ...
     59           saver=saver)
     60   ```
     61 
     62   Note that for evaluation, the normal saver should be used instead of
     63   swapping_saver().
     64   """
     65 
     66   def __init__(self, opt, average_decay=0.9999, num_updates=None,
     67                sequential_update=True):
     68     """Construct a new MovingAverageOptimizer.
     69 
     70     Args:
     71       opt: A tf.Optimizer that will be used to compute and apply gradients.
     72       average_decay: Float.  Decay to use to maintain the moving averages
     73                      of trained variables.
     74                      See tf.train.ExponentialMovingAverage for details.
     75       num_updates: Optional count of number of updates applied to variables.
     76                    See tf.train.ExponentialMovingAverage for details.
     77       sequential_update: Bool. If False, will compute the moving average at the
     78                          same time as the model is updated, potentially doing
     79                          benign data races.
     80                          If True, will update the moving average after gradient
     81                          updates.
     82     """
     83     self._optimizer = opt
     84     self._ema = moving_averages.ExponentialMovingAverage(
     85         average_decay, num_updates=num_updates)
     86     self._swapped_variable_name_map = None
     87     self._sequential_update = sequential_update
     88 
     89   def compute_gradients(self, *args, **kwargs):
     90     return self._optimizer.compute_gradients(*args, **kwargs)
     91 
     92   def apply_gradients(self, grads_and_vars, global_step=None, name=None):
     93     train_op = self._optimizer.apply_gradients(
     94         grads_and_vars, global_step=global_step, name=name)
     95     var_list = [x[1] for x in grads_and_vars if x[0] is not None]
     96     self._swapped_variable_name_map = {}
     97     if self._sequential_update:
     98       with ops.control_dependencies([train_op]):
     99         ma_op = self._ema.apply(var_list)
    100     else:
    101       ma_op = self._ema.apply(var_list)
    102 
    103     for v in var_list:
    104       v_avg = self._ema.average(v)
    105       self._swapped_variable_name_map[v.op.name] = v_avg.op.name
    106       self._swapped_variable_name_map[v_avg.op.name] = v.op.name
    107     return control_flow_ops.group(train_op, ma_op, name='train_with_avg')
    108 
    109   def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs):
    110     """Create a saver swapping moving averages and variables.
    111 
    112     You should use this saver during training.  It will save the moving averages
    113     of the trained parameters under the original parameter names.  For
    114     evaluations or inference you should use a regular saver and it will
    115     automatically use the moving averages for the trained variable.
    116 
    117     You must call this function after all variables have been created and after
    118     you have called Optimizer.minimize().
    119 
    120     Args:
    121       var_list: List of variables to save, as per `Saver()`.
    122                 If set to None, will save all the variables that have been
    123                 created before this call.
    124       name: The name of the saver.
    125       **kwargs: Keyword arguments of `Saver()`.
    126 
    127     Returns:
    128       A `tf.train.Saver` object.
    129 
    130     Raises:
    131       RuntimeError: If apply_gradients or minimize has not been called before.
    132       ValueError: If var_list is provided and contains some variables but not
    133         their moving average counterpart.
    134     """
    135 
    136     if self._swapped_variable_name_map is None:
    137       raise RuntimeError('Must call apply_gradients or minimize before '
    138                          'creating the swapping_saver')
    139     if var_list is None:
    140       var_list = variables.global_variables()
    141     if not isinstance(var_list, dict):
    142       var_list = saver.BaseSaverBuilder.OpListToDict(var_list)
    143 
    144     # OpListToDict converts variables to tensors. We make sure we can get
    145     # the unique variable name for normal and resource vaiables.
    146     def get_v_name(tensor):
    147       if tensor.op.type == 'ReadVariableOp':
    148         return tensor.op.inputs[0].op.name
    149       else:
    150         return tensor.op.name
    151 
    152     v_name_to_tensor = {}
    153     for tensor in six.itervalues(var_list):
    154       v_name = get_v_name(tensor)
    155       v_name_to_tensor[v_name] = tensor
    156 
    157     # Now swap variables and moving averages
    158     swapped_var_list = {}
    159     for k, tensor in six.iteritems(var_list):
    160       v_name = get_v_name(tensor)
    161       swapped_v_name = self._swapped_variable_name_map.get(v_name, None)
    162       tensor_to_save = tensor
    163       if swapped_v_name is not None:
    164         if swapped_v_name in v_name_to_tensor:
    165           tensor_to_save = v_name_to_tensor[swapped_v_name]
    166         else:
    167           raise ValueError(
    168               ('Variable to swap %s is not part of variables to save. '
    169                'This breaks MovingAverageOptimizer.') % swapped_v_name)
    170       swapped_var_list[k] = tensor_to_save
    171 
    172     # Build the swapping saver.
    173     return saver.Saver(swapped_var_list, name=name, **kwargs)
    174