1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Wrapper optimizer for Elastic Average SGD """ 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.framework import ops 21 from tensorflow.python.ops import math_ops 22 23 from tensorflow.python.ops import gen_nn_ops 24 from tensorflow.python.ops import control_flow_ops 25 from tensorflow.python.ops import variable_scope 26 from tensorflow.python.ops import variables 27 from tensorflow.python.training import optimizer 28 from tensorflow.python.training import session_run_hook 29 from tensorflow.python.ops import state_ops 30 from tensorflow.python.ops import data_flow_ops 31 from tensorflow.python.framework import dtypes 32 from tensorflow.python.framework import constant_op 33 34 LOCAL_VARIABLE_NAME = 'local_center_variable' 35 GLOBAL_VARIABLE_NAME = 'global_center_variable' 36 37 38 class ElasticAverageCustomGetter(object): 39 """Custom_getter class is used to do: 40 1. Change trainable variables to local collection and place them at worker 41 device 42 2. Generate global variables(global center variables) 43 3. Generate local variables(local center variables) which record the global 44 variables and place them at worker device 45 Notice that the class should be used with tf.replica_device_setter, 46 so that the global center variables and global step variable can be placed 47 at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to 48 use this custom getter. 49 50 For example, 51 ea_custom_getter = ElasticAverageCustomGetter(worker_device) 52 with tf.device( 53 tf.train.replica_device_setter( 54 worker_device=worker_device, 55 ps_device="/job:ps/cpu:0", 56 cluster=cluster)), 57 tf.variable_scope('',custom_getter=ea_custom_getter): 58 hid_w = tf.get_variable( 59 initializer=tf.truncated_normal( 60 [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], 61 stddev=1.0 / IMAGE_PIXELS), 62 name="hid_w") 63 hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]), 64 name="hid_b") 65 """ 66 67 def __init__(self, worker_device): 68 """Create a new `ElasticAverageCustomGetter`. 69 70 Args: 71 worker_device: String. Name of the `worker` job. 72 """ 73 self._worker_device = worker_device 74 self._local_map = {} 75 self._global_map = {} 76 77 def __call__(self, getter, name, trainable, collections, *args, **kwargs): 78 if trainable: 79 with ops.device(self._worker_device): 80 local_var = getter( 81 name, 82 trainable=True, 83 collections=[ops.GraphKeys.LOCAL_VARIABLES], 84 *args, 85 **kwargs) 86 global_center_variable = variable_scope.variable( 87 name='%s/%s' % (GLOBAL_VARIABLE_NAME, name), 88 initial_value=local_var.initialized_value(), 89 trainable=False, 90 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 91 92 with ops.device(self._worker_device): 93 local_center_variable = variable_scope.variable( 94 name='%s/%s' % (LOCAL_VARIABLE_NAME, name), 95 initial_value=local_var.initialized_value(), 96 trainable=False, 97 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 98 99 self._local_map[local_var] = local_center_variable 100 self._global_map[local_var] = global_center_variable 101 return local_var 102 else: 103 return getter(name, trainable, collections, *args, **kwargs) 104 105 106 class ElasticAverageOptimizer(optimizer.Optimizer): 107 """Wrapper optimizer that implements the Elastic Average SGD algorithm. 108 This is an async optimizer. During the training, Each worker will update 109 the local variables and maintains its own local_step, which starts from 0 110 and is incremented by 1 after each update of local variables. Whenever 111 the communication period divides the local step, the worker requests 112 the current global center variables and then computed the elastic difference 113 between global center variables and local variables. The elastic difference 114 then be used to update both local variables and global variables. 115 """ 116 117 # Default value as paper described 118 BETA = 0.9 119 120 def __init__(self, 121 opt, 122 num_worker, 123 ea_custom_getter, 124 communication_period=10, 125 moving_rate=None, 126 rho=None, 127 use_locking=True, 128 name='ElasticAverageOptimizer'): 129 """Construct a new gradient descent optimizer. 130 131 Args: 132 opt: The actual optimizer that will be used to update local variables. 133 Must be one of the Optimizer classes. 134 num_worker: The number of workers 135 ea_custom_getter: The ElasticAverageCustomGetter 136 communication_period: An int point value to controls the frequency 137 of the communication between every worker and the ps. 138 moving_rate: A floating point value to control the elastic difference. 139 rho: the amount of exploration we allow ine the model. The default 140 value is moving_rate/learning_rate 141 use_locking: If True use locks for update operations. 142 name: Optional name prefix for the operations created when applying 143 gradients. Defaults to "ElasticAverageOptimizer". 144 """ 145 super(ElasticAverageOptimizer, self).__init__(use_locking, name) 146 self._opt = opt 147 self._num_worker = num_worker 148 self._period = communication_period 149 self._local_map = ea_custom_getter._local_map 150 self._global_map = ea_custom_getter._global_map 151 152 if moving_rate is None: 153 self._moving_rate = self.BETA / communication_period / num_worker 154 else: 155 self._moving_rate = moving_rate 156 if rho is None: 157 self._rho = self._moving_rate / self._opt._learning_rate 158 else: 159 self._rho = rho 160 161 self._local_step = variable_scope.get_variable( 162 initializer=0, 163 trainable=False, 164 collections=[ops.GraphKeys.LOCAL_VARIABLES], 165 name='local_step') 166 self._opt._prepare() 167 168 def compute_gradients(self, 169 loss, 170 var_list=None, 171 gate_gradients=optimizer.Optimizer.GATE_OP, 172 aggregation_method=None, 173 colocate_gradients_with_ops=False, 174 grad_loss=None): 175 """Compute gradients of `loss` for the variables in `var_list`. 176 177 Add rho*elastic_difference to loss to control the exploration 178 This is the first part of `minimize()`. It returns a list 179 of (gradient, variable) pairs where "gradient" is the gradient 180 for "variable". Note that "gradient" can be a `Tensor`, an 181 `IndexedSlices`, or `None` if there is no gradient for the 182 given variable. 183 184 Args: 185 loss: A Tensor containing the value to minimize. 186 var_list: Optional list or tuple of `tf.Variable` to update to minimize 187 `loss`. Defaults to the list of variables collected in the graph 188 under the key `GraphKey.TRAINABLE_VARIABLES`. 189 gate_gradients: How to gate the computation of gradients. Can be 190 `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. 191 aggregation_method: Specifies the method used to combine gradient terms. 192 Valid values are defined in the class `AggregationMethod`. 193 colocate_gradients_with_ops: If True, try colocating gradients with 194 the corresponding op. 195 grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. 196 197 Returns: 198 A list of (gradient, variable) pairs. Variable is always present, but 199 gradient can be `None`. 200 201 Raises: 202 TypeError: If `var_list` contains anything else than `Variable` objects. 203 ValueError: If some arguments are invalid. 204 """ 205 if not var_list: 206 var_list = variables.trainable_variables() 207 208 elastic_difference = [ 209 math_ops.subtract(v, lv) 210 for v, lv in zip(variables.trainable_variables(), 211 [self._local_map[var] for var in var_list]) 212 ] 213 214 distance_loss = self._rho * math_ops.add_n( 215 [gen_nn_ops.l2_loss(ed) for ed in elastic_difference]) 216 217 total_loss = loss + distance_loss 218 return self._opt.compute_gradients(total_loss, var_list, gate_gradients, 219 aggregation_method, 220 colocate_gradients_with_ops, grad_loss) 221 222 def apply_gradients(self, grads_and_vars, global_step=None, name=None): 223 """Apply gradients to global variables. 224 225 This is the second part of `minimize()`. It returns an `Operation` that 226 applies gradients. 227 228 Args: 229 grads_and_vars: List of (gradient, variable) pairs as returned by 230 `compute_gradients()`. 231 global_step: Optional `Variable` to increment by one after the 232 variables have been updated. 233 name: Optional name for the returned operation. Default to the 234 name passed to the `Optimizer` constructor. 235 236 Returns: 237 An `Operation` that applies the specified gradients. If `global_step` 238 was not None, that operation also increments `global_step`. 239 240 Raises: 241 TypeError: If `grads_and_vars` is malformed. 242 ValueError: If none of the variables have gradients. 243 """ 244 apply_updates = self._opt.apply_gradients(grads_and_vars) 245 with ops.control_dependencies([apply_updates]): 246 local_update = state_ops.assign_add( 247 self._local_step, 1, name='local_step_update').op 248 249 # update global variables. 250 def _Update_global_variables(): 251 local_vars = [v for g, v in grads_and_vars if g is not None] 252 global_center_vars = [self._global_map[var] for var in local_vars] 253 local_center_vars = [self._local_map[var] for var in local_vars] 254 local_center_vars_update = [] 255 for lvar, var in zip(local_center_vars, global_center_vars): 256 local_center_vars_update.append(lvar.assign(var)) 257 update_ops = [] 258 differences = [] 259 with ops.control_dependencies(local_center_vars_update): 260 for v, lv in zip(local_vars, local_center_vars): 261 with ops.device(v.device): 262 differences.append(math_ops.subtract(v, lv)) 263 for lvar, diff in zip(local_vars, differences): 264 with ops.device(lvar.device): 265 update_ops.append( 266 state_ops.assign_sub(lvar, 267 math_ops.multiply(self._moving_rate, 268 diff))) 269 for var, diff in zip(global_center_vars, differences): 270 with ops.device(var.device): 271 update_ops.append( 272 state_ops.assign_add(var, 273 math_ops.multiply(self._moving_rate, 274 diff))) 275 if global_step: 276 with ops.colocate_with(global_step): 277 update_ops.append(state_ops.assign_add(global_step, 1)) 278 variable_update = control_flow_ops.group(*(update_ops)) 279 return variable_update 280 281 with ops.control_dependencies([local_update]): 282 condition = math_ops.equal( 283 math_ops.mod(self._local_step, self._period), 0) 284 conditional_update = control_flow_ops.cond( 285 condition, _Update_global_variables, control_flow_ops.no_op) 286 return conditional_update 287 288 def get_init_op(self, task_index): 289 """Returns the op to let all the local variables and local center 290 variables equal to the global center variables before the training begins""" 291 292 def _Add_sync_queues_and_barrier(enqueue_after_list): 293 """Adds ops to enqueu on all worker queues""" 294 sync_queues = [ 295 data_flow_ops.FIFOQueue( 296 self._num_worker, [dtypes.bool], 297 shapes=[[]], 298 shared_name='%s%s' % ('variable_init_sync_queue', i)) 299 for i in range(self._num_worker) 300 ] 301 queue_ops = [] 302 # For each other worker, add an entry in a queue 303 token = constant_op.constant(False) 304 with ops.control_dependencies(enqueue_after_list): 305 for i, q in enumerate(sync_queues): 306 if i == task_index: 307 queue_ops.append(control_flow_ops.no_op()) 308 else: 309 queue_ops.append(q.enqueue(token)) 310 queue_ops.append( 311 sync_queues[task_index].dequeue_many(len(sync_queues) - 1)) 312 return control_flow_ops.group(*queue_ops) 313 314 init_ops = [] 315 local_vars = variables.trainable_variables() 316 global_center_vars = [self._global_map[var] for var in local_vars] 317 local_center_vars = [self._local_map[var] for var in local_vars] 318 if not (local_vars and global_center_vars and local_center_vars): 319 raise ValueError('The lists of local_variables, global_center_variables, ' 320 'local_center_variables should not be empty ') 321 for lvar, gc_var, lc_var in zip(local_vars, global_center_vars, 322 local_center_vars): 323 init_ops.append(state_ops.assign(lvar, gc_var)) 324 init_ops.append(state_ops.assign(lc_var, gc_var)) 325 326 init_op = control_flow_ops.group(*(init_ops)) 327 sync_queue_op = _Add_sync_queues_and_barrier([init_op]) 328 return sync_queue_op 329 330 def make_session_run_hook(self, is_chief, task_index): 331 """Creates a hook to handle ElasticAverageOptimizerHook ops such as initialization.""" 332 return _ElasticAverageOptimizerHook(self, is_chief, task_index) 333 334 335 class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook): 336 337 def __init__(self, ea_optimizer, is_chief, task_index): 338 """Creates hook to handle ElasticAverageOptimizer initialization ops. 339 340 Args: 341 ea_optimizer: `ElasticAverageOptimizer` which this hook will initialize. 342 is_chief: `Bool`, whether is this a chief replica or not. 343 """ 344 self._ea_optimizer = ea_optimizer 345 self._is_chief = is_chief 346 self._task_index = task_index 347 348 def begin(self): 349 self._local_init_op = variables.local_variables_initializer() 350 self._global_init_op = None 351 if self._is_chief: 352 self._global_init_op = variables.global_variables_initializer() 353 self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index) 354