Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 # http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.training.training."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 
     23 import numpy as np
     24 
     25 from tensorflow.contrib.framework.python.ops import variables as variables_lib
     26 from tensorflow.contrib.layers.python.layers import layers
     27 from tensorflow.contrib.training.python.training import training
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import random_seed
     32 from tensorflow.python.ops import gradients_impl
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops import variables as variables_lib2
     35 from tensorflow.python.ops.losses import losses
     36 from tensorflow.python.platform import gfile
     37 from tensorflow.python.platform import test
     38 from tensorflow.python.training import basic_session_run_hooks
     39 from tensorflow.python.training import gradient_descent
     40 from tensorflow.python.training import monitored_session
     41 from tensorflow.python.training import saver as saver_lib
     42 # pylint: enable=g-import-not-at-top
     43 
     44 
     45 def logistic_classifier(inputs):
     46   return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid)
     47 
     48 
     49 def batchnorm_classifier(inputs):
     50   inputs = layers.batch_norm(inputs, decay=0.1, fused=False)
     51   return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid)
     52 
     53 
     54 class ClipGradsTest(test.TestCase):
     55 
     56   def testClipGrads(self):
     57     xs = variables_lib2.Variable(0.0)
     58     ys = xs * 4.0
     59     grads = gradients_impl.gradients([ys], [xs])
     60     gradients_to_variables = list(zip(grads, [xs]))
     61     clipped_gradients_to_variables = training.clip_gradient_norms(
     62         gradients_to_variables, 3.0)
     63 
     64     with self.test_session() as session:
     65       session.run(variables_lib2.global_variables_initializer())
     66       self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
     67       self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
     68 
     69   def testClipGradsFn(self):
     70     xs = variables_lib2.Variable(0.0)
     71     ys = xs * 4.0
     72     grads = gradients_impl.gradients([ys], [xs])
     73     gradients_to_variables = list(zip(grads, [xs]))
     74     clipped_gradients_to_variables = training.clip_gradient_norms_fn(3.0)(
     75         gradients_to_variables)
     76 
     77     with self.test_session() as session:
     78       session.run(variables_lib2.global_variables_initializer())
     79       self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval())
     80       self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval())
     81 
     82 
     83 class CreateTrainOpTest(test.TestCase):
     84 
     85   def setUp(self):
     86     np.random.seed(0)
     87 
     88     # Create an easy training set:
     89     self._inputs = np.random.rand(16, 4).astype(np.float32)
     90     self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
     91 
     92   def testTrainOpInCollection(self):
     93     with ops.Graph().as_default():
     94       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
     95       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
     96 
     97       tf_predictions = batchnorm_classifier(tf_inputs)
     98       loss = losses.log_loss(tf_labels, tf_predictions)
     99       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    100       train_op = training.create_train_op(loss, optimizer)
    101 
    102       # Make sure the training op was recorded in the proper collection
    103       self.assertTrue(train_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
    104 
    105   def testUseUpdateOps(self):
    106     with ops.Graph().as_default():
    107       random_seed.set_random_seed(0)
    108       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    109       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    110 
    111       expected_mean = np.mean(self._inputs, axis=(0))
    112       expected_var = np.var(self._inputs, axis=(0))
    113 
    114       tf_predictions = batchnorm_classifier(tf_inputs)
    115       loss = losses.log_loss(tf_labels, tf_predictions)
    116       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    117 
    118       train_op = training.create_train_op(loss, optimizer)
    119 
    120       moving_mean = variables_lib.get_variables_by_name('moving_mean')[0]
    121       moving_variance = variables_lib.get_variables_by_name('moving_variance')[
    122           0]
    123 
    124       with self.test_session() as session:
    125         # Initialize all variables
    126         session.run(variables_lib2.global_variables_initializer())
    127         mean, variance = session.run([moving_mean, moving_variance])
    128         # After initialization moving_mean == 0 and moving_variance == 1.
    129         self.assertAllClose(mean, [0] * 4)
    130         self.assertAllClose(variance, [1] * 4)
    131 
    132         for _ in range(10):
    133           session.run(train_op)
    134 
    135         mean = moving_mean.eval()
    136         variance = moving_variance.eval()
    137         # After 10 updates with decay 0.1 moving_mean == expected_mean and
    138         # moving_variance == expected_var.
    139         self.assertAllClose(mean, expected_mean)
    140         self.assertAllClose(variance, expected_var)
    141 
    142   def testEmptyUpdateOps(self):
    143     with ops.Graph().as_default():
    144       random_seed.set_random_seed(0)
    145       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    146       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    147 
    148       tf_predictions = batchnorm_classifier(tf_inputs)
    149       loss = losses.log_loss(tf_labels, tf_predictions)
    150       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    151       train_op = training.create_train_op(loss, optimizer, update_ops=[])
    152 
    153       moving_mean = variables_lib.get_variables_by_name('moving_mean')[0]
    154       moving_variance = variables_lib.get_variables_by_name('moving_variance')[
    155           0]
    156 
    157       with self.test_session() as session:
    158         # Initialize all variables
    159         session.run(variables_lib2.global_variables_initializer())
    160         mean, variance = session.run([moving_mean, moving_variance])
    161         # After initialization moving_mean == 0 and moving_variance == 1.
    162         self.assertAllClose(mean, [0] * 4)
    163         self.assertAllClose(variance, [1] * 4)
    164 
    165         for _ in range(10):
    166           session.run(train_op)
    167 
    168         mean = moving_mean.eval()
    169         variance = moving_variance.eval()
    170 
    171         # Since we skip update_ops the moving_vars are not updated.
    172         self.assertAllClose(mean, [0] * 4)
    173         self.assertAllClose(variance, [1] * 4)
    174 
    175   def testGlobalStepIsIncrementedByDefault(self):
    176     with ops.Graph().as_default():
    177       random_seed.set_random_seed(0)
    178       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    179       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    180 
    181       tf_predictions = batchnorm_classifier(tf_inputs)
    182       loss = losses.log_loss(tf_labels, tf_predictions)
    183       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    184       train_op = training.create_train_op(loss, optimizer)
    185 
    186       global_step = variables_lib.get_or_create_global_step()
    187 
    188       with self.test_session() as session:
    189         # Initialize all variables
    190         session.run(variables_lib2.global_variables_initializer())
    191 
    192         for _ in range(10):
    193           session.run(train_op)
    194 
    195         # After 10 updates global_step should be 10.
    196         self.assertAllClose(global_step.eval(), 10)
    197 
    198   def testGlobalStepNotIncrementedWhenSetToNone(self):
    199     with ops.Graph().as_default():
    200       random_seed.set_random_seed(0)
    201       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    202       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    203 
    204       tf_predictions = batchnorm_classifier(tf_inputs)
    205       loss = losses.log_loss(tf_labels, tf_predictions)
    206       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    207       train_op = training.create_train_op(loss, optimizer, global_step=None)
    208 
    209       global_step = variables_lib.get_or_create_global_step()
    210 
    211       with self.test_session() as session:
    212         # Initialize all variables
    213         session.run(variables_lib2.global_variables_initializer())
    214 
    215         for _ in range(10):
    216           session.run(train_op)
    217 
    218         # Since train_op don't use global_step it shouldn't change.
    219         self.assertAllClose(global_step.eval(), 0)
    220 
    221 
    222 class TrainBatchNormClassifierTest(test.TestCase):
    223 
    224   def setUp(self):
    225     # Create an easy training set:
    226     np.random.seed(0)
    227 
    228     self._inputs = np.zeros((16, 4))
    229     self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
    230 
    231     for i in range(16):
    232       j = int(2 * self._labels[i] + np.random.randint(0, 2))
    233       self._inputs[i, j] = 1
    234 
    235   def testTrainWithNoInitAssignCanAchieveZeroLoss(self):
    236     with ops.Graph().as_default():
    237       random_seed.set_random_seed(0)
    238       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    239       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    240 
    241       tf_predictions = batchnorm_classifier(tf_inputs)
    242       losses.log_loss(tf_labels, tf_predictions)
    243       total_loss = losses.get_total_loss()
    244 
    245       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    246 
    247       train_op = training.create_train_op(total_loss, optimizer)
    248 
    249       loss = training.train(
    250           train_op,
    251           None,
    252           hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)],
    253           save_summaries_steps=None,
    254           save_checkpoint_secs=None)
    255       self.assertLess(loss, .1)
    256 
    257 
    258 class TrainTest(test.TestCase):
    259 
    260   def setUp(self):
    261     # Create an easy training set:
    262     np.random.seed(0)
    263 
    264     self._inputs = np.zeros((16, 4))
    265     self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
    266 
    267     for i in range(16):
    268       j = int(2 * self._labels[i] + np.random.randint(0, 2))
    269       self._inputs[i, j] = 1
    270 
    271   def testCanAchieveZeroLoss(self):
    272     with ops.Graph().as_default():
    273       random_seed.set_random_seed(0)
    274       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    275       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    276 
    277       tf_predictions = logistic_classifier(tf_inputs)
    278       losses.log_loss(tf_labels, tf_predictions)
    279       total_loss = losses.get_total_loss()
    280       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    281       train_op = training.create_train_op(total_loss, optimizer)
    282 
    283       loss = training.train(
    284           train_op,
    285           None,
    286           hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)],
    287           save_summaries_steps=None,
    288           save_checkpoint_secs=None)
    289       self.assertIsNotNone(loss)
    290       self.assertLess(loss, .015)
    291 
    292   def testTrainWithLocalVariable(self):
    293     with ops.Graph().as_default():
    294       random_seed.set_random_seed(0)
    295       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    296       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    297 
    298       local_multiplier = variables_lib.local_variable(1.0)
    299 
    300       tf_predictions = logistic_classifier(tf_inputs) * local_multiplier
    301       losses.log_loss(tf_labels, tf_predictions)
    302       total_loss = losses.get_total_loss()
    303       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    304       train_op = training.create_train_op(total_loss, optimizer)
    305 
    306       loss = training.train(
    307           train_op,
    308           None,
    309           hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)],
    310           save_summaries_steps=None,
    311           save_checkpoint_secs=None)
    312       self.assertIsNotNone(loss)
    313       self.assertLess(loss, .015)
    314 
    315   def testResumeTrainAchievesRoughlyTheSameLoss(self):
    316     number_of_steps = [300, 1, 5]
    317     logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')
    318 
    319     for i in range(len(number_of_steps)):
    320       with ops.Graph().as_default():
    321         random_seed.set_random_seed(i)
    322         tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    323         tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    324 
    325         tf_predictions = logistic_classifier(tf_inputs)
    326         losses.log_loss(tf_labels, tf_predictions)
    327         total_loss = losses.get_total_loss()
    328 
    329         optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    330 
    331         train_op = training.create_train_op(total_loss, optimizer)
    332 
    333         saver = saver_lib.Saver()
    334 
    335         loss = training.train(
    336             train_op,
    337             logdir,
    338             hooks=[
    339                 basic_session_run_hooks.StopAtStepHook(
    340                     num_steps=number_of_steps[i]),
    341                 basic_session_run_hooks.CheckpointSaverHook(
    342                     logdir, save_steps=50, saver=saver),
    343             ],
    344             save_checkpoint_secs=None,
    345             save_summaries_steps=None)
    346         self.assertIsNotNone(loss)
    347         self.assertLess(loss, .015)
    348 
    349   def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0):
    350     tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    351     tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    352 
    353     tf_predictions = logistic_classifier(tf_inputs)
    354     losses.log_loss(tf_labels, tf_predictions)
    355     total_loss = losses.get_total_loss()
    356 
    357     optimizer = gradient_descent.GradientDescentOptimizer(
    358         learning_rate=learning_rate)
    359 
    360     def transform_grads_fn(grads):
    361       if gradient_multiplier != 1.0:
    362         variables = variables_lib2.trainable_variables()
    363         gradient_multipliers = {var: gradient_multiplier for var in variables}
    364 
    365         with ops.name_scope('multiply_grads'):
    366           return training.multiply_gradients(grads, gradient_multipliers)
    367       else:
    368         return grads
    369 
    370     return training.create_train_op(
    371         total_loss, optimizer, transform_grads_fn=transform_grads_fn)
    372 
    373   def testTrainWithInitFromCheckpoint(self):
    374     logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
    375     logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
    376 
    377     if gfile.Exists(logdir1):  # For running on jenkins.
    378       gfile.DeleteRecursively(logdir1)
    379     if gfile.Exists(logdir2):  # For running on jenkins.
    380       gfile.DeleteRecursively(logdir2)
    381 
    382     # First, train the model one step (make sure the error is high).
    383     with ops.Graph().as_default():
    384       random_seed.set_random_seed(0)
    385       train_op = self.create_train_op()
    386       saver = saver_lib.Saver()
    387       loss = training.train(
    388           train_op,
    389           logdir1,
    390           hooks=[
    391               basic_session_run_hooks.CheckpointSaverHook(
    392                   logdir1, save_steps=1, saver=saver),
    393               basic_session_run_hooks.StopAtStepHook(num_steps=1),
    394           ],
    395           save_checkpoint_secs=None,
    396           save_summaries_steps=None)
    397       self.assertGreater(loss, .5)
    398 
    399     # Next, train the model to convergence.
    400     with ops.Graph().as_default():
    401       random_seed.set_random_seed(1)
    402       train_op = self.create_train_op()
    403       saver = saver_lib.Saver()
    404       loss = training.train(
    405           train_op,
    406           logdir1,
    407           hooks=[
    408               basic_session_run_hooks.CheckpointSaverHook(
    409                   logdir1, save_steps=300, saver=saver),
    410               basic_session_run_hooks.StopAtStepHook(num_steps=300),
    411           ],
    412           save_checkpoint_secs=None,
    413           save_summaries_steps=None)
    414       self.assertIsNotNone(loss)
    415       self.assertLess(loss, .02)
    416 
    417     # Finally, advance the model a single step and validate that the loss is
    418     # still low.
    419     with ops.Graph().as_default():
    420       random_seed.set_random_seed(2)
    421       train_op = self.create_train_op()
    422 
    423       model_variables = variables_lib2.global_variables()
    424       model_path = saver_lib.latest_checkpoint(logdir1)
    425 
    426       assign_fn = variables_lib.assign_from_checkpoint_fn(
    427           model_path, model_variables)
    428 
    429       def init_fn(_, session):
    430         assign_fn(session)
    431 
    432       loss = training.train(
    433           train_op,
    434           None,
    435           scaffold=monitored_session.Scaffold(init_fn=init_fn),
    436           hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)],
    437           save_checkpoint_secs=None,
    438           save_summaries_steps=None)
    439 
    440       self.assertIsNotNone(loss)
    441       self.assertLess(loss, .02)
    442 
    443   def ModelLoss(self):
    444     tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    445     tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    446 
    447     tf_predictions = logistic_classifier(tf_inputs)
    448     losses.log_loss(tf_labels, tf_predictions)
    449     return losses.get_total_loss()
    450 
    451   def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
    452     logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
    453     if gfile.Exists(logdir):  # For running on jenkins.
    454       gfile.DeleteRecursively(logdir)
    455 
    456     # First, train only the weights of the model.
    457     with ops.Graph().as_default():
    458       random_seed.set_random_seed(0)
    459       total_loss = self.ModelLoss()
    460       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    461       weights = variables_lib.get_variables_by_name('weights')
    462 
    463       train_op = training.create_train_op(
    464           total_loss, optimizer, variables_to_train=weights)
    465 
    466       saver = saver_lib.Saver()
    467       loss = training.train(
    468           train_op,
    469           logdir,
    470           hooks=[
    471               basic_session_run_hooks.CheckpointSaverHook(
    472                   logdir, save_steps=200, saver=saver),
    473               basic_session_run_hooks.StopAtStepHook(num_steps=200),
    474           ],
    475           save_checkpoint_secs=None,
    476           save_summaries_steps=None)
    477       self.assertGreater(loss, .015)
    478       self.assertLess(loss, .05)
    479 
    480     # Next, train the biases of the model.
    481     with ops.Graph().as_default():
    482       random_seed.set_random_seed(1)
    483       total_loss = self.ModelLoss()
    484       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    485       biases = variables_lib.get_variables_by_name('biases')
    486 
    487       train_op = training.create_train_op(
    488           total_loss, optimizer, variables_to_train=biases)
    489 
    490       saver = saver_lib.Saver()
    491       loss = training.train(
    492           train_op,
    493           logdir,
    494           hooks=[
    495               basic_session_run_hooks.CheckpointSaverHook(
    496                   logdir, save_steps=300, saver=saver),
    497               basic_session_run_hooks.StopAtStepHook(num_steps=300),
    498           ],
    499           save_checkpoint_secs=None,
    500           save_summaries_steps=None)
    501       self.assertGreater(loss, .015)
    502       self.assertLess(loss, .05)
    503 
    504     # Finally, train both weights and bias to get lower loss.
    505     with ops.Graph().as_default():
    506       random_seed.set_random_seed(2)
    507       total_loss = self.ModelLoss()
    508       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    509 
    510       train_op = training.create_train_op(total_loss, optimizer)
    511       saver = saver_lib.Saver()
    512       loss = training.train(
    513           train_op,
    514           logdir,
    515           hooks=[
    516               basic_session_run_hooks.StopAtStepHook(num_steps=400),
    517           ],
    518           save_checkpoint_secs=None,
    519           save_summaries_steps=None)
    520       self.assertIsNotNone(loss)
    521       self.assertLess(loss, .015)
    522 
    523   def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self):
    524     # First, train only the weights of the model.
    525     with ops.Graph().as_default():
    526       random_seed.set_random_seed(0)
    527       total_loss = self.ModelLoss()
    528       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    529       weights, biases = variables_lib.get_variables()
    530 
    531       train_op = training.create_train_op(total_loss, optimizer)
    532       train_weights = training.create_train_op(
    533           total_loss, optimizer, variables_to_train=[weights])
    534       train_biases = training.create_train_op(
    535           total_loss, optimizer, variables_to_train=[biases])
    536 
    537       with self.test_session() as session:
    538         # Initialize the variables.
    539         session.run(variables_lib2.global_variables_initializer())
    540 
    541         # Get the initial weights and biases values.
    542         weights_values, biases_values = session.run([weights, biases])
    543         self.assertGreater(np.linalg.norm(weights_values), 0)
    544         self.assertAlmostEqual(np.linalg.norm(biases_values), 0)
    545 
    546         # Update weights and biases.
    547         loss = session.run(train_op)
    548         self.assertGreater(loss, .5)
    549         new_weights, new_biases = session.run([weights, biases])
    550 
    551         # Check that the weights and biases have been updated.
    552         self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
    553         self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
    554 
    555         weights_values, biases_values = new_weights, new_biases
    556 
    557         # Update only weights.
    558         loss = session.run(train_weights)
    559         self.assertGreater(loss, .5)
    560         new_weights, new_biases = session.run([weights, biases])
    561 
    562         # Check that the weights have been updated, but biases have not.
    563         self.assertGreater(np.linalg.norm(weights_values - new_weights), 0)
    564         self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0)
    565         weights_values = new_weights
    566 
    567         # Update only biases.
    568         loss = session.run(train_biases)
    569         self.assertGreater(loss, .5)
    570         new_weights, new_biases = session.run([weights, biases])
    571 
    572         # Check that the biases have been updated, but weights have not.
    573         self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0)
    574         self.assertGreater(np.linalg.norm(biases_values - new_biases), 0)
    575 
    576   def testTrainWithAlteredGradients(self):
    577     # Use the same learning rate but different gradient multipliers
    578     # to train two models. Model with equivalently larger learning
    579     # rate (i.e., learning_rate * gradient_multiplier) has smaller
    580     # training loss.
    581     multipliers = [1., 1000.]
    582     number_of_steps = 10
    583     learning_rate = 0.001
    584 
    585     # First, train the model with equivalently smaller learning rate.
    586     with ops.Graph().as_default():
    587       random_seed.set_random_seed(0)
    588       train_op = self.create_train_op(
    589           learning_rate=learning_rate, gradient_multiplier=multipliers[0])
    590 
    591       loss0 = training.train(
    592           train_op,
    593           None,
    594           hooks=[
    595               basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps),
    596           ],
    597           save_checkpoint_secs=None,
    598           save_summaries_steps=None)
    599       self.assertIsNotNone(loss0)
    600       self.assertGreater(loss0, .5)
    601 
    602     # Second, train the model with equivalently larger learning rate.
    603     with ops.Graph().as_default():
    604       random_seed.set_random_seed(0)
    605       train_op = self.create_train_op(
    606           learning_rate=learning_rate, gradient_multiplier=multipliers[1])
    607 
    608       loss1 = training.train(
    609           train_op,
    610           None,
    611           hooks=[
    612               basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps),
    613           ],
    614           save_checkpoint_secs=None,
    615           save_summaries_steps=None)
    616       self.assertIsNotNone(loss1)
    617       self.assertLess(loss1, .5)
    618 
    619     # The loss of the model trained with larger learning rate should
    620     # be smaller.
    621     self.assertGreater(loss0, loss1)
    622 
    623 
    624 if __name__ == '__main__':
    625   test.main()
    626