1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for slim.evaluation.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import glob 22 import os 23 import shutil 24 import time 25 26 import numpy as np 27 28 from tensorflow.contrib.framework.python.ops import variables as variables_lib 29 from tensorflow.contrib.metrics.python.ops import metric_ops 30 from tensorflow.contrib.slim.python.slim import evaluation 31 from tensorflow.contrib.training.python.training import evaluation as evaluation_lib 32 from tensorflow.core.protobuf import saver_pb2 33 from tensorflow.python.debug.lib import debug_data 34 from tensorflow.python.debug.wrappers import hooks 35 from tensorflow.python.framework import constant_op 36 from tensorflow.python.framework import dtypes 37 from tensorflow.python.framework import errors 38 from tensorflow.python.ops import control_flow_ops 39 from tensorflow.python.ops import math_ops 40 from tensorflow.python.ops import variables 41 from tensorflow.python.platform import flags 42 from tensorflow.python.platform import gfile 43 from tensorflow.python.platform import test 44 from tensorflow.python.summary import summary_iterator 45 from tensorflow.python.training import input # pylint: disable=redefined-builtin 46 from tensorflow.python.training import saver as saver_lib 47 from tensorflow.python.training import session_run_hook 48 49 50 FLAGS = flags.FLAGS 51 52 53 def GenerateTestData(num_classes, batch_size): 54 inputs = np.random.rand(batch_size, num_classes) 55 56 np.random.seed(0) 57 labels = np.random.randint(low=0, high=num_classes, size=batch_size) 58 labels = labels.reshape((batch_size,)) 59 return inputs, labels 60 61 62 def TestModel(inputs): 63 scale = variables.Variable(1.0, trainable=False) 64 65 # Scaling the outputs wont change the result... 66 outputs = math_ops.multiply(inputs, scale) 67 return math_ops.argmax(outputs, 1), scale 68 69 70 def GroundTruthAccuracy(inputs, labels, batch_size): 71 predictions = np.argmax(inputs, 1) 72 num_correct = np.sum(predictions == labels) 73 return float(num_correct) / batch_size 74 75 76 class EvaluationTest(test.TestCase): 77 78 def setUp(self): 79 super(EvaluationTest, self).setUp() 80 81 num_classes = 8 82 batch_size = 16 83 inputs, labels = GenerateTestData(num_classes, batch_size) 84 self._expected_accuracy = GroundTruthAccuracy(inputs, labels, batch_size) 85 86 self._global_step = variables_lib.get_or_create_global_step() 87 self._inputs = constant_op.constant(inputs, dtype=dtypes.float32) 88 self._labels = constant_op.constant(labels, dtype=dtypes.int64) 89 self._predictions, self._scale = TestModel(self._inputs) 90 91 def testFinalOpsOnEvaluationLoop(self): 92 value_op, update_op = metric_ops.streaming_accuracy(self._predictions, 93 self._labels) 94 init_op = control_flow_ops.group(variables.global_variables_initializer(), 95 variables.local_variables_initializer()) 96 # Create checkpoint and log directories: 97 chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/') 98 gfile.MakeDirs(chkpt_dir) 99 logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/') 100 gfile.MakeDirs(logdir) 101 102 # Save initialized variables to a checkpoint directory: 103 saver = saver_lib.Saver() 104 with self.test_session() as sess: 105 init_op.run() 106 saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) 107 108 class Object(object): 109 110 def __init__(self): 111 self.hook_was_run = False 112 113 obj = Object() 114 115 # Create a custom session run hook. 116 class CustomHook(session_run_hook.SessionRunHook): 117 118 def __init__(self, obj): 119 self.obj = obj 120 121 def end(self, session): 122 self.obj.hook_was_run = True 123 124 # Now, run the evaluation loop: 125 accuracy_value = evaluation.evaluation_loop( 126 '', 127 chkpt_dir, 128 logdir, 129 eval_op=update_op, 130 final_op=value_op, 131 hooks=[CustomHook(obj)], 132 max_number_of_evaluations=1) 133 self.assertAlmostEqual(accuracy_value, self._expected_accuracy) 134 135 # Validate that custom hook ran. 136 self.assertTrue(obj.hook_was_run) 137 138 def _create_names_to_metrics(self, predictions, labels): 139 accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels) 140 accuracy1, update_op1 = metric_ops.streaming_accuracy(predictions + 1, 141 labels) 142 143 names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1} 144 names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1} 145 return names_to_values, names_to_updates 146 147 def _verify_summaries(self, output_dir, names_to_values): 148 """Verifies that the given `names_to_values` are found in the summaries. 149 150 Args: 151 output_dir: An existing directory where summaries are found. 152 names_to_values: A dictionary of strings to values. 153 """ 154 # Check that the results were saved. The events file may have additional 155 # entries, e.g. the event version stamp, so have to parse things a bit. 156 output_filepath = glob.glob(os.path.join(output_dir, '*')) 157 self.assertEqual(len(output_filepath), 1) 158 159 events = summary_iterator.summary_iterator(output_filepath[0]) 160 summaries = [e.summary for e in events if e.summary.value] 161 values = [] 162 for summary in summaries: 163 for value in summary.value: 164 values.append(value) 165 saved_results = {v.tag: v.simple_value for v in values} 166 for name in names_to_values: 167 self.assertAlmostEqual(names_to_values[name], saved_results[name]) 168 169 def testLatestCheckpointReturnsNoneAfterTimeout(self): 170 start = time.time() 171 ret = evaluation_lib.wait_for_new_checkpoint( 172 '/non-existent-dir', 'foo', timeout=1.0, seconds_to_sleep=0.5) 173 end = time.time() 174 self.assertIsNone(ret) 175 # We've waited one time. 176 self.assertGreater(end, start + 0.5) 177 # The timeout kicked in. 178 self.assertLess(end, start + 1.1) 179 180 def testMonitorCheckpointsLoopTimeout(self): 181 ret = list( 182 evaluation_lib.checkpoints_iterator( 183 '/non-existent-dir', timeout=0)) 184 self.assertEqual(ret, []) 185 186 def testWithEpochLimit(self): 187 predictions_limited = input.limit_epochs(self._predictions, num_epochs=1) 188 labels_limited = input.limit_epochs(self._labels, num_epochs=1) 189 190 value_op, update_op = metric_ops.streaming_accuracy( 191 predictions_limited, labels_limited) 192 193 init_op = control_flow_ops.group(variables.global_variables_initializer(), 194 variables.local_variables_initializer()) 195 # Create checkpoint and log directories: 196 chkpt_dir = os.path.join(self.get_temp_dir(), 'tmp_logs/') 197 gfile.MakeDirs(chkpt_dir) 198 logdir = os.path.join(self.get_temp_dir(), 'tmp_logs2/') 199 gfile.MakeDirs(logdir) 200 201 # Save initialized variables to a checkpoint directory: 202 saver = saver_lib.Saver() 203 with self.test_session() as sess: 204 init_op.run() 205 saver.save(sess, os.path.join(chkpt_dir, 'chkpt')) 206 207 # Now, run the evaluation loop: 208 accuracy_value = evaluation.evaluation_loop( 209 '', chkpt_dir, logdir, eval_op=update_op, final_op=value_op, 210 max_number_of_evaluations=1, num_evals=10000) 211 self.assertAlmostEqual(accuracy_value, self._expected_accuracy) 212 213 214 class SingleEvaluationTest(test.TestCase): 215 216 def setUp(self): 217 super(SingleEvaluationTest, self).setUp() 218 219 num_classes = 8 220 batch_size = 16 221 inputs, labels = GenerateTestData(num_classes, batch_size) 222 self._expected_accuracy = GroundTruthAccuracy(inputs, labels, batch_size) 223 224 self._global_step = variables_lib.get_or_create_global_step() 225 self._inputs = constant_op.constant(inputs, dtype=dtypes.float32) 226 self._labels = constant_op.constant(labels, dtype=dtypes.int64) 227 self._predictions, self._scale = TestModel(self._inputs) 228 229 def testErrorRaisedIfCheckpointDoesntExist(self): 230 checkpoint_path = os.path.join(self.get_temp_dir(), 231 'this_file_doesnt_exist') 232 log_dir = os.path.join(self.get_temp_dir(), 'error_raised') 233 with self.assertRaises(errors.NotFoundError): 234 evaluation.evaluate_once('', checkpoint_path, log_dir) 235 236 def _prepareCheckpoint(self, checkpoint_path): 237 init_op = control_flow_ops.group(variables.global_variables_initializer(), 238 variables.local_variables_initializer()) 239 saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) 240 with self.test_session() as sess: 241 sess.run(init_op) 242 saver.save(sess, checkpoint_path) 243 244 def testRestoredModelPerformance(self): 245 checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') 246 log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') 247 248 # First, save out the current model to a checkpoint: 249 self._prepareCheckpoint(checkpoint_path) 250 251 # Next, determine the metric to evaluate: 252 value_op, update_op = metric_ops.streaming_accuracy(self._predictions, 253 self._labels) 254 255 # Run the evaluation and verify the results: 256 accuracy_value = evaluation.evaluate_once( 257 '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op) 258 self.assertAlmostEqual(accuracy_value, self._expected_accuracy) 259 260 def testAdditionalHooks(self): 261 checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt') 262 log_dir = os.path.join(self.get_temp_dir(), 'log_dir1/') 263 264 # First, save out the current model to a checkpoint: 265 self._prepareCheckpoint(checkpoint_path) 266 267 # Next, determine the metric to evaluate: 268 value_op, update_op = metric_ops.streaming_accuracy(self._predictions, 269 self._labels) 270 271 dumping_root = os.path.join(self.get_temp_dir(), 'tfdbg_dump_dir') 272 dumping_hook = hooks.DumpingDebugHook(dumping_root, log_usage=False) 273 try: 274 # Run the evaluation and verify the results: 275 accuracy_value = evaluation.evaluate_once( 276 '', checkpoint_path, log_dir, eval_op=update_op, final_op=value_op, 277 hooks=[dumping_hook]) 278 self.assertAlmostEqual(accuracy_value, self._expected_accuracy) 279 280 dump = debug_data.DebugDumpDir( 281 glob.glob(os.path.join(dumping_root, 'run_*'))[0]) 282 # Here we simply assert that the dumped data has been loaded and is 283 # non-empty. We do not care about the detailed model-internal tensors or 284 # their values. 285 self.assertTrue(dump.dumped_tensor_data) 286 finally: 287 if os.path.isdir(dumping_root): 288 shutil.rmtree(dumping_root) 289 290 291 if __name__ == '__main__': 292 test.main() 293