1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for SessionManager.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 23 from tensorflow.python.client import session as session_lib 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import errors 26 from tensorflow.python.framework import errors_impl 27 from tensorflow.python.framework import ops 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import control_flow_ops 30 from tensorflow.python.ops import variables 31 from tensorflow.python.platform import gfile 32 from tensorflow.python.platform import test 33 from tensorflow.python.training import saver as saver_lib 34 from tensorflow.python.training import server_lib 35 from tensorflow.python.training import session_manager 36 37 38 class SessionManagerTest(test.TestCase): 39 40 def testPrepareSessionSucceeds(self): 41 with ops.Graph().as_default(): 42 v = variables.Variable([1.0, 2.0, 3.0], name="v") 43 sm = session_manager.SessionManager( 44 ready_op=variables.report_uninitialized_variables()) 45 sess = sm.prepare_session( 46 "", init_op=variables.global_variables_initializer()) 47 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 48 49 def testPrepareSessionSucceedsWithInitFeedDict(self): 50 with ops.Graph().as_default(): 51 p = array_ops.placeholder(dtypes.float32, shape=(3,)) 52 v = variables.Variable(p, name="v") 53 sm = session_manager.SessionManager( 54 ready_op=variables.report_uninitialized_variables()) 55 sess = sm.prepare_session( 56 "", 57 init_op=variables.global_variables_initializer(), 58 init_feed_dict={p: [1.0, 2.0, 3.0]}) 59 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 60 61 def testPrepareSessionSucceedsWithInitFn(self): 62 with ops.Graph().as_default(): 63 v = variables.Variable([125], name="v") 64 sm = session_manager.SessionManager( 65 ready_op=variables.report_uninitialized_variables()) 66 sess = sm.prepare_session( 67 "", init_fn=lambda sess: sess.run(v.initializer)) 68 self.assertAllClose([125], sess.run(v)) 69 70 def testPrepareSessionFails(self): 71 checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") 72 checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2") 73 try: 74 gfile.DeleteRecursively(checkpoint_dir) 75 gfile.DeleteRecursively(checkpoint_dir2) 76 except errors.OpError: 77 pass # Ignore 78 gfile.MakeDirs(checkpoint_dir) 79 80 with ops.Graph().as_default(): 81 v = variables.Variable([1.0, 2.0, 3.0], name="v") 82 sm = session_manager.SessionManager( 83 ready_op=variables.report_uninitialized_variables()) 84 saver = saver_lib.Saver({"v": v}) 85 sess = sm.prepare_session( 86 "", 87 init_op=variables.global_variables_initializer(), 88 saver=saver, 89 checkpoint_dir=checkpoint_dir) 90 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 91 checkpoint_filename = os.path.join(checkpoint_dir, 92 "prepare_session_checkpoint") 93 saver.save(sess, checkpoint_filename) 94 # Create a new Graph and SessionManager and recover. 95 with ops.Graph().as_default(): 96 # Renames the checkpoint directory. 97 os.rename(checkpoint_dir, checkpoint_dir2) 98 gfile.MakeDirs(checkpoint_dir) 99 v = variables.Variable([6.0, 7.0, 8.0], name="v") 100 with self.test_session(): 101 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 102 session_manager.SessionManager( 103 ready_op=variables.report_uninitialized_variables()) 104 saver = saver_lib.Saver({"v": v}) 105 # This should fail as there's no checkpoint within 2 seconds. 106 with self.assertRaisesRegexp( 107 RuntimeError, "no init_op or init_fn or local_init_op was given"): 108 sess = sm.prepare_session( 109 "", 110 init_op=None, 111 saver=saver, 112 checkpoint_dir=checkpoint_dir, 113 wait_for_checkpoint=True, 114 max_wait_secs=2) 115 # Rename the checkpoint directory back. 116 gfile.DeleteRecursively(checkpoint_dir) 117 os.rename(checkpoint_dir2, checkpoint_dir) 118 # This should succeed as there's checkpoint. 119 sess = sm.prepare_session( 120 "", 121 init_op=None, 122 saver=saver, 123 checkpoint_dir=checkpoint_dir, 124 wait_for_checkpoint=True, 125 max_wait_secs=2) 126 self.assertEqual( 127 True, 128 variables.is_variable_initialized( 129 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 130 131 def _test_recovered_variable(self, 132 checkpoint_dir=None, 133 checkpoint_filename_with_path=None): 134 # Create a new Graph and SessionManager and recover from a checkpoint. 135 with ops.Graph().as_default(): 136 v = variables.Variable(2, name="v") 137 with session_lib.Session(): 138 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 139 sm2 = session_manager.SessionManager( 140 ready_op=variables.report_uninitialized_variables()) 141 saver = saver_lib.Saver({"v": v}) 142 sess, initialized = sm2.recover_session( 143 "", 144 saver=saver, 145 checkpoint_dir=checkpoint_dir, 146 checkpoint_filename_with_path=checkpoint_filename_with_path) 147 self.assertTrue(initialized) 148 self.assertEqual( 149 True, 150 variables.is_variable_initialized( 151 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 152 self.assertEquals(1, sess.run(v)) 153 154 def testRecoverSession(self): 155 # Create a checkpoint. 156 checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session") 157 try: 158 gfile.DeleteRecursively(checkpoint_dir) 159 except errors.OpError: 160 pass # Ignore 161 gfile.MakeDirs(checkpoint_dir) 162 163 with ops.Graph().as_default(): 164 v = variables.Variable(1, name="v") 165 sm = session_manager.SessionManager( 166 ready_op=variables.report_uninitialized_variables()) 167 saver = saver_lib.Saver({"v": v}) 168 sess, initialized = sm.recover_session( 169 "", saver=saver, checkpoint_dir=checkpoint_dir) 170 self.assertFalse(initialized) 171 sess.run(v.initializer) 172 self.assertEquals(1, sess.run(v)) 173 saver.save(sess, 174 os.path.join(checkpoint_dir, "recover_session_checkpoint")) 175 self._test_recovered_variable(checkpoint_dir=checkpoint_dir) 176 self._test_recovered_variable( 177 checkpoint_filename_with_path=saver_lib.latest_checkpoint( 178 checkpoint_dir)) 179 # Cannot set both checkpoint_dir and checkpoint_filename_with_path. 180 with self.assertRaises(ValueError): 181 self._test_recovered_variable( 182 checkpoint_dir=checkpoint_dir, 183 checkpoint_filename_with_path=saver_lib.latest_checkpoint( 184 checkpoint_dir)) 185 186 def testWaitForSessionReturnsNoneAfterTimeout(self): 187 with ops.Graph().as_default(): 188 variables.Variable(1, name="v") 189 sm = session_manager.SessionManager( 190 ready_op=variables.report_uninitialized_variables(), 191 recovery_wait_secs=1) 192 193 # Set max_wait_secs to allow us to try a few times. 194 with self.assertRaises(errors.DeadlineExceededError): 195 sm.wait_for_session(master="", max_wait_secs=3) 196 197 def testInitWithNoneLocalInitOpError(self): 198 # Creating a SessionManager with a None local_init_op but 199 # non-None ready_for_local_init_op raises ValueError 200 with self.assertRaisesRegexp(ValueError, 201 "If you pass a ready_for_local_init_op " 202 "you must also pass a local_init_op "): 203 session_manager.SessionManager( 204 ready_for_local_init_op=variables.report_uninitialized_variables( 205 variables.global_variables()), 206 local_init_op=None) 207 208 def testRecoverSessionWithReadyForLocalInitOp(self): 209 # Create a checkpoint. 210 checkpoint_dir = os.path.join(self.get_temp_dir(), 211 "recover_session_ready_for_local_init") 212 try: 213 gfile.DeleteRecursively(checkpoint_dir) 214 except errors.OpError: 215 pass # Ignore 216 gfile.MakeDirs(checkpoint_dir) 217 218 with ops.Graph().as_default(): 219 v = variables.Variable(1, name="v") 220 sm = session_manager.SessionManager( 221 ready_op=variables.report_uninitialized_variables()) 222 saver = saver_lib.Saver({"v": v}) 223 sess, initialized = sm.recover_session( 224 "", saver=saver, checkpoint_dir=checkpoint_dir) 225 self.assertFalse(initialized) 226 sess.run(v.initializer) 227 self.assertEquals(1, sess.run(v)) 228 saver.save(sess, 229 os.path.join(checkpoint_dir, "recover_session_checkpoint")) 230 # Create a new Graph and SessionManager and recover. 231 with ops.Graph().as_default(): 232 v = variables.Variable(2, name="v") 233 w = variables.Variable( 234 v, 235 trainable=False, 236 collections=[ops.GraphKeys.LOCAL_VARIABLES], 237 name="w") 238 with self.test_session(): 239 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 240 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 241 sm2 = session_manager.SessionManager( 242 ready_op=variables.report_uninitialized_variables(), 243 ready_for_local_init_op=variables.report_uninitialized_variables( 244 variables.global_variables()), 245 local_init_op=w.initializer) 246 saver = saver_lib.Saver({"v": v}) 247 sess, initialized = sm2.recover_session( 248 "", saver=saver, checkpoint_dir=checkpoint_dir) 249 self.assertTrue(initialized) 250 self.assertEqual( 251 True, 252 variables.is_variable_initialized( 253 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 254 self.assertEqual( 255 True, 256 variables.is_variable_initialized( 257 sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) 258 self.assertEquals(1, sess.run(v)) 259 self.assertEquals(1, sess.run(w)) 260 261 def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self): 262 # We use ready_for_local_init_op=tf.report_uninitialized_variables(), 263 # which causes recover_session to not run local_init_op, and to return 264 # initialized=False 265 266 # Create a checkpoint. 267 checkpoint_dir = os.path.join( 268 self.get_temp_dir(), 269 "recover_session_ready_for_local_init_fails_to_ready_local") 270 try: 271 gfile.DeleteRecursively(checkpoint_dir) 272 except errors.OpError: 273 pass # Ignore 274 gfile.MakeDirs(checkpoint_dir) 275 276 with ops.Graph().as_default(): 277 v = variables.Variable(1, name="v") 278 sm = session_manager.SessionManager( 279 ready_op=variables.report_uninitialized_variables()) 280 saver = saver_lib.Saver({"v": v}) 281 sess, initialized = sm.recover_session( 282 "", saver=saver, checkpoint_dir=checkpoint_dir) 283 self.assertFalse(initialized) 284 sess.run(v.initializer) 285 self.assertEquals(1, sess.run(v)) 286 saver.save(sess, 287 os.path.join(checkpoint_dir, "recover_session_checkpoint")) 288 # Create a new Graph and SessionManager and recover. 289 with ops.Graph().as_default(): 290 v = variables.Variable(2, name="v") 291 w = variables.Variable( 292 v, 293 trainable=False, 294 collections=[ops.GraphKeys.LOCAL_VARIABLES], 295 name="w") 296 with self.test_session(): 297 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 298 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 299 sm2 = session_manager.SessionManager( 300 ready_op=variables.report_uninitialized_variables(), 301 ready_for_local_init_op=variables.report_uninitialized_variables(), 302 local_init_op=w.initializer) 303 saver = saver_lib.Saver({"v": v}) 304 sess, initialized = sm2.recover_session( 305 "", saver=saver, checkpoint_dir=checkpoint_dir) 306 self.assertFalse(initialized) 307 self.assertEqual( 308 True, 309 variables.is_variable_initialized( 310 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 311 self.assertEqual( 312 False, 313 variables.is_variable_initialized( 314 sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) 315 self.assertEquals(1, sess.run(v)) 316 317 def testRecoverSessionNoChkptStillRunsLocalInitOp(self): 318 # This test checks for backwards compatibility. 319 # In particular, we continue to ensure that recover_session will execute 320 # local_init_op exactly once, regardless of whether the session was 321 # successfully recovered. 322 with ops.Graph().as_default(): 323 w = variables.Variable( 324 1, 325 trainable=False, 326 collections=[ops.GraphKeys.LOCAL_VARIABLES], 327 name="w") 328 with self.test_session(): 329 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 330 sm2 = session_manager.SessionManager( 331 ready_op=variables.report_uninitialized_variables(), 332 ready_for_local_init_op=None, 333 local_init_op=w.initializer) 334 # Try to recover session from None 335 sess, initialized = sm2.recover_session( 336 "", saver=None, checkpoint_dir=None) 337 # Succeeds because recover_session still run local_init_op 338 self.assertFalse(initialized) 339 self.assertEqual( 340 True, 341 variables.is_variable_initialized( 342 sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) 343 self.assertEquals(1, sess.run(w)) 344 345 def testRecoverSessionFailsStillRunsLocalInitOp(self): 346 # Create a checkpoint. 347 checkpoint_dir = os.path.join( 348 self.get_temp_dir(), 349 "recover_session_ready_for_local_init_fails_stil_run") 350 try: 351 gfile.DeleteRecursively(checkpoint_dir) 352 except errors.OpError: 353 pass # Ignore 354 gfile.MakeDirs(checkpoint_dir) 355 356 # Create a new Graph and SessionManager and recover. 357 with ops.Graph().as_default(): 358 v = variables.Variable(2, name="v") 359 w = variables.Variable( 360 1, 361 trainable=False, 362 collections=[ops.GraphKeys.LOCAL_VARIABLES], 363 name="w") 364 with self.test_session(): 365 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 366 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 367 sm2 = session_manager.SessionManager( 368 ready_op=variables.report_uninitialized_variables(), 369 ready_for_local_init_op=None, 370 local_init_op=w.initializer) 371 saver = saver_lib.Saver({"v": v}) 372 sess, initialized = sm2.recover_session( 373 "", 374 saver=saver, 375 checkpoint_dir=checkpoint_dir, 376 wait_for_checkpoint=False) 377 self.assertFalse(initialized) 378 self.assertEqual( 379 False, 380 variables.is_variable_initialized( 381 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 382 self.assertEqual( 383 True, 384 variables.is_variable_initialized( 385 sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) 386 self.assertEquals(1, sess.run(w)) 387 388 def testWaitForSessionLocalInit(self): 389 server = server_lib.Server.create_local_server() 390 with ops.Graph().as_default() as graph: 391 v = variables.Variable(1, name="v") 392 w = variables.Variable( 393 v, 394 trainable=False, 395 collections=[ops.GraphKeys.LOCAL_VARIABLES], 396 name="w") 397 sm = session_manager.SessionManager( 398 graph=graph, 399 ready_op=variables.report_uninitialized_variables(), 400 ready_for_local_init_op=variables.report_uninitialized_variables( 401 variables.global_variables()), 402 local_init_op=w.initializer) 403 404 # Initialize v but not w 405 s = session_lib.Session(server.target, graph=graph) 406 s.run(v.initializer) 407 408 sess = sm.wait_for_session(server.target, max_wait_secs=3) 409 self.assertEqual( 410 True, 411 variables.is_variable_initialized( 412 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 413 self.assertEqual( 414 True, 415 variables.is_variable_initialized( 416 sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) 417 self.assertEquals(1, sess.run(v)) 418 self.assertEquals(1, sess.run(w)) 419 420 def testWaitForSessionWithReadyForLocalInitOpFailsToReadyLocal(self): 421 with ops.Graph().as_default() as graph: 422 v = variables.Variable(1, name="v") 423 w = variables.Variable( 424 v, 425 trainable=False, 426 collections=[ops.GraphKeys.LOCAL_VARIABLES], 427 name="w") 428 sm = session_manager.SessionManager( 429 graph=graph, 430 ready_op=variables.report_uninitialized_variables(), 431 ready_for_local_init_op=variables.report_uninitialized_variables(), 432 local_init_op=w.initializer) 433 434 with self.assertRaises(errors_impl.DeadlineExceededError): 435 # Time-out because w fails to be initialized, 436 # because of overly restrictive ready_for_local_init_op 437 sm.wait_for_session("", max_wait_secs=3) 438 439 def testWaitForSessionInsufficientReadyForLocalInitCheck(self): 440 with ops.Graph().as_default() as graph: 441 v = variables.Variable(1, name="v") 442 w = variables.Variable( 443 v, 444 trainable=False, 445 collections=[ops.GraphKeys.LOCAL_VARIABLES], 446 name="w") 447 sm = session_manager.SessionManager( 448 graph=graph, 449 ready_op=variables.report_uninitialized_variables(), 450 ready_for_local_init_op=None, 451 local_init_op=w.initializer) 452 with self.assertRaisesRegexp(errors_impl.DeadlineExceededError, 453 "Session was not ready after waiting.*"): 454 sm.wait_for_session("", max_wait_secs=3) 455 456 def testPrepareSessionWithReadyForLocalInitOp(self): 457 with ops.Graph().as_default(): 458 v = variables.Variable(1, name="v") 459 w = variables.Variable( 460 v, 461 trainable=False, 462 collections=[ops.GraphKeys.LOCAL_VARIABLES], 463 name="w") 464 x = variables.Variable( 465 3 * v, 466 trainable=False, 467 collections=[ops.GraphKeys.LOCAL_VARIABLES], 468 name="x") 469 with self.test_session(): 470 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 471 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 472 self.assertEqual(False, variables.is_variable_initialized(x).eval()) 473 sm2 = session_manager.SessionManager( 474 ready_op=variables.report_uninitialized_variables(), 475 ready_for_local_init_op=variables.report_uninitialized_variables( 476 variables.global_variables()), 477 local_init_op=[w.initializer, x.initializer]) 478 sess = sm2.prepare_session("", init_op=v.initializer) 479 self.assertEqual( 480 True, 481 variables.is_variable_initialized( 482 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 483 self.assertEqual( 484 True, 485 variables.is_variable_initialized( 486 sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) 487 self.assertEqual( 488 True, 489 variables.is_variable_initialized( 490 sess.graph.get_tensor_by_name("x:0")).eval(session=sess)) 491 self.assertEquals(1, sess.run(v)) 492 self.assertEquals(1, sess.run(w)) 493 self.assertEquals(3, sess.run(x)) 494 495 def testPrepareSessionWithPartialInitOp(self): 496 with ops.Graph().as_default(): 497 v = variables.Variable(1, name="v") 498 w = variables.Variable( 499 v, 500 trainable=False, 501 collections=[ops.GraphKeys.LOCAL_VARIABLES], 502 name="w") 503 x = variables.Variable( 504 3 * v, 505 trainable=False, 506 collections=[ops.GraphKeys.LOCAL_VARIABLES], 507 name="x") 508 # TODO(b/70206927): Use ResourceVariables once they are handled properly. 509 v_res = variables.Variable(1, name="v_res") 510 w_res = variables.Variable( 511 v_res, 512 trainable=False, 513 collections=[ops.GraphKeys.LOCAL_VARIABLES], 514 name="w_res") 515 x_res = variables.Variable( 516 3 * v_res, 517 trainable=False, 518 collections=[ops.GraphKeys.LOCAL_VARIABLES], 519 name="x_res") 520 521 with self.test_session(): 522 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 523 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 524 self.assertEqual(False, variables.is_variable_initialized(x).eval()) 525 self.assertEqual(False, variables.is_variable_initialized(v_res).eval()) 526 self.assertEqual(False, variables.is_variable_initialized(w_res).eval()) 527 self.assertEqual(False, variables.is_variable_initialized(x_res).eval()) 528 sm2 = session_manager.SessionManager(local_init_op=[ 529 w.initializer, x.initializer, w_res.initializer, x_res.initializer 530 ]) 531 sess = sm2.prepare_session("", init_op=None) 532 self.assertEqual( 533 False, 534 variables.is_variable_initialized( 535 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 536 self.assertEqual( 537 True, 538 variables.is_variable_initialized( 539 sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) 540 self.assertEqual( 541 True, 542 variables.is_variable_initialized( 543 sess.graph.get_tensor_by_name("x:0")).eval(session=sess)) 544 self.assertEquals(1, sess.run(w)) 545 self.assertEquals(3, sess.run(x)) 546 self.assertEqual( 547 False, 548 variables.is_variable_initialized( 549 sess.graph.get_tensor_by_name("v_res:0")).eval(session=sess)) 550 self.assertEqual( 551 True, 552 variables.is_variable_initialized( 553 sess.graph.get_tensor_by_name("w_res:0")).eval(session=sess)) 554 self.assertEqual( 555 True, 556 variables.is_variable_initialized( 557 sess.graph.get_tensor_by_name("x_res:0")).eval(session=sess)) 558 self.assertEquals(1, sess.run(w_res)) 559 self.assertEquals(3, sess.run(x_res)) 560 561 def testPrepareSessionWithCyclicInitializer(self): 562 # Regression test. Previously Variable._build_initializer_expr would enter 563 # into an infinite recursion when the variable's initial_value involved 564 # cyclic dependencies. 565 with ops.Graph().as_default(): 566 i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0]) 567 v = variables.Variable(array_ops.identity(i), name="v") 568 with self.test_session(): 569 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 570 sm = session_manager.SessionManager( 571 ready_op=variables.report_uninitialized_variables()) 572 sess = sm.prepare_session("", init_op=v.initializer) 573 self.assertEqual(1, sess.run(v)) 574 self.assertEqual( 575 True, 576 variables.is_variable_initialized( 577 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 578 579 def testPrepareSessionDidNotInitLocalVariable(self): 580 with ops.Graph().as_default(): 581 v = variables.Variable(1, name="v") 582 w = variables.Variable( 583 v, 584 trainable=False, 585 collections=[ops.GraphKeys.LOCAL_VARIABLES], 586 name="w") 587 with self.test_session(): 588 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 589 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 590 sm2 = session_manager.SessionManager( 591 ready_op=variables.report_uninitialized_variables()) 592 with self.assertRaisesRegexp( 593 RuntimeError, "Init operations did not make model ready.*"): 594 sm2.prepare_session("", init_op=v.initializer) 595 596 def testPrepareSessionDidNotInitLocalVariableList(self): 597 with ops.Graph().as_default(): 598 v = variables.Variable(1, name="v") 599 w = variables.Variable( 600 v, 601 trainable=False, 602 collections=[ops.GraphKeys.LOCAL_VARIABLES], 603 name="w") 604 with self.test_session(): 605 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 606 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 607 sm2 = session_manager.SessionManager( 608 ready_op=variables.report_uninitialized_variables()) 609 with self.assertRaisesRegexp(RuntimeError, 610 "Init operations did not make model ready"): 611 sm2.prepare_session("", init_op=[v.initializer]) 612 613 def testPrepareSessionWithReadyNotReadyForLocal(self): 614 with ops.Graph().as_default(): 615 v = variables.Variable(1, name="v") 616 w = variables.Variable( 617 v, 618 trainable=False, 619 collections=[ops.GraphKeys.LOCAL_VARIABLES], 620 name="w") 621 with self.test_session(): 622 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 623 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 624 sm2 = session_manager.SessionManager( 625 ready_op=variables.report_uninitialized_variables(), 626 ready_for_local_init_op=variables.report_uninitialized_variables( 627 variables.global_variables()), 628 local_init_op=w.initializer) 629 with self.assertRaisesRegexp( 630 RuntimeError, 631 "Init operations did not make model ready for local_init"): 632 sm2.prepare_session("", init_op=None) 633 634 def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self): 635 with ops.Graph().as_default(): 636 v = variables.Variable(1, name="v") 637 w = variables.Variable( 638 v, 639 trainable=False, 640 collections=[ops.GraphKeys.LOCAL_VARIABLES], 641 name="w") 642 with self.test_session(): 643 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 644 self.assertEqual(False, variables.is_variable_initialized(w).eval()) 645 sm2 = session_manager.SessionManager( 646 ready_op=variables.report_uninitialized_variables(), 647 ready_for_local_init_op=None, 648 local_init_op=w.initializer) 649 with self.assertRaisesRegexp(RuntimeError, 650 "Init operations did not make model ready.*"): 651 sm2.prepare_session("", init_op=None) 652 653 654 class ObsoleteSessionManagerTest(test.TestCase): 655 656 def testPrepareSessionSucceeds(self): 657 with ops.Graph().as_default(): 658 v = variables.Variable([1.0, 2.0, 3.0], name="v") 659 sm = session_manager.SessionManager( 660 ready_op=variables.assert_variables_initialized()) 661 sess = sm.prepare_session( 662 "", init_op=variables.global_variables_initializer()) 663 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 664 665 def testPrepareSessionSucceedsWithInitFeedDict(self): 666 with ops.Graph().as_default(): 667 p = array_ops.placeholder(dtypes.float32, shape=(3,)) 668 v = variables.Variable(p, name="v") 669 sm = session_manager.SessionManager( 670 ready_op=variables.assert_variables_initialized()) 671 sess = sm.prepare_session( 672 "", 673 init_op=variables.global_variables_initializer(), 674 init_feed_dict={p: [1.0, 2.0, 3.0]}) 675 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 676 677 def testPrepareSessionSucceedsWithInitFn(self): 678 with ops.Graph().as_default(): 679 v = variables.Variable([125], name="v") 680 sm = session_manager.SessionManager( 681 ready_op=variables.assert_variables_initialized()) 682 sess = sm.prepare_session( 683 "", init_fn=lambda sess: sess.run(v.initializer)) 684 self.assertAllClose([125], sess.run(v)) 685 686 def testPrepareSessionFails(self): 687 checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") 688 checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2") 689 try: 690 gfile.DeleteRecursively(checkpoint_dir) 691 gfile.DeleteRecursively(checkpoint_dir2) 692 except errors.OpError: 693 pass # Ignore 694 gfile.MakeDirs(checkpoint_dir) 695 696 with ops.Graph().as_default(): 697 v = variables.Variable([1.0, 2.0, 3.0], name="v") 698 sm = session_manager.SessionManager( 699 ready_op=variables.assert_variables_initialized()) 700 saver = saver_lib.Saver({"v": v}) 701 sess = sm.prepare_session( 702 "", 703 init_op=variables.global_variables_initializer(), 704 saver=saver, 705 checkpoint_dir=checkpoint_dir) 706 self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) 707 checkpoint_filename = os.path.join(checkpoint_dir, 708 "prepare_session_checkpoint") 709 saver.save(sess, checkpoint_filename) 710 # Create a new Graph and SessionManager and recover. 711 with ops.Graph().as_default(): 712 # Renames the checkpoint directory. 713 os.rename(checkpoint_dir, checkpoint_dir2) 714 gfile.MakeDirs(checkpoint_dir) 715 v = variables.Variable([6.0, 7.0, 8.0], name="v") 716 with self.test_session(): 717 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 718 session_manager.SessionManager( 719 ready_op=variables.assert_variables_initialized()) 720 saver = saver_lib.Saver({"v": v}) 721 # This should fail as there's no checkpoint within 2 seconds. 722 with self.assertRaisesRegexp( 723 RuntimeError, "no init_op or init_fn or local_init_op was given"): 724 sess = sm.prepare_session( 725 "", 726 init_op=None, 727 saver=saver, 728 checkpoint_dir=checkpoint_dir, 729 wait_for_checkpoint=True, 730 max_wait_secs=2) 731 # Rename the checkpoint directory back. 732 gfile.DeleteRecursively(checkpoint_dir) 733 os.rename(checkpoint_dir2, checkpoint_dir) 734 # This should succeed as there's checkpoint. 735 sess = sm.prepare_session( 736 "", 737 init_op=None, 738 saver=saver, 739 checkpoint_dir=checkpoint_dir, 740 wait_for_checkpoint=True, 741 max_wait_secs=2) 742 self.assertEqual( 743 True, 744 variables.is_variable_initialized( 745 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 746 747 def testRecoverSession(self): 748 # Create a checkpoint. 749 checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session") 750 try: 751 gfile.DeleteRecursively(checkpoint_dir) 752 except errors.OpError: 753 pass # Ignore 754 gfile.MakeDirs(checkpoint_dir) 755 756 with ops.Graph().as_default(): 757 v = variables.Variable(1, name="v") 758 sm = session_manager.SessionManager( 759 ready_op=variables.assert_variables_initialized()) 760 saver = saver_lib.Saver({"v": v}) 761 sess, initialized = sm.recover_session( 762 "", saver=saver, checkpoint_dir=checkpoint_dir) 763 self.assertFalse(initialized) 764 sess.run(v.initializer) 765 self.assertEquals(1, sess.run(v)) 766 saver.save(sess, 767 os.path.join(checkpoint_dir, "recover_session_checkpoint")) 768 # Create a new Graph and SessionManager and recover. 769 with ops.Graph().as_default(): 770 v = variables.Variable(2, name="v") 771 with self.test_session(): 772 self.assertEqual(False, variables.is_variable_initialized(v).eval()) 773 sm2 = session_manager.SessionManager( 774 ready_op=variables.assert_variables_initialized()) 775 saver = saver_lib.Saver({"v": v}) 776 sess, initialized = sm2.recover_session( 777 "", saver=saver, checkpoint_dir=checkpoint_dir) 778 self.assertTrue(initialized) 779 self.assertEqual( 780 True, 781 variables.is_variable_initialized( 782 sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) 783 self.assertEquals(1, sess.run(v)) 784 785 def testWaitForSessionReturnsNoneAfterTimeout(self): 786 with ops.Graph().as_default(): 787 variables.Variable(1, name="v") 788 sm = session_manager.SessionManager( 789 ready_op=variables.assert_variables_initialized(), 790 recovery_wait_secs=1) 791 792 # Set max_wait_secs to allow us to try a few times. 793 with self.assertRaises(errors.DeadlineExceededError): 794 sm.wait_for_session(master="", max_wait_secs=3) 795 796 797 if __name__ == "__main__": 798 test.main() 799