Home | History | Annotate | Download | only in training
      1 # pylint: disable=g-bad-file-header
      2 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 """Tests for basic_session_run_hooks."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import os.path
     23 import shutil
     24 import tempfile
     25 import threading
     26 import time
     27 
     28 from tensorflow.contrib.framework.python.framework import checkpoint_utils
     29 from tensorflow.contrib.framework.python.ops import variables
     30 from tensorflow.contrib.testing.python.framework import fake_summary_writer
     31 from tensorflow.python.client import session as session_lib
     32 from tensorflow.python.framework import constant_op
     33 from tensorflow.python.framework import dtypes
     34 from tensorflow.python.framework import meta_graph
     35 from tensorflow.python.framework import ops
     36 from tensorflow.python.ops import array_ops
     37 from tensorflow.python.ops import control_flow_ops
     38 from tensorflow.python.ops import state_ops
     39 from tensorflow.python.ops import variable_scope
     40 from tensorflow.python.ops import variables as variables_lib
     41 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     42 from tensorflow.python.platform import gfile
     43 from tensorflow.python.platform import test
     44 from tensorflow.python.platform import tf_logging
     45 from tensorflow.python.summary import summary as summary_lib
     46 from tensorflow.python.summary.writer import writer_cache
     47 from tensorflow.python.training import basic_session_run_hooks
     48 from tensorflow.python.training import monitored_session
     49 from tensorflow.python.training import session_run_hook
     50 from tensorflow.python.training import training_util
     51 
     52 
     53 class MockCheckpointSaverListener(
     54     basic_session_run_hooks.CheckpointSaverListener):
     55 
     56   def __init__(self):
     57     self.begin_count = 0
     58     self.before_save_count = 0
     59     self.after_save_count = 0
     60     self.end_count = 0
     61 
     62   def begin(self):
     63     self.begin_count += 1
     64 
     65   def before_save(self, session, global_step):
     66     self.before_save_count += 1
     67 
     68   def after_save(self, session, global_step):
     69     self.after_save_count += 1
     70 
     71   def end(self, session, global_step):
     72     self.end_count += 1
     73 
     74   def get_counts(self):
     75     return {
     76         'begin': self.begin_count,
     77         'before_save': self.before_save_count,
     78         'after_save': self.after_save_count,
     79         'end': self.end_count
     80     }
     81 
     82 
     83 class SecondOrStepTimerTest(test.TestCase):
     84 
     85   def test_raise_in_both_secs_and_steps(self):
     86     with self.assertRaises(ValueError):
     87       basic_session_run_hooks.SecondOrStepTimer(every_secs=2.0, every_steps=10)
     88 
     89   def test_raise_in_none_secs_and_steps(self):
     90     with self.assertRaises(ValueError):
     91       basic_session_run_hooks.SecondOrStepTimer()
     92 
     93   def test_every_secs(self):
     94     timer = basic_session_run_hooks.SecondOrStepTimer(every_secs=1.0)
     95     self.assertTrue(timer.should_trigger_for_step(1))
     96 
     97     timer.update_last_triggered_step(1)
     98     self.assertFalse(timer.should_trigger_for_step(1))
     99     self.assertFalse(timer.should_trigger_for_step(2))
    100 
    101     time.sleep(1.0)
    102     self.assertFalse(timer.should_trigger_for_step(1))
    103     self.assertTrue(timer.should_trigger_for_step(2))
    104 
    105   def test_every_steps(self):
    106     timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=3)
    107     self.assertTrue(timer.should_trigger_for_step(1))
    108 
    109     timer.update_last_triggered_step(1)
    110     self.assertFalse(timer.should_trigger_for_step(1))
    111     self.assertFalse(timer.should_trigger_for_step(2))
    112     self.assertFalse(timer.should_trigger_for_step(3))
    113     self.assertTrue(timer.should_trigger_for_step(4))
    114 
    115   def test_update_last_triggered_step(self):
    116     timer = basic_session_run_hooks.SecondOrStepTimer(every_steps=1)
    117 
    118     elapsed_secs, elapsed_steps = timer.update_last_triggered_step(1)
    119     self.assertEqual(None, elapsed_secs)
    120     self.assertEqual(None, elapsed_steps)
    121 
    122     elapsed_secs, elapsed_steps = timer.update_last_triggered_step(5)
    123     self.assertLess(0, elapsed_secs)
    124     self.assertEqual(4, elapsed_steps)
    125 
    126     elapsed_secs, elapsed_steps = timer.update_last_triggered_step(7)
    127     self.assertLess(0, elapsed_secs)
    128     self.assertEqual(2, elapsed_steps)
    129 
    130 
    131 class StopAtStepTest(test.TestCase):
    132 
    133   def test_raise_in_both_last_step_and_num_steps(self):
    134     with self.assertRaises(ValueError):
    135       basic_session_run_hooks.StopAtStepHook(num_steps=10, last_step=20)
    136 
    137   def test_stop_based_on_last_step(self):
    138     h = basic_session_run_hooks.StopAtStepHook(last_step=10)
    139     with ops.Graph().as_default():
    140       global_step = variables.get_or_create_global_step()
    141       no_op = control_flow_ops.no_op()
    142       h.begin()
    143       with session_lib.Session() as sess:
    144         mon_sess = monitored_session._HookedSession(sess, [h])
    145         sess.run(state_ops.assign(global_step, 5))
    146         h.after_create_session(sess, None)
    147         mon_sess.run(no_op)
    148         self.assertFalse(mon_sess.should_stop())
    149         sess.run(state_ops.assign(global_step, 9))
    150         mon_sess.run(no_op)
    151         self.assertFalse(mon_sess.should_stop())
    152         sess.run(state_ops.assign(global_step, 10))
    153         mon_sess.run(no_op)
    154         self.assertTrue(mon_sess.should_stop())
    155         sess.run(state_ops.assign(global_step, 11))
    156         mon_sess._should_stop = False
    157         mon_sess.run(no_op)
    158         self.assertTrue(mon_sess.should_stop())
    159 
    160   def test_stop_based_on_num_step(self):
    161     h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
    162 
    163     with ops.Graph().as_default():
    164       global_step = variables.get_or_create_global_step()
    165       no_op = control_flow_ops.no_op()
    166       h.begin()
    167       with session_lib.Session() as sess:
    168         mon_sess = monitored_session._HookedSession(sess, [h])
    169         sess.run(state_ops.assign(global_step, 5))
    170         h.after_create_session(sess, None)
    171         mon_sess.run(no_op)
    172         self.assertFalse(mon_sess.should_stop())
    173         sess.run(state_ops.assign(global_step, 13))
    174         mon_sess.run(no_op)
    175         self.assertFalse(mon_sess.should_stop())
    176         sess.run(state_ops.assign(global_step, 14))
    177         mon_sess.run(no_op)
    178         self.assertFalse(mon_sess.should_stop())
    179         sess.run(state_ops.assign(global_step, 15))
    180         mon_sess.run(no_op)
    181         self.assertTrue(mon_sess.should_stop())
    182         sess.run(state_ops.assign(global_step, 16))
    183         mon_sess._should_stop = False
    184         mon_sess.run(no_op)
    185         self.assertTrue(mon_sess.should_stop())
    186 
    187   def test_stop_based_with_multiple_steps(self):
    188     h = basic_session_run_hooks.StopAtStepHook(num_steps=10)
    189 
    190     with ops.Graph().as_default():
    191       global_step = variables.get_or_create_global_step()
    192       no_op = control_flow_ops.no_op()
    193       h.begin()
    194       with session_lib.Session() as sess:
    195         mon_sess = monitored_session._HookedSession(sess, [h])
    196         sess.run(state_ops.assign(global_step, 5))
    197         h.after_create_session(sess, None)
    198         mon_sess.run(no_op)
    199         self.assertFalse(mon_sess.should_stop())
    200         sess.run(state_ops.assign(global_step, 15))
    201         mon_sess.run(no_op)
    202         self.assertTrue(mon_sess.should_stop())
    203 
    204 
    205 class LoggingTensorHookTest(test.TestCase):
    206 
    207   def setUp(self):
    208     # Mock out logging calls so we can verify whether correct tensors are being
    209     # monitored.
    210     self._actual_log = tf_logging.info
    211     self.logged_message = None
    212 
    213     def mock_log(*args, **kwargs):
    214       self.logged_message = args
    215       self._actual_log(*args, **kwargs)
    216 
    217     tf_logging.info = mock_log
    218 
    219   def tearDown(self):
    220     tf_logging.info = self._actual_log
    221 
    222   def test_illegal_args(self):
    223     with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
    224       basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=0)
    225     with self.assertRaisesRegexp(ValueError, 'nvalid every_n_iter'):
    226       basic_session_run_hooks.LoggingTensorHook(tensors=['t'], every_n_iter=-10)
    227     with self.assertRaisesRegexp(ValueError, 'xactly one of'):
    228       basic_session_run_hooks.LoggingTensorHook(
    229           tensors=['t'], every_n_iter=5, every_n_secs=5)
    230     with self.assertRaisesRegexp(ValueError, 'xactly one of'):
    231       basic_session_run_hooks.LoggingTensorHook(tensors=['t'])
    232 
    233   def test_print_at_end_only(self):
    234     with ops.Graph().as_default(), session_lib.Session() as sess:
    235       t = constant_op.constant(42.0, name='foo')
    236       train_op = constant_op.constant(3)
    237       hook = basic_session_run_hooks.LoggingTensorHook(
    238           tensors=[t.name], at_end=True)
    239       hook.begin()
    240       mon_sess = monitored_session._HookedSession(sess, [hook])
    241       sess.run(variables_lib.global_variables_initializer())
    242       self.logged_message = ''
    243       for _ in range(3):
    244         mon_sess.run(train_op)
    245         # assertNotRegexpMatches is not supported by python 3.1 and later
    246         self.assertEqual(str(self.logged_message).find(t.name), -1)
    247 
    248       hook.end(sess)
    249       self.assertRegexpMatches(str(self.logged_message), t.name)
    250 
    251   def _validate_print_every_n_steps(self, sess, at_end):
    252     t = constant_op.constant(42.0, name='foo')
    253 
    254     train_op = constant_op.constant(3)
    255     hook = basic_session_run_hooks.LoggingTensorHook(
    256         tensors=[t.name], every_n_iter=10, at_end=at_end)
    257     hook.begin()
    258     mon_sess = monitored_session._HookedSession(sess, [hook])
    259     sess.run(variables_lib.global_variables_initializer())
    260     mon_sess.run(train_op)
    261     self.assertRegexpMatches(str(self.logged_message), t.name)
    262     for _ in range(3):
    263       self.logged_message = ''
    264       for _ in range(9):
    265         mon_sess.run(train_op)
    266         # assertNotRegexpMatches is not supported by python 3.1 and later
    267         self.assertEqual(str(self.logged_message).find(t.name), -1)
    268       mon_sess.run(train_op)
    269       self.assertRegexpMatches(str(self.logged_message), t.name)
    270 
    271     # Add additional run to verify proper reset when called multiple times.
    272     self.logged_message = ''
    273     mon_sess.run(train_op)
    274     # assertNotRegexpMatches is not supported by python 3.1 and later
    275     self.assertEqual(str(self.logged_message).find(t.name), -1)
    276 
    277     self.logged_message = ''
    278     hook.end(sess)
    279     if at_end:
    280       self.assertRegexpMatches(str(self.logged_message), t.name)
    281     else:
    282       # assertNotRegexpMatches is not supported by python 3.1 and later
    283       self.assertEqual(str(self.logged_message).find(t.name), -1)
    284 
    285   def test_print_every_n_steps(self):
    286     with ops.Graph().as_default(), session_lib.Session() as sess:
    287       self._validate_print_every_n_steps(sess, at_end=False)
    288       # Verify proper reset.
    289       self._validate_print_every_n_steps(sess, at_end=False)
    290 
    291   def test_print_every_n_steps_and_end(self):
    292     with ops.Graph().as_default(), session_lib.Session() as sess:
    293       self._validate_print_every_n_steps(sess, at_end=True)
    294       # Verify proper reset.
    295       self._validate_print_every_n_steps(sess, at_end=True)
    296 
    297   def test_print_first_step(self):
    298     # if it runs every iteration, first iteration has None duration.
    299     with ops.Graph().as_default(), session_lib.Session() as sess:
    300       t = constant_op.constant(42.0, name='foo')
    301       train_op = constant_op.constant(3)
    302       hook = basic_session_run_hooks.LoggingTensorHook(
    303           tensors={'foo': t}, every_n_iter=1)
    304       hook.begin()
    305       mon_sess = monitored_session._HookedSession(sess, [hook])
    306       sess.run(variables_lib.global_variables_initializer())
    307       mon_sess.run(train_op)
    308       self.assertRegexpMatches(str(self.logged_message), 'foo')
    309       # in first run, elapsed time is None.
    310       self.assertEqual(str(self.logged_message).find('sec'), -1)
    311 
    312   def _validate_print_every_n_secs(self, sess, at_end):
    313     t = constant_op.constant(42.0, name='foo')
    314     train_op = constant_op.constant(3)
    315 
    316     hook = basic_session_run_hooks.LoggingTensorHook(
    317         tensors=[t.name], every_n_secs=1.0, at_end=at_end)
    318     hook.begin()
    319     mon_sess = monitored_session._HookedSession(sess, [hook])
    320     sess.run(variables_lib.global_variables_initializer())
    321 
    322     mon_sess.run(train_op)
    323     self.assertRegexpMatches(str(self.logged_message), t.name)
    324 
    325     # assertNotRegexpMatches is not supported by python 3.1 and later
    326     self.logged_message = ''
    327     mon_sess.run(train_op)
    328     self.assertEqual(str(self.logged_message).find(t.name), -1)
    329     time.sleep(1.0)
    330 
    331     self.logged_message = ''
    332     mon_sess.run(train_op)
    333     self.assertRegexpMatches(str(self.logged_message), t.name)
    334 
    335     self.logged_message = ''
    336     hook.end(sess)
    337     if at_end:
    338       self.assertRegexpMatches(str(self.logged_message), t.name)
    339     else:
    340       # assertNotRegexpMatches is not supported by python 3.1 and later
    341       self.assertEqual(str(self.logged_message).find(t.name), -1)
    342 
    343   def test_print_every_n_secs(self):
    344     with ops.Graph().as_default(), session_lib.Session() as sess:
    345       self._validate_print_every_n_secs(sess, at_end=False)
    346       # Verify proper reset.
    347       self._validate_print_every_n_secs(sess, at_end=False)
    348 
    349   def test_print_every_n_secs_and_end(self):
    350     with ops.Graph().as_default(), session_lib.Session() as sess:
    351       self._validate_print_every_n_secs(sess, at_end=True)
    352       # Verify proper reset.
    353       self._validate_print_every_n_secs(sess, at_end=True)
    354 
    355   def test_print_formatter(self):
    356     with ops.Graph().as_default(), session_lib.Session() as sess:
    357       t = constant_op.constant(42.0, name='foo')
    358       train_op = constant_op.constant(3)
    359       hook = basic_session_run_hooks.LoggingTensorHook(
    360           tensors=[t.name], every_n_iter=10,
    361           formatter=lambda items: 'qqq=%s' % items[t.name])
    362       hook.begin()
    363       mon_sess = monitored_session._HookedSession(sess, [hook])
    364       sess.run(variables_lib.global_variables_initializer())
    365       mon_sess.run(train_op)
    366       self.assertEqual(self.logged_message[0], 'qqq=42.0')
    367 
    368 
    369 class CheckpointSaverHookTest(test.TestCase):
    370 
    371   def setUp(self):
    372     self.model_dir = tempfile.mkdtemp()
    373     self.graph = ops.Graph()
    374     with self.graph.as_default():
    375       self.scaffold = monitored_session.Scaffold()
    376       self.global_step = variables.get_or_create_global_step()
    377       self.train_op = training_util._increment_global_step(1)
    378 
    379   def tearDown(self):
    380     shutil.rmtree(self.model_dir, ignore_errors=True)
    381 
    382   def test_saves_when_saver_and_scaffold_both_missing(self):
    383     with self.graph.as_default():
    384       hook = basic_session_run_hooks.CheckpointSaverHook(
    385           self.model_dir, save_steps=1)
    386       hook.begin()
    387       self.scaffold.finalize()
    388       with session_lib.Session() as sess:
    389         sess.run(self.scaffold.init_op)
    390         mon_sess = monitored_session._HookedSession(sess, [hook])
    391         mon_sess.run(self.train_op)
    392         self.assertEqual(1,
    393                          checkpoint_utils.load_variable(self.model_dir,
    394                                                         self.global_step.name))
    395 
    396   def test_raise_when_saver_and_scaffold_both_present(self):
    397     with self.assertRaises(ValueError):
    398       basic_session_run_hooks.CheckpointSaverHook(
    399           self.model_dir, saver=self.scaffold.saver, scaffold=self.scaffold)
    400 
    401   def test_raise_in_both_secs_and_steps(self):
    402     with self.assertRaises(ValueError):
    403       basic_session_run_hooks.CheckpointSaverHook(
    404           self.model_dir, save_secs=10, save_steps=20)
    405 
    406   def test_raise_in_none_secs_and_steps(self):
    407     with self.assertRaises(ValueError):
    408       basic_session_run_hooks.CheckpointSaverHook(self.model_dir)
    409 
    410   def test_save_secs_saves_in_first_step(self):
    411     with self.graph.as_default():
    412       hook = basic_session_run_hooks.CheckpointSaverHook(
    413           self.model_dir, save_secs=2, scaffold=self.scaffold)
    414       hook.begin()
    415       self.scaffold.finalize()
    416       with session_lib.Session() as sess:
    417         sess.run(self.scaffold.init_op)
    418         mon_sess = monitored_session._HookedSession(sess, [hook])
    419         mon_sess.run(self.train_op)
    420         self.assertEqual(1,
    421                          checkpoint_utils.load_variable(self.model_dir,
    422                                                         self.global_step.name))
    423 
    424   def test_save_secs_calls_listeners_at_begin_and_end(self):
    425     with self.graph.as_default():
    426       listener = MockCheckpointSaverListener()
    427       hook = basic_session_run_hooks.CheckpointSaverHook(
    428           self.model_dir,
    429           save_secs=2,
    430           scaffold=self.scaffold,
    431           listeners=[listener])
    432       hook.begin()
    433       self.scaffold.finalize()
    434       with session_lib.Session() as sess:
    435         sess.run(self.scaffold.init_op)
    436         mon_sess = monitored_session._HookedSession(sess, [hook])
    437         mon_sess.run(self.train_op)  # hook runs here
    438         mon_sess.run(self.train_op)  # hook won't run here, so it does at end
    439         hook.end(sess)  # hook runs here
    440       self.assertEqual({
    441           'begin': 1,
    442           'before_save': 2,
    443           'after_save': 2,
    444           'end': 1
    445       }, listener.get_counts())
    446 
    447   def test_listener_with_monitored_session(self):
    448     with ops.Graph().as_default():
    449       scaffold = monitored_session.Scaffold()
    450       global_step = variables.get_or_create_global_step()
    451       train_op = training_util._increment_global_step(1)
    452       listener = MockCheckpointSaverListener()
    453       hook = basic_session_run_hooks.CheckpointSaverHook(
    454           self.model_dir,
    455           save_steps=1,
    456           scaffold=scaffold,
    457           listeners=[listener])
    458       with monitored_session.SingularMonitoredSession(
    459           hooks=[hook],
    460           scaffold=scaffold,
    461           checkpoint_dir=self.model_dir) as sess:
    462         sess.run(train_op)
    463         sess.run(train_op)
    464         global_step_val = sess.raw_session().run(global_step)
    465       listener_counts = listener.get_counts()
    466     self.assertEqual(2, global_step_val)
    467     self.assertEqual({
    468         'begin': 1,
    469         'before_save': 2,
    470         'after_save': 2,
    471         'end': 1
    472     }, listener_counts)
    473 
    474   def test_listener_with_default_saver(self):
    475     with ops.Graph().as_default():
    476       global_step = variables.get_or_create_global_step()
    477       train_op = training_util._increment_global_step(1)
    478       listener = MockCheckpointSaverListener()
    479       hook = basic_session_run_hooks.CheckpointSaverHook(
    480           self.model_dir,
    481           save_steps=1,
    482           listeners=[listener])
    483       with monitored_session.SingularMonitoredSession(
    484           hooks=[hook],
    485           checkpoint_dir=self.model_dir) as sess:
    486         sess.run(train_op)
    487         sess.run(train_op)
    488         global_step_val = sess.raw_session().run(global_step)
    489       listener_counts = listener.get_counts()
    490     self.assertEqual(2, global_step_val)
    491     self.assertEqual({
    492         'begin': 1,
    493         'before_save': 2,
    494         'after_save': 2,
    495         'end': 1
    496     }, listener_counts)
    497 
    498     with ops.Graph().as_default():
    499       global_step = variables.get_or_create_global_step()
    500       with monitored_session.SingularMonitoredSession(
    501           checkpoint_dir=self.model_dir) as sess2:
    502         global_step_saved_val = sess2.run(global_step)
    503     self.assertEqual(2, global_step_saved_val)
    504 
    505   def test_two_listeners_with_default_saver(self):
    506     with ops.Graph().as_default():
    507       global_step = variables.get_or_create_global_step()
    508       train_op = training_util._increment_global_step(1)
    509       listener1 = MockCheckpointSaverListener()
    510       listener2 = MockCheckpointSaverListener()
    511       hook = basic_session_run_hooks.CheckpointSaverHook(
    512           self.model_dir,
    513           save_steps=1,
    514           listeners=[listener1, listener2])
    515       with monitored_session.SingularMonitoredSession(
    516           hooks=[hook],
    517           checkpoint_dir=self.model_dir) as sess:
    518         sess.run(train_op)
    519         sess.run(train_op)
    520         global_step_val = sess.raw_session().run(global_step)
    521       listener1_counts = listener1.get_counts()
    522       listener2_counts = listener2.get_counts()
    523     self.assertEqual(2, global_step_val)
    524     self.assertEqual({
    525         'begin': 1,
    526         'before_save': 2,
    527         'after_save': 2,
    528         'end': 1
    529     }, listener1_counts)
    530     self.assertEqual(listener1_counts, listener2_counts)
    531 
    532     with ops.Graph().as_default():
    533       global_step = variables.get_or_create_global_step()
    534       with monitored_session.SingularMonitoredSession(
    535           checkpoint_dir=self.model_dir) as sess2:
    536         global_step_saved_val = sess2.run(global_step)
    537     self.assertEqual(2, global_step_saved_val)
    538 
    539   @test.mock.patch.object(time, 'time')
    540   def test_save_secs_saves_periodically(self, mock_time):
    541     # Let's have a realistic start time
    542     current_time = 1484695987.209386
    543 
    544     with self.graph.as_default():
    545       mock_time.return_value = current_time
    546       hook = basic_session_run_hooks.CheckpointSaverHook(
    547           self.model_dir, save_secs=2, scaffold=self.scaffold)
    548       hook.begin()
    549       self.scaffold.finalize()
    550 
    551       with session_lib.Session() as sess:
    552         sess.run(self.scaffold.init_op)
    553         mon_sess = monitored_session._HookedSession(sess, [hook])
    554 
    555         mock_time.return_value = current_time
    556         mon_sess.run(self.train_op)  # Saved.
    557 
    558         mock_time.return_value = current_time + 0.5
    559         mon_sess.run(self.train_op)  # Not saved.
    560 
    561         self.assertEqual(1,
    562                          checkpoint_utils.load_variable(self.model_dir,
    563                                                         self.global_step.name))
    564 
    565         # Simulate 2.5 seconds of sleep.
    566         mock_time.return_value = current_time + 2.5
    567         mon_sess.run(self.train_op)  # Saved.
    568 
    569         mock_time.return_value = current_time + 2.6
    570         mon_sess.run(self.train_op)  # Not saved.
    571 
    572         mock_time.return_value = current_time + 2.7
    573         mon_sess.run(self.train_op)  # Not saved.
    574 
    575         self.assertEqual(3,
    576                          checkpoint_utils.load_variable(self.model_dir,
    577                                                         self.global_step.name))
    578 
    579         # Simulate 7.5 more seconds of sleep (10 seconds from start.
    580         mock_time.return_value = current_time + 10
    581         mon_sess.run(self.train_op)  # Saved.
    582         self.assertEqual(6,
    583                          checkpoint_utils.load_variable(self.model_dir,
    584                                                         self.global_step.name))
    585 
    586   @test.mock.patch.object(time, 'time')
    587   def test_save_secs_calls_listeners_periodically(self, mock_time):
    588     # Let's have a realistic start time
    589     current_time = 1484695987.209386
    590 
    591     with self.graph.as_default():
    592       mock_time.return_value = current_time
    593       listener = MockCheckpointSaverListener()
    594       hook = basic_session_run_hooks.CheckpointSaverHook(
    595           self.model_dir,
    596           save_secs=2,
    597           scaffold=self.scaffold,
    598           listeners=[listener])
    599       hook.begin()
    600       self.scaffold.finalize()
    601       with session_lib.Session() as sess:
    602         sess.run(self.scaffold.init_op)
    603         mon_sess = monitored_session._HookedSession(sess, [hook])
    604 
    605         mock_time.return_value = current_time + 0.5
    606         mon_sess.run(self.train_op)  # hook runs here
    607 
    608         mock_time.return_value = current_time + 0.5
    609         mon_sess.run(self.train_op)
    610 
    611         mock_time.return_value = current_time + 3.0
    612         mon_sess.run(self.train_op)  # hook runs here
    613 
    614         mock_time.return_value = current_time + 3.5
    615         mon_sess.run(self.train_op)
    616 
    617         mock_time.return_value = current_time + 4.0
    618         mon_sess.run(self.train_op)
    619 
    620         mock_time.return_value = current_time + 6.5
    621         mon_sess.run(self.train_op)  # hook runs here
    622 
    623         mock_time.return_value = current_time + 7.0
    624         mon_sess.run(self.train_op)  # hook won't run here, so it does at end
    625 
    626         mock_time.return_value = current_time + 7.5
    627         hook.end(sess)  # hook runs here
    628       self.assertEqual({
    629           'begin': 1,
    630           'before_save': 4,
    631           'after_save': 4,
    632           'end': 1
    633       }, listener.get_counts())
    634 
    635   def test_save_steps_saves_in_first_step(self):
    636     with self.graph.as_default():
    637       hook = basic_session_run_hooks.CheckpointSaverHook(
    638           self.model_dir, save_steps=2, scaffold=self.scaffold)
    639       hook.begin()
    640       self.scaffold.finalize()
    641       with session_lib.Session() as sess:
    642         sess.run(self.scaffold.init_op)
    643         mon_sess = monitored_session._HookedSession(sess, [hook])
    644         mon_sess.run(self.train_op)
    645         self.assertEqual(1,
    646                          checkpoint_utils.load_variable(self.model_dir,
    647                                                         self.global_step.name))
    648 
    649   def test_save_steps_saves_periodically(self):
    650     with self.graph.as_default():
    651       hook = basic_session_run_hooks.CheckpointSaverHook(
    652           self.model_dir, save_steps=2, scaffold=self.scaffold)
    653       hook.begin()
    654       self.scaffold.finalize()
    655       with session_lib.Session() as sess:
    656         sess.run(self.scaffold.init_op)
    657         mon_sess = monitored_session._HookedSession(sess, [hook])
    658         mon_sess.run(self.train_op)
    659         mon_sess.run(self.train_op)
    660         # Not saved
    661         self.assertEqual(1,
    662                          checkpoint_utils.load_variable(self.model_dir,
    663                                                         self.global_step.name))
    664         mon_sess.run(self.train_op)
    665         # saved
    666         self.assertEqual(3,
    667                          checkpoint_utils.load_variable(self.model_dir,
    668                                                         self.global_step.name))
    669         mon_sess.run(self.train_op)
    670         # Not saved
    671         self.assertEqual(3,
    672                          checkpoint_utils.load_variable(self.model_dir,
    673                                                         self.global_step.name))
    674         mon_sess.run(self.train_op)
    675         # saved
    676         self.assertEqual(5,
    677                          checkpoint_utils.load_variable(self.model_dir,
    678                                                         self.global_step.name))
    679 
    680   def test_save_saves_at_end(self):
    681     with self.graph.as_default():
    682       hook = basic_session_run_hooks.CheckpointSaverHook(
    683           self.model_dir, save_secs=2, scaffold=self.scaffold)
    684       hook.begin()
    685       self.scaffold.finalize()
    686       with session_lib.Session() as sess:
    687         sess.run(self.scaffold.init_op)
    688         mon_sess = monitored_session._HookedSession(sess, [hook])
    689         mon_sess.run(self.train_op)
    690         mon_sess.run(self.train_op)
    691         hook.end(sess)
    692         self.assertEqual(2,
    693                          checkpoint_utils.load_variable(self.model_dir,
    694                                                         self.global_step.name))
    695 
    696   def test_summary_writer_defs(self):
    697     fake_summary_writer.FakeSummaryWriter.install()
    698     writer_cache.FileWriterCache.clear()
    699     summary_writer = writer_cache.FileWriterCache.get(self.model_dir)
    700 
    701     with self.graph.as_default():
    702       hook = basic_session_run_hooks.CheckpointSaverHook(
    703           self.model_dir, save_steps=2, scaffold=self.scaffold)
    704       hook.begin()
    705       self.scaffold.finalize()
    706       with session_lib.Session() as sess:
    707         sess.run(self.scaffold.init_op)
    708         mon_sess = monitored_session._HookedSession(sess, [hook])
    709         mon_sess.run(self.train_op)
    710       summary_writer.assert_summaries(
    711           test_case=self,
    712           expected_logdir=self.model_dir,
    713           expected_added_meta_graphs=[
    714               meta_graph.create_meta_graph_def(
    715                   graph_def=self.graph.as_graph_def(add_shapes=True),
    716                   saver_def=self.scaffold.saver.saver_def)
    717           ])
    718 
    719     fake_summary_writer.FakeSummaryWriter.uninstall()
    720 
    721 
    722 class ResourceCheckpointSaverHookTest(test.TestCase):
    723 
    724   def setUp(self):
    725     self.model_dir = tempfile.mkdtemp()
    726     self.graph = ops.Graph()
    727     with self.graph.as_default():
    728       self.scaffold = monitored_session.Scaffold()
    729       with variable_scope.variable_scope('foo', use_resource=True):
    730         self.global_step = training_util.get_or_create_global_step()
    731       self.train_op = training_util._increment_global_step(1)
    732 
    733   def test_save_steps_saves_periodically(self):
    734     with self.graph.as_default():
    735       hook = basic_session_run_hooks.CheckpointSaverHook(
    736           self.model_dir, save_steps=2, scaffold=self.scaffold)
    737       hook.begin()
    738       self.scaffold.finalize()
    739       with session_lib.Session() as sess:
    740         sess.run(self.scaffold.init_op)
    741         mon_sess = monitored_session._HookedSession(sess, [hook])
    742         mon_sess.run(self.train_op)
    743         mon_sess.run(self.train_op)
    744         # Not saved
    745         self.assertEqual(1,
    746                          checkpoint_utils.load_variable(self.model_dir,
    747                                                         self.global_step.name))
    748         mon_sess.run(self.train_op)
    749         # saved
    750         self.assertEqual(3,
    751                          checkpoint_utils.load_variable(self.model_dir,
    752                                                         self.global_step.name))
    753         mon_sess.run(self.train_op)
    754         # Not saved
    755         self.assertEqual(3,
    756                          checkpoint_utils.load_variable(self.model_dir,
    757                                                         self.global_step.name))
    758         mon_sess.run(self.train_op)
    759         # saved
    760         self.assertEqual(5,
    761                          checkpoint_utils.load_variable(self.model_dir,
    762                                                         self.global_step.name))
    763 
    764 
    765 class StepCounterHookTest(test.TestCase):
    766 
    767   def setUp(self):
    768     self.log_dir = tempfile.mkdtemp()
    769 
    770   def tearDown(self):
    771     shutil.rmtree(self.log_dir, ignore_errors=True)
    772 
    773   def test_step_counter_every_n_steps(self):
    774     with ops.Graph().as_default() as g, session_lib.Session() as sess:
    775       variables.get_or_create_global_step()
    776       train_op = training_util._increment_global_step(1)
    777       summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
    778       hook = basic_session_run_hooks.StepCounterHook(
    779           summary_writer=summary_writer, every_n_steps=10)
    780       hook.begin()
    781       sess.run(variables_lib.global_variables_initializer())
    782       mon_sess = monitored_session._HookedSession(sess, [hook])
    783       with test.mock.patch.object(tf_logging, 'warning') as mock_log:
    784         for _ in range(30):
    785           time.sleep(0.01)
    786           mon_sess.run(train_op)
    787         # logging.warning should not be called.
    788         self.assertIsNone(mock_log.call_args)
    789       hook.end(sess)
    790       summary_writer.assert_summaries(
    791           test_case=self,
    792           expected_logdir=self.log_dir,
    793           expected_graph=g,
    794           expected_summaries={})
    795       self.assertItemsEqual([11, 21], summary_writer.summaries.keys())
    796       for step in [11, 21]:
    797         summary_value = summary_writer.summaries[step][0].value[0]
    798         self.assertEqual('global_step/sec', summary_value.tag)
    799         self.assertGreater(summary_value.simple_value, 0)
    800 
    801   def test_step_counter_every_n_secs(self):
    802     with ops.Graph().as_default() as g, session_lib.Session() as sess:
    803       variables.get_or_create_global_step()
    804       train_op = training_util._increment_global_step(1)
    805       summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
    806       hook = basic_session_run_hooks.StepCounterHook(
    807           summary_writer=summary_writer, every_n_steps=None, every_n_secs=0.1)
    808 
    809       hook.begin()
    810       sess.run(variables_lib.global_variables_initializer())
    811       mon_sess = monitored_session._HookedSession(sess, [hook])
    812       mon_sess.run(train_op)
    813       time.sleep(0.2)
    814       mon_sess.run(train_op)
    815       time.sleep(0.2)
    816       mon_sess.run(train_op)
    817       hook.end(sess)
    818 
    819       summary_writer.assert_summaries(
    820           test_case=self,
    821           expected_logdir=self.log_dir,
    822           expected_graph=g,
    823           expected_summaries={})
    824       self.assertTrue(summary_writer.summaries, 'No summaries were created.')
    825       self.assertItemsEqual([2, 3], summary_writer.summaries.keys())
    826       for summary in summary_writer.summaries.values():
    827         summary_value = summary[0].value[0]
    828         self.assertEqual('global_step/sec', summary_value.tag)
    829         self.assertGreater(summary_value.simple_value, 0)
    830 
    831   def test_global_step_name(self):
    832     with ops.Graph().as_default() as g, session_lib.Session() as sess:
    833       with variable_scope.variable_scope('bar'):
    834         variable_scope.get_variable(
    835             'foo',
    836             initializer=0,
    837             trainable=False,
    838             collections=[
    839                 ops.GraphKeys.GLOBAL_STEP, ops.GraphKeys.GLOBAL_VARIABLES
    840             ])
    841       train_op = training_util._increment_global_step(1)
    842       summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir, g)
    843       hook = basic_session_run_hooks.StepCounterHook(
    844           summary_writer=summary_writer, every_n_steps=1, every_n_secs=None)
    845 
    846       hook.begin()
    847       sess.run(variables_lib.global_variables_initializer())
    848       mon_sess = monitored_session._HookedSession(sess, [hook])
    849       mon_sess.run(train_op)
    850       mon_sess.run(train_op)
    851       hook.end(sess)
    852 
    853       summary_writer.assert_summaries(
    854           test_case=self,
    855           expected_logdir=self.log_dir,
    856           expected_graph=g,
    857           expected_summaries={})
    858       self.assertTrue(summary_writer.summaries, 'No summaries were created.')
    859       self.assertItemsEqual([2], summary_writer.summaries.keys())
    860       summary_value = summary_writer.summaries[2][0].value[0]
    861       self.assertEqual('bar/foo/sec', summary_value.tag)
    862 
    863   def test_log_warning_if_global_step_not_increased(self):
    864     with ops.Graph().as_default(), session_lib.Session() as sess:
    865       variables.get_or_create_global_step()
    866       train_op = training_util._increment_global_step(0)  # keep same.
    867       sess.run(variables_lib.global_variables_initializer())
    868       hook = basic_session_run_hooks.StepCounterHook(
    869           every_n_steps=1, every_n_secs=None)
    870       hook.begin()
    871       mon_sess = monitored_session._HookedSession(sess, [hook])
    872       mon_sess.run(train_op)  # Run one step to record global step.
    873       with test.mock.patch.object(tf_logging, 'warning') as mock_log:
    874         for _ in range(30):
    875           mon_sess.run(train_op)
    876         self.assertRegexpMatches(
    877             str(mock_log.call_args),
    878             'global step.*has not been increased')
    879       hook.end(sess)
    880 
    881 
    882 class SummarySaverHookTest(test.TestCase):
    883 
    884   def setUp(self):
    885     test.TestCase.setUp(self)
    886 
    887     self.log_dir = 'log/dir'
    888     self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
    889 
    890     var = variables_lib.Variable(0.0)
    891     tensor = state_ops.assign_add(var, 1.0)
    892     tensor2 = tensor * 2
    893     self.summary_op = summary_lib.scalar('my_summary', tensor)
    894     self.summary_op2 = summary_lib.scalar('my_summary2', tensor2)
    895 
    896     variables.get_or_create_global_step()
    897     self.train_op = training_util._increment_global_step(1)
    898 
    899   def test_raise_when_scaffold_and_summary_op_both_missing(self):
    900     with self.assertRaises(ValueError):
    901       basic_session_run_hooks.SummarySaverHook()
    902 
    903   def test_raise_when_scaffold_and_summary_op_both_present(self):
    904     with self.assertRaises(ValueError):
    905       basic_session_run_hooks.SummarySaverHook(
    906           scaffold=monitored_session.Scaffold(), summary_op=self.summary_op)
    907 
    908   def test_raise_in_both_secs_and_steps(self):
    909     with self.assertRaises(ValueError):
    910       basic_session_run_hooks.SummarySaverHook(
    911           save_secs=10, save_steps=20, summary_writer=self.summary_writer)
    912 
    913   def test_raise_in_none_secs_and_steps(self):
    914     with self.assertRaises(ValueError):
    915       basic_session_run_hooks.SummarySaverHook(
    916           save_secs=None, save_steps=None, summary_writer=self.summary_writer)
    917 
    918   def test_save_steps(self):
    919     hook = basic_session_run_hooks.SummarySaverHook(
    920         save_steps=8,
    921         summary_writer=self.summary_writer,
    922         summary_op=self.summary_op)
    923 
    924     with self.test_session() as sess:
    925       hook.begin()
    926       sess.run(variables_lib.global_variables_initializer())
    927       mon_sess = monitored_session._HookedSession(sess, [hook])
    928       for _ in range(30):
    929         mon_sess.run(self.train_op)
    930       hook.end(sess)
    931 
    932     self.summary_writer.assert_summaries(
    933         test_case=self,
    934         expected_logdir=self.log_dir,
    935         expected_summaries={
    936             1: {
    937                 'my_summary': 1.0
    938             },
    939             9: {
    940                 'my_summary': 2.0
    941             },
    942             17: {
    943                 'my_summary': 3.0
    944             },
    945             25: {
    946                 'my_summary': 4.0
    947             },
    948         })
    949 
    950   def test_multiple_summaries(self):
    951     hook = basic_session_run_hooks.SummarySaverHook(
    952         save_steps=8,
    953         summary_writer=self.summary_writer,
    954         summary_op=[self.summary_op, self.summary_op2])
    955 
    956     with self.test_session() as sess:
    957       hook.begin()
    958       sess.run(variables_lib.global_variables_initializer())
    959       mon_sess = monitored_session._HookedSession(sess, [hook])
    960       for _ in range(10):
    961         mon_sess.run(self.train_op)
    962       hook.end(sess)
    963 
    964     self.summary_writer.assert_summaries(
    965         test_case=self,
    966         expected_logdir=self.log_dir,
    967         expected_summaries={
    968             1: {
    969                 'my_summary': 1.0,
    970                 'my_summary2': 2.0
    971             },
    972             9: {
    973                 'my_summary': 2.0,
    974                 'my_summary2': 4.0
    975             },
    976         })
    977 
    978   def test_save_secs_saving_once_every_step(self):
    979     hook = basic_session_run_hooks.SummarySaverHook(
    980         save_secs=0.5,
    981         summary_writer=self.summary_writer,
    982         summary_op=self.summary_op)
    983 
    984     with self.test_session() as sess:
    985       hook.begin()
    986       sess.run(variables_lib.global_variables_initializer())
    987       mon_sess = monitored_session._HookedSession(sess, [hook])
    988       for _ in range(4):
    989         mon_sess.run(self.train_op)
    990         time.sleep(0.5)
    991       hook.end(sess)
    992 
    993     self.summary_writer.assert_summaries(
    994         test_case=self,
    995         expected_logdir=self.log_dir,
    996         expected_summaries={
    997             1: {
    998                 'my_summary': 1.0
    999             },
   1000             2: {
   1001                 'my_summary': 2.0
   1002             },
   1003             3: {
   1004                 'my_summary': 3.0
   1005             },
   1006             4: {
   1007                 'my_summary': 4.0
   1008             },
   1009         })
   1010 
   1011   @test.mock.patch.object(time, 'time')
   1012   def test_save_secs_saving_once_every_three_steps(self, mock_time):
   1013     mock_time.return_value = 1484695987.209386
   1014     hook = basic_session_run_hooks.SummarySaverHook(
   1015         save_secs=9.,
   1016         summary_writer=self.summary_writer,
   1017         summary_op=self.summary_op)
   1018 
   1019     with self.test_session() as sess:
   1020       hook.begin()
   1021       sess.run(variables_lib.global_variables_initializer())
   1022       mon_sess = monitored_session._HookedSession(sess, [hook])
   1023       for _ in range(8):
   1024         mon_sess.run(self.train_op)
   1025         mock_time.return_value += 3.1
   1026       hook.end(sess)
   1027 
   1028     # 24.8 seconds passed (3.1*8), it saves every 9 seconds starting from first:
   1029     self.summary_writer.assert_summaries(
   1030         test_case=self,
   1031         expected_logdir=self.log_dir,
   1032         expected_summaries={
   1033             1: {
   1034                 'my_summary': 1.0
   1035             },
   1036             4: {
   1037                 'my_summary': 2.0
   1038             },
   1039             7: {
   1040                 'my_summary': 3.0
   1041             },
   1042         })
   1043 
   1044 
   1045 class GlobalStepWaiterHookTest(test.TestCase):
   1046 
   1047   def test_not_wait_for_step_zero(self):
   1048     with ops.Graph().as_default():
   1049       variables.get_or_create_global_step()
   1050       hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=0)
   1051       hook.begin()
   1052       with session_lib.Session() as sess:
   1053         # Before run should return without waiting gstep increment.
   1054         hook.before_run(
   1055             session_run_hook.SessionRunContext(
   1056                 original_args=None, session=sess))
   1057 
   1058   def test_wait_for_step(self):
   1059     with ops.Graph().as_default():
   1060       gstep = variables.get_or_create_global_step()
   1061       hook = basic_session_run_hooks.GlobalStepWaiterHook(wait_until_step=1000)
   1062       hook.begin()
   1063       with session_lib.Session() as sess:
   1064         sess.run(variables_lib.global_variables_initializer())
   1065         waiter = threading.Thread(
   1066             target=hook.before_run,
   1067             args=(session_run_hook.SessionRunContext(
   1068                 original_args=None, session=sess),))
   1069         waiter.daemon = True
   1070         waiter.start()
   1071         time.sleep(1.0)
   1072         self.assertTrue(waiter.is_alive())
   1073         sess.run(state_ops.assign(gstep, 500))
   1074         time.sleep(1.0)
   1075         self.assertTrue(waiter.is_alive())
   1076         sess.run(state_ops.assign(gstep, 1100))
   1077         time.sleep(1.2)
   1078         self.assertFalse(waiter.is_alive())
   1079 
   1080 
   1081 class FinalOpsHookTest(test.TestCase):
   1082 
   1083   def test_final_ops_is_scalar_tensor(self):
   1084     with ops.Graph().as_default():
   1085       expected_value = 4
   1086       final_ops = constant_op.constant(expected_value)
   1087 
   1088       hook = basic_session_run_hooks.FinalOpsHook(final_ops)
   1089       hook.begin()
   1090 
   1091       with session_lib.Session() as session:
   1092         hook.end(session)
   1093         self.assertEqual(expected_value,
   1094                          hook.final_ops_values)
   1095 
   1096   def test_final_ops_is_tensor(self):
   1097     with ops.Graph().as_default():
   1098       expected_values = [1, 6, 3, 5, 2, 4]
   1099       final_ops = constant_op.constant(expected_values)
   1100 
   1101       hook = basic_session_run_hooks.FinalOpsHook(final_ops)
   1102       hook.begin()
   1103 
   1104       with session_lib.Session() as session:
   1105         hook.end(session)
   1106         self.assertListEqual(expected_values,
   1107                              hook.final_ops_values.tolist())
   1108 
   1109   def test_final_ops_with_dictionary(self):
   1110     with ops.Graph().as_default():
   1111       expected_values = [4, -3]
   1112       final_ops = array_ops.placeholder(dtype=dtypes.float32)
   1113       final_ops_feed_dict = {final_ops: expected_values}
   1114 
   1115       hook = basic_session_run_hooks.FinalOpsHook(
   1116           final_ops, final_ops_feed_dict)
   1117       hook.begin()
   1118 
   1119       with session_lib.Session() as session:
   1120         hook.end(session)
   1121         self.assertListEqual(expected_values,
   1122                              hook.final_ops_values.tolist())
   1123 
   1124 
   1125 class ResourceSummarySaverHookTest(test.TestCase):
   1126 
   1127   def setUp(self):
   1128     test.TestCase.setUp(self)
   1129 
   1130     self.log_dir = 'log/dir'
   1131     self.summary_writer = fake_summary_writer.FakeSummaryWriter(self.log_dir)
   1132 
   1133     var = variable_scope.get_variable('var', initializer=0.0, use_resource=True)
   1134     tensor = state_ops.assign_add(var, 1.0)
   1135     self.summary_op = summary_lib.scalar('my_summary', tensor)
   1136 
   1137     with variable_scope.variable_scope('foo', use_resource=True):
   1138       variables.create_global_step()
   1139     self.train_op = training_util._increment_global_step(1)
   1140 
   1141   def test_save_steps(self):
   1142     hook = basic_session_run_hooks.SummarySaverHook(
   1143         save_steps=8,
   1144         summary_writer=self.summary_writer,
   1145         summary_op=self.summary_op)
   1146 
   1147     with self.test_session() as sess:
   1148       hook.begin()
   1149       sess.run(variables_lib.global_variables_initializer())
   1150       mon_sess = monitored_session._HookedSession(sess, [hook])
   1151       for _ in range(30):
   1152         mon_sess.run(self.train_op)
   1153       hook.end(sess)
   1154 
   1155     self.summary_writer.assert_summaries(
   1156         test_case=self,
   1157         expected_logdir=self.log_dir,
   1158         expected_summaries={
   1159             1: {
   1160                 'my_summary': 1.0
   1161             },
   1162             9: {
   1163                 'my_summary': 2.0
   1164             },
   1165             17: {
   1166                 'my_summary': 3.0
   1167             },
   1168             25: {
   1169                 'my_summary': 4.0
   1170             },
   1171         })
   1172 
   1173 
   1174 class FeedFnHookTest(test.TestCase):
   1175 
   1176   def test_feeding_placeholder(self):
   1177     with ops.Graph().as_default(), session_lib.Session() as sess:
   1178       x = array_ops.placeholder(dtype=dtypes.float32)
   1179       y = x + 1
   1180       hook = basic_session_run_hooks.FeedFnHook(
   1181           feed_fn=lambda: {x: 1.0})
   1182       hook.begin()
   1183       mon_sess = monitored_session._HookedSession(sess, [hook])
   1184       self.assertEqual(mon_sess.run(y), 2)
   1185 
   1186 
   1187 class ProfilerHookTest(test.TestCase):
   1188 
   1189   def setUp(self):
   1190     super(ProfilerHookTest, self).setUp()
   1191     self.output_dir = tempfile.mkdtemp()
   1192     self.graph = ops.Graph()
   1193     self.filepattern = os.path.join(self.output_dir, 'timeline-*.json')
   1194     with self.graph.as_default():
   1195       self.global_step = variables.get_or_create_global_step()
   1196       self.train_op = state_ops.assign_add(self.global_step, 1)
   1197 
   1198   def tearDown(self):
   1199     super(ProfilerHookTest, self).tearDown()
   1200     shutil.rmtree(self.output_dir, ignore_errors=True)
   1201 
   1202   def _count_timeline_files(self):
   1203     return len(gfile.Glob(self.filepattern))
   1204 
   1205   def test_raise_in_both_secs_and_steps(self):
   1206     with self.assertRaises(ValueError):
   1207       basic_session_run_hooks.ProfilerHook(save_secs=10, save_steps=20)
   1208 
   1209   def test_raise_in_none_secs_and_steps(self):
   1210     with self.assertRaises(ValueError):
   1211       basic_session_run_hooks.ProfilerHook(save_secs=None, save_steps=None)
   1212 
   1213   def test_save_secs_saves_in_first_step(self):
   1214     with self.graph.as_default():
   1215       hook = basic_session_run_hooks.ProfilerHook(
   1216           save_secs=2, output_dir=self.output_dir)
   1217       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
   1218         sess.run(self.train_op)
   1219         self.assertEqual(1, self._count_timeline_files())
   1220 
   1221   @test.mock.patch.object(time, 'time')
   1222   def test_save_secs_saves_periodically(self, mock_time):
   1223     # Pick a fixed start time.
   1224     current_time = 1484863632.320497
   1225 
   1226     with self.graph.as_default():
   1227       mock_time.return_value = current_time
   1228       hook = basic_session_run_hooks.ProfilerHook(
   1229           save_secs=2, output_dir=self.output_dir)
   1230       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
   1231         sess.run(self.train_op)  # Saved.
   1232         self.assertEqual(1, self._count_timeline_files())
   1233         sess.run(self.train_op)  # Not saved.
   1234         self.assertEqual(1, self._count_timeline_files())
   1235         # Simulate 2.5 seconds of sleep.
   1236         mock_time.return_value = current_time + 2.5
   1237         sess.run(self.train_op)  # Saved.
   1238 
   1239         # Pretend some small amount of time has passed.
   1240         mock_time.return_value = current_time + 0.1
   1241         sess.run(self.train_op)  # Not saved.
   1242         # Edge test just before we should save the timeline.
   1243         mock_time.return_value = current_time + 1.9
   1244         sess.run(self.train_op)  # Not saved.
   1245         self.assertEqual(2, self._count_timeline_files())
   1246 
   1247         mock_time.return_value = current_time + 4.5
   1248         sess.run(self.train_op)  # Saved.
   1249         self.assertEqual(3, self._count_timeline_files())
   1250 
   1251   def test_save_steps_saves_in_first_step(self):
   1252     with self.graph.as_default():
   1253       hook = basic_session_run_hooks.ProfilerHook(
   1254           save_secs=2, output_dir=self.output_dir)
   1255       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
   1256         sess.run(self.train_op)  # Saved.
   1257         sess.run(self.train_op)  # Not saved.
   1258         self.assertEqual(1, self._count_timeline_files())
   1259 
   1260   def test_save_steps_saves_periodically(self):
   1261     with self.graph.as_default():
   1262       hook = basic_session_run_hooks.ProfilerHook(
   1263           save_steps=2, output_dir=self.output_dir)
   1264       with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess:
   1265         self.assertEqual(0, self._count_timeline_files())
   1266         sess.run(self.train_op)  # Saved.
   1267         self.assertEqual(1, self._count_timeline_files())
   1268         sess.run(self.train_op)  # Not saved.
   1269         self.assertEqual(1, self._count_timeline_files())
   1270         sess.run(self.train_op)  # Saved.
   1271         self.assertEqual(2, self._count_timeline_files())
   1272         sess.run(self.train_op)  # Not saved.
   1273         self.assertEqual(2, self._count_timeline_files())
   1274         sess.run(self.train_op)  # Saved.
   1275         self.assertEqual(3, self._count_timeline_files())
   1276 
   1277 
   1278 if __name__ == '__main__':
   1279   test.main()
   1280