1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Moving average optimizer.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import six 22 23 from tensorflow.python.framework import ops 24 from tensorflow.python.ops import control_flow_ops 25 from tensorflow.python.ops import variables 26 from tensorflow.python.training import moving_averages 27 from tensorflow.python.training import optimizer 28 from tensorflow.python.training import saver 29 30 31 class MovingAverageOptimizer(optimizer.Optimizer): 32 """Optimizer that computes a moving average of the variables. 33 34 Empirically it has been found that using the moving average of the trained 35 parameters of a deep network is better than using its trained parameters 36 directly. This optimizer allows you to compute this moving average and swap 37 the variables at save time so that any code outside of the training loop will 38 use by default the averaged values instead of the original ones. 39 40 Example of usage: 41 42 ```python 43 44 // Encapsulate your favorite optimizer (here the momentum one) 45 // inside the MovingAverageOptimizer. 46 opt = tf.train.MomentumOptimizer(learning_rate, FLAGS.momentum) 47 opt = tf.contrib.opt.MovingAverageOptimizer(opt) 48 // Then create your model and all its variables. 49 model = build_model() 50 // Add the training op that optimizes using opt. 51 // This needs to be called before swapping_saver(). 52 opt.minimize(cost, var_list) 53 // Then create your saver like this: 54 saver = opt.swapping_saver() 55 // Pass it to your training loop. 56 slim.learning.train( 57 model, 58 ... 59 saver=saver) 60 ``` 61 62 Note that for evaluation, the normal saver should be used instead of 63 swapping_saver(). 64 """ 65 66 def __init__(self, opt, average_decay=0.9999, num_updates=None, 67 sequential_update=True): 68 """Construct a new MovingAverageOptimizer. 69 70 Args: 71 opt: A tf.Optimizer that will be used to compute and apply gradients. 72 average_decay: Float. Decay to use to maintain the moving averages 73 of trained variables. 74 See tf.train.ExponentialMovingAverage for details. 75 num_updates: Optional count of number of updates applied to variables. 76 See tf.train.ExponentialMovingAverage for details. 77 sequential_update: Bool. If False, will compute the moving average at the 78 same time as the model is updated, potentially doing 79 benign data races. 80 If True, will update the moving average after gradient 81 updates. 82 """ 83 self._optimizer = opt 84 self._ema = moving_averages.ExponentialMovingAverage( 85 average_decay, num_updates=num_updates) 86 self._swapped_variable_name_map = None 87 self._sequential_update = sequential_update 88 89 def compute_gradients(self, *args, **kwargs): 90 return self._optimizer.compute_gradients(*args, **kwargs) 91 92 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 93 train_op = self._optimizer.apply_gradients( 94 grads_and_vars, global_step=global_step, name=name) 95 var_list = [x[1] for x in grads_and_vars if x[0] is not None] 96 self._swapped_variable_name_map = {} 97 if self._sequential_update: 98 with ops.control_dependencies([train_op]): 99 ma_op = self._ema.apply(var_list) 100 else: 101 ma_op = self._ema.apply(var_list) 102 103 for v in var_list: 104 v_avg = self._ema.average(v) 105 self._swapped_variable_name_map[v.op.name] = v_avg.op.name 106 self._swapped_variable_name_map[v_avg.op.name] = v.op.name 107 return control_flow_ops.group(train_op, ma_op, name='train_with_avg') 108 109 def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): 110 """Create a saver swapping moving averages and variables. 111 112 You should use this saver during training. It will save the moving averages 113 of the trained parameters under the original parameter names. For 114 evaluations or inference you should use a regular saver and it will 115 automatically use the moving averages for the trained variable. 116 117 You must call this function after all variables have been created and after 118 you have called Optimizer.minimize(). 119 120 Args: 121 var_list: List of variables to save, as per `Saver()`. 122 If set to None, will save all the variables that have been 123 created before this call. 124 name: The name of the saver. 125 **kwargs: Keyword arguments of `Saver()`. 126 127 Returns: 128 A `tf.train.Saver` object. 129 130 Raises: 131 RuntimeError: If apply_gradients or minimize has not been called before. 132 ValueError: If var_list is provided and contains some variables but not 133 their moving average counterpart. 134 """ 135 136 if self._swapped_variable_name_map is None: 137 raise RuntimeError('Must call apply_gradients or minimize before ' 138 'creating the swapping_saver') 139 if var_list is None: 140 var_list = variables.global_variables() 141 if not isinstance(var_list, dict): 142 var_list = saver.BaseSaverBuilder.OpListToDict(var_list) 143 144 # OpListToDict converts variables to tensors. We make sure we can get 145 # the unique variable name for normal and resource vaiables. 146 def get_v_name(tensor): 147 if tensor.op.type == 'ReadVariableOp': 148 return tensor.op.inputs[0].op.name 149 else: 150 return tensor.op.name 151 152 v_name_to_tensor = {} 153 for tensor in six.itervalues(var_list): 154 v_name = get_v_name(tensor) 155 v_name_to_tensor[v_name] = tensor 156 157 # Now swap variables and moving averages 158 swapped_var_list = {} 159 for k, tensor in six.iteritems(var_list): 160 v_name = get_v_name(tensor) 161 swapped_v_name = self._swapped_variable_name_map.get(v_name, None) 162 tensor_to_save = tensor 163 if swapped_v_name is not None: 164 if swapped_v_name in v_name_to_tensor: 165 tensor_to_save = v_name_to_tensor[swapped_v_name] 166 else: 167 raise ValueError( 168 ('Variable to swap %s is not part of variables to save. ' 169 'This breaks MovingAverageOptimizer.') % swapped_v_name) 170 swapped_var_list[k] = tensor_to_save 171 172 # Build the swapping saver. 173 return saver.Saver(swapped_var_list, name=name, **kwargs) 174