Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Wrapper optimizer for Model Average."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.framework import constant_op
     21 from tensorflow.python.framework import dtypes
     22 from tensorflow.python.framework import ops
     23 from tensorflow.python.ops import array_ops
     24 from tensorflow.python.ops import control_flow_ops
     25 from tensorflow.python.ops import data_flow_ops
     26 from tensorflow.python.ops import math_ops
     27 from tensorflow.python.ops import state_ops
     28 from tensorflow.python.ops import variable_scope
     29 from tensorflow.python.ops import variables
     30 from tensorflow.python.training import optimizer
     31 from tensorflow.python.training import session_run_hook
     32 
     33 GLOBAL_VARIABLE_NAME = "global_center_variable"
     34 
     35 
     36 class ModelAverageCustomGetter(object):
     37   """Custom_getter class is used to do.
     38 
     39   1. Change trainable variables to local collection and place them at worker
     40     device
     41   2. Generate global variables
     42     Notice that the class should be used with tf.replica_device_setter,
     43     so that the global center variables and global step variable can be placed
     44     at ps device. Besides, use 'tf.get_variable' instead of 'tf.Variable' to
     45     use this custom getter.
     46 
     47   For example,
     48   ma_custom_getter = ModelAverageCustomGetter(worker_device)
     49   with tf.device(
     50     tf.train.replica_device_setter(
     51       worker_device=worker_device,
     52       ps_device="/job:ps/cpu:0",
     53       cluster=cluster)),
     54     tf.variable_scope('',custom_getter=ma_custom_getter):
     55     hid_w = tf.get_variable(
     56       initializer=tf.truncated_normal(
     57           [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
     58           stddev=1.0 / IMAGE_PIXELS),
     59       name="hid_w")
     60     hid_b = tf.get_variable(initializer=tf.zeros([FLAGS.hidden_units]),
     61                             name="hid_b")
     62   """
     63 
     64   def __init__(self, worker_device):
     65     """Create a new `ElasticAverageCustomGetter`.
     66 
     67     Args:
     68       worker_device: String.  Name of the `worker` job.
     69     """
     70     self._worker_device = worker_device
     71     self._local_2_global = {}
     72 
     73   def __call__(self, getter, name, trainable, collections, *args, **kwargs):
     74     if trainable:
     75       with ops.device(self._worker_device):
     76         local_var = getter(
     77             name,
     78             trainable=True,
     79             collections=[ops.GraphKeys.LOCAL_VARIABLES],
     80             *args,
     81             **kwargs)
     82 
     83       global_variable = variable_scope.variable(
     84           name="%s/%s" % (GLOBAL_VARIABLE_NAME, name),
     85           initial_value=local_var.initialized_value(),
     86           trainable=False,
     87           collections=[ops.GraphKeys.GLOBAL_VARIABLES])
     88 
     89       self._local_2_global[local_var] = global_variable
     90       return local_var
     91     else:
     92       return getter(name, trainable, collections, *args, **kwargs)
     93 
     94 
     95 class ModelAverageOptimizer(optimizer.Optimizer):
     96   """Wrapper optimizer that implements the Model Average algorithm.
     97 
     98   This is a sync optimizer. During the training, each worker will update
     99   the local variables and maintains its own local_step, which starts from 0
    100   and is incremented by 1 after each update of local variables. Whenever the
    101   interval_steps divides the local step, the local variables from all the
    102   workers will be averaged and assigned to global center variables. Then the
    103   local variables will be assigned by global center variables.
    104   """
    105 
    106   def __init__(self,
    107                opt,
    108                num_worker,
    109                is_chief,
    110                ma_custom_getter,
    111                interval_steps=100,
    112                use_locking=True,
    113                name="ModelAverageOptimizer"):
    114     """Construct a new model average optimizer.
    115 
    116     Args:
    117       opt: The actual optimizer that will be used to update local variables
    118       num_worker: The number of workers
    119       is_chief: whether chief worker
    120       ma_custom_getter: ModelAverageCustomGetter
    121       interval_steps: An int point value to controls the frequency of the
    122         average of local variables
    123       use_locking: If True use locks for update operations
    124       name: string. Optional name of the returned operation
    125     """
    126     super(ModelAverageOptimizer, self).__init__(use_locking, name)
    127     self._opt = opt
    128     self._num_worker = num_worker
    129     self._is_chief = is_chief
    130     self._local_2_global = ma_custom_getter._local_2_global  # pylint:disable=protected-access
    131     self._interval_steps = interval_steps
    132     self._accumulator_list = []
    133     self._chief_init_op = None
    134 
    135     self._local_step = variable_scope.get_variable(
    136         initializer=0,
    137         trainable=False,
    138         collections=[ops.GraphKeys.LOCAL_VARIABLES],
    139         name="local_step")
    140 
    141     self._opt._prepare()  # pylint:disable=protected-access
    142 
    143   def compute_gradients(self, *args, **kwargs):
    144     """Compute gradients of "loss" for the variables in "var_list".
    145 
    146     This simply wraps the compute_gradients() from the real optimizer.
    147 
    148     Args:
    149       *args: Arguments for compute_gradients().
    150       **kwargs: Keyword arguments for compute_gradients().
    151 
    152     Returns:
    153       A list of (gradient, variable) pairs.
    154     """
    155     return self._opt.compute_gradients(*args, **kwargs)
    156 
    157   def _local_vars_update(self, var_list):
    158     """Get the update ops for the local variables in "var_list".
    159 
    160     Args:
    161       var_list: Optional list or tuple of 'tf.Variable' to update
    162 
    163     Returns:
    164       An update op
    165 
    166     Raises:
    167       ValueError: if var_list is empty.
    168     """
    169     if not var_list:
    170       raise ValueError("The list of local_variables should not be empty")
    171     update_ops = []
    172     global_center_vars = [self._local_2_global[var] for var in var_list]
    173     for lvar, gvar in zip(var_list, global_center_vars):
    174       with ops.device(lvar.device):
    175         update_ops.append(state_ops.assign(lvar, gvar.read_value()))
    176     return control_flow_ops.group(*(update_ops))
    177 
    178   def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    179     """Apply gradients to variables.
    180 
    181     This contains most of the synchronization implementation and also wraps the
    182     apply_gradients() from the real optimizer. The chief work updates global
    183     variables.
    184 
    185     Args:
    186       grads_and_vars: List of (gradient, variable) pairs as returned by
    187         compute_gradients().
    188       global_step: Optional Variable to increment by one after the
    189         variables have been updated.
    190       name: Optional name for the returned operation.  Default to the
    191         name passed to the Optimizer constructor.
    192 
    193     Returns:
    194       A conditional 'Operation' that update both local and global variables or
    195       just local variables
    196 
    197     Raises:
    198       ValueError: If the grads_and_vars is empty.
    199       ValueError: If global step is not provided, the staleness cannot be
    200         checked.
    201     """
    202 
    203     # update local variables
    204     if not grads_and_vars:
    205       raise ValueError("Must supply at least one variable")
    206     if global_step is None:
    207       raise ValueError("Global step is required")
    208 
    209     apply_updates = self._opt.apply_gradients(grads_and_vars)
    210     with ops.control_dependencies([apply_updates]):
    211       local_update = state_ops.assign_add(
    212           self._local_step, 1, name="local_step_update").op
    213 
    214     # update global variables.
    215     def _update_global_variables():  # pylint: disable=missing-docstring
    216       local_vars = [v for g, v in grads_and_vars if g is not None]
    217       global_vars = [self._local_2_global[v] for v in local_vars]
    218       # sync queue
    219       with ops.colocate_with(global_step):
    220         sync_queue = data_flow_ops.FIFOQueue(
    221             -1, [dtypes.bool], shapes=[[]], shared_name="sync_queue")
    222       train_ops = []
    223       aggregated_vars = []
    224       with ops.name_scope(None, self._name + "/global"):
    225         for var, gvar in zip(local_vars, global_vars):
    226           # pylint: disable=protected-access
    227           with ops.device(gvar.device):
    228             if isinstance(var._ref(), ops.Tensor):
    229               var_accum = data_flow_ops.ConditionalAccumulator(
    230                   var.dtype,
    231                   shape=var.get_shape(),
    232                   shared_name=gvar.name + "/var_accum")
    233               train_ops.append(
    234                   var_accum.apply_grad(var._ref(), local_step=global_step))
    235               aggregated_vars.append(var_accum.take_grad(self._num_worker))
    236             else:
    237               raise ValueError("Unknown local variable type!")
    238             self._accumulator_list.append((var_accum, gvar.device))
    239       # chief worker updates global vars and enqueues tokens to the sync queue
    240       if self._is_chief:
    241         update_ops = []
    242         with ops.control_dependencies(train_ops):
    243           for avg_var, gvar in zip(aggregated_vars, global_vars):
    244             with ops.device(gvar.device):
    245               update_ops.append(state_ops.assign(gvar, avg_var))
    246           with ops.device(global_step.device):
    247             update_ops.append(state_ops.assign_add(global_step, 1))
    248         with ops.control_dependencies(update_ops), ops.device(
    249             global_step.device):
    250           tokens = array_ops.fill([self._num_worker - 1],
    251                                   constant_op.constant(False))
    252           sync_op = sync_queue.enqueue_many(tokens)
    253       else:
    254         with ops.control_dependencies(train_ops), ops.device(
    255             global_step.device):
    256           sync_op = sync_queue.dequeue()
    257 
    258       with ops.control_dependencies([sync_op]):
    259         local_update_op = self._local_vars_update(local_vars)
    260       return local_update_op
    261 
    262     with ops.control_dependencies([local_update]):
    263       condition = math_ops.equal(
    264           math_ops.mod(self._local_step, self._interval_steps), 0)
    265       conditional_update = control_flow_ops.cond(
    266           condition, _update_global_variables, control_flow_ops.no_op)
    267 
    268     chief_init_ops = []
    269     for accum, dev in self._accumulator_list:
    270       with ops.device(dev):
    271         chief_init_ops.append(
    272             accum.set_global_step(global_step, name="SetGlobalStep"))
    273     self._chief_init_op = control_flow_ops.group(*(chief_init_ops))
    274 
    275     return conditional_update
    276 
    277   def get_init_op(self):
    278     """Returns the op.
    279 
    280     This method lets all the local variables equal to the global
    281     variables before the training begins.
    282     """
    283     return self._local_vars_update(variables.trainable_variables())
    284 
    285   def make_session_run_hook(self):
    286     """Creates a hook to handle ModelAverage ops such as initialization."""
    287     return _ModelAverageOptimizerHook(self, self._is_chief)
    288 
    289 
    290 class _ModelAverageOptimizerHook(session_run_hook.SessionRunHook):  # pylint: disable=missing-docstring
    291 
    292   def __init__(self, ma_optimizer, is_chief):
    293     """Creates hook to handle ModelAverageOptimizer initialization ops.
    294 
    295     Args:
    296       ma_optimizer: `ModelAverageOptimizer` which this hook will initialize.
    297       is_chief: `Bool`, whether is this a chief replica or not.
    298     """
    299     self._ma_optimizer = ma_optimizer
    300     self._is_chief = is_chief
    301 
    302   def begin(self):
    303     self._local_init_op = variables.local_variables_initializer()
    304     self._global_init_op = None
    305     if self._is_chief:
    306       self._global_init_op = variables.global_variables_initializer()
    307       self._chief_init_op = self._ma_optimizer._chief_init_op  # pylint: disable=protected-access
    308     self._variable_init_op = self._ma_optimizer.get_init_op()
    309