Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Wrapper optimizer for Elastic Average SGD """
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.framework import ops
     21 from tensorflow.python.ops import math_ops
     22 
     23 from tensorflow.python.ops import gen_nn_ops
     24 from tensorflow.python.ops import control_flow_ops
     25 from tensorflow.python.ops import variable_scope
     26 from tensorflow.python.ops import variables
     27 from tensorflow.python.training import optimizer
     28 from tensorflow.python.training import session_run_hook
     29 from tensorflow.python.ops import state_ops
     30 from tensorflow.python.ops import data_flow_ops
     31 from tensorflow.python.framework import dtypes
     32 from tensorflow.python.framework import constant_op
     33 
     34 LOCAL_VARIABLE_NAME = 'local_center_variable'
     35 GLOBAL_VARIABLE_NAME = 'global_center_variable'
     36 
     37 
     38 class ElasticAverageCustomGetter(object):
     39   """Custom_getter class is used to do:
     40   1. Change trainable variables to local collection and place them at worker
     41     device
     42   2. Generate global variables(global center variables)
     43   3. Generate local variables(local center variables) which record the global
     44     variables and place them at worker device
     45     Notice that the class should be used with tf.replica_device_setter,
     46     so that the global center variables and global step variable can be placed
     47     at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to
     48     use this custom getter.
     49 
     50   For example,
     51   ea_custom_getter = ElasticAverageCustomGetter(worker_device)
     52   with tf.device(
     53     tf.train.replica_device_setter(
     54       worker_device=worker_device,
     55       ps_device="/job:ps/cpu:0",
     56       cluster=cluster)),
     57     tf.variable_scope('',custom_getter=ea_custom_getter):
     58     hid_w = tf.get_variable(
     59       initializer=tf.truncated_normal(
     60           [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
     61           stddev=1.0 / IMAGE_PIXELS),
     62       name="hid_w")
     63     hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]),
     64                             name="hid_b")
     65   """
     66 
     67   def __init__(self, worker_device):
     68     """Create a new `ElasticAverageCustomGetter`.
     69 
     70     Args:
     71       worker_device: String.  Name of the `worker` job.
     72     """
     73     self._worker_device = worker_device
     74     self._local_map = {}
     75     self._global_map = {}
     76 
     77   def __call__(self, getter, name, trainable, collections, *args, **kwargs):
     78     if trainable:
     79       with ops.device(self._worker_device):
     80         local_var = getter(
     81             name,
     82             trainable=True,
     83             collections=[ops.GraphKeys.LOCAL_VARIABLES],
     84             *args,
     85             **kwargs)
     86       global_center_variable = variable_scope.variable(
     87           name='%s/%s' % (GLOBAL_VARIABLE_NAME, name),
     88           initial_value=local_var.initialized_value(),
     89           trainable=False,
     90           collections=[ops.GraphKeys.GLOBAL_VARIABLES])
     91 
     92       with ops.device(self._worker_device):
     93         local_center_variable = variable_scope.variable(
     94             name='%s/%s' % (LOCAL_VARIABLE_NAME, name),
     95             initial_value=local_var.initialized_value(),
     96             trainable=False,
     97             collections=[ops.GraphKeys.LOCAL_VARIABLES])
     98 
     99       self._local_map[local_var] = local_center_variable
    100       self._global_map[local_var] = global_center_variable
    101       return local_var
    102     else:
    103       return getter(name, trainable, collections, *args, **kwargs)
    104 
    105 
    106 class ElasticAverageOptimizer(optimizer.Optimizer):
    107   """Wrapper optimizer that implements the Elastic Average SGD algorithm.
    108   This is an async optimizer. During the training, Each worker will update
    109   the local variables and maintains its own local_step, which starts from 0
    110   and is incremented by 1 after each update of local variables. Whenever
    111   the communication period divides the local step, the worker requests
    112   the current global center variables and then computed the elastic difference
    113   between global center variables and local variables. The elastic difference
    114   then be used to update both local variables and global variables.
    115   """
    116 
    117   # Default value as paper described
    118   BETA = 0.9
    119 
    120   def __init__(self,
    121                opt,
    122                num_worker,
    123                ea_custom_getter,
    124                communication_period=10,
    125                moving_rate=None,
    126                rho=None,
    127                use_locking=True,
    128                name='ElasticAverageOptimizer'):
    129     """Construct a new gradient descent optimizer.
    130 
    131     Args:
    132       opt: The actual optimizer that will be used to update local variables.
    133         Must be one of the Optimizer classes.
    134       num_worker: The number of workers
    135       ea_custom_getter: The ElasticAverageCustomGetter
    136       communication_period: An int point value to controls the frequency
    137         of the communication between every worker and the ps.
    138       moving_rate: A floating point value to control the elastic difference.
    139       rho: the amount of exploration we allow ine the model. The default
    140         value is moving_rate/learning_rate
    141       use_locking: If True use locks for update operations.
    142       name: Optional name prefix for the operations created when applying
    143         gradients. Defaults to "ElasticAverageOptimizer".
    144     """
    145     super(ElasticAverageOptimizer, self).__init__(use_locking, name)
    146     self._opt = opt
    147     self._num_worker = num_worker
    148     self._period = communication_period
    149     self._local_map = ea_custom_getter._local_map
    150     self._global_map = ea_custom_getter._global_map
    151 
    152     if moving_rate is None:
    153       self._moving_rate = self.BETA / communication_period / num_worker
    154     else:
    155       self._moving_rate = moving_rate
    156     if rho is None:
    157       self._rho = self._moving_rate / self._opt._learning_rate
    158     else:
    159       self._rho = rho
    160 
    161     self._local_step = variable_scope.get_variable(
    162         initializer=0,
    163         trainable=False,
    164         collections=[ops.GraphKeys.LOCAL_VARIABLES],
    165         name='local_step')
    166     self._opt._prepare()
    167 
    168   def compute_gradients(self,
    169                         loss,
    170                         var_list=None,
    171                         gate_gradients=optimizer.Optimizer.GATE_OP,
    172                         aggregation_method=None,
    173                         colocate_gradients_with_ops=False,
    174                         grad_loss=None):
    175     """Compute gradients of `loss` for the variables in `var_list`.
    176 
    177     Add rho*elastic_difference to loss to control the exploration
    178     This is the first part of `minimize()`.  It returns a list
    179     of (gradient, variable) pairs where "gradient" is the gradient
    180     for "variable".  Note that "gradient" can be a `Tensor`, an
    181     `IndexedSlices`, or `None` if there is no gradient for the
    182     given variable.
    183 
    184     Args:
    185       loss: A Tensor containing the value to minimize.
    186       var_list: Optional list or tuple of `tf.Variable` to update to minimize
    187         `loss`.  Defaults to the list of variables collected in the graph
    188         under the key `GraphKey.TRAINABLE_VARIABLES`.
    189       gate_gradients: How to gate the computation of gradients.  Can be
    190         `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
    191       aggregation_method: Specifies the method used to combine gradient terms.
    192         Valid values are defined in the class `AggregationMethod`.
    193       colocate_gradients_with_ops: If True, try colocating gradients with
    194         the corresponding op.
    195       grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
    196 
    197     Returns:
    198       A list of (gradient, variable) pairs. Variable is always present, but
    199       gradient can be `None`.
    200 
    201     Raises:
    202       TypeError: If `var_list` contains anything else than `Variable` objects.
    203       ValueError: If some arguments are invalid.
    204     """
    205     if not var_list:
    206       var_list = variables.trainable_variables()
    207 
    208     elastic_difference = [
    209         math_ops.subtract(v, lv)
    210         for v, lv in zip(variables.trainable_variables(),
    211                          [self._local_map[var] for var in var_list])
    212     ]
    213 
    214     distance_loss = self._rho * math_ops.add_n(
    215         [gen_nn_ops.l2_loss(ed) for ed in elastic_difference])
    216 
    217     total_loss = loss + distance_loss
    218     return self._opt.compute_gradients(total_loss, var_list, gate_gradients,
    219                                        aggregation_method,
    220                                        colocate_gradients_with_ops, grad_loss)
    221 
    222   def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    223     """Apply gradients to global variables.
    224 
    225     This is the second part of `minimize()`. It returns an `Operation` that
    226     applies gradients.
    227 
    228     Args:
    229       grads_and_vars: List of (gradient, variable) pairs as returned by
    230         `compute_gradients()`.
    231       global_step: Optional `Variable` to increment by one after the
    232         variables have been updated.
    233       name: Optional name for the returned operation.  Default to the
    234         name passed to the `Optimizer` constructor.
    235 
    236     Returns:
    237       An `Operation` that applies the specified gradients. If `global_step`
    238       was not None, that operation also increments `global_step`.
    239 
    240     Raises:
    241       TypeError: If `grads_and_vars` is malformed.
    242       ValueError: If none of the variables have gradients.
    243     """
    244     apply_updates = self._opt.apply_gradients(grads_and_vars)
    245     with ops.control_dependencies([apply_updates]):
    246       local_update = state_ops.assign_add(
    247           self._local_step, 1, name='local_step_update').op
    248 
    249     # update global variables.
    250     def _Update_global_variables():
    251       local_vars = [v for g, v in grads_and_vars if g is not None]
    252       global_center_vars = [self._global_map[var] for var in local_vars]
    253       local_center_vars = [self._local_map[var] for var in local_vars]
    254       local_center_vars_update = []
    255       for lvar, var in zip(local_center_vars, global_center_vars):
    256         local_center_vars_update.append(lvar.assign(var))
    257       update_ops = []
    258       differences = []
    259       with ops.control_dependencies(local_center_vars_update):
    260         for v, lv in zip(local_vars, local_center_vars):
    261           with ops.device(v.device):
    262             differences.append(math_ops.subtract(v, lv))
    263         for lvar, diff in zip(local_vars, differences):
    264           with ops.device(lvar.device):
    265             update_ops.append(
    266                 state_ops.assign_sub(lvar,
    267                                      math_ops.multiply(self._moving_rate,
    268                                                        diff)))
    269         for var, diff in zip(global_center_vars, differences):
    270           with ops.device(var.device):
    271             update_ops.append(
    272                 state_ops.assign_add(var,
    273                                      math_ops.multiply(self._moving_rate,
    274                                                        diff)))
    275         if global_step:
    276           with ops.colocate_with(global_step):
    277             update_ops.append(state_ops.assign_add(global_step, 1))
    278       variable_update = control_flow_ops.group(*(update_ops))
    279       return variable_update
    280 
    281     with ops.control_dependencies([local_update]):
    282       condition = math_ops.equal(
    283           math_ops.mod(self._local_step, self._period), 0)
    284       conditional_update = control_flow_ops.cond(
    285           condition, _Update_global_variables, control_flow_ops.no_op)
    286     return conditional_update
    287 
    288   def get_init_op(self, task_index):
    289     """Returns the op to let all the local variables and local center
    290     variables equal to the global center variables before the training begins"""
    291 
    292     def _Add_sync_queues_and_barrier(enqueue_after_list):
    293       """Adds ops to enqueu on all worker queues"""
    294       sync_queues = [
    295           data_flow_ops.FIFOQueue(
    296               self._num_worker, [dtypes.bool],
    297               shapes=[[]],
    298               shared_name='%s%s' % ('variable_init_sync_queue', i))
    299           for i in range(self._num_worker)
    300       ]
    301       queue_ops = []
    302       # For each other worker, add an entry in a queue
    303       token = constant_op.constant(False)
    304       with ops.control_dependencies(enqueue_after_list):
    305         for i, q in enumerate(sync_queues):
    306           if i == task_index:
    307             queue_ops.append(control_flow_ops.no_op())
    308           else:
    309             queue_ops.append(q.enqueue(token))
    310       queue_ops.append(
    311           sync_queues[task_index].dequeue_many(len(sync_queues) - 1))
    312       return control_flow_ops.group(*queue_ops)
    313 
    314     init_ops = []
    315     local_vars = variables.trainable_variables()
    316     global_center_vars = [self._global_map[var] for var in local_vars]
    317     local_center_vars = [self._local_map[var] for var in local_vars]
    318     if not (local_vars and global_center_vars and local_center_vars):
    319       raise ValueError('The lists of local_variables, global_center_variables, '
    320                        'local_center_variables should not be empty  ')
    321     for lvar, gc_var, lc_var in zip(local_vars, global_center_vars,
    322                                     local_center_vars):
    323       init_ops.append(state_ops.assign(lvar, gc_var))
    324       init_ops.append(state_ops.assign(lc_var, gc_var))
    325 
    326     init_op = control_flow_ops.group(*(init_ops))
    327     sync_queue_op = _Add_sync_queues_and_barrier([init_op])
    328     return sync_queue_op
    329 
    330   def make_session_run_hook(self, is_chief, task_index):
    331     """Creates a hook to handle ElasticAverageOptimizerHook ops such as initialization."""
    332     return _ElasticAverageOptimizerHook(self, is_chief, task_index)
    333 
    334 
    335 class _ElasticAverageOptimizerHook(session_run_hook.SessionRunHook):
    336 
    337   def __init__(self, ea_optimizer, is_chief, task_index):
    338     """Creates hook to handle ElasticAverageOptimizer initialization ops.
    339 
    340     Args:
    341       ea_optimizer: `ElasticAverageOptimizer` which this hook will initialize.
    342       is_chief: `Bool`, whether is this a chief replica or not.
    343     """
    344     self._ea_optimizer = ea_optimizer
    345     self._is_chief = is_chief
    346     self._task_index = task_index
    347 
    348   def begin(self):
    349     self._local_init_op = variables.local_variables_initializer()
    350     self._global_init_op = None
    351     if self._is_chief:
    352       self._global_init_op = variables.global_variables_initializer()
    353     self._variable_init_op = self._ea_optimizer.get_init_op(self._task_index)
    354