Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Maintain moving averages of parameters."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.framework import dtypes
     21 from tensorflow.python.framework import ops
     22 from tensorflow.python.ops import control_flow_ops
     23 from tensorflow.python.ops import init_ops
     24 from tensorflow.python.ops import math_ops
     25 from tensorflow.python.ops import state_ops
     26 from tensorflow.python.ops import variable_scope
     27 from tensorflow.python.ops import variables
     28 from tensorflow.python.training import slot_creator
     29 from tensorflow.python.util.tf_export import tf_export
     30 
     31 
     32 # TODO(touts): switch to variables.Variable.
     33 def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
     34   """Compute the moving average of a variable.
     35 
     36   The moving average of 'variable' updated with 'value' is:
     37     variable * decay + value * (1 - decay)
     38 
     39   The returned Operation sets 'variable' to the newly computed moving average.
     40 
     41   The new value of 'variable' can be set with the 'AssignSub' op as:
     42      variable -= (1 - decay) * (variable - value)
     43 
     44   Since variables that are initialized to a `0` value will be `0` biased,
     45   `zero_debias` optionally enables scaling by the mathematically correct
     46   debiasing factor of
     47     1 - decay ** num_updates
     48   See `ADAM: A Method for Stochastic Optimization` Section 3 for more details
     49   (https://arxiv.org/abs/1412.6980).
     50 
     51   The names of the debias shadow variables, by default, include both the scope
     52   they were created in and the scope of the variables they debias. They are also
     53   given a uniqifying-suffix.
     54 
     55   Ex:
     56     with tf.variable_scope('scope1'):
     57       with tf.variable_scope('scope2'):
     58         var = tf.get_variable('foo')
     59         assign_moving_average(var, 0.0, 1.0)
     60         assign_moving_average(var, 0.0, 0.9)
     61 
     62     var.name: 'scope1/scope2/foo'
     63     shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
     64                       'scope1/scope2/scope1/scope2/foo/biased_1'
     65 
     66   Args:
     67     variable: A Variable.
     68     value: A tensor with the same shape as 'variable'.
     69     decay: A float Tensor or float value.  The moving average decay.
     70     zero_debias: A python bool. If true, assume the variable is 0-initialized
     71       and unbias it, as in https://arxiv.org/abs/1412.6980. See docstring in
     72       `_zero_debias` for more details.
     73     name: Optional name of the returned operation.
     74 
     75   Returns:
     76     A reference to the input 'variable' tensor with the newly computed
     77     moving average.
     78   """
     79   with ops.name_scope(name, "AssignMovingAvg",
     80                       [variable, value, decay]) as scope:
     81     with ops.colocate_with(variable):
     82       decay = ops.convert_to_tensor(1.0 - decay, name="decay")
     83       if decay.dtype != variable.dtype.base_dtype:
     84         decay = math_ops.cast(decay, variable.dtype.base_dtype)
     85       if zero_debias:
     86         update_delta = _zero_debias(variable, value, decay)
     87       else:
     88         update_delta = (variable - value) * decay
     89       return state_ops.assign_sub(variable, update_delta, name=scope)
     90 
     91 
     92 def weighted_moving_average(value,
     93                             decay,
     94                             weight,
     95                             truediv=True,
     96                             collections=None,
     97                             name=None):
     98   """Compute the weighted moving average of `value`.
     99 
    100   Conceptually, the weighted moving average is:
    101     `moving_average(value * weight) / moving_average(weight)`,
    102   where a moving average updates by the rule
    103     `new_value = decay * old_value + (1 - decay) * update`
    104   Internally, this Op keeps moving average variables of both `value * weight`
    105   and `weight`.
    106 
    107   Args:
    108     value: A numeric `Tensor`.
    109     decay: A float `Tensor` or float value.  The moving average decay.
    110     weight:  `Tensor` that keeps the current value of a weight.
    111       Shape should be able to multiply `value`.
    112     truediv:  Boolean, if `True`, dividing by `moving_average(weight)` is
    113       floating point division.  If `False`, use division implied by dtypes.
    114     collections:  List of graph collections keys to add the internal variables
    115       `value * weight` and `weight` to.
    116       Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
    117     name: Optional name of the returned operation.
    118       Defaults to "WeightedMovingAvg".
    119 
    120   Returns:
    121     An Operation that updates and returns the weighted moving average.
    122   """
    123   # Unlike assign_moving_average, the weighted moving average doesn't modify
    124   # user-visible variables. It is the ratio of two internal variables, which are
    125   # moving averages of the updates.  Thus, the signature of this function is
    126   # quite different than assign_moving_average.
    127   if collections is None:
    128     collections = [ops.GraphKeys.GLOBAL_VARIABLES]
    129   with variable_scope.variable_scope(name, "WeightedMovingAvg",
    130                                      [value, weight, decay]) as scope:
    131     value_x_weight_var = variable_scope.get_variable(
    132         "value_x_weight",
    133         shape=value.get_shape(),
    134         dtype=value.dtype,
    135         initializer=init_ops.zeros_initializer(),
    136         trainable=False,
    137         collections=collections)
    138     weight_var = variable_scope.get_variable(
    139         "weight",
    140         shape=weight.get_shape(),
    141         dtype=weight.dtype,
    142         initializer=init_ops.zeros_initializer(),
    143         trainable=False,
    144         collections=collections)
    145     numerator = assign_moving_average(
    146         value_x_weight_var, value * weight, decay, zero_debias=False)
    147     denominator = assign_moving_average(
    148         weight_var, weight, decay, zero_debias=False)
    149 
    150     if truediv:
    151       return math_ops.truediv(numerator, denominator, name=scope.name)
    152     else:
    153       return math_ops.div(numerator, denominator, name=scope.name)
    154 
    155 
    156 def _zero_debias(unbiased_var, value, decay):
    157   """Compute the delta required for a debiased Variable.
    158 
    159   All exponential moving averages initialized with Tensors are initialized to 0,
    160   and therefore are biased to 0. Variables initialized to 0 and used as EMAs are
    161   similarly biased. This function creates the debias updated amount according to
    162   a scale factor, as in https://arxiv.org/abs/1412.6980.
    163 
    164   To demonstrate the bias the results from 0-initialization, take an EMA that
    165   was initialized to `0` with decay `b`. After `t` timesteps of seeing the
    166   constant `c`, the variable have the following value:
    167 
    168   ```
    169     EMA = 0*b^(t) + c*(1 - b)*b^(t-1) + c*(1 - b)*b^(t-2) + ...
    170         = c*(1 - b^t)
    171   ```
    172 
    173   To have the true value `c`, we would divide by the scale factor `1 - b^t`.
    174 
    175   In order to perform debiasing, we use two shadow variables. One keeps track of
    176   the biased estimate, and the other keeps track of the number of updates that
    177   have occurred.
    178 
    179   Args:
    180     unbiased_var: A Variable representing the current value of the unbiased EMA.
    181     value: A Tensor representing the most recent value.
    182     decay: A Tensor representing `1-decay` for the EMA.
    183 
    184   Returns:
    185     The amount that the unbiased variable should be updated. Computing this
    186     tensor will also update the shadow variables appropriately.
    187   """
    188   with variable_scope.variable_scope(
    189       unbiased_var.op.name, values=[unbiased_var, value, decay]) as scope:
    190     with ops.colocate_with(unbiased_var):
    191       with ops.init_scope():
    192         biased_initializer = init_ops.zeros_initializer(
    193             dtype=unbiased_var.dtype)(unbiased_var.get_shape())
    194         local_step_initializer = init_ops.zeros_initializer()
    195       def _maybe_get_unique(name):
    196         """Get name for a unique variable, if not `reuse=True`."""
    197         if variable_scope.get_variable_scope().reuse:
    198           return name
    199         vs_vars = [x.op.name for x in
    200                    variable_scope.get_variable_scope().global_variables()]
    201         full_name = variable_scope.get_variable_scope().name + "/" + name
    202         if full_name not in vs_vars: return name
    203         idx = 1
    204         while full_name + ("_%d" % idx) in vs_vars:
    205           idx += 1
    206         return name + ("_%d" % idx)
    207       biased_var = variable_scope.get_variable(
    208           _maybe_get_unique("biased"), initializer=biased_initializer,
    209           trainable=False)
    210       local_step = variable_scope.get_variable(
    211           _maybe_get_unique("local_step"),
    212           shape=[],
    213           dtype=unbiased_var.dtype,
    214           initializer=local_step_initializer,
    215           trainable=False)
    216 
    217       # Get an update ops for both shadow variables.
    218       update_biased = state_ops.assign_sub(biased_var,
    219                                            (biased_var - value) * decay,
    220                                            name=scope.name)
    221       update_local_step = local_step.assign_add(1)
    222 
    223       # Compute the value of the delta to update the unbiased EMA. Make sure to
    224       # use the new values of the biased variable and the local step.
    225       with ops.control_dependencies([update_biased, update_local_step]):
    226         # This function gets `1 - decay`, so use `1.0 - decay` in the exponent.
    227         unbiased_ema_delta = (unbiased_var - biased_var.read_value() /
    228                               (1 - math_ops.pow(
    229                                   1.0 - decay, local_step.read_value())))
    230 
    231       return unbiased_ema_delta
    232 
    233 
    234 @tf_export("train.ExponentialMovingAverage")
    235 class ExponentialMovingAverage(object):
    236   """Maintains moving averages of variables by employing an exponential decay.
    237 
    238   When training a model, it is often beneficial to maintain moving averages of
    239   the trained parameters.  Evaluations that use averaged parameters sometimes
    240   produce significantly better results than the final trained values.
    241 
    242   The `apply()` method adds shadow copies of trained variables and add ops that
    243   maintain a moving average of the trained variables in their shadow copies.
    244   It is used when building the training model.  The ops that maintain moving
    245   averages are typically run after each training step.
    246   The `average()` and `average_name()` methods give access to the shadow
    247   variables and their names.  They are useful when building an evaluation
    248   model, or when restoring a model from a checkpoint file.  They help use the
    249   moving averages in place of the last trained values for evaluations.
    250 
    251   The moving averages are computed using exponential decay.  You specify the
    252   decay value when creating the `ExponentialMovingAverage` object.  The shadow
    253   variables are initialized with the same initial values as the trained
    254   variables.  When you run the ops to maintain the moving averages, each
    255   shadow variable is updated with the formula:
    256 
    257     `shadow_variable -= (1 - decay) * (shadow_variable - variable)`
    258 
    259   This is mathematically equivalent to the classic formula below, but the use
    260   of an `assign_sub` op (the `"-="` in the formula) allows concurrent lockless
    261   updates to the variables:
    262 
    263     `shadow_variable = decay * shadow_variable + (1 - decay) * variable`
    264 
    265   Reasonable values for `decay` are close to 1.0, typically in the
    266   multiple-nines range: 0.999, 0.9999, etc.
    267 
    268   Example usage when creating a training model:
    269 
    270   ```python
    271   # Create variables.
    272   var0 = tf.Variable(...)
    273   var1 = tf.Variable(...)
    274   # ... use the variables to build a training model...
    275   ...
    276   # Create an op that applies the optimizer.  This is what we usually
    277   # would use as a training op.
    278   opt_op = opt.minimize(my_loss, [var0, var1])
    279 
    280   # Create an ExponentialMovingAverage object
    281   ema = tf.train.ExponentialMovingAverage(decay=0.9999)
    282 
    283   with tf.control_dependencies([opt_op]):
    284       # Create the shadow variables, and add ops to maintain moving averages
    285       # of var0 and var1. This also creates an op that will update the moving
    286       # averages after each training step.  This is what we will use in place
    287       # of the usual training op.
    288       training_op = ema.apply([var0, var1])
    289 
    290   ...train the model by running training_op...
    291   ```
    292 
    293   There are two ways to use the moving averages for evaluations:
    294 
    295   *  Build a model that uses the shadow variables instead of the variables.
    296      For this, use the `average()` method which returns the shadow variable
    297      for a given variable.
    298   *  Build a model normally but load the checkpoint files to evaluate by using
    299      the shadow variable names.  For this use the `average_name()` method.  See
    300      the @{tf.train.Saver} for more
    301      information on restoring saved variables.
    302 
    303   Example of restoring the shadow variable values:
    304 
    305   ```python
    306   # Create a Saver that loads variables from their saved shadow values.
    307   shadow_var0_name = ema.average_name(var0)
    308   shadow_var1_name = ema.average_name(var1)
    309   saver = tf.train.Saver({shadow_var0_name: var0, shadow_var1_name: var1})
    310   saver.restore(...checkpoint filename...)
    311   # var0 and var1 now hold the moving average values
    312   ```
    313   """
    314 
    315   def __init__(self, decay, num_updates=None, zero_debias=False,
    316                name="ExponentialMovingAverage"):
    317     """Creates a new ExponentialMovingAverage object.
    318 
    319     The `apply()` method has to be called to create shadow variables and add
    320     ops to maintain moving averages.
    321 
    322     The optional `num_updates` parameter allows one to tweak the decay rate
    323     dynamically. It is typical to pass the count of training steps, usually
    324     kept in a variable that is incremented at each step, in which case the
    325     decay rate is lower at the start of training.  This makes moving averages
    326     move faster.  If passed, the actual decay rate used is:
    327 
    328       `min(decay, (1 + num_updates) / (10 + num_updates))`
    329 
    330     Args:
    331       decay: Float.  The decay to use.
    332       num_updates: Optional count of number of updates applied to variables.
    333       zero_debias: If `True`, zero debias moving-averages that are initialized
    334         with tensors.
    335       name: String. Optional prefix name to use for the name of ops added in
    336         `apply()`.
    337     """
    338     self._decay = decay
    339     self._num_updates = num_updates
    340     self._zero_debias = zero_debias
    341     self._name = name
    342     self._averages = {}
    343 
    344   def apply(self, var_list=None):
    345     """Maintains moving averages of variables.
    346 
    347     `var_list` must be a list of `Variable` or `Tensor` objects.  This method
    348     creates shadow variables for all elements of `var_list`.  Shadow variables
    349     for `Variable` objects are initialized to the variable's initial value.
    350     They will be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection.
    351     For `Tensor` objects, the shadow variables are initialized to 0 and zero
    352     debiased (see docstring in `assign_moving_average` for more details).
    353 
    354     shadow variables are created with `trainable=False` and added to the
    355     `GraphKeys.ALL_VARIABLES` collection.  They will be returned by calls to
    356     `tf.global_variables()`.
    357 
    358     Returns an op that updates all shadow variables as described above.
    359 
    360     Note that `apply()` can be called multiple times with different lists of
    361     variables.
    362 
    363     Args:
    364       var_list: A list of Variable or Tensor objects. The variables
    365         and Tensors must be of types float16, float32, or float64.
    366 
    367     Returns:
    368       An Operation that updates the moving averages.
    369 
    370     Raises:
    371       TypeError: If the arguments are not all float16, float32, or float64.
    372       ValueError: If the moving average of one of the variables is already
    373         being computed.
    374     """
    375     # TODO(touts): op_scope
    376     if var_list is None:
    377       var_list = variables.trainable_variables()
    378     zero_debias_true = set()  # set of vars to set `zero_debias=True`
    379     for var in var_list:
    380       if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32,
    381                                       dtypes.float64]:
    382         raise TypeError("The variables must be half, float, or double: %s" %
    383                         var.name)
    384       if var in self._averages:
    385         raise ValueError("Moving average already computed for: %s" % var.name)
    386 
    387       # For variables: to lower communication bandwidth across devices we keep
    388       # the moving averages on the same device as the variables. For other
    389       # tensors, we rely on the existing device allocation mechanism.
    390       with ops.init_scope():
    391         if isinstance(var, variables.Variable):
    392           avg = slot_creator.create_slot(var,
    393                                          var.initialized_value(),
    394                                          self._name,
    395                                          colocate_with_primary=True)
    396           # NOTE(mrry): We only add `tf.Variable` objects to the
    397           # `MOVING_AVERAGE_VARIABLES` collection.
    398           ops.add_to_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, var)
    399         else:
    400           avg = slot_creator.create_zeros_slot(
    401               var,
    402               self._name,
    403               colocate_with_primary=(var.op.type in ["Variable",
    404                                                      "VariableV2",
    405                                                      "VarHandleOp"]))
    406           if self._zero_debias:
    407             zero_debias_true.add(avg)
    408       self._averages[var] = avg
    409 
    410     with ops.name_scope(self._name) as scope:
    411       decay = ops.convert_to_tensor(self._decay, name="decay")
    412       if self._num_updates is not None:
    413         num_updates = math_ops.cast(self._num_updates,
    414                                     dtypes.float32,
    415                                     name="num_updates")
    416         decay = math_ops.minimum(decay,
    417                                  (1.0 + num_updates) / (10.0 + num_updates))
    418       updates = []
    419       for var in var_list:
    420         zero_debias = self._averages[var] in zero_debias_true
    421         updates.append(assign_moving_average(
    422             self._averages[var], var, decay, zero_debias=zero_debias))
    423       return control_flow_ops.group(*updates, name=scope)
    424 
    425   def average(self, var):
    426     """Returns the `Variable` holding the average of `var`.
    427 
    428     Args:
    429       var: A `Variable` object.
    430 
    431     Returns:
    432       A `Variable` object or `None` if the moving average of `var`
    433       is not maintained.
    434     """
    435     return self._averages.get(var, None)
    436 
    437   def average_name(self, var):
    438     """Returns the name of the `Variable` holding the average for `var`.
    439 
    440     The typical scenario for `ExponentialMovingAverage` is to compute moving
    441     averages of variables during training, and restore the variables from the
    442     computed moving averages during evaluations.
    443 
    444     To restore variables, you have to know the name of the shadow variables.
    445     That name and the original variable can then be passed to a `Saver()` object
    446     to restore the variable from the moving average value with:
    447       `saver = tf.train.Saver({ema.average_name(var): var})`
    448 
    449     `average_name()` can be called whether or not `apply()` has been called.
    450 
    451     Args:
    452       var: A `Variable` object.
    453 
    454     Returns:
    455       A string: The name of the variable that will be used or was used
    456       by the `ExponentialMovingAverage class` to hold the moving average of
    457       `var`.
    458     """
    459     if var in self._averages:
    460       return self._averages[var].op.name
    461     return ops.get_default_graph().unique_name(
    462         var.op.name + "/" + self._name, mark_as_used=False)
    463 
    464   def variables_to_restore(self, moving_avg_variables=None):
    465     """Returns a map of names to `Variables` to restore.
    466 
    467     If a variable has a moving average, use the moving average variable name as
    468     the restore name; otherwise, use the variable name.
    469 
    470     For example,
    471 
    472     ```python
    473       variables_to_restore = ema.variables_to_restore()
    474       saver = tf.train.Saver(variables_to_restore)
    475     ```
    476 
    477     Below is an example of such mapping:
    478 
    479     ```
    480       conv/batchnorm/gamma/ExponentialMovingAverage: conv/batchnorm/gamma,
    481       conv_4/conv2d_params/ExponentialMovingAverage: conv_4/conv2d_params,
    482       global_step: global_step
    483     ```
    484     Args:
    485       moving_avg_variables: a list of variables that require to use of the
    486         moving variable name to be restored. If None, it will default to
    487         variables.moving_average_variables() + variables.trainable_variables()
    488 
    489     Returns:
    490       A map from restore_names to variables. The restore_name can be the
    491       moving_average version of the variable name if it exist, or the original
    492       variable name.
    493     """
    494     name_map = {}
    495     if moving_avg_variables is None:
    496       # Include trainable variables and variables which have been explicitly
    497       # added to the moving_average_variables collection.
    498       moving_avg_variables = variables.trainable_variables()
    499       moving_avg_variables += variables.moving_average_variables()
    500     # Remove duplicates
    501     moving_avg_variables = set(moving_avg_variables)
    502     # Collect all the variables with moving average,
    503     for v in moving_avg_variables:
    504       name_map[self.average_name(v)] = v
    505     # Make sure we restore variables without moving averages as well.
    506     moving_avg_variable_names = set([v.name for v in moving_avg_variables])
    507     for v in list(set(variables.global_variables())):
    508       if v.name not in moving_avg_variable_names and v.op.name not in name_map:
    509         name_map[v.op.name] = v
    510     return name_map
    511