1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for supervisor.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import glob 22 import os 23 import shutil 24 import time 25 import uuid 26 27 from six.moves import xrange # pylint: disable=redefined-builtin 28 29 from tensorflow.core.framework import graph_pb2 30 from tensorflow.core.protobuf import config_pb2 31 from tensorflow.core.protobuf import meta_graph_pb2 32 from tensorflow.core.util import event_pb2 33 from tensorflow.python.framework import constant_op 34 from tensorflow.python.framework import dtypes 35 from tensorflow.python.framework import errors_impl 36 from tensorflow.python.framework import meta_graph 37 from tensorflow.python.framework import ops 38 from tensorflow.python.ops import array_ops 39 from tensorflow.python.ops import io_ops 40 from tensorflow.python.ops import parsing_ops 41 from tensorflow.python.ops import variables 42 from tensorflow.python.platform import gfile 43 from tensorflow.python.platform import test 44 from tensorflow.python.summary import summary 45 from tensorflow.python.summary import summary_iterator 46 from tensorflow.python.summary.writer import writer 47 from tensorflow.python.training import input as input_lib 48 from tensorflow.python.training import saver as saver_lib 49 from tensorflow.python.training import server_lib 50 from tensorflow.python.training import session_manager as session_manager_lib 51 from tensorflow.python.training import supervisor 52 53 54 def _summary_iterator(test_dir): 55 """Reads events from test_dir/events. 56 57 Args: 58 test_dir: Name of the test directory. 59 60 Returns: 61 A summary_iterator 62 """ 63 event_paths = sorted(glob.glob(os.path.join(test_dir, "event*"))) 64 return summary_iterator.summary_iterator(event_paths[-1]) 65 66 67 class SupervisorTest(test.TestCase): 68 69 def _test_dir(self, test_name): 70 test_dir = os.path.join(self.get_temp_dir(), test_name) 71 if os.path.exists(test_dir): 72 shutil.rmtree(test_dir) 73 return test_dir 74 75 def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True): 76 """Wait for a checkpoint file to appear. 77 78 Args: 79 pattern: A string. 80 timeout_secs: How long to wait for in seconds. 81 for_checkpoint: whether we're globbing for checkpoints. 82 """ 83 end_time = time.time() + timeout_secs 84 while time.time() < end_time: 85 if for_checkpoint: 86 if saver_lib.checkpoint_exists(pattern): 87 return 88 else: 89 if len(gfile.Glob(pattern)) >= 1: 90 return 91 time.sleep(0.05) 92 self.assertFalse(True, "Glob never matched any file: %s" % pattern) 93 94 # This test does not test much. 95 def testBasics(self): 96 logdir = self._test_dir("basics") 97 with ops.Graph().as_default(): 98 my_op = constant_op.constant(1.0) 99 sv = supervisor.Supervisor(logdir=logdir) 100 sess = sv.prepare_or_wait_for_session("") 101 for _ in xrange(10): 102 sess.run(my_op) 103 sess.close() 104 sv.stop() 105 106 def testManagedSession(self): 107 logdir = self._test_dir("managed_session") 108 with ops.Graph().as_default(): 109 my_op = constant_op.constant(1.0) 110 sv = supervisor.Supervisor(logdir=logdir) 111 with sv.managed_session("") as sess: 112 for _ in xrange(10): 113 sess.run(my_op) 114 # Supervisor has been stopped. 115 self.assertTrue(sv.should_stop()) 116 117 def testManagedSessionUserError(self): 118 logdir = self._test_dir("managed_user_error") 119 with ops.Graph().as_default(): 120 my_op = constant_op.constant(1.0) 121 sv = supervisor.Supervisor(logdir=logdir) 122 last_step = None 123 with self.assertRaisesRegexp(RuntimeError, "failing here"): 124 with sv.managed_session("") as sess: 125 for step in xrange(10): 126 last_step = step 127 if step == 1: 128 raise RuntimeError("failing here") 129 else: 130 sess.run(my_op) 131 # Supervisor has been stopped. 132 self.assertTrue(sv.should_stop()) 133 self.assertEqual(1, last_step) 134 135 def testManagedSessionIgnoreOutOfRangeError(self): 136 logdir = self._test_dir("managed_out_of_range") 137 with ops.Graph().as_default(): 138 my_op = constant_op.constant(1.0) 139 sv = supervisor.Supervisor(logdir=logdir) 140 last_step = None 141 with sv.managed_session("") as sess: 142 for step in xrange(10): 143 last_step = step 144 if step == 3: 145 raise errors_impl.OutOfRangeError(my_op.op.node_def, my_op.op, 146 "all done") 147 else: 148 sess.run(my_op) 149 # Supervisor has been stopped. OutOfRangeError was not thrown. 150 self.assertTrue(sv.should_stop()) 151 self.assertEqual(3, last_step) 152 153 def testManagedSessionDoNotKeepSummaryWriter(self): 154 logdir = self._test_dir("managed_not_keep_summary_writer") 155 with ops.Graph().as_default(): 156 summary.scalar("c1", constant_op.constant(1)) 157 summary.scalar("c2", constant_op.constant(2)) 158 summary.scalar("c3", constant_op.constant(3)) 159 summ = summary.merge_all() 160 sv = supervisor.Supervisor(logdir=logdir, summary_op=None) 161 with sv.managed_session( 162 "", close_summary_writer=True, start_standard_services=False) as sess: 163 sv.summary_computed(sess, sess.run(summ)) 164 # Sleep 1.2s to make sure that the next event file has a different name 165 # than the current one. 166 time.sleep(1.2) 167 with sv.managed_session( 168 "", close_summary_writer=True, start_standard_services=False) as sess: 169 sv.summary_computed(sess, sess.run(summ)) 170 event_paths = sorted(glob.glob(os.path.join(logdir, "event*"))) 171 self.assertEquals(2, len(event_paths)) 172 # The two event files should have the same contents. 173 for path in event_paths: 174 # The summary iterator should report the summary once as we closed the 175 # summary writer across the 2 sessions. 176 rr = summary_iterator.summary_iterator(path) 177 # The first event should list the file_version. 178 ev = next(rr) 179 self.assertEquals("brain.Event:2", ev.file_version) 180 181 # The next one has the graph and metagraph. 182 ev = next(rr) 183 self.assertTrue(ev.graph_def) 184 185 ev = next(rr) 186 self.assertTrue(ev.meta_graph_def) 187 188 # The next one should have the values from the summary. 189 # But only once. 190 ev = next(rr) 191 self.assertProtoEquals(""" 192 value { tag: 'c1' simple_value: 1.0 } 193 value { tag: 'c2' simple_value: 2.0 } 194 value { tag: 'c3' simple_value: 3.0 } 195 """, ev.summary) 196 197 # The next one should be a stop message if we closed cleanly. 198 ev = next(rr) 199 self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) 200 201 # We should be done. 202 with self.assertRaises(StopIteration): 203 next(rr) 204 205 def testManagedSessionKeepSummaryWriter(self): 206 logdir = self._test_dir("managed_keep_summary_writer") 207 with ops.Graph().as_default(): 208 summary.scalar("c1", constant_op.constant(1)) 209 summary.scalar("c2", constant_op.constant(2)) 210 summary.scalar("c3", constant_op.constant(3)) 211 summ = summary.merge_all() 212 sv = supervisor.Supervisor(logdir=logdir) 213 with sv.managed_session( 214 "", close_summary_writer=False, 215 start_standard_services=False) as sess: 216 sv.summary_computed(sess, sess.run(summ)) 217 with sv.managed_session( 218 "", close_summary_writer=False, 219 start_standard_services=False) as sess: 220 sv.summary_computed(sess, sess.run(summ)) 221 # Now close the summary writer to flush the events. 222 sv.summary_writer.close() 223 # The summary iterator should report the summary twice as we reused 224 # the same summary writer across the 2 sessions. 225 rr = _summary_iterator(logdir) 226 # The first event should list the file_version. 227 ev = next(rr) 228 self.assertEquals("brain.Event:2", ev.file_version) 229 230 # The next one has the graph. 231 ev = next(rr) 232 self.assertTrue(ev.graph_def) 233 234 ev = next(rr) 235 self.assertTrue(ev.meta_graph_def) 236 237 # The next one should have the values from the summary. 238 ev = next(rr) 239 self.assertProtoEquals(""" 240 value { tag: 'c1' simple_value: 1.0 } 241 value { tag: 'c2' simple_value: 2.0 } 242 value { tag: 'c3' simple_value: 3.0 } 243 """, ev.summary) 244 245 # The next one should also have the values from the summary. 246 ev = next(rr) 247 self.assertProtoEquals(""" 248 value { tag: 'c1' simple_value: 1.0 } 249 value { tag: 'c2' simple_value: 2.0 } 250 value { tag: 'c3' simple_value: 3.0 } 251 """, ev.summary) 252 253 # We should be done. 254 self.assertRaises(StopIteration, lambda: next(rr)) 255 256 def _csv_data(self, logdir): 257 # Create a small data file with 3 CSV records. 258 data_path = os.path.join(logdir, "data.csv") 259 with open(data_path, "w") as f: 260 f.write("1,2,3\n") 261 f.write("4,5,6\n") 262 f.write("7,8,9\n") 263 return data_path 264 265 def testManagedEndOfInputOneQueue(self): 266 # Tests that the supervisor finishes without an error when using 267 # a fixed number of epochs, reading from a single queue. 268 logdir = self._test_dir("managed_end_of_input_one_queue") 269 os.makedirs(logdir) 270 data_path = self._csv_data(logdir) 271 with ops.Graph().as_default(): 272 # Create an input pipeline that reads the file 3 times. 273 filename_queue = input_lib.string_input_producer( 274 [data_path], num_epochs=3) 275 reader = io_ops.TextLineReader() 276 _, csv = reader.read(filename_queue) 277 rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) 278 sv = supervisor.Supervisor(logdir=logdir) 279 with sv.managed_session("") as sess: 280 while not sv.should_stop(): 281 sess.run(rec) 282 283 def testManagedEndOfInputTwoQueues(self): 284 # Tests that the supervisor finishes without an error when using 285 # a fixed number of epochs, reading from two queues, the second 286 # one producing a batch from the first one. 287 logdir = self._test_dir("managed_end_of_input_two_queues") 288 os.makedirs(logdir) 289 data_path = self._csv_data(logdir) 290 with ops.Graph().as_default(): 291 # Create an input pipeline that reads the file 3 times. 292 filename_queue = input_lib.string_input_producer( 293 [data_path], num_epochs=3) 294 reader = io_ops.TextLineReader() 295 _, csv = reader.read(filename_queue) 296 rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) 297 shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) 298 sv = supervisor.Supervisor(logdir=logdir) 299 with sv.managed_session("") as sess: 300 while not sv.should_stop(): 301 sess.run(shuff_rec) 302 303 def testManagedMainErrorTwoQueues(self): 304 # Tests that the supervisor correctly raises a main loop 305 # error even when using multiple queues for input. 306 logdir = self._test_dir("managed_main_error_two_queues") 307 os.makedirs(logdir) 308 data_path = self._csv_data(logdir) 309 with self.assertRaisesRegexp(RuntimeError, "fail at step 3"): 310 with ops.Graph().as_default(): 311 # Create an input pipeline that reads the file 3 times. 312 filename_queue = input_lib.string_input_producer( 313 [data_path], num_epochs=3) 314 reader = io_ops.TextLineReader() 315 _, csv = reader.read(filename_queue) 316 rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]]) 317 shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4) 318 sv = supervisor.Supervisor(logdir=logdir) 319 with sv.managed_session("") as sess: 320 for step in range(9): 321 if sv.should_stop(): 322 break 323 elif step == 3: 324 raise RuntimeError("fail at step 3") 325 else: 326 sess.run(shuff_rec) 327 328 def testSessionConfig(self): 329 logdir = self._test_dir("session_config") 330 with ops.Graph().as_default(): 331 with ops.device("/cpu:1"): 332 my_op = constant_op.constant([1.0]) 333 sv = supervisor.Supervisor(logdir=logdir) 334 sess = sv.prepare_or_wait_for_session( 335 "", config=config_pb2.ConfigProto(device_count={"CPU": 2})) 336 for _ in xrange(10): 337 sess.run(my_op) 338 sess.close() 339 sv.stop() 340 341 def testChiefCanWriteEvents(self): 342 logdir = self._test_dir("can_write") 343 with ops.Graph().as_default(): 344 summary.scalar("c1", constant_op.constant(1)) 345 summary.scalar("c2", constant_op.constant(2)) 346 summary.scalar("c3", constant_op.constant(3)) 347 summ = summary.merge_all() 348 sv = supervisor.Supervisor(is_chief=True, logdir=logdir, summary_op=None) 349 meta_graph_def = meta_graph.create_meta_graph_def() 350 sess = sv.prepare_or_wait_for_session("") 351 sv.summary_computed(sess, sess.run(summ)) 352 sess.close() 353 # Wait to make sure everything is written to file before stopping. 354 time.sleep(1) 355 sv.stop() 356 357 rr = _summary_iterator(logdir) 358 359 # The first event should list the file_version. 360 ev = next(rr) 361 self.assertEquals("brain.Event:2", ev.file_version) 362 363 # The next one has the graph. 364 ev = next(rr) 365 ev_graph = graph_pb2.GraphDef() 366 ev_graph.ParseFromString(ev.graph_def) 367 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 368 369 # Stored MetaGraphDef 370 ev = next(rr) 371 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 372 ev_meta_graph.ParseFromString(ev.meta_graph_def) 373 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 374 self.assertProtoEquals( 375 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 376 # The next one should have the values from the summary. 377 ev = next(rr) 378 self.assertProtoEquals(""" 379 value { tag: 'c1' simple_value: 1.0 } 380 value { tag: 'c2' simple_value: 2.0 } 381 value { tag: 'c3' simple_value: 3.0 } 382 """, ev.summary) 383 384 # The next one should be a stop message if we closed cleanly. 385 ev = next(rr) 386 self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) 387 388 # We should be done. 389 self.assertRaises(StopIteration, lambda: next(rr)) 390 391 def testNonChiefCannotWriteEvents(self): 392 393 def _summary_computed(): 394 with ops.Graph().as_default(): 395 sv = supervisor.Supervisor(is_chief=False) 396 sess = sv.prepare_or_wait_for_session("") 397 summary.scalar("c1", constant_op.constant(1)) 398 summary.scalar("c2", constant_op.constant(2)) 399 summ = summary.merge_all() 400 sv.summary_computed(sess, sess.run(summ)) 401 402 def _start_standard_services(): 403 with ops.Graph().as_default(): 404 sv = supervisor.Supervisor(is_chief=False) 405 sess = sv.prepare_or_wait_for_session("") 406 sv.start_standard_services(sess) 407 408 self.assertRaises(RuntimeError, _summary_computed) 409 self.assertRaises(RuntimeError, _start_standard_services) 410 411 def testNoLogdirButWantSummary(self): 412 with ops.Graph().as_default(): 413 summary.scalar("c1", constant_op.constant(1)) 414 summary.scalar("c2", constant_op.constant(2)) 415 summary.scalar("c3", constant_op.constant(3)) 416 summ = summary.merge_all() 417 sv = supervisor.Supervisor(logdir="", summary_op=None) 418 sess = sv.prepare_or_wait_for_session("") 419 with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"): 420 sv.summary_computed(sess, sess.run(summ)) 421 422 def testLogdirButExplicitlyNoSummaryWriter(self): 423 logdir = self._test_dir("explicit_no_summary_writer") 424 with ops.Graph().as_default(): 425 variables.Variable([1.0], name="foo") 426 summary.scalar("c1", constant_op.constant(1)) 427 summary.scalar("c2", constant_op.constant(2)) 428 summary.scalar("c3", constant_op.constant(3)) 429 summ = summary.merge_all() 430 sv = supervisor.Supervisor(logdir=logdir, summary_writer=None) 431 sess = sv.prepare_or_wait_for_session("") 432 # Check that a checkpoint is still be generated. 433 self._wait_for_glob(sv.save_path, 3.0) 434 # Check that we cannot write a summary 435 with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"): 436 sv.summary_computed(sess, sess.run(summ)) 437 438 def testNoLogdirButExplicitSummaryWriter(self): 439 logdir = self._test_dir("explicit_summary_writer") 440 with ops.Graph().as_default(): 441 summary.scalar("c1", constant_op.constant(1)) 442 summary.scalar("c2", constant_op.constant(2)) 443 summary.scalar("c3", constant_op.constant(3)) 444 summ = summary.merge_all() 445 sw = writer.FileWriter(logdir) 446 sv = supervisor.Supervisor(logdir="", summary_op=None, summary_writer=sw) 447 meta_graph_def = meta_graph.create_meta_graph_def() 448 sess = sv.prepare_or_wait_for_session("") 449 sv.summary_computed(sess, sess.run(summ)) 450 sess.close() 451 # Wait to make sure everything is written to file before stopping. 452 time.sleep(1) 453 sv.stop() 454 455 # Check the summary was written to 'logdir' 456 rr = _summary_iterator(logdir) 457 458 # The first event should list the file_version. 459 ev = next(rr) 460 self.assertEquals("brain.Event:2", ev.file_version) 461 462 # The next one has the graph. 463 ev = next(rr) 464 ev_graph = graph_pb2.GraphDef() 465 ev_graph.ParseFromString(ev.graph_def) 466 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 467 468 # Stored MetaGraphDef 469 ev = next(rr) 470 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 471 ev_meta_graph.ParseFromString(ev.meta_graph_def) 472 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 473 self.assertProtoEquals( 474 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 475 476 # The next one should have the values from the summary. 477 ev = next(rr) 478 self.assertProtoEquals(""" 479 value { tag: 'c1' simple_value: 1.0 } 480 value { tag: 'c2' simple_value: 2.0 } 481 value { tag: 'c3' simple_value: 3.0 } 482 """, ev.summary) 483 484 # The next one should be a stop message if we closed cleanly. 485 ev = next(rr) 486 self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) 487 488 # We should be done. 489 self.assertRaises(StopIteration, lambda: next(rr)) 490 491 def testNoLogdirSucceeds(self): 492 with ops.Graph().as_default(): 493 variables.Variable([1.0, 2.0, 3.0]) 494 sv = supervisor.Supervisor(logdir="", summary_op=None) 495 sess = sv.prepare_or_wait_for_session("") 496 sess.close() 497 sv.stop() 498 499 def testUseSessionManager(self): 500 with ops.Graph().as_default(): 501 variables.Variable([1.0, 2.0, 3.0]) 502 sm = session_manager_lib.SessionManager() 503 # Pass in session_manager. The additional init_op is ignored. 504 sv = supervisor.Supervisor(logdir="", session_manager=sm) 505 sv.prepare_or_wait_for_session("") 506 507 def testInitOp(self): 508 logdir = self._test_dir("default_init_op") 509 with ops.Graph().as_default(): 510 v = variables.Variable([1.0, 2.0, 3.0]) 511 sv = supervisor.Supervisor(logdir=logdir) 512 sess = sv.prepare_or_wait_for_session("") 513 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 514 sv.stop() 515 516 def testInitFn(self): 517 logdir = self._test_dir("default_init_op") 518 with ops.Graph().as_default(): 519 v = variables.Variable([1.0, 2.0, 3.0]) 520 521 def _init_fn(sess): 522 sess.run(v.initializer) 523 524 sv = supervisor.Supervisor(logdir=logdir, init_op=None, init_fn=_init_fn) 525 sess = sv.prepare_or_wait_for_session("") 526 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 527 sv.stop() 528 529 def testInitOpWithFeedDict(self): 530 logdir = self._test_dir("feed_dict_init_op") 531 with ops.Graph().as_default(): 532 p = array_ops.placeholder(dtypes.float32, shape=(3,)) 533 v = variables.Variable(p, name="v") 534 sv = supervisor.Supervisor( 535 logdir=logdir, 536 init_op=variables.global_variables_initializer(), 537 init_feed_dict={p: [1.0, 2.0, 3.0]}) 538 sess = sv.prepare_or_wait_for_session("") 539 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 540 sv.stop() 541 542 def testReadyForLocalInitOp(self): 543 server = server_lib.Server.create_local_server() 544 logdir = self._test_dir("default_ready_for_local_init_op") 545 546 uid = uuid.uuid4().hex 547 548 def get_session(is_chief): 549 g = ops.Graph() 550 with g.as_default(): 551 with ops.device("/job:local"): 552 v = variables.Variable( 553 1, name="default_ready_for_local_init_op_v_" + str(uid)) 554 vadd = v.assign_add(1) 555 w = variables.Variable( 556 v, 557 trainable=False, 558 collections=[ops.GraphKeys.LOCAL_VARIABLES], 559 name="default_ready_for_local_init_op_w_" + str(uid)) 560 ready_for_local_init_op = variables.report_uninitialized_variables( 561 variables.global_variables()) 562 sv = supervisor.Supervisor( 563 logdir=logdir, 564 is_chief=is_chief, 565 graph=g, 566 recovery_wait_secs=1, 567 init_op=v.initializer, 568 ready_for_local_init_op=ready_for_local_init_op) 569 sess = sv.prepare_or_wait_for_session(server.target) 570 571 return sv, sess, v, vadd, w 572 573 sv0, sess0, v0, _, w0 = get_session(True) 574 sv1, sess1, _, vadd1, w1 = get_session(False) 575 576 self.assertEqual(1, sess0.run(w0)) 577 self.assertEqual(2, sess1.run(vadd1)) 578 self.assertEqual(1, sess1.run(w1)) 579 self.assertEqual(2, sess0.run(v0)) 580 581 sv0.stop() 582 sv1.stop() 583 584 def testReadyForLocalInitOpRestoreFromCheckpoint(self): 585 server = server_lib.Server.create_local_server() 586 logdir = self._test_dir("ready_for_local_init_op_restore") 587 588 uid = uuid.uuid4().hex 589 590 # Create a checkpoint. 591 with ops.Graph().as_default(): 592 v = variables.Variable( 593 10.0, name="ready_for_local_init_op_restore_v_" + str(uid)) 594 summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v) 595 sv = supervisor.Supervisor(logdir=logdir) 596 sv.prepare_or_wait_for_session(server.target) 597 save_path = sv.save_path 598 self._wait_for_glob(save_path, 3.0) 599 self._wait_for_glob( 600 os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) 601 # Wait to make sure everything is written to file before stopping. 602 time.sleep(1) 603 sv.stop() 604 605 def get_session(is_chief): 606 g = ops.Graph() 607 with g.as_default(): 608 with ops.device("/job:local"): 609 v = variables.Variable( 610 1.0, name="ready_for_local_init_op_restore_v_" + str(uid)) 611 vadd = v.assign_add(1) 612 w = variables.Variable( 613 v, 614 trainable=False, 615 collections=[ops.GraphKeys.LOCAL_VARIABLES], 616 name="ready_for_local_init_op_restore_w_" + str(uid)) 617 ready_for_local_init_op = variables.report_uninitialized_variables( 618 variables.global_variables()) 619 sv = supervisor.Supervisor( 620 logdir=logdir, 621 is_chief=is_chief, 622 graph=g, 623 recovery_wait_secs=1, 624 ready_for_local_init_op=ready_for_local_init_op) 625 sess = sv.prepare_or_wait_for_session(server.target) 626 627 return sv, sess, v, vadd, w 628 629 sv0, sess0, v0, _, w0 = get_session(True) 630 sv1, sess1, _, vadd1, w1 = get_session(False) 631 632 self.assertEqual(10, sess0.run(w0)) 633 self.assertEqual(11, sess1.run(vadd1)) 634 self.assertEqual(10, sess1.run(w1)) 635 self.assertEqual(11, sess0.run(v0)) 636 637 sv0.stop() 638 sv1.stop() 639 640 def testLocalInitOp(self): 641 logdir = self._test_dir("default_local_init_op") 642 with ops.Graph().as_default(): 643 # A local variable. 644 v = variables.Variable( 645 [1.0, 2.0, 3.0], 646 trainable=False, 647 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 648 649 # An entity which is initialized through a TABLE_INITIALIZER. 650 w = variables.Variable([4, 5, 6], trainable=False, collections=[]) 651 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer) 652 653 # This shouldn't add a variable to the VARIABLES collection responsible 654 # for variables that are saved/restored from checkpoints. 655 self.assertEquals(len(variables.global_variables()), 0) 656 657 # Suppress normal variable inits to make sure the local one is 658 # initialized via local_init_op. 659 sv = supervisor.Supervisor(logdir=logdir, init_op=None) 660 sess = sv.prepare_or_wait_for_session("") 661 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 662 self.assertAllClose([4, 5, 6], sess.run(w)) 663 sv.stop() 664 665 def testLocalInitOpForNonChief(self): 666 logdir = self._test_dir("default_local_init_op_non_chief") 667 with ops.Graph().as_default(): 668 with ops.device("/job:localhost"): 669 # A local variable. 670 v = variables.Variable( 671 [1.0, 2.0, 3.0], 672 trainable=False, 673 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 674 # This shouldn't add a variable to the VARIABLES collection responsible 675 # for variables that are saved/restored from checkpoints. 676 self.assertEquals(len(variables.global_variables()), 0) 677 678 # Suppress normal variable inits to make sure the local one is 679 # initialized via local_init_op. 680 sv = supervisor.Supervisor(logdir=logdir, init_op=None, is_chief=False) 681 sess = sv.prepare_or_wait_for_session("") 682 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 683 sv.stop() 684 685 def testInitOpFails(self): 686 server = server_lib.Server.create_local_server() 687 logdir = self._test_dir("default_init_op_fails") 688 with ops.Graph().as_default(): 689 v = variables.Variable([1.0, 2.0, 3.0], name="v") 690 variables.Variable([4.0, 5.0, 6.0], name="w") 691 # w will not be initialized. 692 sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer) 693 with self.assertRaisesRegexp(RuntimeError, 694 "Variables not initialized: w"): 695 sv.prepare_or_wait_for_session(server.target) 696 697 def testInitOpFailsForTransientVariable(self): 698 server = server_lib.Server.create_local_server() 699 logdir = self._test_dir("default_init_op_fails_for_local_variable") 700 with ops.Graph().as_default(): 701 v = variables.Variable( 702 [1.0, 2.0, 3.0], 703 name="v", 704 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 705 variables.Variable( 706 [1.0, 2.0, 3.0], 707 name="w", 708 collections=[ops.GraphKeys.LOCAL_VARIABLES]) 709 # w will not be initialized. 710 sv = supervisor.Supervisor(logdir=logdir, local_init_op=v.initializer) 711 with self.assertRaisesRegexp(RuntimeError, 712 "Variables not initialized: w"): 713 sv.prepare_or_wait_for_session(server.target) 714 715 def testSetupFail(self): 716 logdir = self._test_dir("setup_fail") 717 with ops.Graph().as_default(): 718 variables.Variable([1.0, 2.0, 3.0], name="v") 719 with self.assertRaisesRegexp(ValueError, "must have their device set"): 720 supervisor.Supervisor(logdir=logdir, is_chief=False) 721 with ops.Graph().as_default(), ops.device("/job:ps"): 722 variables.Variable([1.0, 2.0, 3.0], name="v") 723 supervisor.Supervisor(logdir=logdir, is_chief=False) 724 725 def testDefaultGlobalStep(self): 726 logdir = self._test_dir("default_global_step") 727 with ops.Graph().as_default(): 728 variables.Variable(287, name="global_step") 729 sv = supervisor.Supervisor(logdir=logdir) 730 sess = sv.prepare_or_wait_for_session("") 731 self.assertEquals(287, sess.run(sv.global_step)) 732 sv.stop() 733 734 def testRestoreFromMetaGraph(self): 735 logdir = self._test_dir("restore_from_meta_graph") 736 with ops.Graph().as_default(): 737 variables.Variable(1, name="v0") 738 sv = supervisor.Supervisor(logdir=logdir) 739 sess = sv.prepare_or_wait_for_session("") 740 filename = sv.saver.save(sess, sv.save_path) 741 sv.stop() 742 # Create a new Graph and Supervisor and recover. 743 with ops.Graph().as_default(): 744 new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"])) 745 self.assertIsNotNone(new_saver) 746 sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver) 747 sess = sv2.prepare_or_wait_for_session("") 748 self.assertEquals(1, sess.run("v0:0")) 749 sv2.saver.save(sess, sv2.save_path) 750 sv2.stop() 751 752 # This test is based on the fact that the standard services start 753 # right away and get to run once before sv.stop() returns. 754 # We still sleep a bit to make the test robust. 755 def testStandardServicesWithoutGlobalStep(self): 756 logdir = self._test_dir("standard_services_without_global_step") 757 # Create a checkpoint. 758 with ops.Graph().as_default(): 759 v = variables.Variable([1.0], name="foo") 760 summary.scalar("v", v[0]) 761 sv = supervisor.Supervisor(logdir=logdir) 762 meta_graph_def = meta_graph.create_meta_graph_def( 763 saver_def=sv.saver.saver_def) 764 sess = sv.prepare_or_wait_for_session("") 765 save_path = sv.save_path 766 self._wait_for_glob(save_path, 3.0) 767 self._wait_for_glob( 768 os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) 769 # Wait to make sure everything is written to file before stopping. 770 time.sleep(1) 771 sv.stop() 772 # There should be an event file with a version number. 773 rr = _summary_iterator(logdir) 774 ev = next(rr) 775 self.assertEquals("brain.Event:2", ev.file_version) 776 ev = next(rr) 777 ev_graph = graph_pb2.GraphDef() 778 ev_graph.ParseFromString(ev.graph_def) 779 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 780 781 # Stored MetaGraphDef 782 ev = next(rr) 783 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 784 ev_meta_graph.ParseFromString(ev.meta_graph_def) 785 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 786 self.assertProtoEquals( 787 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 788 789 ev = next(rr) 790 self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary) 791 792 ev = next(rr) 793 self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) 794 795 self.assertRaises(StopIteration, lambda: next(rr)) 796 # There should be a checkpoint file with the variable "foo" 797 with ops.Graph().as_default(), self.test_session() as sess: 798 v = variables.Variable([10.10], name="foo") 799 sav = saver_lib.Saver([v]) 800 sav.restore(sess, save_path) 801 self.assertEqual(1.0, v.eval()[0]) 802 803 # Same as testStandardServicesNoGlobalStep but with a global step. 804 # We should get a summary about the step time. 805 def testStandardServicesWithGlobalStep(self): 806 logdir = self._test_dir("standard_services_with_global_step") 807 # Create a checkpoint. 808 with ops.Graph().as_default(): 809 v = variables.Variable([123], name="global_step") 810 sv = supervisor.Supervisor(logdir=logdir) 811 meta_graph_def = meta_graph.create_meta_graph_def( 812 saver_def=sv.saver.saver_def) 813 sess = sv.prepare_or_wait_for_session("") 814 # This is where the checkpoint will appear, with step number 123. 815 save_path = "%s-123" % sv.save_path 816 self._wait_for_glob(save_path, 3.0) 817 self._wait_for_glob( 818 os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False) 819 # Wait to make sure everything is written to file before stopping. 820 time.sleep(1) 821 sv.stop() 822 # There should be an event file with a version number. 823 rr = _summary_iterator(logdir) 824 ev = next(rr) 825 self.assertEquals("brain.Event:2", ev.file_version) 826 ev = next(rr) 827 ev_graph = graph_pb2.GraphDef() 828 ev_graph.ParseFromString(ev.graph_def) 829 self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph) 830 ev = next(rr) 831 ev_meta_graph = meta_graph_pb2.MetaGraphDef() 832 ev_meta_graph.ParseFromString(ev.meta_graph_def) 833 self.assertProtoEquals(meta_graph_def, ev_meta_graph) 834 self.assertProtoEquals( 835 sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def) 836 ev = next(rr) 837 # It is actually undeterministic whether SessionLog.START gets written 838 # before the summary or the checkpoint, but this works when run 10000 times. 839 self.assertEquals(123, ev.step) 840 self.assertEquals(event_pb2.SessionLog.START, ev.session_log.status) 841 first = next(rr) 842 second = next(rr) 843 # It is undeterministic whether the value gets written before the checkpoint 844 # since they are on separate threads, so we check for both conditions. 845 if first.HasField("summary"): 846 self.assertProtoEquals("""value { tag: 'global_step/sec' 847 simple_value: 0.0 }""", first.summary) 848 self.assertEquals(123, second.step) 849 self.assertEquals(event_pb2.SessionLog.CHECKPOINT, 850 second.session_log.status) 851 else: 852 self.assertEquals(123, first.step) 853 self.assertEquals(event_pb2.SessionLog.CHECKPOINT, 854 first.session_log.status) 855 self.assertProtoEquals("""value { tag: 'global_step/sec' 856 simple_value: 0.0 }""", second.summary) 857 ev = next(rr) 858 self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status) 859 self.assertRaises(StopIteration, lambda: next(rr)) 860 # There should be a checkpoint file with the variable "foo" 861 with ops.Graph().as_default(), self.test_session() as sess: 862 v = variables.Variable([-12], name="global_step") 863 sav = saver_lib.Saver([v]) 864 sav.restore(sess, save_path) 865 self.assertEqual(123, v.eval()[0]) 866 867 def testNoQueueRunners(self): 868 with ops.Graph().as_default(), self.test_session() as sess: 869 sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners")) 870 self.assertEqual(0, len(sv.start_queue_runners(sess))) 871 sv.stop() 872 873 def testPrepareSessionAfterStopForChief(self): 874 logdir = self._test_dir("prepare_after_stop_chief") 875 with ops.Graph().as_default(): 876 sv = supervisor.Supervisor(logdir=logdir, is_chief=True) 877 878 # Create a first session and then stop. 879 sess = sv.prepare_or_wait_for_session("") 880 sv.stop() 881 sess.close() 882 self.assertTrue(sv.should_stop()) 883 884 # Now create a second session and test that we don't stay stopped, until 885 # we ask to stop again. 886 sess2 = sv.prepare_or_wait_for_session("") 887 self.assertFalse(sv.should_stop()) 888 sv.stop() 889 sess2.close() 890 self.assertTrue(sv.should_stop()) 891 892 def testPrepareSessionAfterStopForNonChief(self): 893 logdir = self._test_dir("prepare_after_stop_nonchief") 894 with ops.Graph().as_default(): 895 sv = supervisor.Supervisor(logdir=logdir, is_chief=False) 896 897 # Create a first session and then stop. 898 sess = sv.prepare_or_wait_for_session("") 899 sv.stop() 900 sess.close() 901 self.assertTrue(sv.should_stop()) 902 903 # Now create a second session and test that we don't stay stopped, until 904 # we ask to stop again. 905 sess2 = sv.prepare_or_wait_for_session("") 906 self.assertFalse(sv.should_stop()) 907 sv.stop() 908 sess2.close() 909 self.assertTrue(sv.should_stop()) 910 911 912 if __name__ == "__main__": 913 test.main() 914