Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 # http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tf.contrib.training.evaluation."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import glob
     22 import os
     23 import time
     24 
     25 import numpy as np
     26 
     27 from tensorflow.contrib.framework.python.ops import variables
     28 from tensorflow.contrib.layers.python.layers import layers
     29 from tensorflow.contrib.losses.python.losses import loss_ops
     30 from tensorflow.contrib.metrics.python.ops import metric_ops
     31 from tensorflow.contrib.training.python.training import evaluation
     32 from tensorflow.contrib.training.python.training import training
     33 from tensorflow.core.protobuf import config_pb2
     34 from tensorflow.python.client import session as session_lib
     35 from tensorflow.python.framework import constant_op
     36 from tensorflow.python.framework import dtypes
     37 from tensorflow.python.framework import ops
     38 from tensorflow.python.framework import random_seed
     39 from tensorflow.python.ops import array_ops
     40 from tensorflow.python.ops import math_ops
     41 from tensorflow.python.ops import state_ops
     42 from tensorflow.python.ops import variables as variables_lib
     43 from tensorflow.python.platform import gfile
     44 from tensorflow.python.platform import test
     45 from tensorflow.python.summary import summary as summary_lib
     46 from tensorflow.python.summary import summary_iterator
     47 from tensorflow.python.training import basic_session_run_hooks
     48 from tensorflow.python.training import gradient_descent
     49 from tensorflow.python.training import saver as saver_lib
     50 
     51 
     52 class CheckpointIteratorTest(test.TestCase):
     53 
     54   def testReturnsEmptyIfNoCheckpointsFound(self):
     55     checkpoint_dir = os.path.join(self.get_temp_dir(), 'no_checkpoints_found')
     56 
     57     num_found = 0
     58     for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
     59       num_found += 1
     60     self.assertEqual(num_found, 0)
     61 
     62   def testReturnsSingleCheckpointIfOneCheckpointFound(self):
     63     checkpoint_dir = os.path.join(self.get_temp_dir(), 'one_checkpoint_found')
     64     if not gfile.Exists(checkpoint_dir):
     65       gfile.MakeDirs(checkpoint_dir)
     66 
     67     global_step = variables.get_or_create_global_step()
     68     saver = saver_lib.Saver()  # Saves the global step.
     69 
     70     with self.test_session() as session:
     71       session.run(variables_lib.global_variables_initializer())
     72       save_path = os.path.join(checkpoint_dir, 'model.ckpt')
     73       saver.save(session, save_path, global_step=global_step)
     74 
     75     num_found = 0
     76     for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
     77       num_found += 1
     78     self.assertEqual(num_found, 1)
     79 
     80   def testReturnsSingleCheckpointIfOneShardedCheckpoint(self):
     81     checkpoint_dir = os.path.join(self.get_temp_dir(),
     82                                   'one_checkpoint_found_sharded')
     83     if not gfile.Exists(checkpoint_dir):
     84       gfile.MakeDirs(checkpoint_dir)
     85 
     86     global_step = variables.get_or_create_global_step()
     87 
     88     # This will result in 3 different checkpoint shard files.
     89     with ops.device('/cpu:0'):
     90       variables_lib.Variable(10, name='v0')
     91     with ops.device('/cpu:1'):
     92       variables_lib.Variable(20, name='v1')
     93 
     94     saver = saver_lib.Saver(sharded=True)
     95 
     96     with session_lib.Session(
     97         target='',
     98         config=config_pb2.ConfigProto(device_count={'CPU': 2})) as session:
     99 
    100       session.run(variables_lib.global_variables_initializer())
    101       save_path = os.path.join(checkpoint_dir, 'model.ckpt')
    102       saver.save(session, save_path, global_step=global_step)
    103 
    104     num_found = 0
    105     for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0):
    106       num_found += 1
    107     self.assertEqual(num_found, 1)
    108 
    109   def testTimeoutFn(self):
    110     timeout_fn_calls = [0]
    111     def timeout_fn():
    112       timeout_fn_calls[0] += 1
    113       return timeout_fn_calls[0] > 3
    114 
    115     results = list(
    116         evaluation.checkpoints_iterator(
    117             '/non-existent-dir', timeout=0.1, timeout_fn=timeout_fn))
    118     self.assertEqual([], results)
    119     self.assertEqual(4, timeout_fn_calls[0])
    120 
    121 
    122 class WaitForNewCheckpointTest(test.TestCase):
    123 
    124   def testReturnsNoneAfterTimeout(self):
    125     start = time.time()
    126     ret = evaluation.wait_for_new_checkpoint(
    127         '/non-existent-dir', 'foo', timeout=1.0, seconds_to_sleep=0.5)
    128     end = time.time()
    129     self.assertIsNone(ret)
    130 
    131     # We've waited one second.
    132     self.assertGreater(end, start + 0.5)
    133 
    134     # The timeout kicked in.
    135     self.assertLess(end, start + 1.1)
    136 
    137 
    138 def logistic_classifier(inputs):
    139   return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid)
    140 
    141 
    142 class EvaluateOnceTest(test.TestCase):
    143 
    144   def setUp(self):
    145     super(EvaluateOnceTest, self).setUp()
    146 
    147     # Create an easy training set:
    148     np.random.seed(0)
    149 
    150     self._inputs = np.zeros((16, 4))
    151     self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
    152 
    153     for i in range(16):
    154       j = int(2 * self._labels[i] + np.random.randint(0, 2))
    155       self._inputs[i, j] = 1
    156 
    157   def _train_model(self, checkpoint_dir, num_steps):
    158     """Trains a simple classification model.
    159 
    160     Note that the data has been configured such that after around 300 steps,
    161     the model has memorized the dataset (e.g. we can expect %100 accuracy).
    162 
    163     Args:
    164       checkpoint_dir: The directory where the checkpoint is written to.
    165       num_steps: The number of steps to train for.
    166     """
    167     with ops.Graph().as_default():
    168       random_seed.set_random_seed(0)
    169       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    170       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    171 
    172       tf_predictions = logistic_classifier(tf_inputs)
    173       loss = loss_ops.log_loss(tf_predictions, tf_labels)
    174 
    175       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    176       train_op = training.create_train_op(loss, optimizer)
    177 
    178       loss = training.train(
    179           train_op,
    180           checkpoint_dir,
    181           hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)])
    182 
    183       if num_steps >= 300:
    184         assert loss < .015
    185 
    186   def testEvaluatePerfectModel(self):
    187     checkpoint_dir = os.path.join(self.get_temp_dir(),
    188                                   'evaluate_perfect_model_once')
    189 
    190     # Train a Model to completion:
    191     self._train_model(checkpoint_dir, num_steps=300)
    192 
    193     # Run
    194     inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    195     labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    196     logits = logistic_classifier(inputs)
    197     predictions = math_ops.round(logits)
    198 
    199     accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels)
    200 
    201     checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir)
    202 
    203     final_ops_values = evaluation.evaluate_once(
    204         checkpoint_path=checkpoint_path,
    205         eval_ops=update_op,
    206         final_ops={'accuracy': accuracy},
    207         hooks=[
    208             evaluation.StopAfterNEvalsHook(1),
    209         ])
    210     self.assertTrue(final_ops_values['accuracy'] > .99)
    211 
    212   def testEvalOpAndFinalOp(self):
    213     checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops')
    214 
    215     # Train a model for a single step to get a checkpoint.
    216     self._train_model(checkpoint_dir, num_steps=1)
    217     checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir)
    218 
    219     # Create the model so we have something to restore.
    220     inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    221     logistic_classifier(inputs)
    222 
    223     num_evals = 5
    224     final_increment = 9.0
    225 
    226     my_var = variables.local_variable(0.0, name='MyVar')
    227     eval_ops = state_ops.assign_add(my_var, 1.0)
    228     final_ops = array_ops.identity(my_var) + final_increment
    229 
    230     final_ops_values = evaluation.evaluate_once(
    231         checkpoint_path=checkpoint_path,
    232         eval_ops=eval_ops,
    233         final_ops={'value': final_ops},
    234         hooks=[
    235             evaluation.StopAfterNEvalsHook(num_evals),
    236         ])
    237     self.assertEqual(final_ops_values['value'], num_evals + final_increment)
    238 
    239   def testOnlyFinalOp(self):
    240     checkpoint_dir = os.path.join(self.get_temp_dir(), 'only_final_ops')
    241 
    242     # Train a model for a single step to get a checkpoint.
    243     self._train_model(checkpoint_dir, num_steps=1)
    244     checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir)
    245 
    246     # Create the model so we have something to restore.
    247     inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    248     logistic_classifier(inputs)
    249 
    250     final_increment = 9.0
    251 
    252     my_var = variables.local_variable(0.0, name='MyVar')
    253     final_ops = array_ops.identity(my_var) + final_increment
    254 
    255     final_ops_values = evaluation.evaluate_once(
    256         checkpoint_path=checkpoint_path, final_ops={'value': final_ops})
    257     self.assertEqual(final_ops_values['value'], final_increment)
    258 
    259 
    260 class EvaluateRepeatedlyTest(test.TestCase):
    261 
    262   def setUp(self):
    263     super(EvaluateRepeatedlyTest, self).setUp()
    264 
    265     # Create an easy training set:
    266     np.random.seed(0)
    267 
    268     self._inputs = np.zeros((16, 4))
    269     self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
    270 
    271     for i in range(16):
    272       j = int(2 * self._labels[i] + np.random.randint(0, 2))
    273       self._inputs[i, j] = 1
    274 
    275   def _train_model(self, checkpoint_dir, num_steps):
    276     """Trains a simple classification model.
    277 
    278     Note that the data has been configured such that after around 300 steps,
    279     the model has memorized the dataset (e.g. we can expect %100 accuracy).
    280 
    281     Args:
    282       checkpoint_dir: The directory where the checkpoint is written to.
    283       num_steps: The number of steps to train for.
    284     """
    285     with ops.Graph().as_default():
    286       random_seed.set_random_seed(0)
    287       tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    288       tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    289 
    290       tf_predictions = logistic_classifier(tf_inputs)
    291       loss = loss_ops.log_loss(tf_predictions, tf_labels)
    292 
    293       optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
    294       train_op = training.create_train_op(loss, optimizer)
    295 
    296       loss = training.train(
    297           train_op,
    298           checkpoint_dir,
    299           hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)])
    300 
    301   def testEvaluatePerfectModel(self):
    302     checkpoint_dir = os.path.join(self.get_temp_dir(),
    303                                   'evaluate_perfect_model_repeated')
    304 
    305     # Train a Model to completion:
    306     self._train_model(checkpoint_dir, num_steps=300)
    307 
    308     # Run
    309     inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    310     labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    311     logits = logistic_classifier(inputs)
    312     predictions = math_ops.round(logits)
    313 
    314     accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels)
    315 
    316     final_values = evaluation.evaluate_repeatedly(
    317         checkpoint_dir=checkpoint_dir,
    318         eval_ops=update_op,
    319         final_ops={'accuracy': accuracy},
    320         hooks=[
    321             evaluation.StopAfterNEvalsHook(1),
    322         ],
    323         max_number_of_evaluations=1)
    324     self.assertTrue(final_values['accuracy'] > .99)
    325 
    326   def testEvaluationLoopTimeout(self):
    327     checkpoint_dir = os.path.join(self.get_temp_dir(),
    328                                   'evaluation_loop_timeout')
    329     if not gfile.Exists(checkpoint_dir):
    330       gfile.MakeDirs(checkpoint_dir)
    331 
    332     # We need a variable that the saver will try to restore.
    333     variables.get_or_create_global_step()
    334 
    335     # Run with placeholders. If we actually try to evaluate this, we'd fail
    336     # since we're not using a feed_dict.
    337     cant_run_op = array_ops.placeholder(dtype=dtypes.float32)
    338 
    339     start = time.time()
    340     final_values = evaluation.evaluate_repeatedly(
    341         checkpoint_dir=checkpoint_dir,
    342         eval_ops=cant_run_op,
    343         hooks=[evaluation.StopAfterNEvalsHook(10)],
    344         timeout=6)
    345     end = time.time()
    346     self.assertFalse(final_values)
    347 
    348     # Assert that we've waited for the duration of the timeout (minus the sleep
    349     # time).
    350     self.assertGreater(end - start, 5.0)
    351 
    352     # Then the timeout kicked in and stops the loop.
    353     self.assertLess(end - start, 7)
    354 
    355   def testEvaluationLoopTimeoutWithTimeoutFn(self):
    356     checkpoint_dir = os.path.join(self.get_temp_dir(),
    357                                   'evaluation_loop_timeout_with_timeout_fn')
    358 
    359     # Train a Model to completion:
    360     self._train_model(checkpoint_dir, num_steps=300)
    361 
    362     # Run
    363     inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    364     labels = constant_op.constant(self._labels, dtype=dtypes.float32)
    365     logits = logistic_classifier(inputs)
    366     predictions = math_ops.round(logits)
    367 
    368     accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels)
    369 
    370     timeout_fn_calls = [0]
    371     def timeout_fn():
    372       timeout_fn_calls[0] += 1
    373       return timeout_fn_calls[0] > 3
    374 
    375     final_values = evaluation.evaluate_repeatedly(
    376         checkpoint_dir=checkpoint_dir,
    377         eval_ops=update_op,
    378         final_ops={'accuracy': accuracy},
    379         hooks=[
    380             evaluation.StopAfterNEvalsHook(1),
    381         ],
    382         eval_interval_secs=1,
    383         max_number_of_evaluations=2,
    384         timeout=0.1,
    385         timeout_fn=timeout_fn)
    386     # We should have evaluated once.
    387     self.assertTrue(final_values['accuracy'] > .99)
    388     # And called 4 times the timeout fn
    389     self.assertEqual(4, timeout_fn_calls[0])
    390 
    391   def testEvaluateWithEvalFeedDict(self):
    392     # Create a checkpoint.
    393     checkpoint_dir = os.path.join(self.get_temp_dir(),
    394                                   'evaluate_with_eval_feed_dict')
    395     self._train_model(checkpoint_dir, num_steps=1)
    396 
    397     # We need a variable that the saver will try to restore.
    398     variables.get_or_create_global_step()
    399 
    400     # Create a variable and an eval op that increments it with a placeholder.
    401     my_var = variables.local_variable(0.0, name='my_var')
    402     increment = array_ops.placeholder(dtype=dtypes.float32)
    403     eval_ops = state_ops.assign_add(my_var, increment)
    404 
    405     increment_value = 3
    406     num_evals = 5
    407     expected_value = increment_value * num_evals
    408     final_values = evaluation.evaluate_repeatedly(
    409         checkpoint_dir=checkpoint_dir,
    410         eval_ops=eval_ops,
    411         feed_dict={increment: 3},
    412         final_ops={'my_var': array_ops.identity(my_var)},
    413         hooks=[
    414             evaluation.StopAfterNEvalsHook(num_evals),
    415         ],
    416         max_number_of_evaluations=1)
    417     self.assertEqual(final_values['my_var'], expected_value)
    418 
    419   def _create_names_to_metrics(self, predictions, labels):
    420     accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels)
    421     accuracy1, update_op1 = metric_ops.streaming_accuracy(
    422         predictions + 1, labels)
    423 
    424     names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1}
    425     names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}
    426     return names_to_values, names_to_updates
    427 
    428   def _verify_summaries(self, output_dir, names_to_values):
    429     """Verifies that the given `names_to_values` are found in the summaries.
    430 
    431     Args:
    432       output_dir: An existing directory where summaries are found.
    433       names_to_values: A dictionary of strings to values.
    434     """
    435     # Check that the results were saved. The events file may have additional
    436     # entries, e.g. the event version stamp, so have to parse things a bit.
    437     output_filepath = glob.glob(os.path.join(output_dir, '*'))
    438     self.assertEqual(len(output_filepath), 1)
    439 
    440     events = summary_iterator.summary_iterator(output_filepath[0])
    441     summaries = [e.summary for e in events if e.summary.value]
    442     values = []
    443     for summary in summaries:
    444       for value in summary.value:
    445         values.append(value)
    446     saved_results = {v.tag: v.simple_value for v in values}
    447     for name in names_to_values:
    448       self.assertAlmostEqual(names_to_values[name], saved_results[name], 5)
    449 
    450   def testSummariesAreFlushedToDisk(self):
    451     checkpoint_dir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed')
    452     logdir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed_eval')
    453     if gfile.Exists(logdir):
    454       gfile.DeleteRecursively(logdir)
    455 
    456     # Train a Model to completion:
    457     self._train_model(checkpoint_dir, num_steps=300)
    458 
    459     # Create the model (which can be restored).
    460     inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
    461     logistic_classifier(inputs)
    462 
    463     names_to_values = {'bread': 3.4, 'cheese': 4.5, 'tomato': 2.0}
    464 
    465     for k in names_to_values:
    466       v = names_to_values[k]
    467       summary_lib.scalar(k, v)
    468 
    469     evaluation.evaluate_repeatedly(
    470         checkpoint_dir=checkpoint_dir,
    471         hooks=[
    472             evaluation.SummaryAtEndHook(log_dir=logdir),
    473         ],
    474         max_number_of_evaluations=1)
    475 
    476     self._verify_summaries(logdir, names_to_values)
    477 
    478 
    479 if __name__ == '__main__':
    480   test.main()
    481