Home | History | Annotate | Download | only in slim
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for slim.evaluation."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import glob
     22 import os
     23 import shutil
     24 import time
     25 
     26 import numpy as np
     27 
     28 from tensorflow.contrib.framework.python.ops import variables as variables_lib
     29 from tensorflow.contrib.metrics.python.ops import metric_ops
     30 from tensorflow.contrib.slim.python.slim import evaluation
     31 from tensorflow.contrib.training.python.training import evaluation as evaluation_lib
     32 from tensorflow.core.protobuf import saver_pb2
     33 from tensorflow.python.debug.lib import debug_data
     34 from tensorflow.python.debug.wrappers import hooks
     35 from tensorflow.python.framework import constant_op
     36 from tensorflow.python.framework import dtypes
     37 from tensorflow.python.framework import errors
     38 from tensorflow.python.ops import control_flow_ops
     39 from tensorflow.python.ops import math_ops
     40 from tensorflow.python.ops import variables
     41 from tensorflow.python.platform import flags
     42 from tensorflow.python.platform import gfile
     43 from tensorflow.python.platform import test
     44 from tensorflow.python.summary import summary_iterator
     45 from tensorflow.python.training import input  # pylint: disable=redefined-builtin
     46 from tensorflow.python.training import saver as saver_lib
     47 from tensorflow.python.training import session_run_hook
     48 
     49 
     50 FLAGS = flags.FLAGS
     51 
     52 
     53 def GenerateTestData(num_classes, batch_size):
     54   inputs = np.random.rand(batch_size, num_classes)
     55 
     56   np.random.seed(0)
     57   labels = np.random.randint(low=0, high=num_classes, size=batch_size)
     58   labels = labels.reshape((batch_size,))
     59   return inputs, labels
     60 
     61 
     62 def TestModel(inputs):
     63   scale = variables.Variable(1.0, trainable=False)
     64 
     65   # Scaling the outputs wont change the result...
     66   outputs = math_ops.multiply(inputs, scale)
     67   return math_ops.argmax(outputs, 1), scale
     68 
     69 
     70 def GroundTruthAccuracy(inputs, labels, batch_size):
     71   predictions = np.argmax(inputs, 1)
     72   num_correct = np.sum(predictions == labels)
     73   return float(num_correct) / batch_size
     74 
     75 
     76 class EvaluationTest(test.TestCase):
     77 
     78   def setUp(self):
     79     super(EvaluationTest, self).setUp()
     80 
     81     num_classes = 8
     82     batch_size = 16
     83     inputs, labels = GenerateTestData(num_classes, batch_size)
     84     self._expected_accuracy = GroundTruthAccuracy(inputs, labels, batch_size)
     85 
     86     self._global_step = variables_lib.get_or_create_global_step()
     87     self._inputs = constant_op.constant(inputs, dtype=dtypes.float32)
     88     self._labels = constant_op.constant(labels, dtype=dtypes.int64)
     89     self._predictions, self._scale = TestModel(self._inputs)
     90 
     91   def testFinalOpsOnEvaluationLoop(self):
     92     value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
     93                                                         self._labels)
     94     init_op = control_flow_ops.group(variables.global_variables_initializer(),
     95                                      variables.local_variables_initializer())
     96     # Create checkpoint and log directories:
     97     chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
     98     gfile.MakeDirs(chkpt_dir)
     99     logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
    100     gfile.MakeDirs(logdir)
    101 
    102     # Save initialized variables to a checkpoint directory:
    103     saver = saver_lib.Saver()
    104     with self.test_session() as sess:
    105       init_op.run()
    106       saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))
    107 
    108     class Object(object):
    109 
    110       def __init__(self):
    111         self.hook_was_run = False
    112 
    113     obj = Object()
    114 
    115     # Create a custom session run hook.
    116     class CustomHook(session_run_hook.SessionRunHook):
    117 
    118       def __init__(self, obj):
    119         self.obj = obj
    120 
    121       def end(self, session):
    122         self.obj.hook_was_run = True
    123 
    124     # Now, run the evaluation loop:
    125     accuracy_value = evaluation.evaluation_loop(
    126         '',
    127         chkpt_dir,
    128         logdir,
    129         eval_op=update_op,
    130         final_op=value_op,
    131         hooks=[CustomHook(obj)],
    132         max_number_of_evaluations=1)
    133     self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
    134 
    135     # Validate that custom hook ran.
    136     self.assertTrue(obj.hook_was_run)
    137 
    138   def _create_names_to_metrics(self, predictions, labels):
    139     accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels)
    140     accuracy1, update_op1 = metric_ops.streaming_accuracy(predictions + 1,
    141                                                           labels)
    142 
    143     names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1}
    144     names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1}
    145     return names_to_values, names_to_updates
    146 
    147   def _verify_summaries(self, output_dir, names_to_values):
    148     """Verifies that the given `names_to_values` are found in the summaries.
    149 
    150     Args:
    151       output_dir: An existing directory where summaries are found.
    152       names_to_values: A dictionary of strings to values.
    153     """
    154     # Check that the results were saved. The events file may have additional
    155     # entries, e.g. the event version stamp, so have to parse things a bit.
    156     output_filepath = glob.glob(os.path.join(output_dir, '*'))
    157     self.assertEqual(len(output_filepath), 1)
    158 
    159     events = summary_iterator.summary_iterator(output_filepath[0])
    160     summaries = [e.summary for e in events if e.summary.value]
    161     values = []
    162     for summary in summaries:
    163       for value in summary.value:
    164         values.append(value)
    165     saved_results = {v.tag: v.simple_value for v in values}
    166     for name in names_to_values:
    167       self.assertAlmostEqual(names_to_values[name], saved_results[name])
    168 
    169   def testLatestCheckpointReturnsNoneAfterTimeout(self):
    170     start = time.time()
    171     ret = evaluation_lib.wait_for_new_checkpoint(
    172         '/non-existent-dir', 'foo', timeout=1.0, seconds_to_sleep=0.5)
    173     end = time.time()
    174     self.assertIsNone(ret)
    175     # We've waited one time.
    176     self.assertGreater(end, start + 0.5)
    177     # The timeout kicked in.
    178     self.assertLess(end, start + 1.1)
    179 
    180   def testMonitorCheckpointsLoopTimeout(self):
    181     ret = list(
    182         evaluation_lib.checkpoints_iterator(
    183             '/non-existent-dir', timeout=0))
    184     self.assertEqual(ret, [])
    185 
    186   def testWithEpochLimit(self):
    187     predictions_limited = input.limit_epochs(self._predictions, num_epochs=1)
    188     labels_limited = input.limit_epochs(self._labels, num_epochs=1)
    189 
    190     value_op, update_op = metric_ops.streaming_accuracy(
    191         predictions_limited, labels_limited)
    192 
    193     init_op = control_flow_ops.group(variables.global_variables_initializer(),
    194                                      variables.local_variables_initializer())
    195     # Create checkpoint and log directories:
    196     chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/')
    197     gfile.MakeDirs(chkpt_dir)
    198     logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/')
    199     gfile.MakeDirs(logdir)
    200 
    201     # Save initialized variables to a checkpoint directory:
    202     saver = saver_lib.Saver()
    203     with self.test_session() as sess:
    204       init_op.run()
    205       saver.save(sess, os.path.join(chkpt_dir, 'chkpt'))
    206 
    207     # Now, run the evaluation loop:
    208     accuracy_value = evaluation.evaluation_loop(
    209         '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op,
    210         max_number_of_evaluations=1, num_evals=10000)
    211     self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
    212 
    213 
    214 class SingleEvaluationTest(test.TestCase):
    215 
    216   def setUp(self):
    217     super(SingleEvaluationTest, self).setUp()
    218 
    219     num_classes = 8
    220     batch_size = 16
    221     inputs, labels = GenerateTestData(num_classes, batch_size)
    222     self._expected_accuracy = GroundTruthAccuracy(inputs, labels, batch_size)
    223 
    224     self._global_step = variables_lib.get_or_create_global_step()
    225     self._inputs = constant_op.constant(inputs, dtype=dtypes.float32)
    226     self._labels = constant_op.constant(labels, dtype=dtypes.int64)
    227     self._predictions, self._scale = TestModel(self._inputs)
    228 
    229   def testErrorRaisedIfCheckpointDoesntExist(self):
    230     checkpoint_path = os.path.join(self.get_temp_dir(),
    231                                    'this_file_doesnt_exist')
    232     log_dir = os.path.join(self.get_temp_dir(), 'error_raised')
    233     with self.assertRaises(errors.NotFoundError):
    234       evaluation.evaluate_once('', checkpoint_path, log_dir)
    235 
    236   def _prepareCheckpoint(self, checkpoint_path):
    237     init_op = control_flow_ops.group(variables.global_variables_initializer(),
    238                                      variables.local_variables_initializer())
    239     saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1)
    240     with self.test_session() as sess:
    241       sess.run(init_op)
    242       saver.save(sess, checkpoint_path)
    243 
    244   def testRestoredModelPerformance(self):
    245     checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
    246     log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
    247 
    248     # First, save out the current model to a checkpoint:
    249     self._prepareCheckpoint(checkpoint_path)
    250 
    251     # Next, determine the metric to evaluate:
    252     value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
    253                                                         self._labels)
    254 
    255     # Run the evaluation and verify the results:
    256     accuracy_value = evaluation.evaluate_once(
    257         '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op)
    258     self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
    259 
    260   def testAdditionalHooks(self):
    261     checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
    262     log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/')
    263 
    264     # First, save out the current model to a checkpoint:
    265     self._prepareCheckpoint(checkpoint_path)
    266 
    267     # Next, determine the metric to evaluate:
    268     value_op, update_op = metric_ops.streaming_accuracy(self._predictions,
    269                                                         self._labels)
    270 
    271     dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir')
    272     dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False)
    273     try:
    274       # Run the evaluation and verify the results:
    275       accuracy_value = evaluation.evaluate_once(
    276           '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op,
    277           hooks=[dumping_hook])
    278       self.assertAlmostEqual(accuracy_value, self._expected_accuracy)
    279 
    280       dump = debug_data.DebugDumpDir(
    281           glob.glob(os.path.join(dumping_root, 'run_*'))[0])
    282       # Here we simply assert that the dumped data has been loaded and is
    283       # non-empty. We do not care about the detailed model-internal tensors or
    284       # their values.
    285       self.assertTrue(dump.dumped_tensor_data)
    286     finally:
    287       if os.path.isdir(dumping_root):
    288         shutil.rmtree(dumping_root)
    289 
    290 
    291 if __name__ == '__main__':
    292   test.main()
    293