1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tf.contrib.training.evaluation.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import glob 22 import os 23 import time 24 25 import numpy as np 26 27 from tensorflow.contrib.framework.python.ops import variables 28 from tensorflow.contrib.layers.python.layers import layers 29 from tensorflow.contrib.losses.python.losses import loss_ops 30 from tensorflow.contrib.metrics.python.ops import metric_ops 31 from tensorflow.contrib.training.python.training import evaluation 32 from tensorflow.contrib.training.python.training import training 33 from tensorflow.core.protobuf import config_pb2 34 from tensorflow.python.client import session as session_lib 35 from tensorflow.python.framework import constant_op 36 from tensorflow.python.framework import dtypes 37 from tensorflow.python.framework import ops 38 from tensorflow.python.framework import random_seed 39 from tensorflow.python.ops import array_ops 40 from tensorflow.python.ops import math_ops 41 from tensorflow.python.ops import state_ops 42 from tensorflow.python.ops import variables as variables_lib 43 from tensorflow.python.platform import gfile 44 from tensorflow.python.platform import test 45 from tensorflow.python.summary import summary as summary_lib 46 from tensorflow.python.summary import summary_iterator 47 from tensorflow.python.training import basic_session_run_hooks 48 from tensorflow.python.training import gradient_descent 49 from tensorflow.python.training import saver as saver_lib 50 51 52 class CheckpointIteratorTest(test.TestCase): 53 54 def testReturnsEmptyIfNoCheckpointsFound(self): 55 checkpoint_dir = os.path.join(self.get_temp_dir(), 'no_checkpoints_found') 56 57 num_found = 0 58 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): 59 num_found += 1 60 self.assertEqual(num_found, 0) 61 62 def testReturnsSingleCheckpointIfOneCheckpointFound(self): 63 checkpoint_dir = os.path.join(self.get_temp_dir(), 'one_checkpoint_found') 64 if not gfile.Exists(checkpoint_dir): 65 gfile.MakeDirs(checkpoint_dir) 66 67 global_step = variables.get_or_create_global_step() 68 saver = saver_lib.Saver() # Saves the global step. 69 70 with self.test_session() as session: 71 session.run(variables_lib.global_variables_initializer()) 72 save_path = os.path.join(checkpoint_dir, 'model.ckpt') 73 saver.save(session, save_path, global_step=global_step) 74 75 num_found = 0 76 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): 77 num_found += 1 78 self.assertEqual(num_found, 1) 79 80 def testReturnsSingleCheckpointIfOneShardedCheckpoint(self): 81 checkpoint_dir = os.path.join(self.get_temp_dir(), 82 'one_checkpoint_found_sharded') 83 if not gfile.Exists(checkpoint_dir): 84 gfile.MakeDirs(checkpoint_dir) 85 86 global_step = variables.get_or_create_global_step() 87 88 # This will result in 3 different checkpoint shard files. 89 with ops.device('/cpu:0'): 90 variables_lib.Variable(10, name='v0') 91 with ops.device('/cpu:1'): 92 variables_lib.Variable(20, name='v1') 93 94 saver = saver_lib.Saver(sharded=True) 95 96 with session_lib.Session( 97 target='', 98 config=config_pb2.ConfigProto(device_count={'CPU': 2})) as session: 99 100 session.run(variables_lib.global_variables_initializer()) 101 save_path = os.path.join(checkpoint_dir, 'model.ckpt') 102 saver.save(session, save_path, global_step=global_step) 103 104 num_found = 0 105 for _ in evaluation.checkpoints_iterator(checkpoint_dir, timeout=0): 106 num_found += 1 107 self.assertEqual(num_found, 1) 108 109 def testTimeoutFn(self): 110 timeout_fn_calls = [0] 111 def timeout_fn(): 112 timeout_fn_calls[0] += 1 113 return timeout_fn_calls[0] > 3 114 115 results = list( 116 evaluation.checkpoints_iterator( 117 '/non-existent-dir', timeout=0.1, timeout_fn=timeout_fn)) 118 self.assertEqual([], results) 119 self.assertEqual(4, timeout_fn_calls[0]) 120 121 122 class WaitForNewCheckpointTest(test.TestCase): 123 124 def testReturnsNoneAfterTimeout(self): 125 start = time.time() 126 ret = evaluation.wait_for_new_checkpoint( 127 '/non-existent-dir', 'foo', timeout=1.0, seconds_to_sleep=0.5) 128 end = time.time() 129 self.assertIsNone(ret) 130 131 # We've waited one second. 132 self.assertGreater(end, start + 0.5) 133 134 # The timeout kicked in. 135 self.assertLess(end, start + 1.1) 136 137 138 def logistic_classifier(inputs): 139 return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid) 140 141 142 class EvaluateOnceTest(test.TestCase): 143 144 def setUp(self): 145 super(EvaluateOnceTest, self).setUp() 146 147 # Create an easy training set: 148 np.random.seed(0) 149 150 self._inputs = np.zeros((16, 4)) 151 self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 152 153 for i in range(16): 154 j = int(2 * self._labels[i] + np.random.randint(0, 2)) 155 self._inputs[i, j] = 1 156 157 def _train_model(self, checkpoint_dir, num_steps): 158 """Trains a simple classification model. 159 160 Note that the data has been configured such that after around 300 steps, 161 the model has memorized the dataset (e.g. we can expect %100 accuracy). 162 163 Args: 164 checkpoint_dir: The directory where the checkpoint is written to. 165 num_steps: The number of steps to train for. 166 """ 167 with ops.Graph().as_default(): 168 random_seed.set_random_seed(0) 169 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 170 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 171 172 tf_predictions = logistic_classifier(tf_inputs) 173 loss = loss_ops.log_loss(tf_predictions, tf_labels) 174 175 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 176 train_op = training.create_train_op(loss, optimizer) 177 178 loss = training.train( 179 train_op, 180 checkpoint_dir, 181 hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)]) 182 183 if num_steps >= 300: 184 assert loss < .015 185 186 def testEvaluatePerfectModel(self): 187 checkpoint_dir = os.path.join(self.get_temp_dir(), 188 'evaluate_perfect_model_once') 189 190 # Train a Model to completion: 191 self._train_model(checkpoint_dir, num_steps=300) 192 193 # Run 194 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 195 labels = constant_op.constant(self._labels, dtype=dtypes.float32) 196 logits = logistic_classifier(inputs) 197 predictions = math_ops.round(logits) 198 199 accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels) 200 201 checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir) 202 203 final_ops_values = evaluation.evaluate_once( 204 checkpoint_path=checkpoint_path, 205 eval_ops=update_op, 206 final_ops={'accuracy': accuracy}, 207 hooks=[ 208 evaluation.StopAfterNEvalsHook(1), 209 ]) 210 self.assertTrue(final_ops_values['accuracy'] > .99) 211 212 def testEvalOpAndFinalOp(self): 213 checkpoint_dir = os.path.join(self.get_temp_dir(), 'eval_ops_and_final_ops') 214 215 # Train a model for a single step to get a checkpoint. 216 self._train_model(checkpoint_dir, num_steps=1) 217 checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir) 218 219 # Create the model so we have something to restore. 220 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 221 logistic_classifier(inputs) 222 223 num_evals = 5 224 final_increment = 9.0 225 226 my_var = variables.local_variable(0.0, name='MyVar') 227 eval_ops = state_ops.assign_add(my_var, 1.0) 228 final_ops = array_ops.identity(my_var) + final_increment 229 230 final_ops_values = evaluation.evaluate_once( 231 checkpoint_path=checkpoint_path, 232 eval_ops=eval_ops, 233 final_ops={'value': final_ops}, 234 hooks=[ 235 evaluation.StopAfterNEvalsHook(num_evals), 236 ]) 237 self.assertEqual(final_ops_values['value'], num_evals + final_increment) 238 239 def testOnlyFinalOp(self): 240 checkpoint_dir = os.path.join(self.get_temp_dir(), 'only_final_ops') 241 242 # Train a model for a single step to get a checkpoint. 243 self._train_model(checkpoint_dir, num_steps=1) 244 checkpoint_path = evaluation.wait_for_new_checkpoint(checkpoint_dir) 245 246 # Create the model so we have something to restore. 247 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 248 logistic_classifier(inputs) 249 250 final_increment = 9.0 251 252 my_var = variables.local_variable(0.0, name='MyVar') 253 final_ops = array_ops.identity(my_var) + final_increment 254 255 final_ops_values = evaluation.evaluate_once( 256 checkpoint_path=checkpoint_path, final_ops={'value': final_ops}) 257 self.assertEqual(final_ops_values['value'], final_increment) 258 259 260 class EvaluateRepeatedlyTest(test.TestCase): 261 262 def setUp(self): 263 super(EvaluateRepeatedlyTest, self).setUp() 264 265 # Create an easy training set: 266 np.random.seed(0) 267 268 self._inputs = np.zeros((16, 4)) 269 self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) 270 271 for i in range(16): 272 j = int(2 * self._labels[i] + np.random.randint(0, 2)) 273 self._inputs[i, j] = 1 274 275 def _train_model(self, checkpoint_dir, num_steps): 276 """Trains a simple classification model. 277 278 Note that the data has been configured such that after around 300 steps, 279 the model has memorized the dataset (e.g. we can expect %100 accuracy). 280 281 Args: 282 checkpoint_dir: The directory where the checkpoint is written to. 283 num_steps: The number of steps to train for. 284 """ 285 with ops.Graph().as_default(): 286 random_seed.set_random_seed(0) 287 tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 288 tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32) 289 290 tf_predictions = logistic_classifier(tf_inputs) 291 loss = loss_ops.log_loss(tf_predictions, tf_labels) 292 293 optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0) 294 train_op = training.create_train_op(loss, optimizer) 295 296 loss = training.train( 297 train_op, 298 checkpoint_dir, 299 hooks=[basic_session_run_hooks.StopAtStepHook(num_steps)]) 300 301 def testEvaluatePerfectModel(self): 302 checkpoint_dir = os.path.join(self.get_temp_dir(), 303 'evaluate_perfect_model_repeated') 304 305 # Train a Model to completion: 306 self._train_model(checkpoint_dir, num_steps=300) 307 308 # Run 309 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 310 labels = constant_op.constant(self._labels, dtype=dtypes.float32) 311 logits = logistic_classifier(inputs) 312 predictions = math_ops.round(logits) 313 314 accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels) 315 316 final_values = evaluation.evaluate_repeatedly( 317 checkpoint_dir=checkpoint_dir, 318 eval_ops=update_op, 319 final_ops={'accuracy': accuracy}, 320 hooks=[ 321 evaluation.StopAfterNEvalsHook(1), 322 ], 323 max_number_of_evaluations=1) 324 self.assertTrue(final_values['accuracy'] > .99) 325 326 def testEvaluationLoopTimeout(self): 327 checkpoint_dir = os.path.join(self.get_temp_dir(), 328 'evaluation_loop_timeout') 329 if not gfile.Exists(checkpoint_dir): 330 gfile.MakeDirs(checkpoint_dir) 331 332 # We need a variable that the saver will try to restore. 333 variables.get_or_create_global_step() 334 335 # Run with placeholders. If we actually try to evaluate this, we'd fail 336 # since we're not using a feed_dict. 337 cant_run_op = array_ops.placeholder(dtype=dtypes.float32) 338 339 start = time.time() 340 final_values = evaluation.evaluate_repeatedly( 341 checkpoint_dir=checkpoint_dir, 342 eval_ops=cant_run_op, 343 hooks=[evaluation.StopAfterNEvalsHook(10)], 344 timeout=6) 345 end = time.time() 346 self.assertFalse(final_values) 347 348 # Assert that we've waited for the duration of the timeout (minus the sleep 349 # time). 350 self.assertGreater(end - start, 5.0) 351 352 # Then the timeout kicked in and stops the loop. 353 self.assertLess(end - start, 7) 354 355 def testEvaluationLoopTimeoutWithTimeoutFn(self): 356 checkpoint_dir = os.path.join(self.get_temp_dir(), 357 'evaluation_loop_timeout_with_timeout_fn') 358 359 # Train a Model to completion: 360 self._train_model(checkpoint_dir, num_steps=300) 361 362 # Run 363 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 364 labels = constant_op.constant(self._labels, dtype=dtypes.float32) 365 logits = logistic_classifier(inputs) 366 predictions = math_ops.round(logits) 367 368 accuracy, update_op = metric_ops.streaming_accuracy(predictions, labels) 369 370 timeout_fn_calls = [0] 371 def timeout_fn(): 372 timeout_fn_calls[0] += 1 373 return timeout_fn_calls[0] > 3 374 375 final_values = evaluation.evaluate_repeatedly( 376 checkpoint_dir=checkpoint_dir, 377 eval_ops=update_op, 378 final_ops={'accuracy': accuracy}, 379 hooks=[ 380 evaluation.StopAfterNEvalsHook(1), 381 ], 382 eval_interval_secs=1, 383 max_number_of_evaluations=2, 384 timeout=0.1, 385 timeout_fn=timeout_fn) 386 # We should have evaluated once. 387 self.assertTrue(final_values['accuracy'] > .99) 388 # And called 4 times the timeout fn 389 self.assertEqual(4, timeout_fn_calls[0]) 390 391 def testEvaluateWithEvalFeedDict(self): 392 # Create a checkpoint. 393 checkpoint_dir = os.path.join(self.get_temp_dir(), 394 'evaluate_with_eval_feed_dict') 395 self._train_model(checkpoint_dir, num_steps=1) 396 397 # We need a variable that the saver will try to restore. 398 variables.get_or_create_global_step() 399 400 # Create a variable and an eval op that increments it with a placeholder. 401 my_var = variables.local_variable(0.0, name='my_var') 402 increment = array_ops.placeholder(dtype=dtypes.float32) 403 eval_ops = state_ops.assign_add(my_var, increment) 404 405 increment_value = 3 406 num_evals = 5 407 expected_value = increment_value * num_evals 408 final_values = evaluation.evaluate_repeatedly( 409 checkpoint_dir=checkpoint_dir, 410 eval_ops=eval_ops, 411 feed_dict={increment: 3}, 412 final_ops={'my_var': array_ops.identity(my_var)}, 413 hooks=[ 414 evaluation.StopAfterNEvalsHook(num_evals), 415 ], 416 max_number_of_evaluations=1) 417 self.assertEqual(final_values['my_var'], expected_value) 418 419 def _create_names_to_metrics(self, predictions, labels): 420 accuracy0, update_op0 = metric_ops.streaming_accuracy(predictions, labels) 421 accuracy1, update_op1 = metric_ops.streaming_accuracy( 422 predictions + 1, labels) 423 424 names_to_values = {'Accuracy': accuracy0, 'Another_accuracy': accuracy1} 425 names_to_updates = {'Accuracy': update_op0, 'Another_accuracy': update_op1} 426 return names_to_values, names_to_updates 427 428 def _verify_summaries(self, output_dir, names_to_values): 429 """Verifies that the given `names_to_values` are found in the summaries. 430 431 Args: 432 output_dir: An existing directory where summaries are found. 433 names_to_values: A dictionary of strings to values. 434 """ 435 # Check that the results were saved. The events file may have additional 436 # entries, e.g. the event version stamp, so have to parse things a bit. 437 output_filepath = glob.glob(os.path.join(output_dir, '*')) 438 self.assertEqual(len(output_filepath), 1) 439 440 events = summary_iterator.summary_iterator(output_filepath[0]) 441 summaries = [e.summary for e in events if e.summary.value] 442 values = [] 443 for summary in summaries: 444 for value in summary.value: 445 values.append(value) 446 saved_results = {v.tag: v.simple_value for v in values} 447 for name in names_to_values: 448 self.assertAlmostEqual(names_to_values[name], saved_results[name], 5) 449 450 def testSummariesAreFlushedToDisk(self): 451 checkpoint_dir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed') 452 logdir = os.path.join(self.get_temp_dir(), 'summaries_are_flushed_eval') 453 if gfile.Exists(logdir): 454 gfile.DeleteRecursively(logdir) 455 456 # Train a Model to completion: 457 self._train_model(checkpoint_dir, num_steps=300) 458 459 # Create the model (which can be restored). 460 inputs = constant_op.constant(self._inputs, dtype=dtypes.float32) 461 logistic_classifier(inputs) 462 463 names_to_values = {'bread': 3.4, 'cheese': 4.5, 'tomato': 2.0} 464 465 for k in names_to_values: 466 v = names_to_values[k] 467 summary_lib.scalar(k, v) 468 469 evaluation.evaluate_repeatedly( 470 checkpoint_dir=checkpoint_dir, 471 hooks=[ 472 evaluation.SummaryAtEndHook(log_dir=logdir), 473 ], 474 max_number_of_evaluations=1) 475 476 self._verify_summaries(logdir, names_to_values) 477 478 479 if __name__ == '__main__': 480 test.main() 481