1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.training.training.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 23 import numpy as np 24 25 from tensorflow.contrib.framework.python.ops import variables as variables_lib 26 from tensorflow.contrib.layers.python.layers import layers 27 from tensorflow.contrib.training.python.training import training 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import random_seed 32 from tensorflow.python.ops import gradients_impl 33 from tensorflow.python.ops import math_ops 34 from tensorflow.python.ops import variables as variables_lib2 35 from tensorflow.python.ops.losses import losses 36 from tensorflow.python.platform import gfile 37 from tensorflow.python.platform import test 38 from tensorflow.python.training import basic_session_run_hooks 39 from tensorflow.python.training import gradient_descent 40 from tensorflow.python.training import monitored_session 41 from tensorflow.python.training import saver as saver_lib 42 # pylint: enable=g-import-not-at-top 43 44 45 def logistic_classifier(inputs): 46 return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid) 47 48 49 def batchnorm_classifier(inputs): 50 inputs = layers.batch_norm(inputs, decay=0.1, fused=False) 51 return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid) 52 53 54 class ClipGradsTest(test.TestCase): 55 56 def testClipGrads(self): 57 xs = variables_lib2.Variable(0.0) 58 ys = xs * 4.0 59 grads = gradients_impl.gradients([ys], [xs]) 60 gradients_to_variables = list(zip(grads, [xs])) 61 clipped_gradients_to_variables = training.clip_gradient_norms( 62 gradients_to_variables, 3.0) 63 64 with self.test_session() as session: 65 session.run(variables_lib2.global_variables_initializer()) 66 self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval()) 67 self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval()) 68 69 def testClipGradsFn(self): 70 xs = variables_lib2.Variable(0.0) 71 ys = xs * 4.0 72 grads = gradients_impl.gradients([ys], [xs]) 73 gradients_to_variables = list(zip(grads, [xs])) 74 clipped_gradients_to_variables = training.clip_gradient_norms_fn(3.0)( 75 gradients_to_variables) 76 77 with self.test_session() as session: 78 session.run(variables_lib2.global_variables_initializer()) 79 self.assertAlmostEqual(4.0, gradients_to_variables[0][0].eval()) 80 self.assertAlmostEqual(3.0, clipped_gradients_to_variables[0][0].eval()) 81 82 83 class CreateTrainOpTest(test.TestCase): 84 85 def setUp(self): 86 np.random.seed(0) 87 88 # Create an easy training set: 89 self._inputs = np.random.rand(16, 4).astype(np.float32) 90 self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 91 92 def testTrainOpInCollection(self): 93 with ops.Graph().as_default(): 94 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 95 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 96 97 tf_predictions = batchnorm_classifier(tf_inputs) 98 loss = losses.log_loss(tf_labels, tf_predictions) 99 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 100 train_op = training.create_train_op(loss, optimizer) 101 102 # Make sure the training op was recorded in the proper collection 103 self.assertTrue(train_op in ops.get_collection(ops.GraphKeys.TRAIN_OP)) 104 105 def testUseUpdateOps(self): 106 with ops.Graph().as_default(): 107 random_seed.set_random_seed(0) 108 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 109 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 110 111 expected_mean = np.mean(self._inputs, axis=(0)) 112 expected_var = np.var(self._inputs, axis=(0)) 113 114 tf_predictions = batchnorm_classifier(tf_inputs) 115 loss = losses.log_loss(tf_labels, tf_predictions) 116 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 117 118 train_op = training.create_train_op(loss, optimizer) 119 120 moving_mean = variables_lib.get_variables_by_name('moving_mean')[0] 121 moving_variance = variables_lib.get_variables_by_name('moving_variance')[ 122 0] 123 124 with self.test_session() as session: 125 # Initialize all variables 126 session.run(variables_lib2.global_variables_initializer()) 127 mean, variance = session.run([moving_mean, moving_variance]) 128 # After initialization moving_mean == 0 and moving_variance == 1. 129 self.assertAllClose(mean, [0] * 4) 130 self.assertAllClose(variance, [1] * 4) 131 132 for _ in range(10): 133 session.run(train_op) 134 135 mean = moving_mean.eval() 136 variance = moving_variance.eval() 137 # After 10 updates with decay 0.1 moving_mean == expected_mean and 138 # moving_variance == expected_var. 139 self.assertAllClose(mean, expected_mean) 140 self.assertAllClose(variance, expected_var) 141 142 def testEmptyUpdateOps(self): 143 with ops.Graph().as_default(): 144 random_seed.set_random_seed(0) 145 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 146 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 147 148 tf_predictions = batchnorm_classifier(tf_inputs) 149 loss = losses.log_loss(tf_labels, tf_predictions) 150 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 151 train_op = training.create_train_op(loss, optimizer, update_ops=[]) 152 153 moving_mean = variables_lib.get_variables_by_name('moving_mean')[0] 154 moving_variance = variables_lib.get_variables_by_name('moving_variance')[ 155 0] 156 157 with self.test_session() as session: 158 # Initialize all variables 159 session.run(variables_lib2.global_variables_initializer()) 160 mean, variance = session.run([moving_mean, moving_variance]) 161 # After initialization moving_mean == 0 and moving_variance == 1. 162 self.assertAllClose(mean, [0] * 4) 163 self.assertAllClose(variance, [1] * 4) 164 165 for _ in range(10): 166 session.run(train_op) 167 168 mean = moving_mean.eval() 169 variance = moving_variance.eval() 170 171 # Since we skip update_ops the moving_vars are not updated. 172 self.assertAllClose(mean, [0] * 4) 173 self.assertAllClose(variance, [1] * 4) 174 175 def testGlobalStepIsIncrementedByDefault(self): 176 with ops.Graph().as_default(): 177 random_seed.set_random_seed(0) 178 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 179 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 180 181 tf_predictions = batchnorm_classifier(tf_inputs) 182 loss = losses.log_loss(tf_labels, tf_predictions) 183 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 184 train_op = training.create_train_op(loss, optimizer) 185 186 global_step = variables_lib.get_or_create_global_step() 187 188 with self.test_session() as session: 189 # Initialize all variables 190 session.run(variables_lib2.global_variables_initializer()) 191 192 for _ in range(10): 193 session.run(train_op) 194 195 # After 10 updates global_step should be 10. 196 self.assertAllClose(global_step.eval(), 10) 197 198 def testGlobalStepNotIncrementedWhenSetToNone(self): 199 with ops.Graph().as_default(): 200 random_seed.set_random_seed(0) 201 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 202 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 203 204 tf_predictions = batchnorm_classifier(tf_inputs) 205 loss = losses.log_loss(tf_labels, tf_predictions) 206 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 207 train_op = training.create_train_op(loss, optimizer, global_step=None) 208 209 global_step = variables_lib.get_or_create_global_step() 210 211 with self.test_session() as session: 212 # Initialize all variables 213 session.run(variables_lib2.global_variables_initializer()) 214 215 for _ in range(10): 216 session.run(train_op) 217 218 # Since train_op don't use global_step it shouldn't change. 219 self.assertAllClose(global_step.eval(), 0) 220 221 222 class TrainBatchNormClassifierTest(test.TestCase): 223 224 def setUp(self): 225 # Create an easy training set: 226 np.random.seed(0) 227 228 self._inputs = np.zeros((16, 4)) 229 self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 230 231 for i in range(16): 232 j = int(2 * self._labels[i] + np.random.randint(0, 2)) 233 self._inputs[i, j] = 1 234 235 def testTrainWithNoInitAssignCanAchieveZeroLoss(self): 236 with ops.Graph().as_default(): 237 random_seed.set_random_seed(0) 238 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 239 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 240 241 tf_predictions = batchnorm_classifier(tf_inputs) 242 losses.log_loss(tf_labels, tf_predictions) 243 total_loss = losses.get_total_loss() 244 245 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 246 247 train_op = training.create_train_op(total_loss, optimizer) 248 249 loss = training.train( 250 train_op, 251 None, 252 hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)], 253 save_summaries_steps=None, 254 save_checkpoint_secs=None) 255 self.assertLess(loss, .1) 256 257 258 class TrainTest(test.TestCase): 259 260 def setUp(self): 261 # Create an easy training set: 262 np.random.seed(0) 263 264 self._inputs = np.zeros((16, 4)) 265 self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 266 267 for i in range(16): 268 j = int(2 * self._labels[i] + np.random.randint(0, 2)) 269 self._inputs[i, j] = 1 270 271 def testCanAchieveZeroLoss(self): 272 with ops.Graph().as_default(): 273 random_seed.set_random_seed(0) 274 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 275 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 276 277 tf_predictions = logistic_classifier(tf_inputs) 278 losses.log_loss(tf_labels, tf_predictions) 279 total_loss = losses.get_total_loss() 280 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 281 train_op = training.create_train_op(total_loss, optimizer) 282 283 loss = training.train( 284 train_op, 285 None, 286 hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)], 287 save_summaries_steps=None, 288 save_checkpoint_secs=None) 289 self.assertIsNotNone(loss) 290 self.assertLess(loss, .015) 291 292 def testTrainWithLocalVariable(self): 293 with ops.Graph().as_default(): 294 random_seed.set_random_seed(0) 295 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 296 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 297 298 local_multiplier = variables_lib.local_variable(1.0) 299 300 tf_predictions = logistic_classifier(tf_inputs) * local_multiplier 301 losses.log_loss(tf_labels, tf_predictions) 302 total_loss = losses.get_total_loss() 303 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 304 train_op = training.create_train_op(total_loss, optimizer) 305 306 loss = training.train( 307 train_op, 308 None, 309 hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=300)], 310 save_summaries_steps=None, 311 save_checkpoint_secs=None) 312 self.assertIsNotNone(loss) 313 self.assertLess(loss, .015) 314 315 def testResumeTrainAchievesRoughlyTheSameLoss(self): 316 number_of_steps = [300, 1, 5] 317 logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss') 318 319 for i in range(len(number_of_steps)): 320 with ops.Graph().as_default(): 321 random_seed.set_random_seed(i) 322 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 323 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 324 325 tf_predictions = logistic_classifier(tf_inputs) 326 losses.log_loss(tf_labels, tf_predictions) 327 total_loss = losses.get_total_loss() 328 329 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 330 331 train_op = training.create_train_op(total_loss, optimizer) 332 333 saver = saver_lib.Saver() 334 335 loss = training.train( 336 train_op, 337 logdir, 338 hooks=[ 339 basic_session_run_hooks.StopAtStepHook( 340 num_steps=number_of_steps[i]), 341 basic_session_run_hooks.CheckpointSaverHook( 342 logdir, save_steps=50, saver=saver), 343 ], 344 save_checkpoint_secs=None, 345 save_summaries_steps=None) 346 self.assertIsNotNone(loss) 347 self.assertLess(loss, .015) 348 349 def create_train_op(self, learning_rate=1.0, gradient_multiplier=1.0): 350 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 351 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 352 353 tf_predictions = logistic_classifier(tf_inputs) 354 losses.log_loss(tf_labels, tf_predictions) 355 total_loss = losses.get_total_loss() 356 357 optimizer = gradient_descent.GradientDescentOptimizer( 358 learning_rate=learning_rate) 359 360 def transform_grads_fn(grads): 361 if gradient_multiplier != 1.0: 362 variables = variables_lib2.trainable_variables() 363 gradient_multipliers = {var: gradient_multiplier for var in variables} 364 365 with ops.name_scope('multiply_grads'): 366 return training.multiply_gradients(grads, gradient_multipliers) 367 else: 368 return grads 369 370 return training.create_train_op( 371 total_loss, optimizer, transform_grads_fn=transform_grads_fn) 372 373 def testTrainWithInitFromCheckpoint(self): 374 logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/') 375 logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/') 376 377 if gfile.Exists(logdir1): # For running on jenkins. 378 gfile.DeleteRecursively(logdir1) 379 if gfile.Exists(logdir2): # For running on jenkins. 380 gfile.DeleteRecursively(logdir2) 381 382 # First, train the model one step (make sure the error is high). 383 with ops.Graph().as_default(): 384 random_seed.set_random_seed(0) 385 train_op = self.create_train_op() 386 saver = saver_lib.Saver() 387 loss = training.train( 388 train_op, 389 logdir1, 390 hooks=[ 391 basic_session_run_hooks.CheckpointSaverHook( 392 logdir1, save_steps=1, saver=saver), 393 basic_session_run_hooks.StopAtStepHook(num_steps=1), 394 ], 395 save_checkpoint_secs=None, 396 save_summaries_steps=None) 397 self.assertGreater(loss, .5) 398 399 # Next, train the model to convergence. 400 with ops.Graph().as_default(): 401 random_seed.set_random_seed(1) 402 train_op = self.create_train_op() 403 saver = saver_lib.Saver() 404 loss = training.train( 405 train_op, 406 logdir1, 407 hooks=[ 408 basic_session_run_hooks.CheckpointSaverHook( 409 logdir1, save_steps=300, saver=saver), 410 basic_session_run_hooks.StopAtStepHook(num_steps=300), 411 ], 412 save_checkpoint_secs=None, 413 save_summaries_steps=None) 414 self.assertIsNotNone(loss) 415 self.assertLess(loss, .02) 416 417 # Finally, advance the model a single step and validate that the loss is 418 # still low. 419 with ops.Graph().as_default(): 420 random_seed.set_random_seed(2) 421 train_op = self.create_train_op() 422 423 model_variables = variables_lib2.global_variables() 424 model_path = saver_lib.latest_checkpoint(logdir1) 425 426 assign_fn = variables_lib.assign_from_checkpoint_fn( 427 model_path, model_variables) 428 429 def init_fn(_, session): 430 assign_fn(session) 431 432 loss = training.train( 433 train_op, 434 None, 435 scaffold=monitored_session.Scaffold(init_fn=init_fn), 436 hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)], 437 save_checkpoint_secs=None, 438 save_summaries_steps=None) 439 440 self.assertIsNotNone(loss) 441 self.assertLess(loss, .02) 442 443 def ModelLoss(self): 444 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 445 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 446 447 tf_predictions = logistic_classifier(tf_inputs) 448 losses.log_loss(tf_labels, tf_predictions) 449 return losses.get_total_loss() 450 451 def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self): 452 logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/') 453 if gfile.Exists(logdir): # For running on jenkins. 454 gfile.DeleteRecursively(logdir) 455 456 # First, train only the weights of the model. 457 with ops.Graph().as_default(): 458 random_seed.set_random_seed(0) 459 total_loss = self.ModelLoss() 460 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 461 weights = variables_lib.get_variables_by_name('weights') 462 463 train_op = training.create_train_op( 464 total_loss, optimizer, variables_to_train=weights) 465 466 saver = saver_lib.Saver() 467 loss = training.train( 468 train_op, 469 logdir, 470 hooks=[ 471 basic_session_run_hooks.CheckpointSaverHook( 472 logdir, save_steps=200, saver=saver), 473 basic_session_run_hooks.StopAtStepHook(num_steps=200), 474 ], 475 save_checkpoint_secs=None, 476 save_summaries_steps=None) 477 self.assertGreater(loss, .015) 478 self.assertLess(loss, .05) 479 480 # Next, train the biases of the model. 481 with ops.Graph().as_default(): 482 random_seed.set_random_seed(1) 483 total_loss = self.ModelLoss() 484 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 485 biases = variables_lib.get_variables_by_name('biases') 486 487 train_op = training.create_train_op( 488 total_loss, optimizer, variables_to_train=biases) 489 490 saver = saver_lib.Saver() 491 loss = training.train( 492 train_op, 493 logdir, 494 hooks=[ 495 basic_session_run_hooks.CheckpointSaverHook( 496 logdir, save_steps=300, saver=saver), 497 basic_session_run_hooks.StopAtStepHook(num_steps=300), 498 ], 499 save_checkpoint_secs=None, 500 save_summaries_steps=None) 501 self.assertGreater(loss, .015) 502 self.assertLess(loss, .05) 503 504 # Finally, train both weights and bias to get lower loss. 505 with ops.Graph().as_default(): 506 random_seed.set_random_seed(2) 507 total_loss = self.ModelLoss() 508 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 509 510 train_op = training.create_train_op(total_loss, optimizer) 511 saver = saver_lib.Saver() 512 loss = training.train( 513 train_op, 514 logdir, 515 hooks=[ 516 basic_session_run_hooks.StopAtStepHook(num_steps=400), 517 ], 518 save_checkpoint_secs=None, 519 save_summaries_steps=None) 520 self.assertIsNotNone(loss) 521 self.assertLess(loss, .015) 522 523 def testTrainingSubsetsOfVariablesOnlyUpdatesThoseVariables(self): 524 # First, train only the weights of the model. 525 with ops.Graph().as_default(): 526 random_seed.set_random_seed(0) 527 total_loss = self.ModelLoss() 528 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 529 weights, biases = variables_lib.get_variables() 530 531 train_op = training.create_train_op(total_loss, optimizer) 532 train_weights = training.create_train_op( 533 total_loss, optimizer, variables_to_train=[weights]) 534 train_biases = training.create_train_op( 535 total_loss, optimizer, variables_to_train=[biases]) 536 537 with self.test_session() as session: 538 # Initialize the variables. 539 session.run(variables_lib2.global_variables_initializer()) 540 541 # Get the initial weights and biases values. 542 weights_values, biases_values = session.run([weights, biases]) 543 self.assertGreater(np.linalg.norm(weights_values), 0) 544 self.assertAlmostEqual(np.linalg.norm(biases_values), 0) 545 546 # Update weights and biases. 547 loss = session.run(train_op) 548 self.assertGreater(loss, .5) 549 new_weights, new_biases = session.run([weights, biases]) 550 551 # Check that the weights and biases have been updated. 552 self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) 553 self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) 554 555 weights_values, biases_values = new_weights, new_biases 556 557 # Update only weights. 558 loss = session.run(train_weights) 559 self.assertGreater(loss, .5) 560 new_weights, new_biases = session.run([weights, biases]) 561 562 # Check that the weights have been updated, but biases have not. 563 self.assertGreater(np.linalg.norm(weights_values - new_weights), 0) 564 self.assertAlmostEqual(np.linalg.norm(biases_values - new_biases), 0) 565 weights_values = new_weights 566 567 # Update only biases. 568 loss = session.run(train_biases) 569 self.assertGreater(loss, .5) 570 new_weights, new_biases = session.run([weights, biases]) 571 572 # Check that the biases have been updated, but weights have not. 573 self.assertAlmostEqual(np.linalg.norm(weights_values - new_weights), 0) 574 self.assertGreater(np.linalg.norm(biases_values - new_biases), 0) 575 576 def testTrainWithAlteredGradients(self): 577 # Use the same learning rate but different gradient multipliers 578 # to train two models. Model with equivalently larger learning 579 # rate (i.e., learning_rate * gradient_multiplier) has smaller 580 # training loss. 581 multipliers = [1., 1000.] 582 number_of_steps = 10 583 learning_rate = 0.001 584 585 # First, train the model with equivalently smaller learning rate. 586 with ops.Graph().as_default(): 587 random_seed.set_random_seed(0) 588 train_op = self.create_train_op( 589 learning_rate=learning_rate, gradient_multiplier=multipliers[0]) 590 591 loss0 = training.train( 592 train_op, 593 None, 594 hooks=[ 595 basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps), 596 ], 597 save_checkpoint_secs=None, 598 save_summaries_steps=None) 599 self.assertIsNotNone(loss0) 600 self.assertGreater(loss0, .5) 601 602 # Second, train the model with equivalently larger learning rate. 603 with ops.Graph().as_default(): 604 random_seed.set_random_seed(0) 605 train_op = self.create_train_op( 606 learning_rate=learning_rate, gradient_multiplier=multipliers[1]) 607 608 loss1 = training.train( 609 train_op, 610 None, 611 hooks=[ 612 basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps), 613 ], 614 save_checkpoint_secs=None, 615 save_summaries_steps=None) 616 self.assertIsNotNone(loss1) 617 self.assertLess(loss1, .5) 618 619 # The loss of the model trained with larger learning rate should 620 # be smaller. 621 self.assertGreater(loss0, loss1) 622 623 624 if __name__ == '__main__': 625 test.main() 626