1 # pylint: disable=g-bad-file-header 2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 """Tests for basic_session_run_hooks.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import os.path 23 import shutil 24 import tempfile 25 import threading 26 import time 27 28 from tensorflow.contrib.framework.python.framework import checkpoint_utils 29 from tensorflow.contrib.framework.python.ops import variables 30 from tensorflow.contrib.testing.python.framework import fake_summary_writer 31 from tensorflow.python.client import session as session_lib 32 from tensorflow.python.framework import constant_op 33 from tensorflow.python.framework import dtypes 34 from tensorflow.python.framework import meta_graph 35 from tensorflow.python.framework import ops 36 from tensorflow.python.ops import array_ops 37 from tensorflow.python.ops import control_flow_ops 38 from tensorflow.python.ops import state_ops 39 from tensorflow.python.ops import variable_scope 40 from tensorflow.python.ops import variables as variables_lib 41 import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 42 from tensorflow.python.platform import gfile 43 from tensorflow.python.platform import test 44 from tensorflow.python.platform import tf_logging 45 from tensorflow.python.summary import summary as summary_lib 46 from tensorflow.python.summary.writer import writer_cache 47 from tensorflow.python.training import basic_session_run_hooks 48 from tensorflow.python.training import monitored_session 49 from tensorflow.python.training import session_run_hook 50 from tensorflow.python.training import training_util 51 52 53 class MockCheckpointSaverListener( 54 basic_session_run_hooks.CheckpointSaverListener): 55 56 def __init__(self): 57 self.begin_count = 0 58 self.before_save_count = 0 59 self.after_save_count = 0 60 self.end_count = 0 61 62 def begin(self): 63 self.begin_count += 1 64 65 def before_save(self, session, global_step): 66 self.before_save_count += 1 67 68 def after_save(self, session, global_step): 69 self.after_save_count += 1 70 71 def end(self, session, global_step): 72 self.end_count += 1 73 74 def get_counts(self): 75 return { 76 'begin': self.begin_count, 77 'before_save': self.before_save_count, 78 'after_save': self.after_save_count, 79 'end': self.end_count 80 } 81 82 83 class SecondOrStepTimerTest(test.TestCase): 84 85 def test_raise_in_both_secs_and_steps(self): 86 with self.assertRaises(ValueError): 87 basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10) 88 89 def test_raise_in_none_secs_and_steps(self): 90 with self.assertRaises(ValueError): 91 basic_session_run_hooks.SecondOrStepTimer() 92 93 def test_every_secs(self): 94 timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0) 95 self.assertTrue(timer.should_trigger_for_step(1)) 96 97 timer.update_last_triggered_step(1) 98 self.assertFalse(timer.should_trigger_for_step(1)) 99 self.assertFalse(timer.should_trigger_for_step(2)) 100 101 time.sleep(1.0) 102 self.assertFalse(timer.should_trigger_for_step(1)) 103 self.assertTrue(timer.should_trigger_for_step(2)) 104 105 def test_every_steps(self): 106 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3) 107 self.assertTrue(timer.should_trigger_for_step(1)) 108 109 timer.update_last_triggered_step(1) 110 self.assertFalse(timer.should_trigger_for_step(1)) 111 self.assertFalse(timer.should_trigger_for_step(2)) 112 self.assertFalse(timer.should_trigger_for_step(3)) 113 self.assertTrue(timer.should_trigger_for_step(4)) 114 115 def test_update_last_triggered_step(self): 116 timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1) 117 118 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1) 119 self.assertEqual(None, elapsed_secs) 120 self.assertEqual(None, elapsed_steps) 121 122 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5) 123 self.assertLess(0, elapsed_secs) 124 self.assertEqual(4, elapsed_steps) 125 126 elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7) 127 self.assertLess(0, elapsed_secs) 128 self.assertEqual(2, elapsed_steps) 129 130 131 class StopAtStepTest(test.TestCase): 132 133 def test_raise_in_both_last_step_and_num_steps(self): 134 with self.assertRaises(ValueError): 135 basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20) 136 137 def test_stop_based_on_last_step(self): 138 h = basic_session_run_hooks.StopAtStepHook(last_step=10) 139 with ops.Graph().as_default(): 140 global_step = variables.get_or_create_global_step() 141 no_op = control_flow_ops.no_op() 142 h.begin() 143 with session_lib.Session() as sess: 144 mon_sess = monitored_session._HookedSession(sess, [h]) 145 sess.run(state_ops.assign(global_step, 5)) 146 h.after_create_session(sess, None) 147 mon_sess.run(no_op) 148 self.assertFalse(mon_sess.should_stop()) 149 sess.run(state_ops.assign(global_step, 9)) 150 mon_sess.run(no_op) 151 self.assertFalse(mon_sess.should_stop()) 152 sess.run(state_ops.assign(global_step, 10)) 153 mon_sess.run(no_op) 154 self.assertTrue(mon_sess.should_stop()) 155 sess.run(state_ops.assign(global_step, 11)) 156 mon_sess._should_stop = False 157 mon_sess.run(no_op) 158 self.assertTrue(mon_sess.should_stop()) 159 160 def test_stop_based_on_num_step(self): 161 h = basic_session_run_hooks.StopAtStepHook(num_steps=10) 162 163 with ops.Graph().as_default(): 164 global_step = variables.get_or_create_global_step() 165 no_op = control_flow_ops.no_op() 166 h.begin() 167 with session_lib.Session() as sess: 168 mon_sess = monitored_session._HookedSession(sess, [h]) 169 sess.run(state_ops.assign(global_step, 5)) 170 h.after_create_session(sess, None) 171 mon_sess.run(no_op) 172 self.assertFalse(mon_sess.should_stop()) 173 sess.run(state_ops.assign(global_step, 13)) 174 mon_sess.run(no_op) 175 self.assertFalse(mon_sess.should_stop()) 176 sess.run(state_ops.assign(global_step, 14)) 177 mon_sess.run(no_op) 178 self.assertFalse(mon_sess.should_stop()) 179 sess.run(state_ops.assign(global_step, 15)) 180 mon_sess.run(no_op) 181 self.assertTrue(mon_sess.should_stop()) 182 sess.run(state_ops.assign(global_step, 16)) 183 mon_sess._should_stop = False 184 mon_sess.run(no_op) 185 self.assertTrue(mon_sess.should_stop()) 186 187 def test_stop_based_with_multiple_steps(self): 188 h = basic_session_run_hooks.StopAtStepHook(num_steps=10) 189 190 with ops.Graph().as_default(): 191 global_step = variables.get_or_create_global_step() 192 no_op = control_flow_ops.no_op() 193 h.begin() 194 with session_lib.Session() as sess: 195 mon_sess = monitored_session._HookedSession(sess, [h]) 196 sess.run(state_ops.assign(global_step, 5)) 197 h.after_create_session(sess, None) 198 mon_sess.run(no_op) 199 self.assertFalse(mon_sess.should_stop()) 200 sess.run(state_ops.assign(global_step, 15)) 201 mon_sess.run(no_op) 202 self.assertTrue(mon_sess.should_stop()) 203 204 205 class LoggingTensorHookTest(test.TestCase): 206 207 def setUp(self): 208 # Mock out logging calls so we can verify whether correct tensors are being 209 # monitored. 210 self._actual_log = tf_logging.info 211 self.logged_message = None 212 213 def mock_log(*args, **kwargs): 214 self.logged_message = args 215 self._actual_log(*args, **kwargs) 216 217 tf_logging.info = mock_log 218 219 def tearDown(self): 220 tf_logging.info = self._actual_log 221 222 def test_illegal_args(self): 223 with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): 224 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0) 225 with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'): 226 basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10) 227 with self.assertRaisesRegexp(ValueError, 'xactly one of'): 228 basic_session_run_hooks.LoggingTensorHook( 229 tensors=['t'], every_n_iter=5, every_n_secs=5) 230 with self.assertRaisesRegexp(ValueError, 'xactly one of'): 231 basic_session_run_hooks.LoggingTensorHook(tensors=['t']) 232 233 def test_print_at_end_only(self): 234 with ops.Graph().as_default(), session_lib.Session() as sess: 235 t = constant_op.constant(42.0, name='foo') 236 train_op = constant_op.constant(3) 237 hook = basic_session_run_hooks.LoggingTensorHook( 238 tensors=[t.name], at_end=True) 239 hook.begin() 240 mon_sess = monitored_session._HookedSession(sess, [hook]) 241 sess.run(variables_lib.global_variables_initializer()) 242 self.logged_message = '' 243 for _ in range(3): 244 mon_sess.run(train_op) 245 # assertNotRegexpMatches is not supported by python 3.1 and later 246 self.assertEqual(str(self.logged_message).find(t.name), -1) 247 248 hook.end(sess) 249 self.assertRegexpMatches(str(self.logged_message), t.name) 250 251 def _validate_print_every_n_steps(self, sess, at_end): 252 t = constant_op.constant(42.0, name='foo') 253 254 train_op = constant_op.constant(3) 255 hook = basic_session_run_hooks.LoggingTensorHook( 256 tensors=[t.name], every_n_iter=10, at_end=at_end) 257 hook.begin() 258 mon_sess = monitored_session._HookedSession(sess, [hook]) 259 sess.run(variables_lib.global_variables_initializer()) 260 mon_sess.run(train_op) 261 self.assertRegexpMatches(str(self.logged_message), t.name) 262 for _ in range(3): 263 self.logged_message = '' 264 for _ in range(9): 265 mon_sess.run(train_op) 266 # assertNotRegexpMatches is not supported by python 3.1 and later 267 self.assertEqual(str(self.logged_message).find(t.name), -1) 268 mon_sess.run(train_op) 269 self.assertRegexpMatches(str(self.logged_message), t.name) 270 271 # Add additional run to verify proper reset when called multiple times. 272 self.logged_message = '' 273 mon_sess.run(train_op) 274 # assertNotRegexpMatches is not supported by python 3.1 and later 275 self.assertEqual(str(self.logged_message).find(t.name), -1) 276 277 self.logged_message = '' 278 hook.end(sess) 279 if at_end: 280 self.assertRegexpMatches(str(self.logged_message), t.name) 281 else: 282 # assertNotRegexpMatches is not supported by python 3.1 and later 283 self.assertEqual(str(self.logged_message).find(t.name), -1) 284 285 def test_print_every_n_steps(self): 286 with ops.Graph().as_default(), session_lib.Session() as sess: 287 self._validate_print_every_n_steps(sess, at_end=False) 288 # Verify proper reset. 289 self._validate_print_every_n_steps(sess, at_end=False) 290 291 def test_print_every_n_steps_and_end(self): 292 with ops.Graph().as_default(), session_lib.Session() as sess: 293 self._validate_print_every_n_steps(sess, at_end=True) 294 # Verify proper reset. 295 self._validate_print_every_n_steps(sess, at_end=True) 296 297 def test_print_first_step(self): 298 # if it runs every iteration, first iteration has None duration. 299 with ops.Graph().as_default(), session_lib.Session() as sess: 300 t = constant_op.constant(42.0, name='foo') 301 train_op = constant_op.constant(3) 302 hook = basic_session_run_hooks.LoggingTensorHook( 303 tensors={'foo': t}, every_n_iter=1) 304 hook.begin() 305 mon_sess = monitored_session._HookedSession(sess, [hook]) 306 sess.run(variables_lib.global_variables_initializer()) 307 mon_sess.run(train_op) 308 self.assertRegexpMatches(str(self.logged_message), 'foo') 309 # in first run, elapsed time is None. 310 self.assertEqual(str(self.logged_message).find('sec'), -1) 311 312 def _validate_print_every_n_secs(self, sess, at_end): 313 t = constant_op.constant(42.0, name='foo') 314 train_op = constant_op.constant(3) 315 316 hook = basic_session_run_hooks.LoggingTensorHook( 317 tensors=[t.name], every_n_secs=1.0, at_end=at_end) 318 hook.begin() 319 mon_sess = monitored_session._HookedSession(sess, [hook]) 320 sess.run(variables_lib.global_variables_initializer()) 321 322 mon_sess.run(train_op) 323 self.assertRegexpMatches(str(self.logged_message), t.name) 324 325 # assertNotRegexpMatches is not supported by python 3.1 and later 326 self.logged_message = '' 327 mon_sess.run(train_op) 328 self.assertEqual(str(self.logged_message).find(t.name), -1) 329 time.sleep(1.0) 330 331 self.logged_message = '' 332 mon_sess.run(train_op) 333 self.assertRegexpMatches(str(self.logged_message), t.name) 334 335 self.logged_message = '' 336 hook.end(sess) 337 if at_end: 338 self.assertRegexpMatches(str(self.logged_message), t.name) 339 else: 340 # assertNotRegexpMatches is not supported by python 3.1 and later 341 self.assertEqual(str(self.logged_message).find(t.name), -1) 342 343 def test_print_every_n_secs(self): 344 with ops.Graph().as_default(), session_lib.Session() as sess: 345 self._validate_print_every_n_secs(sess, at_end=False) 346 # Verify proper reset. 347 self._validate_print_every_n_secs(sess, at_end=False) 348 349 def test_print_every_n_secs_and_end(self): 350 with ops.Graph().as_default(), session_lib.Session() as sess: 351 self._validate_print_every_n_secs(sess, at_end=True) 352 # Verify proper reset. 353 self._validate_print_every_n_secs(sess, at_end=True) 354 355 def test_print_formatter(self): 356 with ops.Graph().as_default(), session_lib.Session() as sess: 357 t = constant_op.constant(42.0, name='foo') 358 train_op = constant_op.constant(3) 359 hook = basic_session_run_hooks.LoggingTensorHook( 360 tensors=[t.name], every_n_iter=10, 361 formatter=lambda items: 'qqq=%s' % items[t.name]) 362 hook.begin() 363 mon_sess = monitored_session._HookedSession(sess, [hook]) 364 sess.run(variables_lib.global_variables_initializer()) 365 mon_sess.run(train_op) 366 self.assertEqual(self.logged_message[0], 'qqq=42.0') 367 368 369 class CheckpointSaverHookTest(test.TestCase): 370 371 def setUp(self): 372 self.model_dir = tempfile.mkdtemp() 373 self.graph = ops.Graph() 374 with self.graph.as_default(): 375 self.scaffold = monitored_session.Scaffold() 376 self.global_step = variables.get_or_create_global_step() 377 self.train_op = training_util._increment_global_step(1) 378 379 def tearDown(self): 380 shutil.rmtree(self.model_dir, ignore_errors=True) 381 382 def test_saves_when_saver_and_scaffold_both_missing(self): 383 with self.graph.as_default(): 384 hook = basic_session_run_hooks.CheckpointSaverHook( 385 self.model_dir, save_steps=1) 386 hook.begin() 387 self.scaffold.finalize() 388 with session_lib.Session() as sess: 389 sess.run(self.scaffold.init_op) 390 mon_sess = monitored_session._HookedSession(sess, [hook]) 391 mon_sess.run(self.train_op) 392 self.assertEqual(1, 393 checkpoint_utils.load_variable(self.model_dir, 394 self.global_step.name)) 395 396 def test_raise_when_saver_and_scaffold_both_present(self): 397 with self.assertRaises(ValueError): 398 basic_session_run_hooks.CheckpointSaverHook( 399 self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold) 400 401 def test_raise_in_both_secs_and_steps(self): 402 with self.assertRaises(ValueError): 403 basic_session_run_hooks.CheckpointSaverHook( 404 self.model_dir, save_secs=10, save_steps=20) 405 406 def test_raise_in_none_secs_and_steps(self): 407 with self.assertRaises(ValueError): 408 basic_session_run_hooks.CheckpointSaverHook(self.model_dir) 409 410 def test_save_secs_saves_in_first_step(self): 411 with self.graph.as_default(): 412 hook = basic_session_run_hooks.CheckpointSaverHook( 413 self.model_dir, save_secs=2, scaffold=self.scaffold) 414 hook.begin() 415 self.scaffold.finalize() 416 with session_lib.Session() as sess: 417 sess.run(self.scaffold.init_op) 418 mon_sess = monitored_session._HookedSession(sess, [hook]) 419 mon_sess.run(self.train_op) 420 self.assertEqual(1, 421 checkpoint_utils.load_variable(self.model_dir, 422 self.global_step.name)) 423 424 def test_save_secs_calls_listeners_at_begin_and_end(self): 425 with self.graph.as_default(): 426 listener = MockCheckpointSaverListener() 427 hook = basic_session_run_hooks.CheckpointSaverHook( 428 self.model_dir, 429 save_secs=2, 430 scaffold=self.scaffold, 431 listeners=[listener]) 432 hook.begin() 433 self.scaffold.finalize() 434 with session_lib.Session() as sess: 435 sess.run(self.scaffold.init_op) 436 mon_sess = monitored_session._HookedSession(sess, [hook]) 437 mon_sess.run(self.train_op) # hook runs here 438 mon_sess.run(self.train_op) # hook won't run here, so it does at end 439 hook.end(sess) # hook runs here 440 self.assertEqual({ 441 'begin': 1, 442 'before_save': 2, 443 'after_save': 2, 444 'end': 1 445 }, listener.get_counts()) 446 447 def test_listener_with_monitored_session(self): 448 with ops.Graph().as_default(): 449 scaffold = monitored_session.Scaffold() 450 global_step = variables.get_or_create_global_step() 451 train_op = training_util._increment_global_step(1) 452 listener = MockCheckpointSaverListener() 453 hook = basic_session_run_hooks.CheckpointSaverHook( 454 self.model_dir, 455 save_steps=1, 456 scaffold=scaffold, 457 listeners=[listener]) 458 with monitored_session.SingularMonitoredSession( 459 hooks=[hook], 460 scaffold=scaffold, 461 checkpoint_dir=self.model_dir) as sess: 462 sess.run(train_op) 463 sess.run(train_op) 464 global_step_val = sess.raw_session().run(global_step) 465 listener_counts = listener.get_counts() 466 self.assertEqual(2, global_step_val) 467 self.assertEqual({ 468 'begin': 1, 469 'before_save': 2, 470 'after_save': 2, 471 'end': 1 472 }, listener_counts) 473 474 def test_listener_with_default_saver(self): 475 with ops.Graph().as_default(): 476 global_step = variables.get_or_create_global_step() 477 train_op = training_util._increment_global_step(1) 478 listener = MockCheckpointSaverListener() 479 hook = basic_session_run_hooks.CheckpointSaverHook( 480 self.model_dir, 481 save_steps=1, 482 listeners=[listener]) 483 with monitored_session.SingularMonitoredSession( 484 hooks=[hook], 485 checkpoint_dir=self.model_dir) as sess: 486 sess.run(train_op) 487 sess.run(train_op) 488 global_step_val = sess.raw_session().run(global_step) 489 listener_counts = listener.get_counts() 490 self.assertEqual(2, global_step_val) 491 self.assertEqual({ 492 'begin': 1, 493 'before_save': 2, 494 'after_save': 2, 495 'end': 1 496 }, listener_counts) 497 498 with ops.Graph().as_default(): 499 global_step = variables.get_or_create_global_step() 500 with monitored_session.SingularMonitoredSession( 501 checkpoint_dir=self.model_dir) as sess2: 502 global_step_saved_val = sess2.run(global_step) 503 self.assertEqual(2, global_step_saved_val) 504 505 def test_two_listeners_with_default_saver(self): 506 with ops.Graph().as_default(): 507 global_step = variables.get_or_create_global_step() 508 train_op = training_util._increment_global_step(1) 509 listener1 = MockCheckpointSaverListener() 510 listener2 = MockCheckpointSaverListener() 511 hook = basic_session_run_hooks.CheckpointSaverHook( 512 self.model_dir, 513 save_steps=1, 514 listeners=[listener1, listener2]) 515 with monitored_session.SingularMonitoredSession( 516 hooks=[hook], 517 checkpoint_dir=self.model_dir) as sess: 518 sess.run(train_op) 519 sess.run(train_op) 520 global_step_val = sess.raw_session().run(global_step) 521 listener1_counts = listener1.get_counts() 522 listener2_counts = listener2.get_counts() 523 self.assertEqual(2, global_step_val) 524 self.assertEqual({ 525 'begin': 1, 526 'before_save': 2, 527 'after_save': 2, 528 'end': 1 529 }, listener1_counts) 530 self.assertEqual(listener1_counts, listener2_counts) 531 532 with ops.Graph().as_default(): 533 global_step = variables.get_or_create_global_step() 534 with monitored_session.SingularMonitoredSession( 535 checkpoint_dir=self.model_dir) as sess2: 536 global_step_saved_val = sess2.run(global_step) 537 self.assertEqual(2, global_step_saved_val) 538 539 @test.mock.patch.object(time, 'time') 540 def test_save_secs_saves_periodically(self, mock_time): 541 # Let's have a realistic start time 542 current_time = 1484695987.209386 543 544 with self.graph.as_default(): 545 mock_time.return_value = current_time 546 hook = basic_session_run_hooks.CheckpointSaverHook( 547 self.model_dir, save_secs=2, scaffold=self.scaffold) 548 hook.begin() 549 self.scaffold.finalize() 550 551 with session_lib.Session() as sess: 552 sess.run(self.scaffold.init_op) 553 mon_sess = monitored_session._HookedSession(sess, [hook]) 554 555 mock_time.return_value = current_time 556 mon_sess.run(self.train_op) # Saved. 557 558 mock_time.return_value = current_time + 0.5 559 mon_sess.run(self.train_op) # Not saved. 560 561 self.assertEqual(1, 562 checkpoint_utils.load_variable(self.model_dir, 563 self.global_step.name)) 564 565 # Simulate 2.5 seconds of sleep. 566 mock_time.return_value = current_time + 2.5 567 mon_sess.run(self.train_op) # Saved. 568 569 mock_time.return_value = current_time + 2.6 570 mon_sess.run(self.train_op) # Not saved. 571 572 mock_time.return_value = current_time + 2.7 573 mon_sess.run(self.train_op) # Not saved. 574 575 self.assertEqual(3, 576 checkpoint_utils.load_variable(self.model_dir, 577 self.global_step.name)) 578 579 # Simulate 7.5 more seconds of sleep (10 seconds from start. 580 mock_time.return_value = current_time + 10 581 mon_sess.run(self.train_op) # Saved. 582 self.assertEqual(6, 583 checkpoint_utils.load_variable(self.model_dir, 584 self.global_step.name)) 585 586 @test.mock.patch.object(time, 'time') 587 def test_save_secs_calls_listeners_periodically(self, mock_time): 588 # Let's have a realistic start time 589 current_time = 1484695987.209386 590 591 with self.graph.as_default(): 592 mock_time.return_value = current_time 593 listener = MockCheckpointSaverListener() 594 hook = basic_session_run_hooks.CheckpointSaverHook( 595 self.model_dir, 596 save_secs=2, 597 scaffold=self.scaffold, 598 listeners=[listener]) 599 hook.begin() 600 self.scaffold.finalize() 601 with session_lib.Session() as sess: 602 sess.run(self.scaffold.init_op) 603 mon_sess = monitored_session._HookedSession(sess, [hook]) 604 605 mock_time.return_value = current_time + 0.5 606 mon_sess.run(self.train_op) # hook runs here 607 608 mock_time.return_value = current_time + 0.5 609 mon_sess.run(self.train_op) 610 611 mock_time.return_value = current_time + 3.0 612 mon_sess.run(self.train_op) # hook runs here 613 614 mock_time.return_value = current_time + 3.5 615 mon_sess.run(self.train_op) 616 617 mock_time.return_value = current_time + 4.0 618 mon_sess.run(self.train_op) 619 620 mock_time.return_value = current_time + 6.5 621 mon_sess.run(self.train_op) # hook runs here 622 623 mock_time.return_value = current_time + 7.0 624 mon_sess.run(self.train_op) # hook won't run here, so it does at end 625 626 mock_time.return_value = current_time + 7.5 627 hook.end(sess) # hook runs here 628 self.assertEqual({ 629 'begin': 1, 630 'before_save': 4, 631 'after_save': 4, 632 'end': 1 633 }, listener.get_counts()) 634 635 def test_save_steps_saves_in_first_step(self): 636 with self.graph.as_default(): 637 hook = basic_session_run_hooks.CheckpointSaverHook( 638 self.model_dir, save_steps=2, scaffold=self.scaffold) 639 hook.begin() 640 self.scaffold.finalize() 641 with session_lib.Session() as sess: 642 sess.run(self.scaffold.init_op) 643 mon_sess = monitored_session._HookedSession(sess, [hook]) 644 mon_sess.run(self.train_op) 645 self.assertEqual(1, 646 checkpoint_utils.load_variable(self.model_dir, 647 self.global_step.name)) 648 649 def test_save_steps_saves_periodically(self): 650 with self.graph.as_default(): 651 hook = basic_session_run_hooks.CheckpointSaverHook( 652 self.model_dir, save_steps=2, scaffold=self.scaffold) 653 hook.begin() 654 self.scaffold.finalize() 655 with session_lib.Session() as sess: 656 sess.run(self.scaffold.init_op) 657 mon_sess = monitored_session._HookedSession(sess, [hook]) 658 mon_sess.run(self.train_op) 659 mon_sess.run(self.train_op) 660 # Not saved 661 self.assertEqual(1, 662 checkpoint_utils.load_variable(self.model_dir, 663 self.global_step.name)) 664 mon_sess.run(self.train_op) 665 # saved 666 self.assertEqual(3, 667 checkpoint_utils.load_variable(self.model_dir, 668 self.global_step.name)) 669 mon_sess.run(self.train_op) 670 # Not saved 671 self.assertEqual(3, 672 checkpoint_utils.load_variable(self.model_dir, 673 self.global_step.name)) 674 mon_sess.run(self.train_op) 675 # saved 676 self.assertEqual(5, 677 checkpoint_utils.load_variable(self.model_dir, 678 self.global_step.name)) 679 680 def test_save_saves_at_end(self): 681 with self.graph.as_default(): 682 hook = basic_session_run_hooks.CheckpointSaverHook( 683 self.model_dir, save_secs=2, scaffold=self.scaffold) 684 hook.begin() 685 self.scaffold.finalize() 686 with session_lib.Session() as sess: 687 sess.run(self.scaffold.init_op) 688 mon_sess = monitored_session._HookedSession(sess, [hook]) 689 mon_sess.run(self.train_op) 690 mon_sess.run(self.train_op) 691 hook.end(sess) 692 self.assertEqual(2, 693 checkpoint_utils.load_variable(self.model_dir, 694 self.global_step.name)) 695 696 def test_summary_writer_defs(self): 697 fake_summary_writer.FakeSummaryWriter.install() 698 writer_cache.FileWriterCache.clear() 699 summary_writer = writer_cache.FileWriterCache.get(self.model_dir) 700 701 with self.graph.as_default(): 702 hook = basic_session_run_hooks.CheckpointSaverHook( 703 self.model_dir, save_steps=2, scaffold=self.scaffold) 704 hook.begin() 705 self.scaffold.finalize() 706 with session_lib.Session() as sess: 707 sess.run(self.scaffold.init_op) 708 mon_sess = monitored_session._HookedSession(sess, [hook]) 709 mon_sess.run(self.train_op) 710 summary_writer.assert_summaries( 711 test_case=self, 712 expected_logdir=self.model_dir, 713 expected_added_meta_graphs=[ 714 meta_graph.create_meta_graph_def( 715 graph_def=self.graph.as_graph_def(add_shapes=True), 716 saver_def=self.scaffold.saver.saver_def) 717 ]) 718 719 fake_summary_writer.FakeSummaryWriter.uninstall() 720 721 722 class ResourceCheckpointSaverHookTest(test.TestCase): 723 724 def setUp(self): 725 self.model_dir = tempfile.mkdtemp() 726 self.graph = ops.Graph() 727 with self.graph.as_default(): 728 self.scaffold = monitored_session.Scaffold() 729 with variable_scope.variable_scope('foo', use_resource=True): 730 self.global_step = training_util.get_or_create_global_step() 731 self.train_op = training_util._increment_global_step(1) 732 733 def test_save_steps_saves_periodically(self): 734 with self.graph.as_default(): 735 hook = basic_session_run_hooks.CheckpointSaverHook( 736 self.model_dir, save_steps=2, scaffold=self.scaffold) 737 hook.begin() 738 self.scaffold.finalize() 739 with session_lib.Session() as sess: 740 sess.run(self.scaffold.init_op) 741 mon_sess = monitored_session._HookedSession(sess, [hook]) 742 mon_sess.run(self.train_op) 743 mon_sess.run(self.train_op) 744 # Not saved 745 self.assertEqual(1, 746 checkpoint_utils.load_variable(self.model_dir, 747 self.global_step.name)) 748 mon_sess.run(self.train_op) 749 # saved 750 self.assertEqual(3, 751 checkpoint_utils.load_variable(self.model_dir, 752 self.global_step.name)) 753 mon_sess.run(self.train_op) 754 # Not saved 755 self.assertEqual(3, 756 checkpoint_utils.load_variable(self.model_dir, 757 self.global_step.name)) 758 mon_sess.run(self.train_op) 759 # saved 760 self.assertEqual(5, 761 checkpoint_utils.load_variable(self.model_dir, 762 self.global_step.name)) 763 764 765 class StepCounterHookTest(test.TestCase): 766 767 def setUp(self): 768 self.log_dir = tempfile.mkdtemp() 769 770 def tearDown(self): 771 shutil.rmtree(self.log_dir, ignore_errors=True) 772 773 def test_step_counter_every_n_steps(self): 774 with ops.Graph().as_default() as g, session_lib.Session() as sess: 775 variables.get_or_create_global_step() 776 train_op = training_util._increment_global_step(1) 777 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 778 hook = basic_session_run_hooks.StepCounterHook( 779 summary_writer=summary_writer, every_n_steps=10) 780 hook.begin() 781 sess.run(variables_lib.global_variables_initializer()) 782 mon_sess = monitored_session._HookedSession(sess, [hook]) 783 with test.mock.patch.object(tf_logging, 'warning') as mock_log: 784 for _ in range(30): 785 time.sleep(0.01) 786 mon_sess.run(train_op) 787 # logging.warning should not be called. 788 self.assertIsNone(mock_log.call_args) 789 hook.end(sess) 790 summary_writer.assert_summaries( 791 test_case=self, 792 expected_logdir=self.log_dir, 793 expected_graph=g, 794 expected_summaries={}) 795 self.assertItemsEqual([11, 21], summary_writer.summaries.keys()) 796 for step in [11, 21]: 797 summary_value = summary_writer.summaries[step][0].value[0] 798 self.assertEqual('global_step/sec', summary_value.tag) 799 self.assertGreater(summary_value.simple_value, 0) 800 801 def test_step_counter_every_n_secs(self): 802 with ops.Graph().as_default() as g, session_lib.Session() as sess: 803 variables.get_or_create_global_step() 804 train_op = training_util._increment_global_step(1) 805 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 806 hook = basic_session_run_hooks.StepCounterHook( 807 summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1) 808 809 hook.begin() 810 sess.run(variables_lib.global_variables_initializer()) 811 mon_sess = monitored_session._HookedSession(sess, [hook]) 812 mon_sess.run(train_op) 813 time.sleep(0.2) 814 mon_sess.run(train_op) 815 time.sleep(0.2) 816 mon_sess.run(train_op) 817 hook.end(sess) 818 819 summary_writer.assert_summaries( 820 test_case=self, 821 expected_logdir=self.log_dir, 822 expected_graph=g, 823 expected_summaries={}) 824 self.assertTrue(summary_writer.summaries, 'No summaries were created.') 825 self.assertItemsEqual([2, 3], summary_writer.summaries.keys()) 826 for summary in summary_writer.summaries.values(): 827 summary_value = summary[0].value[0] 828 self.assertEqual('global_step/sec', summary_value.tag) 829 self.assertGreater(summary_value.simple_value, 0) 830 831 def test_global_step_name(self): 832 with ops.Graph().as_default() as g, session_lib.Session() as sess: 833 with variable_scope.variable_scope('bar'): 834 variable_scope.get_variable( 835 'foo', 836 initializer=0, 837 trainable=False, 838 collections=[ 839 ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES 840 ]) 841 train_op = training_util._increment_global_step(1) 842 summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g) 843 hook = basic_session_run_hooks.StepCounterHook( 844 summary_writer=summary_writer, every_n_steps=1, every_n_secs=None) 845 846 hook.begin() 847 sess.run(variables_lib.global_variables_initializer()) 848 mon_sess = monitored_session._HookedSession(sess, [hook]) 849 mon_sess.run(train_op) 850 mon_sess.run(train_op) 851 hook.end(sess) 852 853 summary_writer.assert_summaries( 854 test_case=self, 855 expected_logdir=self.log_dir, 856 expected_graph=g, 857 expected_summaries={}) 858 self.assertTrue(summary_writer.summaries, 'No summaries were created.') 859 self.assertItemsEqual([2], summary_writer.summaries.keys()) 860 summary_value = summary_writer.summaries[2][0].value[0] 861 self.assertEqual('bar/foo/sec', summary_value.tag) 862 863 def test_log_warning_if_global_step_not_increased(self): 864 with ops.Graph().as_default(), session_lib.Session() as sess: 865 variables.get_or_create_global_step() 866 train_op = training_util._increment_global_step(0) # keep same. 867 sess.run(variables_lib.global_variables_initializer()) 868 hook = basic_session_run_hooks.StepCounterHook( 869 every_n_steps=1, every_n_secs=None) 870 hook.begin() 871 mon_sess = monitored_session._HookedSession(sess, [hook]) 872 mon_sess.run(train_op) # Run one step to record global step. 873 with test.mock.patch.object(tf_logging, 'warning') as mock_log: 874 for _ in range(30): 875 mon_sess.run(train_op) 876 self.assertRegexpMatches( 877 str(mock_log.call_args), 878 'global step.*has not been increased') 879 hook.end(sess) 880 881 882 class SummarySaverHookTest(test.TestCase): 883 884 def setUp(self): 885 test.TestCase.setUp(self) 886 887 self.log_dir = 'log/dir' 888 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir) 889 890 var = variables_lib.Variable(0.0) 891 tensor = state_ops.assign_add(var, 1.0) 892 tensor2 = tensor * 2 893 self.summary_op = summary_lib.scalar('my_summary', tensor) 894 self.summary_op2 = summary_lib.scalar('my_summary2', tensor2) 895 896 variables.get_or_create_global_step() 897 self.train_op = training_util._increment_global_step(1) 898 899 def test_raise_when_scaffold_and_summary_op_both_missing(self): 900 with self.assertRaises(ValueError): 901 basic_session_run_hooks.SummarySaverHook() 902 903 def test_raise_when_scaffold_and_summary_op_both_present(self): 904 with self.assertRaises(ValueError): 905 basic_session_run_hooks.SummarySaverHook( 906 scaffold=monitored_session.Scaffold(), summary_op=self.summary_op) 907 908 def test_raise_in_both_secs_and_steps(self): 909 with self.assertRaises(ValueError): 910 basic_session_run_hooks.SummarySaverHook( 911 save_secs=10, save_steps=20, summary_writer=self.summary_writer) 912 913 def test_raise_in_none_secs_and_steps(self): 914 with self.assertRaises(ValueError): 915 basic_session_run_hooks.SummarySaverHook( 916 save_secs=None, save_steps=None, summary_writer=self.summary_writer) 917 918 def test_save_steps(self): 919 hook = basic_session_run_hooks.SummarySaverHook( 920 save_steps=8, 921 summary_writer=self.summary_writer, 922 summary_op=self.summary_op) 923 924 with self.test_session() as sess: 925 hook.begin() 926 sess.run(variables_lib.global_variables_initializer()) 927 mon_sess = monitored_session._HookedSession(sess, [hook]) 928 for _ in range(30): 929 mon_sess.run(self.train_op) 930 hook.end(sess) 931 932 self.summary_writer.assert_summaries( 933 test_case=self, 934 expected_logdir=self.log_dir, 935 expected_summaries={ 936 1: { 937 'my_summary': 1.0 938 }, 939 9: { 940 'my_summary': 2.0 941 }, 942 17: { 943 'my_summary': 3.0 944 }, 945 25: { 946 'my_summary': 4.0 947 }, 948 }) 949 950 def test_multiple_summaries(self): 951 hook = basic_session_run_hooks.SummarySaverHook( 952 save_steps=8, 953 summary_writer=self.summary_writer, 954 summary_op=[self.summary_op, self.summary_op2]) 955 956 with self.test_session() as sess: 957 hook.begin() 958 sess.run(variables_lib.global_variables_initializer()) 959 mon_sess = monitored_session._HookedSession(sess, [hook]) 960 for _ in range(10): 961 mon_sess.run(self.train_op) 962 hook.end(sess) 963 964 self.summary_writer.assert_summaries( 965 test_case=self, 966 expected_logdir=self.log_dir, 967 expected_summaries={ 968 1: { 969 'my_summary': 1.0, 970 'my_summary2': 2.0 971 }, 972 9: { 973 'my_summary': 2.0, 974 'my_summary2': 4.0 975 }, 976 }) 977 978 def test_save_secs_saving_once_every_step(self): 979 hook = basic_session_run_hooks.SummarySaverHook( 980 save_secs=0.5, 981 summary_writer=self.summary_writer, 982 summary_op=self.summary_op) 983 984 with self.test_session() as sess: 985 hook.begin() 986 sess.run(variables_lib.global_variables_initializer()) 987 mon_sess = monitored_session._HookedSession(sess, [hook]) 988 for _ in range(4): 989 mon_sess.run(self.train_op) 990 time.sleep(0.5) 991 hook.end(sess) 992 993 self.summary_writer.assert_summaries( 994 test_case=self, 995 expected_logdir=self.log_dir, 996 expected_summaries={ 997 1: { 998 'my_summary': 1.0 999 }, 1000 2: { 1001 'my_summary': 2.0 1002 }, 1003 3: { 1004 'my_summary': 3.0 1005 }, 1006 4: { 1007 'my_summary': 4.0 1008 }, 1009 }) 1010 1011 @test.mock.patch.object(time, 'time') 1012 def test_save_secs_saving_once_every_three_steps(self, mock_time): 1013 mock_time.return_value = 1484695987.209386 1014 hook = basic_session_run_hooks.SummarySaverHook( 1015 save_secs=9., 1016 summary_writer=self.summary_writer, 1017 summary_op=self.summary_op) 1018 1019 with self.test_session() as sess: 1020 hook.begin() 1021 sess.run(variables_lib.global_variables_initializer()) 1022 mon_sess = monitored_session._HookedSession(sess, [hook]) 1023 for _ in range(8): 1024 mon_sess.run(self.train_op) 1025 mock_time.return_value += 3.1 1026 hook.end(sess) 1027 1028 # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first: 1029 self.summary_writer.assert_summaries( 1030 test_case=self, 1031 expected_logdir=self.log_dir, 1032 expected_summaries={ 1033 1: { 1034 'my_summary': 1.0 1035 }, 1036 4: { 1037 'my_summary': 2.0 1038 }, 1039 7: { 1040 'my_summary': 3.0 1041 }, 1042 }) 1043 1044 1045 class GlobalStepWaiterHookTest(test.TestCase): 1046 1047 def test_not_wait_for_step_zero(self): 1048 with ops.Graph().as_default(): 1049 variables.get_or_create_global_step() 1050 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0) 1051 hook.begin() 1052 with session_lib.Session() as sess: 1053 # Before run should return without waiting gstep increment. 1054 hook.before_run( 1055 session_run_hook.SessionRunContext( 1056 original_args=None, session=sess)) 1057 1058 def test_wait_for_step(self): 1059 with ops.Graph().as_default(): 1060 gstep = variables.get_or_create_global_step() 1061 hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000) 1062 hook.begin() 1063 with session_lib.Session() as sess: 1064 sess.run(variables_lib.global_variables_initializer()) 1065 waiter = threading.Thread( 1066 target=hook.before_run, 1067 args=(session_run_hook.SessionRunContext( 1068 original_args=None, session=sess),)) 1069 waiter.daemon = True 1070 waiter.start() 1071 time.sleep(1.0) 1072 self.assertTrue(waiter.is_alive()) 1073 sess.run(state_ops.assign(gstep, 500)) 1074 time.sleep(1.0) 1075 self.assertTrue(waiter.is_alive()) 1076 sess.run(state_ops.assign(gstep, 1100)) 1077 time.sleep(1.2) 1078 self.assertFalse(waiter.is_alive()) 1079 1080 1081 class FinalOpsHookTest(test.TestCase): 1082 1083 def test_final_ops_is_scalar_tensor(self): 1084 with ops.Graph().as_default(): 1085 expected_value = 4 1086 final_ops = constant_op.constant(expected_value) 1087 1088 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1089 hook.begin() 1090 1091 with session_lib.Session() as session: 1092 hook.end(session) 1093 self.assertEqual(expected_value, 1094 hook.final_ops_values) 1095 1096 def test_final_ops_is_tensor(self): 1097 with ops.Graph().as_default(): 1098 expected_values = [1, 6, 3, 5, 2, 4] 1099 final_ops = constant_op.constant(expected_values) 1100 1101 hook = basic_session_run_hooks.FinalOpsHook(final_ops) 1102 hook.begin() 1103 1104 with session_lib.Session() as session: 1105 hook.end(session) 1106 self.assertListEqual(expected_values, 1107 hook.final_ops_values.tolist()) 1108 1109 def test_final_ops_with_dictionary(self): 1110 with ops.Graph().as_default(): 1111 expected_values = [4, -3] 1112 final_ops = array_ops.placeholder(dtype=dtypes.float32) 1113 final_ops_feed_dict = {final_ops: expected_values} 1114 1115 hook = basic_session_run_hooks.FinalOpsHook( 1116 final_ops, final_ops_feed_dict) 1117 hook.begin() 1118 1119 with session_lib.Session() as session: 1120 hook.end(session) 1121 self.assertListEqual(expected_values, 1122 hook.final_ops_values.tolist()) 1123 1124 1125 class ResourceSummarySaverHookTest(test.TestCase): 1126 1127 def setUp(self): 1128 test.TestCase.setUp(self) 1129 1130 self.log_dir = 'log/dir' 1131 self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir) 1132 1133 var = variable_scope.get_variable('var', initializer=0.0, use_resource=True) 1134 tensor = state_ops.assign_add(var, 1.0) 1135 self.summary_op = summary_lib.scalar('my_summary', tensor) 1136 1137 with variable_scope.variable_scope('foo', use_resource=True): 1138 variables.create_global_step() 1139 self.train_op = training_util._increment_global_step(1) 1140 1141 def test_save_steps(self): 1142 hook = basic_session_run_hooks.SummarySaverHook( 1143 save_steps=8, 1144 summary_writer=self.summary_writer, 1145 summary_op=self.summary_op) 1146 1147 with self.test_session() as sess: 1148 hook.begin() 1149 sess.run(variables_lib.global_variables_initializer()) 1150 mon_sess = monitored_session._HookedSession(sess, [hook]) 1151 for _ in range(30): 1152 mon_sess.run(self.train_op) 1153 hook.end(sess) 1154 1155 self.summary_writer.assert_summaries( 1156 test_case=self, 1157 expected_logdir=self.log_dir, 1158 expected_summaries={ 1159 1: { 1160 'my_summary': 1.0 1161 }, 1162 9: { 1163 'my_summary': 2.0 1164 }, 1165 17: { 1166 'my_summary': 3.0 1167 }, 1168 25: { 1169 'my_summary': 4.0 1170 }, 1171 }) 1172 1173 1174 class FeedFnHookTest(test.TestCase): 1175 1176 def test_feeding_placeholder(self): 1177 with ops.Graph().as_default(), session_lib.Session() as sess: 1178 x = array_ops.placeholder(dtype=dtypes.float32) 1179 y = x + 1 1180 hook = basic_session_run_hooks.FeedFnHook( 1181 feed_fn=lambda: {x: 1.0}) 1182 hook.begin() 1183 mon_sess = monitored_session._HookedSession(sess, [hook]) 1184 self.assertEqual(mon_sess.run(y), 2) 1185 1186 1187 class ProfilerHookTest(test.TestCase): 1188 1189 def setUp(self): 1190 super(ProfilerHookTest, self).setUp() 1191 self.output_dir = tempfile.mkdtemp() 1192 self.graph = ops.Graph() 1193 self.filepattern = os.path.join(self.output_dir, 'timeline-*.json') 1194 with self.graph.as_default(): 1195 self.global_step = variables.get_or_create_global_step() 1196 self.train_op = state_ops.assign_add(self.global_step, 1) 1197 1198 def tearDown(self): 1199 super(ProfilerHookTest, self).tearDown() 1200 shutil.rmtree(self.output_dir, ignore_errors=True) 1201 1202 def _count_timeline_files(self): 1203 return len(gfile.Glob(self.filepattern)) 1204 1205 def test_raise_in_both_secs_and_steps(self): 1206 with self.assertRaises(ValueError): 1207 basic_session_run_hooks.ProfilerHook(save_secs=10, save_steps=20) 1208 1209 def test_raise_in_none_secs_and_steps(self): 1210 with self.assertRaises(ValueError): 1211 basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None) 1212 1213 def test_save_secs_saves_in_first_step(self): 1214 with self.graph.as_default(): 1215 hook = basic_session_run_hooks.ProfilerHook( 1216 save_secs=2, output_dir=self.output_dir) 1217 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1218 sess.run(self.train_op) 1219 self.assertEqual(1, self._count_timeline_files()) 1220 1221 @test.mock.patch.object(time, 'time') 1222 def test_save_secs_saves_periodically(self, mock_time): 1223 # Pick a fixed start time. 1224 current_time = 1484863632.320497 1225 1226 with self.graph.as_default(): 1227 mock_time.return_value = current_time 1228 hook = basic_session_run_hooks.ProfilerHook( 1229 save_secs=2, output_dir=self.output_dir) 1230 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1231 sess.run(self.train_op) # Saved. 1232 self.assertEqual(1, self._count_timeline_files()) 1233 sess.run(self.train_op) # Not saved. 1234 self.assertEqual(1, self._count_timeline_files()) 1235 # Simulate 2.5 seconds of sleep. 1236 mock_time.return_value = current_time + 2.5 1237 sess.run(self.train_op) # Saved. 1238 1239 # Pretend some small amount of time has passed. 1240 mock_time.return_value = current_time + 0.1 1241 sess.run(self.train_op) # Not saved. 1242 # Edge test just before we should save the timeline. 1243 mock_time.return_value = current_time + 1.9 1244 sess.run(self.train_op) # Not saved. 1245 self.assertEqual(2, self._count_timeline_files()) 1246 1247 mock_time.return_value = current_time + 4.5 1248 sess.run(self.train_op) # Saved. 1249 self.assertEqual(3, self._count_timeline_files()) 1250 1251 def test_save_steps_saves_in_first_step(self): 1252 with self.graph.as_default(): 1253 hook = basic_session_run_hooks.ProfilerHook( 1254 save_secs=2, output_dir=self.output_dir) 1255 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1256 sess.run(self.train_op) # Saved. 1257 sess.run(self.train_op) # Not saved. 1258 self.assertEqual(1, self._count_timeline_files()) 1259 1260 def test_save_steps_saves_periodically(self): 1261 with self.graph.as_default(): 1262 hook = basic_session_run_hooks.ProfilerHook( 1263 save_steps=2, output_dir=self.output_dir) 1264 with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: 1265 self.assertEqual(0, self._count_timeline_files()) 1266 sess.run(self.train_op) # Saved. 1267 self.assertEqual(1, self._count_timeline_files()) 1268 sess.run(self.train_op) # Not saved. 1269 self.assertEqual(1, self._count_timeline_files()) 1270 sess.run(self.train_op) # Saved. 1271 self.assertEqual(2, self._count_timeline_files()) 1272 sess.run(self.train_op) # Not saved. 1273 self.assertEqual(2, self._count_timeline_files()) 1274 sess.run(self.train_op) # Saved. 1275 self.assertEqual(3, self._count_timeline_files()) 1276 1277 1278 if __name__ == '__main__': 1279 test.main() 1280