Home | History | Annotate | Download | only in slim
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Contains TF-Slim code for training models.
     16 
     17 This script contains various functions for training models. These include
     18 manipulating gradients, creating a `train_op` (an operation that computes the
     19 loss and applies the gradients) and a training loop function. The training loop
     20 allows the user to pass in the `train_op` and runs the optimization according
     21 to user-specified arguments. Note that the training loop uses the
     22 tf.train.Supervisor and its managed_session in its implementation to ensure the
     23 ability of worker processes to recover from failures.
     24 
     25 ************************************
     26 * A simple working training script *
     27 ************************************
     28 
     29   # Load data and create the model:
     30   images, labels = LoadData(...)
     31   predictions = MyModel(images)
     32 
     33   # Define the loss:
     34   slim.losses.log_loss(predictions, labels)
     35   total_loss = slim.losses.get_total_loss()
     36 
     37   # Define the optimizer:
     38   optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
     39 
     40   # Create the train_op
     41   train_op = slim.learning.create_train_op(total_loss, optimizer)
     42 
     43   # Run training.
     44   slim.learning.train(train_op, my_log_dir)
     45 
     46 *************************
     47 * Creating the train_op *
     48 *************************
     49 
     50 In order to train, TF-Slim's train loop needs a train_op: an `Operation` that
     51 (a) computes the loss, (b) applies the gradients to update the weights and
     52 (c) returns the value of the loss. slim.learning.create_train_op creates
     53 such an `Operation`. This function also provides the ability to manipulate
     54 the gradients using a few arguments:
     55 
     56   # Create the train_op and clip the gradient norms:
     57   train_op = slim.learning.create_train_op(
     58       total_loss,
     59       optimizer,
     60       clip_gradient_norm=4)
     61 
     62   # Create the train_op and scale the gradients by providing a map from variable
     63   # name (or variable) to a scaling coefficient:
     64   gradient_multipliers = {
     65     'conv0/weights': 1.2,
     66     'fc8/weights': 3.4,
     67   }
     68   train_op = slim.learning.create_train_op(
     69       total_loss,
     70       optimizer,
     71       gradient_multipliers=gradient_multipliers)
     72 
     73 ****************************************************************
     74 * Performing additional (non-gradient) updates during training *
     75 ****************************************************************
     76 
     77 Many networks utilize modules, like BatchNorm, that require performing a series
     78 of non-gradient updates during training. slim.learning.create_train_op allows
     79 a user to pass in a list of update_ops to call along with the gradient updates.
     80 
     81   train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops)
     82 
     83 By default, slim.learning.create_train_op includes all update ops that are
     84 part of the `tf.GraphKeys.UPDATE_OPS` collection. Additionally, TF-Slim's
     85 slim.batch_norm function adds the moving mean and moving variance updates to
     86 this collection. Consequently, users who want to use slim.batch_norm will not
     87 need to take any additional steps in order to have the moving mean and moving
     88 variance updates be computed.
     89 
     90 However, users with additional, specialized updates can either override the
     91 default update ops or simply add additional update ops to the
     92 `tf.GraphKeys.UPDATE_OPS` collection:
     93 
     94   # Force TF-Slim NOT to use ANY update_ops:
     95   train_op = slim.learning.create_train_op(
     96      total_loss,
     97      optimizer,
     98      update_ops=[])
     99 
    100   # Use an alternative set of update ops:
    101   train_op = slim.learning.create_train_op(
    102      total_loss,
    103      optimizer,
    104      update_ops=my_other_update_ops)
    105 
    106   # Use an alternative set of update ops in addition to the default updates:
    107   tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0)
    108   tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1)
    109 
    110   train_op = slim.learning.create_train_op(
    111      total_loss,
    112      optimizer)
    113 
    114   # Which is the same as:
    115   train_op = slim.learning.create_train_op(
    116      total_loss,
    117      optimizer,
    118      update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS))
    119 
    120 ******************************************
    121 * Initializing a model from a checkpoint *
    122 ******************************************
    123 
    124 It is common to want to 'warm-start' a model from a pre-trained checkpoint.
    125 TF-Slim provides a convenient mechanism for doing so:
    126 
    127   ...
    128 
    129   # Create the train_op
    130   train_op = slim.learning.create_train_op(total_loss, optimizer)
    131 
    132   # Create the initial assignment op
    133   checkpoint_path = '/path/to/old_model_checkpoint'
    134   variables_to_restore = slim.get_model_variables()
    135   init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
    136       checkpoint_path, variables_to_restore)
    137 
    138   # Create an initial assignment function.
    139   def InitAssignFn(sess):
    140       sess.run(init_assign_op, init_feed_dict)
    141 
    142   # Run training.
    143   slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
    144 
    145 ***************************************************************************
    146 * Initializing a model from a checkpoint whose variable names don't match *
    147 ***************************************************************************
    148 
    149 At times, a user may want to initialize a new model with values from a
    150 checkpoint whose variable names do not match those of the current model. In this
    151 case, one needs to create a mapping from the checkpoint variable names to the
    152 current model variables. This requires only a small modification of the code
    153 above:
    154   ...
    155   # Creates a model with two variables, var0 and var1
    156   predictions = MyModel(images)
    157   ...
    158 
    159   # Create the train_op
    160   train_op = slim.learning.create_train_op(total_loss, optimizer)
    161 
    162   checkpoint_path = '/path/to/old_model_checkpoint'
    163 
    164   # Create the mapping:
    165   variables_to_restore = {
    166       'name_var_0_in_checkpoint': slim.get_unique_variable('var0'),
    167       'name_var_1_in_checkpoint': slim.get_unique_variable('var1')
    168   }
    169   init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
    170       checkpoint_path, variables_to_restore)
    171 
    172   # Create an initial assignment function.
    173   def InitAssignFn(sess):
    174       sess.run(init_assign_op, init_feed_dict)
    175 
    176   # Run training.
    177   slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
    178 
    179 
    180 *************************************************
    181 * Fine-Tuning Part of a model from a checkpoint *
    182 *************************************************
    183 
    184 Rather than initializing all of the weights of a given model, we sometimes
    185 only want to restore some of the weights from a checkpoint. To do this, one
    186 need only filter those variables to initialize as follows:
    187 
    188   ...
    189 
    190   # Create the train_op
    191   train_op = slim.learning.create_train_op(total_loss, optimizer)
    192 
    193   checkpoint_path = '/path/to/old_model_checkpoint'
    194 
    195   # Specify the variables to restore via a list of inclusion or exclusion
    196   # patterns:
    197   variables_to_restore = slim.get_variables_to_restore(
    198       include=["conv"], exclude=["fc8", "fc9])
    199   # or
    200   variables_to_restore = slim.get_variables_to_restore(exclude=["conv"])
    201 
    202   init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
    203       checkpoint_path, variables_to_restore)
    204 
    205   # Create an initial assignment function.
    206   def InitAssignFn(sess):
    207       sess.run(init_assign_op, init_feed_dict)
    208 
    209   # Run training.
    210   slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
    211 
    212 ******************************************************
    213 * Initializing model variables from values in memory *
    214 ******************************************************
    215 
    216 One may want to initialize the weights of a model from values from an arbitrary
    217 source (a text document, matlab file, etc). While this is technically feasible
    218 using plain TensorFlow, it also results in the values of your weights being
    219 stored in the graph. For large models, this becomes prohibitively large. TF-Slim
    220 allows you to perform this initial assignment without having to store the values
    221 of the initial model in the graph itself by using placeholders and a feed
    222 dictionary:
    223 
    224   ...
    225 
    226   # Create the train_op
    227   train_op = slim.learning.create_train_op(total_loss, optimizer)
    228 
    229   # Create the mapping from variable names to values:
    230   var0_initial_value = ReadFromDisk(...)
    231   var1_initial_value = ReadFromDisk(...)
    232 
    233   var_names_to_values = {
    234     'var0': var0_initial_value,
    235     'var1': var1_initial_value,
    236   }
    237   init_assign_op, init_feed_dict = slim.assign_from_values(var_names_to_values)
    238 
    239   # Create an initial assignment function.
    240   def InitAssignFn(sess):
    241       sess.run(init_assign_op, init_feed_dict)
    242 
    243   # Run training.
    244   slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
    245 """
    246 from __future__ import absolute_import
    247 from __future__ import division
    248 from __future__ import print_function
    249 
    250 import os
    251 import sys
    252 import time
    253 
    254 from tensorflow.contrib.training.python.training import training
    255 from tensorflow.core.protobuf import config_pb2
    256 from tensorflow.python.client import timeline
    257 from tensorflow.python.framework import constant_op
    258 from tensorflow.python.framework import errors
    259 from tensorflow.python.framework import ops
    260 from tensorflow.python.lib.io import file_io
    261 from tensorflow.python.ops import clip_ops
    262 from tensorflow.python.ops import control_flow_ops
    263 from tensorflow.python.ops import lookup_ops
    264 from tensorflow.python.ops import math_ops
    265 from tensorflow.python.ops import variables
    266 from tensorflow.python.platform import tf_logging as logging
    267 from tensorflow.python.summary import summary
    268 from tensorflow.python.training import optimizer as tf_optimizer
    269 from tensorflow.python.training import saver as tf_saver
    270 from tensorflow.python.training import supervisor
    271 from tensorflow.python.training import sync_replicas_optimizer
    272 from tensorflow.python.training import training_util
    273 
    274 __all__ = [
    275     'add_gradients_summaries', 'clip_gradient_norms', 'multiply_gradients',
    276     'create_train_op', 'train_step', 'train'
    277 ]
    278 
    279 
    280 def clip_gradient_norms(gradients_to_variables, max_norm):
    281   """Clips the gradients by the given value.
    282 
    283   Args:
    284     gradients_to_variables: A list of gradient to variable pairs (tuples).
    285     max_norm: the maximum norm value.
    286 
    287   Returns:
    288     A list of clipped gradient to variable pairs.
    289   """
    290   clipped_grads_and_vars = []
    291   for grad, var in gradients_to_variables:
    292     if grad is not None:
    293       if isinstance(grad, ops.IndexedSlices):
    294         tmp = clip_ops.clip_by_norm(grad.values, max_norm)
    295         grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
    296       else:
    297         grad = clip_ops.clip_by_norm(grad, max_norm)
    298     clipped_grads_and_vars.append((grad, var))
    299   return clipped_grads_and_vars
    300 
    301 
    302 def multiply_gradients(grads_and_vars, gradient_multipliers):
    303   """Multiply specified gradients.
    304 
    305   Args:
    306     grads_and_vars: A list of gradient to variable pairs (tuples).
    307     gradient_multipliers: A map from either `Variables` or `Variable` op names
    308       to the coefficient by which the associated gradient should be scaled.
    309 
    310   Returns:
    311     The updated list of gradient to variable pairs.
    312 
    313   Raises:
    314     ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers`
    315     is empty or None or if `gradient_multipliers` is not a dictionary.
    316   """
    317   if not isinstance(grads_and_vars, list):
    318     raise ValueError('`grads_and_vars` must be a list.')
    319   if not gradient_multipliers:
    320     raise ValueError('`gradient_multipliers` is empty.')
    321   if not isinstance(gradient_multipliers, dict):
    322     raise ValueError('`gradient_multipliers` must be a dict.')
    323 
    324   multiplied_grads_and_vars = []
    325   for grad, var in grads_and_vars:
    326     if var in gradient_multipliers or var.op.name in gradient_multipliers:
    327       key = var if var in gradient_multipliers else var.op.name
    328       if grad is None:
    329         raise ValueError('Requested multiple of `None` gradient.')
    330 
    331       multiplier = gradient_multipliers[key]
    332       if not isinstance(multiplier, ops.Tensor):
    333         multiplier = constant_op.constant(multiplier, dtype=grad.dtype)
    334 
    335       if isinstance(grad, ops.IndexedSlices):
    336         tmp = grad.values * multiplier
    337         grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
    338       else:
    339         grad *= multiplier
    340     multiplied_grads_and_vars.append((grad, var))
    341   return multiplied_grads_and_vars
    342 
    343 
    344 def add_gradients_summaries(grads_and_vars):
    345   """Add summaries to gradients.
    346 
    347   Args:
    348     grads_and_vars: A list of gradient to variable pairs (tuples).
    349 
    350   Returns:
    351     The list of created summaries.
    352   """
    353   summaries = []
    354   for grad, var in grads_and_vars:
    355     if grad is not None:
    356       if isinstance(grad, ops.IndexedSlices):
    357         grad_values = grad.values
    358       else:
    359         grad_values = grad
    360       summaries.append(
    361           summary.histogram(var.op.name + '/gradient', grad_values))
    362       summaries.append(
    363           summary.scalar(var.op.name + '/gradient_norm',
    364                          clip_ops.global_norm([grad_values])))
    365     else:
    366       logging.info('Var %s has no gradient', var.op.name)
    367 
    368   return summaries
    369 
    370 
    371 _USE_GLOBAL_STEP = 0
    372 
    373 
    374 def create_train_op(total_loss,
    375                     optimizer,
    376                     global_step=_USE_GLOBAL_STEP,
    377                     update_ops=None,
    378                     variables_to_train=None,
    379                     clip_gradient_norm=0,
    380                     summarize_gradients=False,
    381                     gate_gradients=tf_optimizer.Optimizer.GATE_OP,
    382                     aggregation_method=None,
    383                     colocate_gradients_with_ops=False,
    384                     gradient_multipliers=None,
    385                     check_numerics=True):
    386   """Creates an `Operation` that evaluates the gradients and returns the loss.
    387 
    388   Args:
    389     total_loss: A `Tensor` representing the total loss.
    390     optimizer: A tf.Optimizer to use for computing the gradients.
    391     global_step: A `Tensor` representing the global step variable. If left as
    392       `_USE_GLOBAL_STEP`, then slim.variables.global_step() is used.
    393     update_ops: An optional list of updates to execute. If `update_ops` is
    394       `None`, then the update ops are set to the contents of the
    395       `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but
    396       it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`,
    397       a warning will be displayed.
    398     variables_to_train: an optional list of variables to train. If None, it will
    399       default to all tf.trainable_variables().
    400     clip_gradient_norm: If greater than 0 then the gradients would be clipped
    401       by it.
    402     summarize_gradients: Whether or not add summaries for each gradient.
    403     gate_gradients: How to gate the computation of gradients. See tf.Optimizer.
    404     aggregation_method: Specifies the method used to combine gradient terms.
    405       Valid values are defined in the class `AggregationMethod`.
    406     colocate_gradients_with_ops: Whether or not to try colocating the gradients
    407       with the ops that generated them.
    408     gradient_multipliers: A dictionary of either `Variables` or `Variable` op
    409       names to the coefficient by which the associated gradient should be
    410       scaled.
    411     check_numerics: Whether or not we apply check_numerics.
    412 
    413   Returns:
    414     A `Tensor` that when evaluated, computes the gradients and returns the total
    415       loss value.
    416   """
    417   def transform_grads_fn(grads):
    418     if gradient_multipliers:
    419       with ops.name_scope('multiply_grads'):
    420         grads = multiply_gradients(grads, gradient_multipliers)
    421 
    422     # Clip gradients.
    423     if clip_gradient_norm > 0:
    424       with ops.name_scope('clip_grads'):
    425         grads = clip_gradient_norms(grads, clip_gradient_norm)
    426     return grads
    427 
    428   return training.create_train_op(
    429       total_loss=total_loss,
    430       optimizer=optimizer,
    431       global_step=global_step,
    432       update_ops=update_ops,
    433       variables_to_train=variables_to_train,
    434       transform_grads_fn=transform_grads_fn,
    435       summarize_gradients=summarize_gradients,
    436       gate_gradients=gate_gradients,
    437       aggregation_method=aggregation_method,
    438       colocate_gradients_with_ops=colocate_gradients_with_ops,
    439       check_numerics=check_numerics)
    440 
    441 
    442 def _wait_for_step(sess, global_step, step):
    443   """Wait till the global step has reached at least 'step'.
    444 
    445   Args:
    446     sess: A session.
    447     global_step: A Tensor.
    448     step: Int.  The global step to reach.
    449   """
    450   while True:
    451     if training_util.global_step(sess, global_step) >= step:
    452       break
    453     time.sleep(1.0)
    454 
    455 
    456 def train_step(sess, train_op, global_step, train_step_kwargs):
    457   """Function that takes a gradient step and specifies whether to stop.
    458 
    459   Args:
    460     sess: The current session.
    461     train_op: An `Operation` that evaluates the gradients and returns the
    462       total loss.
    463     global_step: A `Tensor` representing the global training step.
    464     train_step_kwargs: A dictionary of keyword arguments.
    465 
    466   Returns:
    467     The total loss and a boolean indicating whether or not to stop training.
    468 
    469   Raises:
    470     ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not.
    471   """
    472   start_time = time.time()
    473 
    474   trace_run_options = None
    475   run_metadata = None
    476   if 'should_trace' in train_step_kwargs:
    477     if 'logdir' not in train_step_kwargs:
    478       raise ValueError('logdir must be present in train_step_kwargs when '
    479                        'should_trace is present')
    480     if sess.run(train_step_kwargs['should_trace']):
    481       trace_run_options = config_pb2.RunOptions(
    482           trace_level=config_pb2.RunOptions.FULL_TRACE)
    483       run_metadata = config_pb2.RunMetadata()
    484 
    485   total_loss, np_global_step = sess.run([train_op, global_step],
    486                                         options=trace_run_options,
    487                                         run_metadata=run_metadata)
    488   time_elapsed = time.time() - start_time
    489 
    490   if run_metadata is not None:
    491     tl = timeline.Timeline(run_metadata.step_stats)
    492     trace = tl.generate_chrome_trace_format()
    493     trace_filename = os.path.join(train_step_kwargs['logdir'],
    494                                   'tf_trace-%d.json' % np_global_step)
    495     logging.info('Writing trace to %s', trace_filename)
    496     file_io.write_string_to_file(trace_filename, trace)
    497     if 'summary_writer' in train_step_kwargs:
    498       train_step_kwargs['summary_writer'].add_run_metadata(run_metadata,
    499                                                            'run_metadata-%d' %
    500                                                            np_global_step)
    501 
    502   if 'should_log' in train_step_kwargs:
    503     if sess.run(train_step_kwargs['should_log']):
    504       logging.info('global step %d: loss = %.4f (%.3f sec/step)',
    505                    np_global_step, total_loss, time_elapsed)
    506 
    507   # TODO(nsilberman): figure out why we can't put this into sess.run. The
    508   # issue right now is that the stop check depends on the global step. The
    509   # increment of global step often happens via the train op, which used
    510   # created using optimizer.apply_gradients.
    511   #
    512   # Since running `train_op` causes the global step to be incremented, one
    513   # would expected that using a control dependency would allow the
    514   # should_stop check to be run in the same session.run call:
    515   #
    516   #   with ops.control_dependencies([train_op]):
    517   #     should_stop_op = ...
    518   #
    519   # However, this actually seems not to work on certain platforms.
    520   if 'should_stop' in train_step_kwargs:
    521     should_stop = sess.run(train_step_kwargs['should_stop'])
    522   else:
    523     should_stop = False
    524 
    525   return total_loss, should_stop
    526 
    527 
    528 _USE_DEFAULT = 0
    529 
    530 
    531 def train(train_op,
    532           logdir,
    533           train_step_fn=train_step,
    534           train_step_kwargs=_USE_DEFAULT,
    535           log_every_n_steps=1,
    536           graph=None,
    537           master='',
    538           is_chief=True,
    539           global_step=None,
    540           number_of_steps=None,
    541           init_op=_USE_DEFAULT,
    542           init_feed_dict=None,
    543           local_init_op=_USE_DEFAULT,
    544           init_fn=None,
    545           ready_op=_USE_DEFAULT,
    546           summary_op=_USE_DEFAULT,
    547           save_summaries_secs=600,
    548           summary_writer=_USE_DEFAULT,
    549           startup_delay_steps=0,
    550           saver=None,
    551           save_interval_secs=600,
    552           sync_optimizer=None,
    553           session_config=None,
    554           session_wrapper=None,
    555           trace_every_n_steps=None,
    556           ignore_live_threads=False):
    557   """Runs a training loop using a TensorFlow supervisor.
    558 
    559   When the sync_optimizer is supplied, gradient updates are applied
    560   synchronously. Otherwise, gradient updates are applied asynchronous.
    561 
    562   Args:
    563     train_op: A `Tensor` that, when executed, will apply the gradients and
    564       return the loss value.
    565     logdir: The directory where training logs are written to. If None, model
    566       checkpoints and summaries will not be written.
    567     train_step_fn: The function to call in order to execute a single gradient
    568       step. The function must have take exactly four arguments: the current
    569       session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary.
    570     train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By
    571       default, two `Boolean`, scalar ops called "should_stop" and "should_log"
    572       are provided.
    573     log_every_n_steps: The frequency, in terms of global steps, that the loss
    574       and global step and logged.
    575     graph: The graph to pass to the supervisor. If no graph is supplied the
    576       default graph is used.
    577     master: The address of the tensorflow master.
    578     is_chief: Specifies whether or not the training is being run by the primary
    579       replica during replica training.
    580     global_step: The `Tensor` representing the global step. If left as `None`,
    581       then slim.variables.get_or_create_global_step() is used.
    582     number_of_steps: The max number of gradient steps to take during training,
    583       as measured by 'global_step': training will stop if global_step is
    584       greater than 'number_of_steps'. If the value is left as None, training
    585       proceeds indefinitely.
    586     init_op: The initialization operation. If left to its default value, then
    587       the session is initialized by calling `tf.global_variables_initializer()`.
    588     init_feed_dict: A feed dictionary to use when executing the `init_op`.
    589     local_init_op: The local initialization operation. If left to its default
    590       value, then the session is initialized by calling
    591       `tf.local_variables_initializer()` and `tf.tables_initializer()`.
    592     init_fn: An optional callable to be executed after `init_op` is called. The
    593       callable must accept one argument, the session being initialized.
    594     ready_op: Operation to check if the model is ready to use. If left to its
    595       default value, then the session checks for readiness by calling
    596       `tf.report_uninitialized_variables()`.
    597     summary_op: The summary operation.
    598     save_summaries_secs: How often, in seconds, to save summaries.
    599     summary_writer: `SummaryWriter` to use.  Can be `None`
    600       to indicate that no summaries should be written. If unset, we
    601       create a SummaryWriter.
    602     startup_delay_steps: The number of steps to wait for before beginning. Note
    603       that this must be 0 if a sync_optimizer is supplied.
    604     saver: Saver to save checkpoints. If None, a default one will be created
    605       and used.
    606     save_interval_secs: How often, in seconds, to save the model to `logdir`.
    607     sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of
    608       them. If the argument is supplied, gradient updates will be synchronous.
    609       If left as `None`, gradient updates will be asynchronous.
    610     session_config: An instance of `tf.ConfigProto` that will be used to
    611       configure the `Session`. If left as `None`, the default will be used.
    612     session_wrapper: A function that takes a `tf.Session` object as the only
    613       argument and returns a wrapped session object that has the same methods
    614       that the original object has, or `None`. Iff not `None`, the wrapped
    615       object will be used for training.
    616     trace_every_n_steps: produce and save a `Timeline` in Chrome trace format
    617       and add it to the summaries every `trace_every_n_steps`. If None, no trace
    618       information will be produced or saved.
    619     ignore_live_threads: If `True` ignores threads that remain running after
    620       a grace period when stopping the supervisor, instead of raising a
    621       RuntimeError.
    622 
    623   Returns:
    624     the value of the loss function after training.
    625 
    626   Raises:
    627     ValueError: if `train_op` is empty or if `startup_delay_steps` is
    628       non-zero when `sync_optimizer` is supplied, if `number_of_steps` is
    629       negative, or if `trace_every_n_steps` is not `None` and no `logdir` is
    630       provided.
    631   """
    632   if train_op is None:
    633     raise ValueError('train_op cannot be None.')
    634 
    635   if logdir is None:
    636     if summary_op != _USE_DEFAULT:
    637       raise ValueError('Cannot provide summary_op because logdir=None')
    638     if saver is not None:
    639       raise ValueError('Cannot provide saver because logdir=None')
    640     if trace_every_n_steps is not None:
    641       raise ValueError('Cannot provide trace_every_n_steps because '
    642                        'logdir=None')
    643 
    644   if isinstance(sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):
    645     sync_optimizer = [sync_optimizer]
    646   if sync_optimizer is not None and startup_delay_steps > 0:
    647     raise ValueError(
    648         'startup_delay_steps must be zero when sync_optimizer is supplied.')
    649 
    650   if number_of_steps is not None and number_of_steps <= 0:
    651     raise ValueError(
    652         '`number_of_steps` must be either None or a positive number.')
    653 
    654   graph = graph or ops.get_default_graph()
    655   with graph.as_default():
    656     if global_step is None:
    657       global_step = training_util.get_or_create_global_step()
    658     saver = saver or tf_saver.Saver()
    659 
    660     if sync_optimizer is not None:
    661       for opt in sync_optimizer:
    662         if not isinstance(opt, sync_replicas_optimizer.SyncReplicasOptimizer):
    663           raise ValueError(
    664               '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.')
    665 
    666     with ops.name_scope('init_ops'):
    667       if init_op == _USE_DEFAULT:
    668         init_op = variables.global_variables_initializer()
    669 
    670       if ready_op == _USE_DEFAULT:
    671         ready_op = variables.report_uninitialized_variables()
    672 
    673       if local_init_op == _USE_DEFAULT:
    674         local_init_op = control_flow_ops.group(
    675             variables.local_variables_initializer(),
    676             lookup_ops.tables_initializer())
    677 
    678       if sync_optimizer is not None and isinstance(sync_optimizer, list):
    679         with ops.control_dependencies([local_init_op] if local_init_op is
    680                                       not None else []):
    681           if is_chief:
    682             local_init_op = control_flow_ops.group(
    683                 *[opt.chief_init_op for opt in sync_optimizer])
    684           else:
    685             local_init_op = control_flow_ops.group(
    686                 *[opt.local_step_init_op for opt in sync_optimizer])
    687         ready_for_local_init_op = control_flow_ops.group(
    688             *[opt.ready_for_local_init_op for opt in sync_optimizer])
    689       else:
    690         ready_for_local_init_op = None
    691 
    692     if summary_op == _USE_DEFAULT:
    693       summary_op = summary.merge_all()
    694 
    695     if summary_writer == _USE_DEFAULT:
    696       summary_writer = supervisor.Supervisor.USE_DEFAULT
    697 
    698     if is_chief and sync_optimizer is not None:
    699       # Need to create these BEFORE the supervisor finalizes the graph:
    700       init_tokens_op = [opt.get_init_tokens_op() for opt in sync_optimizer]
    701       chief_queue_runner = [
    702           opt.get_chief_queue_runner() for opt in sync_optimizer]
    703 
    704     if train_step_kwargs == _USE_DEFAULT:
    705       with ops.name_scope('train_step'):
    706         train_step_kwargs = {}
    707 
    708         if number_of_steps:
    709           should_stop_op = math_ops.greater_equal(global_step, number_of_steps)
    710         else:
    711           should_stop_op = constant_op.constant(False)
    712         train_step_kwargs['should_stop'] = should_stop_op
    713         if log_every_n_steps > 0:
    714           train_step_kwargs['should_log'] = math_ops.equal(
    715               math_ops.mod(global_step, log_every_n_steps), 0)
    716         if is_chief and trace_every_n_steps is not None:
    717           train_step_kwargs['should_trace'] = math_ops.equal(
    718               math_ops.mod(global_step, trace_every_n_steps), 0)
    719           train_step_kwargs['logdir'] = logdir
    720 
    721   sv = supervisor.Supervisor(
    722       graph=graph,
    723       is_chief=is_chief,
    724       logdir=logdir,
    725       init_op=init_op,
    726       init_feed_dict=init_feed_dict,
    727       local_init_op=local_init_op,
    728       ready_for_local_init_op=ready_for_local_init_op,
    729       ready_op=ready_op,
    730       summary_op=summary_op,
    731       summary_writer=summary_writer,
    732       global_step=global_step,
    733       saver=saver,
    734       save_summaries_secs=save_summaries_secs,
    735       save_model_secs=save_interval_secs,
    736       init_fn=init_fn)
    737 
    738   if summary_writer is not None:
    739     train_step_kwargs['summary_writer'] = sv.summary_writer
    740 
    741   total_loss = None
    742   should_retry = True
    743   while should_retry:
    744     try:
    745       should_retry = False
    746       with sv.managed_session(
    747           master, start_standard_services=False, config=session_config) as sess:
    748         logging.info('Starting Session.')
    749         if session_wrapper is not None:
    750           logging.info(
    751               'Wrapping session with wrapper function: %s', session_wrapper)
    752           sess = session_wrapper(sess)
    753         if is_chief:
    754           if logdir:
    755             sv.start_standard_services(sess)
    756         elif startup_delay_steps > 0:
    757            # (use sys.maxsize because sys.maxint doesn't exist in Python 3)
    758           _wait_for_step(sess, global_step,
    759                          min(startup_delay_steps, number_of_steps or
    760                              sys.maxsize))
    761         threads = sv.start_queue_runners(sess)
    762         logging.info('Starting Queues.')
    763         if is_chief and sync_optimizer is not None:
    764           sv.start_queue_runners(sess, chief_queue_runner)
    765           sess.run(init_tokens_op)
    766         try:
    767           while not sv.should_stop():
    768             total_loss, should_stop = train_step_fn(
    769                 sess, train_op, global_step, train_step_kwargs)
    770             if should_stop:
    771               logging.info('Stopping Training.')
    772               sv.request_stop()
    773               break
    774         except errors.OutOfRangeError as e:
    775           # OutOfRangeError is thrown when epoch limit per
    776           # tf.train.limit_epochs is reached.
    777           logging.info('Caught OutOfRangeError. Stopping Training. %s', e)
    778         if logdir and sv.is_chief:
    779           logging.info('Finished training! Saving model to disk.')
    780           sv.saver.save(sess, sv.save_path, global_step=sv.global_step)
    781           sv.stop(
    782               threads,
    783               close_summary_writer=True,
    784               ignore_live_threads=ignore_live_threads)
    785 
    786     except errors.AbortedError:
    787       # Always re-run on AbortedError as it indicates a restart of one of the
    788       # distributed tensorflow servers.
    789       logging.info('Retrying training!')
    790       should_retry = True
    791 
    792   return total_loss
    793