1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Wrapper optimizer for Model Average.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.framework import constant_op 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import control_flow_ops 25 from tensorflow.python.ops import data_flow_ops 26 from tensorflow.python.ops import math_ops 27 from tensorflow.python.ops import state_ops 28 from tensorflow.python.ops import variable_scope 29 from tensorflow.python.ops import variables 30 from tensorflow.python.training import optimizer 31 from tensorflow.python.training import session_run_hook 32 33 GLOBAL_VARIABLE_NAME = "global_center_variable" 34 35 36 class ModelAverageCustomGetter(object): 37 """Custom_getter class is used to do. 38 39 1. Change trainable variables to local collection and place them at worker 40 device 41 2. Generate global variables 42 Notice that the class should be used with tf.replica_device_setter, 43 so that the global center variables and global step variable can be placed 44 at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to 45 use this custom getter. 46 47 For example, 48 ma_custom_getter = ModelAverageCustomGetter(worker_device) 49 with tf.device( 50 tf.train.replica_device_setter( 51 worker_device=worker_device, 52 ps_device="/job:ps/cpu:0", 53 cluster=cluster)), 54 tf.variable_scope('',custom_getter=ma_custom_getter): 55 hid_w = tf.get_variable( 56 initializer=tf.truncated_normal( 57 [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], 58 stddev=1.0 / IMAGE_PIXELS), 59 name="hid_w") 60 hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), 61 name="hid_b") 62 """ 63 64 def __init__(self, worker_device): 65 """Create a new `ElasticAverageCustomGetter`. 66 67 Args: 68 worker_device: String. Name of the `worker` job. 69 """ 70 self._worker_device = worker_device 71 self._local_2_global = {} 72 73 def __call__(self, getter, name, trainable, collections, *args, **kwargs): 74 if trainable: 75 with ops.device(self._worker_device): 76 local_var = getter( 77 name, 78 trainable=True, 79 collections=[ops.GraphKeys.LOCAL_VARIABLES], 80 *args, 81 **kwargs) 82 83 global_variable = variable_scope.variable( 84 name="%s/%s" % (GLOBAL_VARIABLE_NAME, name), 85 initial_value=local_var.initialized_value(), 86 trainable=False, 87 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 88 89 self._local_2_global[local_var] = global_variable 90 return local_var 91 else: 92 return getter(name, trainable, collections, *args, **kwargs) 93 94 95 class ModelAverageOptimizer(optimizer.Optimizer): 96 """Wrapper optimizer that implements the Model Average algorithm. 97 98 This is a sync optimizer. During the training, each worker will update 99 the local variables and maintains its own local_step, which starts from 0 100 and is incremented by 1 after each update of local variables. Whenever the 101 interval_steps divides the local step, the local variables from all the 102 workers will be averaged and assigned to global center variables. Then the 103 local variables will be assigned by global center variables. 104 """ 105 106 def __init__(self, 107 opt, 108 num_worker, 109 is_chief, 110 ma_custom_getter, 111 interval_steps=100, 112 use_locking=True, 113 name="ModelAverageOptimizer"): 114 """Construct a new model average optimizer. 115 116 Args: 117 opt: The actual optimizer that will be used to update local variables 118 num_worker: The number of workers 119 is_chief: whether chief worker 120 ma_custom_getter: ModelAverageCustomGetter 121 interval_steps: An int point value to controls the frequency of the 122 average of local variables 123 use_locking: If True use locks for update operations 124 name: string. Optional name of the returned operation 125 """ 126 super(ModelAverageOptimizer, self).__init__(use_locking, name) 127 self._opt = opt 128 self._num_worker = num_worker 129 self._is_chief = is_chief 130 self._local_2_global = ma_custom_getter._local_2_global # pylint:disable=protected-access 131 self._interval_steps = interval_steps 132 self._accumulator_list = [] 133 self._chief_init_op = None 134 135 self._local_step = variable_scope.get_variable( 136 initializer=0, 137 trainable=False, 138 collections=[ops.GraphKeys.LOCAL_VARIABLES], 139 name="local_step") 140 141 self._opt._prepare() # pylint:disable=protected-access 142 143 def compute_gradients(self, *args, **kwargs): 144 """Compute gradients of "loss" for the variables in "var_list". 145 146 This simply wraps the compute_gradients() from the real optimizer. 147 148 Args: 149 *args: Arguments for compute_gradients(). 150 **kwargs: Keyword arguments for compute_gradients(). 151 152 Returns: 153 A list of (gradient, variable) pairs. 154 """ 155 return self._opt.compute_gradients(*args, **kwargs) 156 157 def _local_vars_update(self, var_list): 158 """Get the update ops for the local variables in "var_list". 159 160 Args: 161 var_list: Optional list or tuple of 'tf.Variable' to update 162 163 Returns: 164 An update op 165 166 Raises: 167 ValueError: if var_list is empty. 168 """ 169 if not var_list: 170 raise ValueError("The list of local_variables should not be empty") 171 update_ops = [] 172 global_center_vars = [self._local_2_global[var] for var in var_list] 173 for lvar, gvar in zip(var_list, global_center_vars): 174 with ops.device(lvar.device): 175 update_ops.append(state_ops.assign(lvar, gvar.read_value())) 176 return control_flow_ops.group(*(update_ops)) 177 178 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 179 """Apply gradients to variables. 180 181 This contains most of the synchronization implementation and also wraps the 182 apply_gradients() from the real optimizer. The chief work updates global 183 variables. 184 185 Args: 186 grads_and_vars: List of (gradient, variable) pairs as returned by 187 compute_gradients(). 188 global_step: Optional Variable to increment by one after the 189 variables have been updated. 190 name: Optional name for the returned operation. Default to the 191 name passed to the Optimizer constructor. 192 193 Returns: 194 A conditional 'Operation' that update both local and global variables or 195 just local variables 196 197 Raises: 198 ValueError: If the grads_and_vars is empty. 199 ValueError: If global step is not provided, the staleness cannot be 200 checked. 201 """ 202 203 # update local variables 204 if not grads_and_vars: 205 raise ValueError("Must supply at least one variable") 206 if global_step is None: 207 raise ValueError("Global step is required") 208 209 apply_updates = self._opt.apply_gradients(grads_and_vars) 210 with ops.control_dependencies([apply_updates]): 211 local_update = state_ops.assign_add( 212 self._local_step, 1, name="local_step_update").op 213 214 # update global variables. 215 def _update_global_variables(): # pylint: disable=missing-docstring 216 local_vars = [v for g, v in grads_and_vars if g is not None] 217 global_vars = [self._local_2_global[v] for v in local_vars] 218 # sync queue 219 with ops.colocate_with(global_step): 220 sync_queue = data_flow_ops.FIFOQueue( 221 -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue") 222 train_ops = [] 223 aggregated_vars = [] 224 with ops.name_scope(None, self._name + "/global"): 225 for var, gvar in zip(local_vars, global_vars): 226 # pylint: disable=protected-access 227 with ops.device(gvar.device): 228 if isinstance(var._ref(), ops.Tensor): 229 var_accum = data_flow_ops.ConditionalAccumulator( 230 var.dtype, 231 shape=var.get_shape(), 232 shared_name=gvar.name + "/var_accum") 233 train_ops.append( 234 var_accum.apply_grad(var._ref(), local_step=global_step)) 235 aggregated_vars.append(var_accum.take_grad(self._num_worker)) 236 else: 237 raise ValueError("Unknown local variable type!") 238 self._accumulator_list.append((var_accum, gvar.device)) 239 # chief worker updates global vars and enqueues tokens to the sync queue 240 if self._is_chief: 241 update_ops = [] 242 with ops.control_dependencies(train_ops): 243 for avg_var, gvar in zip(aggregated_vars, global_vars): 244 with ops.device(gvar.device): 245 update_ops.append(state_ops.assign(gvar, avg_var)) 246 with ops.device(global_step.device): 247 update_ops.append(state_ops.assign_add(global_step, 1)) 248 with ops.control_dependencies(update_ops), ops.device( 249 global_step.device): 250 tokens = array_ops.fill([self._num_worker - 1], 251 constant_op.constant(False)) 252 sync_op = sync_queue.enqueue_many(tokens) 253 else: 254 with ops.control_dependencies(train_ops), ops.device( 255 global_step.device): 256 sync_op = sync_queue.dequeue() 257 258 with ops.control_dependencies([sync_op]): 259 local_update_op = self._local_vars_update(local_vars) 260 return local_update_op 261 262 with ops.control_dependencies([local_update]): 263 condition = math_ops.equal( 264 math_ops.mod(self._local_step, self._interval_steps), 0) 265 conditional_update = control_flow_ops.cond( 266 condition, _update_global_variables, control_flow_ops.no_op) 267 268 chief_init_ops = [] 269 for accum, dev in self._accumulator_list: 270 with ops.device(dev): 271 chief_init_ops.append( 272 accum.set_global_step(global_step, name="SetGlobalStep")) 273 self._chief_init_op = control_flow_ops.group(*(chief_init_ops)) 274 275 return conditional_update 276 277 def get_init_op(self): 278 """Returns the op. 279 280 This method lets all the local variables equal to the global 281 variables before the training begins. 282 """ 283 return self._local_vars_update(variables.trainable_variables()) 284 285 def make_session_run_hook(self): 286 """Creates a hook to handle ModelAverage ops such as initialization.""" 287 return _ModelAverageOptimizerHook(self, self._is_chief) 288 289 290 class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook): # pylint: disable=missing-docstring 291 292 def __init__(self, ma_optimizer, is_chief): 293 """Creates hook to handle ModelAverageOptimizer initialization ops. 294 295 Args: 296 ma_optimizer: `ModelAverageOptimizer` which this hook will initialize. 297 is_chief: `Bool`, whether is this a chief replica or not. 298 """ 299 self._ma_optimizer = ma_optimizer 300 self._is_chief = is_chief 301 302 def begin(self): 303 self._local_init_op = variables.local_variables_initializer() 304 self._global_init_op = None 305 if self._is_chief: 306 self._global_init_op = variables.global_variables_initializer() 307 self._chief_init_op = self._ma_optimizer._chief_init_op # pylint: disable=protected-access 308 self._variable_init_op = self._ma_optimizer.get_init_op() 309