Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for supervisor.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import glob
     22 import os
     23 import shutil
     24 import time
     25 import uuid
     26 
     27 from six.moves import xrange  # pylint: disable=redefined-builtin
     28 
     29 from tensorflow.core.framework import graph_pb2
     30 from tensorflow.core.protobuf import config_pb2
     31 from tensorflow.core.protobuf import meta_graph_pb2
     32 from tensorflow.core.util import event_pb2
     33 from tensorflow.python.framework import constant_op
     34 from tensorflow.python.framework import dtypes
     35 from tensorflow.python.framework import errors_impl
     36 from tensorflow.python.framework import meta_graph
     37 from tensorflow.python.framework import ops
     38 from tensorflow.python.ops import array_ops
     39 from tensorflow.python.ops import io_ops
     40 from tensorflow.python.ops import parsing_ops
     41 from tensorflow.python.ops import variables
     42 from tensorflow.python.platform import gfile
     43 from tensorflow.python.platform import test
     44 from tensorflow.python.summary import summary
     45 from tensorflow.python.summary import summary_iterator
     46 from tensorflow.python.summary.writer import writer
     47 from tensorflow.python.training import input as input_lib
     48 from tensorflow.python.training import saver as saver_lib
     49 from tensorflow.python.training import server_lib
     50 from tensorflow.python.training import session_manager as session_manager_lib
     51 from tensorflow.python.training import supervisor
     52 
     53 
     54 def _summary_iterator(test_dir):
     55   """Reads events from test_dir/events.
     56 
     57   Args:
     58     test_dir: Name of the test directory.
     59 
     60   Returns:
     61     A summary_iterator
     62   """
     63   event_paths = sorted(glob.glob(os.path.join(test_dir, "event*")))
     64   return summary_iterator.summary_iterator(event_paths[-1])
     65 
     66 
     67 class SupervisorTest(test.TestCase):
     68 
     69   def _test_dir(self, test_name):
     70     test_dir = os.path.join(self.get_temp_dir(), test_name)
     71     if os.path.exists(test_dir):
     72       shutil.rmtree(test_dir)
     73     return test_dir
     74 
     75   def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
     76     """Wait for a checkpoint file to appear.
     77 
     78     Args:
     79       pattern: A string.
     80       timeout_secs: How long to wait for in seconds.
     81       for_checkpoint: whether we're globbing for checkpoints.
     82     """
     83     end_time = time.time() + timeout_secs
     84     while time.time() < end_time:
     85       if for_checkpoint:
     86         if saver_lib.checkpoint_exists(pattern):
     87           return
     88       else:
     89         if len(gfile.Glob(pattern)) >= 1:
     90           return
     91       time.sleep(0.05)
     92     self.assertFalse(True, "Glob never matched any file: %s" % pattern)
     93 
     94   # This test does not test much.
     95   def testBasics(self):
     96     logdir = self._test_dir("basics")
     97     with ops.Graph().as_default():
     98       my_op = constant_op.constant(1.0)
     99       sv = supervisor.Supervisor(logdir=logdir)
    100       sess = sv.prepare_or_wait_for_session("")
    101       for _ in xrange(10):
    102         sess.run(my_op)
    103       sess.close()
    104       sv.stop()
    105 
    106   def testManagedSession(self):
    107     logdir = self._test_dir("managed_session")
    108     with ops.Graph().as_default():
    109       my_op = constant_op.constant(1.0)
    110       sv = supervisor.Supervisor(logdir=logdir)
    111       with sv.managed_session("") as sess:
    112         for _ in xrange(10):
    113           sess.run(my_op)
    114       # Supervisor has been stopped.
    115       self.assertTrue(sv.should_stop())
    116 
    117   def testManagedSessionUserError(self):
    118     logdir = self._test_dir("managed_user_error")
    119     with ops.Graph().as_default():
    120       my_op = constant_op.constant(1.0)
    121       sv = supervisor.Supervisor(logdir=logdir)
    122       last_step = None
    123       with self.assertRaisesRegexp(RuntimeError, "failing here"):
    124         with sv.managed_session("") as sess:
    125           for step in xrange(10):
    126             last_step = step
    127             if step == 1:
    128               raise RuntimeError("failing here")
    129             else:
    130               sess.run(my_op)
    131       # Supervisor has been stopped.
    132       self.assertTrue(sv.should_stop())
    133       self.assertEqual(1, last_step)
    134 
    135   def testManagedSessionIgnoreOutOfRangeError(self):
    136     logdir = self._test_dir("managed_out_of_range")
    137     with ops.Graph().as_default():
    138       my_op = constant_op.constant(1.0)
    139       sv = supervisor.Supervisor(logdir=logdir)
    140       last_step = None
    141       with sv.managed_session("") as sess:
    142         for step in xrange(10):
    143           last_step = step
    144           if step == 3:
    145             raise errors_impl.OutOfRangeError(my_op.op.node_def, my_op.op,
    146                                               "all done")
    147           else:
    148             sess.run(my_op)
    149       # Supervisor has been stopped.  OutOfRangeError was not thrown.
    150       self.assertTrue(sv.should_stop())
    151       self.assertEqual(3, last_step)
    152 
    153   def testManagedSessionDoNotKeepSummaryWriter(self):
    154     logdir = self._test_dir("managed_not_keep_summary_writer")
    155     with ops.Graph().as_default():
    156       summary.scalar("c1", constant_op.constant(1))
    157       summary.scalar("c2", constant_op.constant(2))
    158       summary.scalar("c3", constant_op.constant(3))
    159       summ = summary.merge_all()
    160       sv = supervisor.Supervisor(logdir=logdir, summary_op=None)
    161       with sv.managed_session(
    162           "", close_summary_writer=True, start_standard_services=False) as sess:
    163         sv.summary_computed(sess, sess.run(summ))
    164       # Sleep 1.2s to make sure that the next event file has a different name
    165       # than the current one.
    166       time.sleep(1.2)
    167       with sv.managed_session(
    168           "", close_summary_writer=True, start_standard_services=False) as sess:
    169         sv.summary_computed(sess, sess.run(summ))
    170     event_paths = sorted(glob.glob(os.path.join(logdir, "event*")))
    171     self.assertEquals(2, len(event_paths))
    172     # The two event files should have the same contents.
    173     for path in event_paths:
    174       # The summary iterator should report the summary once as we closed the
    175       # summary writer across the 2 sessions.
    176       rr = summary_iterator.summary_iterator(path)
    177       # The first event should list the file_version.
    178       ev = next(rr)
    179       self.assertEquals("brain.Event:2", ev.file_version)
    180 
    181       # The next one has the graph and metagraph.
    182       ev = next(rr)
    183       self.assertTrue(ev.graph_def)
    184 
    185       ev = next(rr)
    186       self.assertTrue(ev.meta_graph_def)
    187 
    188       # The next one should have the values from the summary.
    189       # But only once.
    190       ev = next(rr)
    191       self.assertProtoEquals("""
    192         value { tag: 'c1' simple_value: 1.0 }
    193         value { tag: 'c2' simple_value: 2.0 }
    194         value { tag: 'c3' simple_value: 3.0 }
    195         """, ev.summary)
    196 
    197       # The next one should be a stop message if we closed cleanly.
    198       ev = next(rr)
    199       self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
    200 
    201       # We should be done.
    202       with self.assertRaises(StopIteration):
    203         next(rr)
    204 
    205   def testManagedSessionKeepSummaryWriter(self):
    206     logdir = self._test_dir("managed_keep_summary_writer")
    207     with ops.Graph().as_default():
    208       summary.scalar("c1", constant_op.constant(1))
    209       summary.scalar("c2", constant_op.constant(2))
    210       summary.scalar("c3", constant_op.constant(3))
    211       summ = summary.merge_all()
    212       sv = supervisor.Supervisor(logdir=logdir)
    213       with sv.managed_session(
    214           "", close_summary_writer=False,
    215           start_standard_services=False) as sess:
    216         sv.summary_computed(sess, sess.run(summ))
    217       with sv.managed_session(
    218           "", close_summary_writer=False,
    219           start_standard_services=False) as sess:
    220         sv.summary_computed(sess, sess.run(summ))
    221     # Now close the summary writer to flush the events.
    222     sv.summary_writer.close()
    223     # The summary iterator should report the summary twice as we reused
    224     # the same summary writer across the 2 sessions.
    225     rr = _summary_iterator(logdir)
    226     # The first event should list the file_version.
    227     ev = next(rr)
    228     self.assertEquals("brain.Event:2", ev.file_version)
    229 
    230     # The next one has the graph.
    231     ev = next(rr)
    232     self.assertTrue(ev.graph_def)
    233 
    234     ev = next(rr)
    235     self.assertTrue(ev.meta_graph_def)
    236 
    237     # The next one should have the values from the summary.
    238     ev = next(rr)
    239     self.assertProtoEquals("""
    240       value { tag: 'c1' simple_value: 1.0 }
    241       value { tag: 'c2' simple_value: 2.0 }
    242       value { tag: 'c3' simple_value: 3.0 }
    243       """, ev.summary)
    244 
    245     # The next one should also have the values from the summary.
    246     ev = next(rr)
    247     self.assertProtoEquals("""
    248       value { tag: 'c1' simple_value: 1.0 }
    249       value { tag: 'c2' simple_value: 2.0 }
    250       value { tag: 'c3' simple_value: 3.0 }
    251       """, ev.summary)
    252 
    253     # We should be done.
    254     self.assertRaises(StopIteration, lambda: next(rr))
    255 
    256   def _csv_data(self, logdir):
    257     # Create a small data file with 3 CSV records.
    258     data_path = os.path.join(logdir, "data.csv")
    259     with open(data_path, "w") as f:
    260       f.write("1,2,3\n")
    261       f.write("4,5,6\n")
    262       f.write("7,8,9\n")
    263     return data_path
    264 
    265   def testManagedEndOfInputOneQueue(self):
    266     # Tests that the supervisor finishes without an error when using
    267     # a fixed number of epochs, reading from a single queue.
    268     logdir = self._test_dir("managed_end_of_input_one_queue")
    269     os.makedirs(logdir)
    270     data_path = self._csv_data(logdir)
    271     with ops.Graph().as_default():
    272       # Create an input pipeline that reads the file 3 times.
    273       filename_queue = input_lib.string_input_producer(
    274           [data_path], num_epochs=3)
    275       reader = io_ops.TextLineReader()
    276       _, csv = reader.read(filename_queue)
    277       rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
    278       sv = supervisor.Supervisor(logdir=logdir)
    279       with sv.managed_session("") as sess:
    280         while not sv.should_stop():
    281           sess.run(rec)
    282 
    283   def testManagedEndOfInputTwoQueues(self):
    284     # Tests that the supervisor finishes without an error when using
    285     # a fixed number of epochs, reading from two queues, the second
    286     # one producing a batch from the first one.
    287     logdir = self._test_dir("managed_end_of_input_two_queues")
    288     os.makedirs(logdir)
    289     data_path = self._csv_data(logdir)
    290     with ops.Graph().as_default():
    291       # Create an input pipeline that reads the file 3 times.
    292       filename_queue = input_lib.string_input_producer(
    293           [data_path], num_epochs=3)
    294       reader = io_ops.TextLineReader()
    295       _, csv = reader.read(filename_queue)
    296       rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
    297       shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
    298       sv = supervisor.Supervisor(logdir=logdir)
    299       with sv.managed_session("") as sess:
    300         while not sv.should_stop():
    301           sess.run(shuff_rec)
    302 
    303   def testManagedMainErrorTwoQueues(self):
    304     # Tests that the supervisor correctly raises a main loop
    305     # error even when using multiple queues for input.
    306     logdir = self._test_dir("managed_main_error_two_queues")
    307     os.makedirs(logdir)
    308     data_path = self._csv_data(logdir)
    309     with self.assertRaisesRegexp(RuntimeError, "fail at step 3"):
    310       with ops.Graph().as_default():
    311         # Create an input pipeline that reads the file 3 times.
    312         filename_queue = input_lib.string_input_producer(
    313             [data_path], num_epochs=3)
    314         reader = io_ops.TextLineReader()
    315         _, csv = reader.read(filename_queue)
    316         rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
    317         shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
    318         sv = supervisor.Supervisor(logdir=logdir)
    319         with sv.managed_session("") as sess:
    320           for step in range(9):
    321             if sv.should_stop():
    322               break
    323             elif step == 3:
    324               raise RuntimeError("fail at step 3")
    325             else:
    326               sess.run(shuff_rec)
    327 
    328   def testSessionConfig(self):
    329     logdir = self._test_dir("session_config")
    330     with ops.Graph().as_default():
    331       with ops.device("/cpu:1"):
    332         my_op = constant_op.constant([1.0])
    333       sv = supervisor.Supervisor(logdir=logdir)
    334       sess = sv.prepare_or_wait_for_session(
    335           "", config=config_pb2.ConfigProto(device_count={"CPU": 2}))
    336       for _ in xrange(10):
    337         sess.run(my_op)
    338       sess.close()
    339       sv.stop()
    340 
    341   def testChiefCanWriteEvents(self):
    342     logdir = self._test_dir("can_write")
    343     with ops.Graph().as_default():
    344       summary.scalar("c1", constant_op.constant(1))
    345       summary.scalar("c2", constant_op.constant(2))
    346       summary.scalar("c3", constant_op.constant(3))
    347       summ = summary.merge_all()
    348       sv = supervisor.Supervisor(is_chief=True, logdir=logdir, summary_op=None)
    349       meta_graph_def = meta_graph.create_meta_graph_def()
    350       sess = sv.prepare_or_wait_for_session("")
    351       sv.summary_computed(sess, sess.run(summ))
    352       sess.close()
    353       # Wait to make sure everything is written to file before stopping.
    354       time.sleep(1)
    355       sv.stop()
    356 
    357     rr = _summary_iterator(logdir)
    358 
    359     # The first event should list the file_version.
    360     ev = next(rr)
    361     self.assertEquals("brain.Event:2", ev.file_version)
    362 
    363     # The next one has the graph.
    364     ev = next(rr)
    365     ev_graph = graph_pb2.GraphDef()
    366     ev_graph.ParseFromString(ev.graph_def)
    367     self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
    368 
    369     # Stored MetaGraphDef
    370     ev = next(rr)
    371     ev_meta_graph = meta_graph_pb2.MetaGraphDef()
    372     ev_meta_graph.ParseFromString(ev.meta_graph_def)
    373     self.assertProtoEquals(meta_graph_def, ev_meta_graph)
    374     self.assertProtoEquals(
    375         sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
    376     # The next one should have the values from the summary.
    377     ev = next(rr)
    378     self.assertProtoEquals("""
    379       value { tag: 'c1' simple_value: 1.0 }
    380       value { tag: 'c2' simple_value: 2.0 }
    381       value { tag: 'c3' simple_value: 3.0 }
    382       """, ev.summary)
    383 
    384     # The next one should be a stop message if we closed cleanly.
    385     ev = next(rr)
    386     self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
    387 
    388     # We should be done.
    389     self.assertRaises(StopIteration, lambda: next(rr))
    390 
    391   def testNonChiefCannotWriteEvents(self):
    392 
    393     def _summary_computed():
    394       with ops.Graph().as_default():
    395         sv = supervisor.Supervisor(is_chief=False)
    396         sess = sv.prepare_or_wait_for_session("")
    397         summary.scalar("c1", constant_op.constant(1))
    398         summary.scalar("c2", constant_op.constant(2))
    399         summ = summary.merge_all()
    400         sv.summary_computed(sess, sess.run(summ))
    401 
    402     def _start_standard_services():
    403       with ops.Graph().as_default():
    404         sv = supervisor.Supervisor(is_chief=False)
    405         sess = sv.prepare_or_wait_for_session("")
    406         sv.start_standard_services(sess)
    407 
    408     self.assertRaises(RuntimeError, _summary_computed)
    409     self.assertRaises(RuntimeError, _start_standard_services)
    410 
    411   def testNoLogdirButWantSummary(self):
    412     with ops.Graph().as_default():
    413       summary.scalar("c1", constant_op.constant(1))
    414       summary.scalar("c2", constant_op.constant(2))
    415       summary.scalar("c3", constant_op.constant(3))
    416       summ = summary.merge_all()
    417       sv = supervisor.Supervisor(logdir="", summary_op=None)
    418       sess = sv.prepare_or_wait_for_session("")
    419       with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"):
    420         sv.summary_computed(sess, sess.run(summ))
    421 
    422   def testLogdirButExplicitlyNoSummaryWriter(self):
    423     logdir = self._test_dir("explicit_no_summary_writer")
    424     with ops.Graph().as_default():
    425       variables.Variable([1.0], name="foo")
    426       summary.scalar("c1", constant_op.constant(1))
    427       summary.scalar("c2", constant_op.constant(2))
    428       summary.scalar("c3", constant_op.constant(3))
    429       summ = summary.merge_all()
    430       sv = supervisor.Supervisor(logdir=logdir, summary_writer=None)
    431       sess = sv.prepare_or_wait_for_session("")
    432       # Check that a checkpoint is still be generated.
    433       self._wait_for_glob(sv.save_path, 3.0)
    434       # Check that we cannot write a summary
    435       with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"):
    436         sv.summary_computed(sess, sess.run(summ))
    437 
    438   def testNoLogdirButExplicitSummaryWriter(self):
    439     logdir = self._test_dir("explicit_summary_writer")
    440     with ops.Graph().as_default():
    441       summary.scalar("c1", constant_op.constant(1))
    442       summary.scalar("c2", constant_op.constant(2))
    443       summary.scalar("c3", constant_op.constant(3))
    444       summ = summary.merge_all()
    445       sw = writer.FileWriter(logdir)
    446       sv = supervisor.Supervisor(logdir="", summary_op=None, summary_writer=sw)
    447       meta_graph_def = meta_graph.create_meta_graph_def()
    448       sess = sv.prepare_or_wait_for_session("")
    449       sv.summary_computed(sess, sess.run(summ))
    450       sess.close()
    451       # Wait to make sure everything is written to file before stopping.
    452       time.sleep(1)
    453       sv.stop()
    454 
    455     # Check the summary was written to 'logdir'
    456     rr = _summary_iterator(logdir)
    457 
    458     # The first event should list the file_version.
    459     ev = next(rr)
    460     self.assertEquals("brain.Event:2", ev.file_version)
    461 
    462     # The next one has the graph.
    463     ev = next(rr)
    464     ev_graph = graph_pb2.GraphDef()
    465     ev_graph.ParseFromString(ev.graph_def)
    466     self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
    467 
    468     # Stored MetaGraphDef
    469     ev = next(rr)
    470     ev_meta_graph = meta_graph_pb2.MetaGraphDef()
    471     ev_meta_graph.ParseFromString(ev.meta_graph_def)
    472     self.assertProtoEquals(meta_graph_def, ev_meta_graph)
    473     self.assertProtoEquals(
    474         sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
    475 
    476     # The next one should have the values from the summary.
    477     ev = next(rr)
    478     self.assertProtoEquals("""
    479       value { tag: 'c1' simple_value: 1.0 }
    480       value { tag: 'c2' simple_value: 2.0 }
    481       value { tag: 'c3' simple_value: 3.0 }
    482       """, ev.summary)
    483 
    484     # The next one should be a stop message if we closed cleanly.
    485     ev = next(rr)
    486     self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
    487 
    488     # We should be done.
    489     self.assertRaises(StopIteration, lambda: next(rr))
    490 
    491   def testNoLogdirSucceeds(self):
    492     with ops.Graph().as_default():
    493       variables.Variable([1.0, 2.0, 3.0])
    494       sv = supervisor.Supervisor(logdir="", summary_op=None)
    495       sess = sv.prepare_or_wait_for_session("")
    496       sess.close()
    497       sv.stop()
    498 
    499   def testUseSessionManager(self):
    500     with ops.Graph().as_default():
    501       variables.Variable([1.0, 2.0, 3.0])
    502       sm = session_manager_lib.SessionManager()
    503       # Pass in session_manager. The additional init_op is ignored.
    504       sv = supervisor.Supervisor(logdir="", session_manager=sm)
    505       sv.prepare_or_wait_for_session("")
    506 
    507   def testInitOp(self):
    508     logdir = self._test_dir("default_init_op")
    509     with ops.Graph().as_default():
    510       v = variables.Variable([1.0, 2.0, 3.0])
    511       sv = supervisor.Supervisor(logdir=logdir)
    512       sess = sv.prepare_or_wait_for_session("")
    513       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    514       sv.stop()
    515 
    516   def testInitFn(self):
    517     logdir = self._test_dir("default_init_op")
    518     with ops.Graph().as_default():
    519       v = variables.Variable([1.0, 2.0, 3.0])
    520 
    521       def _init_fn(sess):
    522         sess.run(v.initializer)
    523 
    524       sv = supervisor.Supervisor(logdir=logdir, init_op=None, init_fn=_init_fn)
    525       sess = sv.prepare_or_wait_for_session("")
    526       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    527       sv.stop()
    528 
    529   def testInitOpWithFeedDict(self):
    530     logdir = self._test_dir("feed_dict_init_op")
    531     with ops.Graph().as_default():
    532       p = array_ops.placeholder(dtypes.float32, shape=(3,))
    533       v = variables.Variable(p, name="v")
    534       sv = supervisor.Supervisor(
    535           logdir=logdir,
    536           init_op=variables.global_variables_initializer(),
    537           init_feed_dict={p: [1.0, 2.0, 3.0]})
    538       sess = sv.prepare_or_wait_for_session("")
    539       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    540       sv.stop()
    541 
    542   def testReadyForLocalInitOp(self):
    543     server = server_lib.Server.create_local_server()
    544     logdir = self._test_dir("default_ready_for_local_init_op")
    545 
    546     uid = uuid.uuid4().hex
    547 
    548     def get_session(is_chief):
    549       g = ops.Graph()
    550       with g.as_default():
    551         with ops.device("/job:local"):
    552           v = variables.Variable(
    553               1, name="default_ready_for_local_init_op_v_" + str(uid))
    554           vadd = v.assign_add(1)
    555           w = variables.Variable(
    556               v,
    557               trainable=False,
    558               collections=[ops.GraphKeys.LOCAL_VARIABLES],
    559               name="default_ready_for_local_init_op_w_" + str(uid))
    560           ready_for_local_init_op = variables.report_uninitialized_variables(
    561               variables.global_variables())
    562       sv = supervisor.Supervisor(
    563           logdir=logdir,
    564           is_chief=is_chief,
    565           graph=g,
    566           recovery_wait_secs=1,
    567           init_op=v.initializer,
    568           ready_for_local_init_op=ready_for_local_init_op)
    569       sess = sv.prepare_or_wait_for_session(server.target)
    570 
    571       return sv, sess, v, vadd, w
    572 
    573     sv0, sess0, v0, _, w0 = get_session(True)
    574     sv1, sess1, _, vadd1, w1 = get_session(False)
    575 
    576     self.assertEqual(1, sess0.run(w0))
    577     self.assertEqual(2, sess1.run(vadd1))
    578     self.assertEqual(1, sess1.run(w1))
    579     self.assertEqual(2, sess0.run(v0))
    580 
    581     sv0.stop()
    582     sv1.stop()
    583 
    584   def testReadyForLocalInitOpRestoreFromCheckpoint(self):
    585     server = server_lib.Server.create_local_server()
    586     logdir = self._test_dir("ready_for_local_init_op_restore")
    587 
    588     uid = uuid.uuid4().hex
    589 
    590     # Create a checkpoint.
    591     with ops.Graph().as_default():
    592       v = variables.Variable(
    593           10.0, name="ready_for_local_init_op_restore_v_" + str(uid))
    594       summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v)
    595       sv = supervisor.Supervisor(logdir=logdir)
    596       sv.prepare_or_wait_for_session(server.target)
    597       save_path = sv.save_path
    598       self._wait_for_glob(save_path, 3.0)
    599       self._wait_for_glob(
    600           os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
    601       # Wait to make sure everything is written to file before stopping.
    602       time.sleep(1)
    603       sv.stop()
    604 
    605     def get_session(is_chief):
    606       g = ops.Graph()
    607       with g.as_default():
    608         with ops.device("/job:local"):
    609           v = variables.Variable(
    610               1.0, name="ready_for_local_init_op_restore_v_" + str(uid))
    611           vadd = v.assign_add(1)
    612           w = variables.Variable(
    613               v,
    614               trainable=False,
    615               collections=[ops.GraphKeys.LOCAL_VARIABLES],
    616               name="ready_for_local_init_op_restore_w_" + str(uid))
    617           ready_for_local_init_op = variables.report_uninitialized_variables(
    618               variables.global_variables())
    619       sv = supervisor.Supervisor(
    620           logdir=logdir,
    621           is_chief=is_chief,
    622           graph=g,
    623           recovery_wait_secs=1,
    624           ready_for_local_init_op=ready_for_local_init_op)
    625       sess = sv.prepare_or_wait_for_session(server.target)
    626 
    627       return sv, sess, v, vadd, w
    628 
    629     sv0, sess0, v0, _, w0 = get_session(True)
    630     sv1, sess1, _, vadd1, w1 = get_session(False)
    631 
    632     self.assertEqual(10, sess0.run(w0))
    633     self.assertEqual(11, sess1.run(vadd1))
    634     self.assertEqual(10, sess1.run(w1))
    635     self.assertEqual(11, sess0.run(v0))
    636 
    637     sv0.stop()
    638     sv1.stop()
    639 
    640   def testLocalInitOp(self):
    641     logdir = self._test_dir("default_local_init_op")
    642     with ops.Graph().as_default():
    643       # A local variable.
    644       v = variables.Variable(
    645           [1.0, 2.0, 3.0],
    646           trainable=False,
    647           collections=[ops.GraphKeys.LOCAL_VARIABLES])
    648 
    649       # An entity which is initialized through a TABLE_INITIALIZER.
    650       w = variables.Variable([4, 5, 6], trainable=False, collections=[])
    651       ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer)
    652 
    653       # This shouldn't add a variable to the VARIABLES collection responsible
    654       # for variables that are saved/restored from checkpoints.
    655       self.assertEquals(len(variables.global_variables()), 0)
    656 
    657       # Suppress normal variable inits to make sure the local one is
    658       # initialized via local_init_op.
    659       sv = supervisor.Supervisor(logdir=logdir, init_op=None)
    660       sess = sv.prepare_or_wait_for_session("")
    661       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    662       self.assertAllClose([4, 5, 6], sess.run(w))
    663       sv.stop()
    664 
    665   def testLocalInitOpForNonChief(self):
    666     logdir = self._test_dir("default_local_init_op_non_chief")
    667     with ops.Graph().as_default():
    668       with ops.device("/job:localhost"):
    669         # A local variable.
    670         v = variables.Variable(
    671             [1.0, 2.0, 3.0],
    672             trainable=False,
    673             collections=[ops.GraphKeys.LOCAL_VARIABLES])
    674         # This shouldn't add a variable to the VARIABLES collection responsible
    675         # for variables that are saved/restored from checkpoints.
    676         self.assertEquals(len(variables.global_variables()), 0)
    677 
    678       # Suppress normal variable inits to make sure the local one is
    679       # initialized via local_init_op.
    680       sv = supervisor.Supervisor(logdir=logdir, init_op=None, is_chief=False)
    681       sess = sv.prepare_or_wait_for_session("")
    682       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    683       sv.stop()
    684 
    685   def testInitOpFails(self):
    686     server = server_lib.Server.create_local_server()
    687     logdir = self._test_dir("default_init_op_fails")
    688     with ops.Graph().as_default():
    689       v = variables.Variable([1.0, 2.0, 3.0], name="v")
    690       variables.Variable([4.0, 5.0, 6.0], name="w")
    691       # w will not be initialized.
    692       sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer)
    693       with self.assertRaisesRegexp(RuntimeError,
    694                                    "Variables not initialized: w"):
    695         sv.prepare_or_wait_for_session(server.target)
    696 
    697   def testInitOpFailsForTransientVariable(self):
    698     server = server_lib.Server.create_local_server()
    699     logdir = self._test_dir("default_init_op_fails_for_local_variable")
    700     with ops.Graph().as_default():
    701       v = variables.Variable(
    702           [1.0, 2.0, 3.0],
    703           name="v",
    704           collections=[ops.GraphKeys.LOCAL_VARIABLES])
    705       variables.Variable(
    706           [1.0, 2.0, 3.0],
    707           name="w",
    708           collections=[ops.GraphKeys.LOCAL_VARIABLES])
    709       # w will not be initialized.
    710       sv = supervisor.Supervisor(logdir=logdir, local_init_op=v.initializer)
    711       with self.assertRaisesRegexp(RuntimeError,
    712                                    "Variables not initialized: w"):
    713         sv.prepare_or_wait_for_session(server.target)
    714 
    715   def testSetupFail(self):
    716     logdir = self._test_dir("setup_fail")
    717     with ops.Graph().as_default():
    718       variables.Variable([1.0, 2.0, 3.0], name="v")
    719       with self.assertRaisesRegexp(ValueError, "must have their device set"):
    720         supervisor.Supervisor(logdir=logdir, is_chief=False)
    721     with ops.Graph().as_default(), ops.device("/job:ps"):
    722       variables.Variable([1.0, 2.0, 3.0], name="v")
    723       supervisor.Supervisor(logdir=logdir, is_chief=False)
    724 
    725   def testDefaultGlobalStep(self):
    726     logdir = self._test_dir("default_global_step")
    727     with ops.Graph().as_default():
    728       variables.Variable(287, name="global_step")
    729       sv = supervisor.Supervisor(logdir=logdir)
    730       sess = sv.prepare_or_wait_for_session("")
    731       self.assertEquals(287, sess.run(sv.global_step))
    732       sv.stop()
    733 
    734   def testRestoreFromMetaGraph(self):
    735     logdir = self._test_dir("restore_from_meta_graph")
    736     with ops.Graph().as_default():
    737       variables.Variable(1, name="v0")
    738       sv = supervisor.Supervisor(logdir=logdir)
    739       sess = sv.prepare_or_wait_for_session("")
    740       filename = sv.saver.save(sess, sv.save_path)
    741       sv.stop()
    742     # Create a new Graph and Supervisor and recover.
    743     with ops.Graph().as_default():
    744       new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"]))
    745       self.assertIsNotNone(new_saver)
    746       sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
    747       sess = sv2.prepare_or_wait_for_session("")
    748       self.assertEquals(1, sess.run("v0:0"))
    749       sv2.saver.save(sess, sv2.save_path)
    750       sv2.stop()
    751 
    752   # This test is based on the fact that the standard services start
    753   # right away and get to run once before sv.stop() returns.
    754   # We still sleep a bit to make the test robust.
    755   def testStandardServicesWithoutGlobalStep(self):
    756     logdir = self._test_dir("standard_services_without_global_step")
    757     # Create a checkpoint.
    758     with ops.Graph().as_default():
    759       v = variables.Variable([1.0], name="foo")
    760       summary.scalar("v", v[0])
    761       sv = supervisor.Supervisor(logdir=logdir)
    762       meta_graph_def = meta_graph.create_meta_graph_def(
    763           saver_def=sv.saver.saver_def)
    764       sess = sv.prepare_or_wait_for_session("")
    765       save_path = sv.save_path
    766       self._wait_for_glob(save_path, 3.0)
    767       self._wait_for_glob(
    768           os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
    769       # Wait to make sure everything is written to file before stopping.
    770       time.sleep(1)
    771       sv.stop()
    772     # There should be an event file with a version number.
    773     rr = _summary_iterator(logdir)
    774     ev = next(rr)
    775     self.assertEquals("brain.Event:2", ev.file_version)
    776     ev = next(rr)
    777     ev_graph = graph_pb2.GraphDef()
    778     ev_graph.ParseFromString(ev.graph_def)
    779     self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
    780 
    781     # Stored MetaGraphDef
    782     ev = next(rr)
    783     ev_meta_graph = meta_graph_pb2.MetaGraphDef()
    784     ev_meta_graph.ParseFromString(ev.meta_graph_def)
    785     self.assertProtoEquals(meta_graph_def, ev_meta_graph)
    786     self.assertProtoEquals(
    787         sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
    788 
    789     ev = next(rr)
    790     self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary)
    791 
    792     ev = next(rr)
    793     self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
    794 
    795     self.assertRaises(StopIteration, lambda: next(rr))
    796     # There should be a checkpoint file with the variable "foo"
    797     with ops.Graph().as_default(), self.test_session() as sess:
    798       v = variables.Variable([10.10], name="foo")
    799       sav = saver_lib.Saver([v])
    800       sav.restore(sess, save_path)
    801       self.assertEqual(1.0, v.eval()[0])
    802 
    803   # Same as testStandardServicesNoGlobalStep but with a global step.
    804   # We should get a summary about the step time.
    805   def testStandardServicesWithGlobalStep(self):
    806     logdir = self._test_dir("standard_services_with_global_step")
    807     # Create a checkpoint.
    808     with ops.Graph().as_default():
    809       v = variables.Variable([123], name="global_step")
    810       sv = supervisor.Supervisor(logdir=logdir)
    811       meta_graph_def = meta_graph.create_meta_graph_def(
    812           saver_def=sv.saver.saver_def)
    813       sess = sv.prepare_or_wait_for_session("")
    814       # This is where the checkpoint will appear, with step number 123.
    815       save_path = "%s-123" % sv.save_path
    816       self._wait_for_glob(save_path, 3.0)
    817       self._wait_for_glob(
    818           os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
    819       # Wait to make sure everything is written to file before stopping.
    820       time.sleep(1)
    821       sv.stop()
    822     # There should be an event file with a version number.
    823     rr = _summary_iterator(logdir)
    824     ev = next(rr)
    825     self.assertEquals("brain.Event:2", ev.file_version)
    826     ev = next(rr)
    827     ev_graph = graph_pb2.GraphDef()
    828     ev_graph.ParseFromString(ev.graph_def)
    829     self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
    830     ev = next(rr)
    831     ev_meta_graph = meta_graph_pb2.MetaGraphDef()
    832     ev_meta_graph.ParseFromString(ev.meta_graph_def)
    833     self.assertProtoEquals(meta_graph_def, ev_meta_graph)
    834     self.assertProtoEquals(
    835         sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
    836     ev = next(rr)
    837     # It is actually undeterministic whether SessionLog.START gets written
    838     # before the summary or the checkpoint, but this works when run 10000 times.
    839     self.assertEquals(123, ev.step)
    840     self.assertEquals(event_pb2.SessionLog.START, ev.session_log.status)
    841     first = next(rr)
    842     second = next(rr)
    843     # It is undeterministic whether the value gets written before the checkpoint
    844     # since they are on separate threads, so we check for both conditions.
    845     if first.HasField("summary"):
    846       self.assertProtoEquals("""value { tag: 'global_step/sec'
    847                                         simple_value: 0.0 }""", first.summary)
    848       self.assertEquals(123, second.step)
    849       self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
    850                         second.session_log.status)
    851     else:
    852       self.assertEquals(123, first.step)
    853       self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
    854                         first.session_log.status)
    855       self.assertProtoEquals("""value { tag: 'global_step/sec'
    856                                         simple_value: 0.0 }""", second.summary)
    857     ev = next(rr)
    858     self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
    859     self.assertRaises(StopIteration, lambda: next(rr))
    860     # There should be a checkpoint file with the variable "foo"
    861     with ops.Graph().as_default(), self.test_session() as sess:
    862       v = variables.Variable([-12], name="global_step")
    863       sav = saver_lib.Saver([v])
    864       sav.restore(sess, save_path)
    865       self.assertEqual(123, v.eval()[0])
    866 
    867   def testNoQueueRunners(self):
    868     with ops.Graph().as_default(), self.test_session() as sess:
    869       sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
    870       self.assertEqual(0, len(sv.start_queue_runners(sess)))
    871       sv.stop()
    872 
    873   def testPrepareSessionAfterStopForChief(self):
    874     logdir = self._test_dir("prepare_after_stop_chief")
    875     with ops.Graph().as_default():
    876       sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
    877 
    878       # Create a first session and then stop.
    879       sess = sv.prepare_or_wait_for_session("")
    880       sv.stop()
    881       sess.close()
    882       self.assertTrue(sv.should_stop())
    883 
    884       # Now create a second session and test that we don't stay stopped, until
    885       # we ask to stop again.
    886       sess2 = sv.prepare_or_wait_for_session("")
    887       self.assertFalse(sv.should_stop())
    888       sv.stop()
    889       sess2.close()
    890       self.assertTrue(sv.should_stop())
    891 
    892   def testPrepareSessionAfterStopForNonChief(self):
    893     logdir = self._test_dir("prepare_after_stop_nonchief")
    894     with ops.Graph().as_default():
    895       sv = supervisor.Supervisor(logdir=logdir, is_chief=False)
    896 
    897       # Create a first session and then stop.
    898       sess = sv.prepare_or_wait_for_session("")
    899       sv.stop()
    900       sess.close()
    901       self.assertTrue(sv.should_stop())
    902 
    903       # Now create a second session and test that we don't stay stopped, until
    904       # we ask to stop again.
    905       sess2 = sv.prepare_or_wait_for_session("")
    906       self.assertFalse(sv.should_stop())
    907       sv.stop()
    908       sess2.close()
    909       self.assertTrue(sv.should_stop())
    910 
    911 
    912 if __name__ == "__main__":
    913   test.main()
    914