1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Contains TF-Slim code for training models. 16 17 This script contains various functions for training models. These include 18 manipulating gradients, creating a `train_op` (an operation that computes the 19 loss and applies the gradients) and a training loop function. The training loop 20 allows the user to pass in the `train_op` and runs the optimization according 21 to user-specified arguments. Note that the training loop uses the 22 tf.train.Supervisor and its managed_session in its implementation to ensure the 23 ability of worker processes to recover from failures. 24 25 ************************************ 26 * A simple working training script * 27 ************************************ 28 29 # Load data and create the model: 30 images, labels = LoadData(...) 31 predictions = MyModel(images) 32 33 # Define the loss: 34 slim.losses.log_loss(predictions, labels) 35 total_loss = slim.losses.get_total_loss() 36 37 # Define the optimizer: 38 optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) 39 40 # Create the train_op 41 train_op = slim.learning.create_train_op(total_loss, optimizer) 42 43 # Run training. 44 slim.learning.train(train_op, my_log_dir) 45 46 ************************* 47 * Creating the train_op * 48 ************************* 49 50 In order to train, TF-Slim's train loop needs a train_op: an `Operation` that 51 (a) computes the loss, (b) applies the gradients to update the weights and 52 (c) returns the value of the loss. slim.learning.create_train_op creates 53 such an `Operation`. This function also provides the ability to manipulate 54 the gradients using a few arguments: 55 56 # Create the train_op and clip the gradient norms: 57 train_op = slim.learning.create_train_op( 58 total_loss, 59 optimizer, 60 clip_gradient_norm=4) 61 62 # Create the train_op and scale the gradients by providing a map from variable 63 # name (or variable) to a scaling coefficient: 64 gradient_multipliers = { 65 'conv0/weights': 1.2, 66 'fc8/weights': 3.4, 67 } 68 train_op = slim.learning.create_train_op( 69 total_loss, 70 optimizer, 71 gradient_multipliers=gradient_multipliers) 72 73 **************************************************************** 74 * Performing additional (non-gradient) updates during training * 75 **************************************************************** 76 77 Many networks utilize modules, like BatchNorm, that require performing a series 78 of non-gradient updates during training. slim.learning.create_train_op allows 79 a user to pass in a list of update_ops to call along with the gradient updates. 80 81 train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops) 82 83 By default, slim.learning.create_train_op includes all update ops that are 84 part of the `tf.GraphKeys.UPDATE_OPS` collection. Additionally, TF-Slim's 85 slim.batch_norm function adds the moving mean and moving variance updates to 86 this collection. Consequently, users who want to use slim.batch_norm will not 87 need to take any additional steps in order to have the moving mean and moving 88 variance updates be computed. 89 90 However, users with additional, specialized updates can either override the 91 default update ops or simply add additional update ops to the 92 `tf.GraphKeys.UPDATE_OPS` collection: 93 94 # Force TF-Slim NOT to use ANY update_ops: 95 train_op = slim.learning.create_train_op( 96 total_loss, 97 optimizer, 98 update_ops=[]) 99 100 # Use an alternative set of update ops: 101 train_op = slim.learning.create_train_op( 102 total_loss, 103 optimizer, 104 update_ops=my_other_update_ops) 105 106 # Use an alternative set of update ops in addition to the default updates: 107 tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update0) 108 tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, my_update1) 109 110 train_op = slim.learning.create_train_op( 111 total_loss, 112 optimizer) 113 114 # Which is the same as: 115 train_op = slim.learning.create_train_op( 116 total_loss, 117 optimizer, 118 update_ops=tf.get_collection(tf.GraphKeys.UPDATE_OPS)) 119 120 ****************************************** 121 * Initializing a model from a checkpoint * 122 ****************************************** 123 124 It is common to want to 'warm-start' a model from a pre-trained checkpoint. 125 TF-Slim provides a convenient mechanism for doing so: 126 127 ... 128 129 # Create the train_op 130 train_op = slim.learning.create_train_op(total_loss, optimizer) 131 132 # Create the initial assignment op 133 checkpoint_path = '/path/to/old_model_checkpoint' 134 variables_to_restore = slim.get_model_variables() 135 init_assign_op, init_feed_dict = slim.assign_from_checkpoint( 136 checkpoint_path, variables_to_restore) 137 138 # Create an initial assignment function. 139 def InitAssignFn(sess): 140 sess.run(init_assign_op, init_feed_dict) 141 142 # Run training. 143 slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn) 144 145 *************************************************************************** 146 * Initializing a model from a checkpoint whose variable names don't match * 147 *************************************************************************** 148 149 At times, a user may want to initialize a new model with values from a 150 checkpoint whose variable names do not match those of the current model. In this 151 case, one needs to create a mapping from the checkpoint variable names to the 152 current model variables. This requires only a small modification of the code 153 above: 154 ... 155 # Creates a model with two variables, var0 and var1 156 predictions = MyModel(images) 157 ... 158 159 # Create the train_op 160 train_op = slim.learning.create_train_op(total_loss, optimizer) 161 162 checkpoint_path = '/path/to/old_model_checkpoint' 163 164 # Create the mapping: 165 variables_to_restore = { 166 'name_var_0_in_checkpoint': slim.get_unique_variable('var0'), 167 'name_var_1_in_checkpoint': slim.get_unique_variable('var1') 168 } 169 init_assign_op, init_feed_dict = slim.assign_from_checkpoint( 170 checkpoint_path, variables_to_restore) 171 172 # Create an initial assignment function. 173 def InitAssignFn(sess): 174 sess.run(init_assign_op, init_feed_dict) 175 176 # Run training. 177 slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn) 178 179 180 ************************************************* 181 * Fine-Tuning Part of a model from a checkpoint * 182 ************************************************* 183 184 Rather than initializing all of the weights of a given model, we sometimes 185 only want to restore some of the weights from a checkpoint. To do this, one 186 need only filter those variables to initialize as follows: 187 188 ... 189 190 # Create the train_op 191 train_op = slim.learning.create_train_op(total_loss, optimizer) 192 193 checkpoint_path = '/path/to/old_model_checkpoint' 194 195 # Specify the variables to restore via a list of inclusion or exclusion 196 # patterns: 197 variables_to_restore = slim.get_variables_to_restore( 198 include=["conv"], exclude=["fc8", "fc9]) 199 # or 200 variables_to_restore = slim.get_variables_to_restore(exclude=["conv"]) 201 202 init_assign_op, init_feed_dict = slim.assign_from_checkpoint( 203 checkpoint_path, variables_to_restore) 204 205 # Create an initial assignment function. 206 def InitAssignFn(sess): 207 sess.run(init_assign_op, init_feed_dict) 208 209 # Run training. 210 slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn) 211 212 ****************************************************** 213 * Initializing model variables from values in memory * 214 ****************************************************** 215 216 One may want to initialize the weights of a model from values from an arbitrary 217 source (a text document, matlab file, etc). While this is technically feasible 218 using plain TensorFlow, it also results in the values of your weights being 219 stored in the graph. For large models, this becomes prohibitively large. TF-Slim 220 allows you to perform this initial assignment without having to store the values 221 of the initial model in the graph itself by using placeholders and a feed 222 dictionary: 223 224 ... 225 226 # Create the train_op 227 train_op = slim.learning.create_train_op(total_loss, optimizer) 228 229 # Create the mapping from variable names to values: 230 var0_initial_value = ReadFromDisk(...) 231 var1_initial_value = ReadFromDisk(...) 232 233 var_names_to_values = { 234 'var0': var0_initial_value, 235 'var1': var1_initial_value, 236 } 237 init_assign_op, init_feed_dict = slim.assign_from_values(var_names_to_values) 238 239 # Create an initial assignment function. 240 def InitAssignFn(sess): 241 sess.run(init_assign_op, init_feed_dict) 242 243 # Run training. 244 slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn) 245 """ 246 from __future__ import absolute_import 247 from __future__ import division 248 from __future__ import print_function 249 250 import os 251 import sys 252 import time 253 254 from tensorflow.contrib.training.python.training import training 255 from tensorflow.core.protobuf import config_pb2 256 from tensorflow.python.client import timeline 257 from tensorflow.python.framework import constant_op 258 from tensorflow.python.framework import errors 259 from tensorflow.python.framework import ops 260 from tensorflow.python.lib.io import file_io 261 from tensorflow.python.ops import clip_ops 262 from tensorflow.python.ops import control_flow_ops 263 from tensorflow.python.ops import lookup_ops 264 from tensorflow.python.ops import math_ops 265 from tensorflow.python.ops import variables 266 from tensorflow.python.platform import tf_logging as logging 267 from tensorflow.python.summary import summary 268 from tensorflow.python.training import optimizer as tf_optimizer 269 from tensorflow.python.training import saver as tf_saver 270 from tensorflow.python.training import supervisor 271 from tensorflow.python.training import sync_replicas_optimizer 272 from tensorflow.python.training import training_util 273 274 __all__ = [ 275 'add_gradients_summaries', 'clip_gradient_norms', 'multiply_gradients', 276 'create_train_op', 'train_step', 'train' 277 ] 278 279 280 def clip_gradient_norms(gradients_to_variables, max_norm): 281 """Clips the gradients by the given value. 282 283 Args: 284 gradients_to_variables: A list of gradient to variable pairs (tuples). 285 max_norm: the maximum norm value. 286 287 Returns: 288 A list of clipped gradient to variable pairs. 289 """ 290 clipped_grads_and_vars = [] 291 for grad, var in gradients_to_variables: 292 if grad is not None: 293 if isinstance(grad, ops.IndexedSlices): 294 tmp = clip_ops.clip_by_norm(grad.values, max_norm) 295 grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape) 296 else: 297 grad = clip_ops.clip_by_norm(grad, max_norm) 298 clipped_grads_and_vars.append((grad, var)) 299 return clipped_grads_and_vars 300 301 302 def multiply_gradients(grads_and_vars, gradient_multipliers): 303 """Multiply specified gradients. 304 305 Args: 306 grads_and_vars: A list of gradient to variable pairs (tuples). 307 gradient_multipliers: A map from either `Variables` or `Variable` op names 308 to the coefficient by which the associated gradient should be scaled. 309 310 Returns: 311 The updated list of gradient to variable pairs. 312 313 Raises: 314 ValueError: If `grads_and_vars` is not a list or if `gradient_multipliers` 315 is empty or None or if `gradient_multipliers` is not a dictionary. 316 """ 317 if not isinstance(grads_and_vars, list): 318 raise ValueError('`grads_and_vars` must be a list.') 319 if not gradient_multipliers: 320 raise ValueError('`gradient_multipliers` is empty.') 321 if not isinstance(gradient_multipliers, dict): 322 raise ValueError('`gradient_multipliers` must be a dict.') 323 324 multiplied_grads_and_vars = [] 325 for grad, var in grads_and_vars: 326 if var in gradient_multipliers or var.op.name in gradient_multipliers: 327 key = var if var in gradient_multipliers else var.op.name 328 if grad is None: 329 raise ValueError('Requested multiple of `None` gradient.') 330 331 multiplier = gradient_multipliers[key] 332 if not isinstance(multiplier, ops.Tensor): 333 multiplier = constant_op.constant(multiplier, dtype=grad.dtype) 334 335 if isinstance(grad, ops.IndexedSlices): 336 tmp = grad.values * multiplier 337 grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape) 338 else: 339 grad *= multiplier 340 multiplied_grads_and_vars.append((grad, var)) 341 return multiplied_grads_and_vars 342 343 344 def add_gradients_summaries(grads_and_vars): 345 """Add summaries to gradients. 346 347 Args: 348 grads_and_vars: A list of gradient to variable pairs (tuples). 349 350 Returns: 351 The list of created summaries. 352 """ 353 summaries = [] 354 for grad, var in grads_and_vars: 355 if grad is not None: 356 if isinstance(grad, ops.IndexedSlices): 357 grad_values = grad.values 358 else: 359 grad_values = grad 360 summaries.append( 361 summary.histogram(var.op.name + '/gradient', grad_values)) 362 summaries.append( 363 summary.scalar(var.op.name + '/gradient_norm', 364 clip_ops.global_norm([grad_values]))) 365 else: 366 logging.info('Var %s has no gradient', var.op.name) 367 368 return summaries 369 370 371 _USE_GLOBAL_STEP = 0 372 373 374 def create_train_op(total_loss, 375 optimizer, 376 global_step=_USE_GLOBAL_STEP, 377 update_ops=None, 378 variables_to_train=None, 379 clip_gradient_norm=0, 380 summarize_gradients=False, 381 gate_gradients=tf_optimizer.Optimizer.GATE_OP, 382 aggregation_method=None, 383 colocate_gradients_with_ops=False, 384 gradient_multipliers=None, 385 check_numerics=True): 386 """Creates an `Operation` that evaluates the gradients and returns the loss. 387 388 Args: 389 total_loss: A `Tensor` representing the total loss. 390 optimizer: A tf.Optimizer to use for computing the gradients. 391 global_step: A `Tensor` representing the global step variable. If left as 392 `_USE_GLOBAL_STEP`, then slim.variables.global_step() is used. 393 update_ops: An optional list of updates to execute. If `update_ops` is 394 `None`, then the update ops are set to the contents of the 395 `tf.GraphKeys.UPDATE_OPS` collection. If `update_ops` is not `None`, but 396 it doesn't contain all of the update ops in `tf.GraphKeys.UPDATE_OPS`, 397 a warning will be displayed. 398 variables_to_train: an optional list of variables to train. If None, it will 399 default to all tf.trainable_variables(). 400 clip_gradient_norm: If greater than 0 then the gradients would be clipped 401 by it. 402 summarize_gradients: Whether or not add summaries for each gradient. 403 gate_gradients: How to gate the computation of gradients. See tf.Optimizer. 404 aggregation_method: Specifies the method used to combine gradient terms. 405 Valid values are defined in the class `AggregationMethod`. 406 colocate_gradients_with_ops: Whether or not to try colocating the gradients 407 with the ops that generated them. 408 gradient_multipliers: A dictionary of either `Variables` or `Variable` op 409 names to the coefficient by which the associated gradient should be 410 scaled. 411 check_numerics: Whether or not we apply check_numerics. 412 413 Returns: 414 A `Tensor` that when evaluated, computes the gradients and returns the total 415 loss value. 416 """ 417 def transform_grads_fn(grads): 418 if gradient_multipliers: 419 with ops.name_scope('multiply_grads'): 420 grads = multiply_gradients(grads, gradient_multipliers) 421 422 # Clip gradients. 423 if clip_gradient_norm > 0: 424 with ops.name_scope('clip_grads'): 425 grads = clip_gradient_norms(grads, clip_gradient_norm) 426 return grads 427 428 return training.create_train_op( 429 total_loss=total_loss, 430 optimizer=optimizer, 431 global_step=global_step, 432 update_ops=update_ops, 433 variables_to_train=variables_to_train, 434 transform_grads_fn=transform_grads_fn, 435 summarize_gradients=summarize_gradients, 436 gate_gradients=gate_gradients, 437 aggregation_method=aggregation_method, 438 colocate_gradients_with_ops=colocate_gradients_with_ops, 439 check_numerics=check_numerics) 440 441 442 def _wait_for_step(sess, global_step, step): 443 """Wait till the global step has reached at least 'step'. 444 445 Args: 446 sess: A session. 447 global_step: A Tensor. 448 step: Int. The global step to reach. 449 """ 450 while True: 451 if training_util.global_step(sess, global_step) >= step: 452 break 453 time.sleep(1.0) 454 455 456 def train_step(sess, train_op, global_step, train_step_kwargs): 457 """Function that takes a gradient step and specifies whether to stop. 458 459 Args: 460 sess: The current session. 461 train_op: An `Operation` that evaluates the gradients and returns the 462 total loss. 463 global_step: A `Tensor` representing the global training step. 464 train_step_kwargs: A dictionary of keyword arguments. 465 466 Returns: 467 The total loss and a boolean indicating whether or not to stop training. 468 469 Raises: 470 ValueError: if 'should_trace' is in `train_step_kwargs` but `logdir` is not. 471 """ 472 start_time = time.time() 473 474 trace_run_options = None 475 run_metadata = None 476 if 'should_trace' in train_step_kwargs: 477 if 'logdir' not in train_step_kwargs: 478 raise ValueError('logdir must be present in train_step_kwargs when ' 479 'should_trace is present') 480 if sess.run(train_step_kwargs['should_trace']): 481 trace_run_options = config_pb2.RunOptions( 482 trace_level=config_pb2.RunOptions.FULL_TRACE) 483 run_metadata = config_pb2.RunMetadata() 484 485 total_loss, np_global_step = sess.run([train_op, global_step], 486 options=trace_run_options, 487 run_metadata=run_metadata) 488 time_elapsed = time.time() - start_time 489 490 if run_metadata is not None: 491 tl = timeline.Timeline(run_metadata.step_stats) 492 trace = tl.generate_chrome_trace_format() 493 trace_filename = os.path.join(train_step_kwargs['logdir'], 494 'tf_trace-%d.json' % np_global_step) 495 logging.info('Writing trace to %s', trace_filename) 496 file_io.write_string_to_file(trace_filename, trace) 497 if 'summary_writer' in train_step_kwargs: 498 train_step_kwargs['summary_writer'].add_run_metadata(run_metadata, 499 'run_metadata-%d' % 500 np_global_step) 501 502 if 'should_log' in train_step_kwargs: 503 if sess.run(train_step_kwargs['should_log']): 504 logging.info('global step %d: loss = %.4f (%.3f sec/step)', 505 np_global_step, total_loss, time_elapsed) 506 507 # TODO(nsilberman): figure out why we can't put this into sess.run. The 508 # issue right now is that the stop check depends on the global step. The 509 # increment of global step often happens via the train op, which used 510 # created using optimizer.apply_gradients. 511 # 512 # Since running `train_op` causes the global step to be incremented, one 513 # would expected that using a control dependency would allow the 514 # should_stop check to be run in the same session.run call: 515 # 516 # with ops.control_dependencies([train_op]): 517 # should_stop_op = ... 518 # 519 # However, this actually seems not to work on certain platforms. 520 if 'should_stop' in train_step_kwargs: 521 should_stop = sess.run(train_step_kwargs['should_stop']) 522 else: 523 should_stop = False 524 525 return total_loss, should_stop 526 527 528 _USE_DEFAULT = 0 529 530 531 def train(train_op, 532 logdir, 533 train_step_fn=train_step, 534 train_step_kwargs=_USE_DEFAULT, 535 log_every_n_steps=1, 536 graph=None, 537 master='', 538 is_chief=True, 539 global_step=None, 540 number_of_steps=None, 541 init_op=_USE_DEFAULT, 542 init_feed_dict=None, 543 local_init_op=_USE_DEFAULT, 544 init_fn=None, 545 ready_op=_USE_DEFAULT, 546 summary_op=_USE_DEFAULT, 547 save_summaries_secs=600, 548 summary_writer=_USE_DEFAULT, 549 startup_delay_steps=0, 550 saver=None, 551 save_interval_secs=600, 552 sync_optimizer=None, 553 session_config=None, 554 session_wrapper=None, 555 trace_every_n_steps=None, 556 ignore_live_threads=False): 557 """Runs a training loop using a TensorFlow supervisor. 558 559 When the sync_optimizer is supplied, gradient updates are applied 560 synchronously. Otherwise, gradient updates are applied asynchronous. 561 562 Args: 563 train_op: A `Tensor` that, when executed, will apply the gradients and 564 return the loss value. 565 logdir: The directory where training logs are written to. If None, model 566 checkpoints and summaries will not be written. 567 train_step_fn: The function to call in order to execute a single gradient 568 step. The function must have take exactly four arguments: the current 569 session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary. 570 train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By 571 default, two `Boolean`, scalar ops called "should_stop" and "should_log" 572 are provided. 573 log_every_n_steps: The frequency, in terms of global steps, that the loss 574 and global step and logged. 575 graph: The graph to pass to the supervisor. If no graph is supplied the 576 default graph is used. 577 master: The address of the tensorflow master. 578 is_chief: Specifies whether or not the training is being run by the primary 579 replica during replica training. 580 global_step: The `Tensor` representing the global step. If left as `None`, 581 then slim.variables.get_or_create_global_step() is used. 582 number_of_steps: The max number of gradient steps to take during training, 583 as measured by 'global_step': training will stop if global_step is 584 greater than 'number_of_steps'. If the value is left as None, training 585 proceeds indefinitely. 586 init_op: The initialization operation. If left to its default value, then 587 the session is initialized by calling `tf.global_variables_initializer()`. 588 init_feed_dict: A feed dictionary to use when executing the `init_op`. 589 local_init_op: The local initialization operation. If left to its default 590 value, then the session is initialized by calling 591 `tf.local_variables_initializer()` and `tf.tables_initializer()`. 592 init_fn: An optional callable to be executed after `init_op` is called. The 593 callable must accept one argument, the session being initialized. 594 ready_op: Operation to check if the model is ready to use. If left to its 595 default value, then the session checks for readiness by calling 596 `tf.report_uninitialized_variables()`. 597 summary_op: The summary operation. 598 save_summaries_secs: How often, in seconds, to save summaries. 599 summary_writer: `SummaryWriter` to use. Can be `None` 600 to indicate that no summaries should be written. If unset, we 601 create a SummaryWriter. 602 startup_delay_steps: The number of steps to wait for before beginning. Note 603 that this must be 0 if a sync_optimizer is supplied. 604 saver: Saver to save checkpoints. If None, a default one will be created 605 and used. 606 save_interval_secs: How often, in seconds, to save the model to `logdir`. 607 sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of 608 them. If the argument is supplied, gradient updates will be synchronous. 609 If left as `None`, gradient updates will be asynchronous. 610 session_config: An instance of `tf.ConfigProto` that will be used to 611 configure the `Session`. If left as `None`, the default will be used. 612 session_wrapper: A function that takes a `tf.Session` object as the only 613 argument and returns a wrapped session object that has the same methods 614 that the original object has, or `None`. Iff not `None`, the wrapped 615 object will be used for training. 616 trace_every_n_steps: produce and save a `Timeline` in Chrome trace format 617 and add it to the summaries every `trace_every_n_steps`. If None, no trace 618 information will be produced or saved. 619 ignore_live_threads: If `True` ignores threads that remain running after 620 a grace period when stopping the supervisor, instead of raising a 621 RuntimeError. 622 623 Returns: 624 the value of the loss function after training. 625 626 Raises: 627 ValueError: if `train_op` is empty or if `startup_delay_steps` is 628 non-zero when `sync_optimizer` is supplied, if `number_of_steps` is 629 negative, or if `trace_every_n_steps` is not `None` and no `logdir` is 630 provided. 631 """ 632 if train_op is None: 633 raise ValueError('train_op cannot be None.') 634 635 if logdir is None: 636 if summary_op != _USE_DEFAULT: 637 raise ValueError('Cannot provide summary_op because logdir=None') 638 if saver is not None: 639 raise ValueError('Cannot provide saver because logdir=None') 640 if trace_every_n_steps is not None: 641 raise ValueError('Cannot provide trace_every_n_steps because ' 642 'logdir=None') 643 644 if isinstance(sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): 645 sync_optimizer = [sync_optimizer] 646 if sync_optimizer is not None and startup_delay_steps > 0: 647 raise ValueError( 648 'startup_delay_steps must be zero when sync_optimizer is supplied.') 649 650 if number_of_steps is not None and number_of_steps <= 0: 651 raise ValueError( 652 '`number_of_steps` must be either None or a positive number.') 653 654 graph = graph or ops.get_default_graph() 655 with graph.as_default(): 656 if global_step is None: 657 global_step = training_util.get_or_create_global_step() 658 saver = saver or tf_saver.Saver() 659 660 if sync_optimizer is not None: 661 for opt in sync_optimizer: 662 if not isinstance(opt, sync_replicas_optimizer.SyncReplicasOptimizer): 663 raise ValueError( 664 '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.') 665 666 with ops.name_scope('init_ops'): 667 if init_op == _USE_DEFAULT: 668 init_op = variables.global_variables_initializer() 669 670 if ready_op == _USE_DEFAULT: 671 ready_op = variables.report_uninitialized_variables() 672 673 if local_init_op == _USE_DEFAULT: 674 local_init_op = control_flow_ops.group( 675 variables.local_variables_initializer(), 676 lookup_ops.tables_initializer()) 677 678 if sync_optimizer is not None and isinstance(sync_optimizer, list): 679 with ops.control_dependencies([local_init_op] if local_init_op is 680 not None else []): 681 if is_chief: 682 local_init_op = control_flow_ops.group( 683 *[opt.chief_init_op for opt in sync_optimizer]) 684 else: 685 local_init_op = control_flow_ops.group( 686 *[opt.local_step_init_op for opt in sync_optimizer]) 687 ready_for_local_init_op = control_flow_ops.group( 688 *[opt.ready_for_local_init_op for opt in sync_optimizer]) 689 else: 690 ready_for_local_init_op = None 691 692 if summary_op == _USE_DEFAULT: 693 summary_op = summary.merge_all() 694 695 if summary_writer == _USE_DEFAULT: 696 summary_writer = supervisor.Supervisor.USE_DEFAULT 697 698 if is_chief and sync_optimizer is not None: 699 # Need to create these BEFORE the supervisor finalizes the graph: 700 init_tokens_op = [opt.get_init_tokens_op() for opt in sync_optimizer] 701 chief_queue_runner = [ 702 opt.get_chief_queue_runner() for opt in sync_optimizer] 703 704 if train_step_kwargs == _USE_DEFAULT: 705 with ops.name_scope('train_step'): 706 train_step_kwargs = {} 707 708 if number_of_steps: 709 should_stop_op = math_ops.greater_equal(global_step, number_of_steps) 710 else: 711 should_stop_op = constant_op.constant(False) 712 train_step_kwargs['should_stop'] = should_stop_op 713 if log_every_n_steps > 0: 714 train_step_kwargs['should_log'] = math_ops.equal( 715 math_ops.mod(global_step, log_every_n_steps), 0) 716 if is_chief and trace_every_n_steps is not None: 717 train_step_kwargs['should_trace'] = math_ops.equal( 718 math_ops.mod(global_step, trace_every_n_steps), 0) 719 train_step_kwargs['logdir'] = logdir 720 721 sv = supervisor.Supervisor( 722 graph=graph, 723 is_chief=is_chief, 724 logdir=logdir, 725 init_op=init_op, 726 init_feed_dict=init_feed_dict, 727 local_init_op=local_init_op, 728 ready_for_local_init_op=ready_for_local_init_op, 729 ready_op=ready_op, 730 summary_op=summary_op, 731 summary_writer=summary_writer, 732 global_step=global_step, 733 saver=saver, 734 save_summaries_secs=save_summaries_secs, 735 save_model_secs=save_interval_secs, 736 init_fn=init_fn) 737 738 if summary_writer is not None: 739 train_step_kwargs['summary_writer'] = sv.summary_writer 740 741 total_loss = None 742 should_retry = True 743 while should_retry: 744 try: 745 should_retry = False 746 with sv.managed_session( 747 master, start_standard_services=False, config=session_config) as sess: 748 logging.info('Starting Session.') 749 if session_wrapper is not None: 750 logging.info( 751 'Wrapping session with wrapper function: %s', session_wrapper) 752 sess = session_wrapper(sess) 753 if is_chief: 754 if logdir: 755 sv.start_standard_services(sess) 756 elif startup_delay_steps > 0: 757 # (use sys.maxsize because sys.maxint doesn't exist in Python 3) 758 _wait_for_step(sess, global_step, 759 min(startup_delay_steps, number_of_steps or 760 sys.maxsize)) 761 threads = sv.start_queue_runners(sess) 762 logging.info('Starting Queues.') 763 if is_chief and sync_optimizer is not None: 764 sv.start_queue_runners(sess, chief_queue_runner) 765 sess.run(init_tokens_op) 766 try: 767 while not sv.should_stop(): 768 total_loss, should_stop = train_step_fn( 769 sess, train_op, global_step, train_step_kwargs) 770 if should_stop: 771 logging.info('Stopping Training.') 772 sv.request_stop() 773 break 774 except errors.OutOfRangeError as e: 775 # OutOfRangeError is thrown when epoch limit per 776 # tf.train.limit_epochs is reached. 777 logging.info('Caught OutOfRangeError. Stopping Training. %s', e) 778 if logdir and sv.is_chief: 779 logging.info('Finished training! Saving model to disk.') 780 sv.saver.save(sess, sv.save_path, global_step=sv.global_step) 781 sv.stop( 782 threads, 783 close_summary_writer=True, 784 ignore_live_threads=ignore_live_threads) 785 786 except errors.AbortedError: 787 # Always re-run on AbortedError as it indicates a restart of one of the 788 # distributed tensorflow servers. 789 logging.info('Retrying training!') 790 should_retry = True 791 792 return total_loss 793