Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for SessionManager."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 
     23 from tensorflow.python.client import session as session_lib
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import errors
     26 from tensorflow.python.framework import errors_impl
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import control_flow_ops
     30 from tensorflow.python.ops import variables
     31 from tensorflow.python.platform import gfile
     32 from tensorflow.python.platform import test
     33 from tensorflow.python.training import saver as saver_lib
     34 from tensorflow.python.training import server_lib
     35 from tensorflow.python.training import session_manager
     36 
     37 
     38 class SessionManagerTest(test.TestCase):
     39 
     40   def testPrepareSessionSucceeds(self):
     41     with ops.Graph().as_default():
     42       v = variables.Variable([1.0, 2.0, 3.0], name="v")
     43       sm = session_manager.SessionManager(
     44           ready_op=variables.report_uninitialized_variables())
     45       sess = sm.prepare_session(
     46           "", init_op=variables.global_variables_initializer())
     47       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
     48 
     49   def testPrepareSessionSucceedsWithInitFeedDict(self):
     50     with ops.Graph().as_default():
     51       p = array_ops.placeholder(dtypes.float32, shape=(3,))
     52       v = variables.Variable(p, name="v")
     53       sm = session_manager.SessionManager(
     54           ready_op=variables.report_uninitialized_variables())
     55       sess = sm.prepare_session(
     56           "",
     57           init_op=variables.global_variables_initializer(),
     58           init_feed_dict={p: [1.0, 2.0, 3.0]})
     59       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
     60 
     61   def testPrepareSessionSucceedsWithInitFn(self):
     62     with ops.Graph().as_default():
     63       v = variables.Variable([125], name="v")
     64       sm = session_manager.SessionManager(
     65           ready_op=variables.report_uninitialized_variables())
     66       sess = sm.prepare_session(
     67           "", init_fn=lambda sess: sess.run(v.initializer))
     68       self.assertAllClose([125], sess.run(v))
     69 
     70   def testPrepareSessionFails(self):
     71     checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
     72     checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
     73     try:
     74       gfile.DeleteRecursively(checkpoint_dir)
     75       gfile.DeleteRecursively(checkpoint_dir2)
     76     except errors.OpError:
     77       pass  # Ignore
     78     gfile.MakeDirs(checkpoint_dir)
     79 
     80     with ops.Graph().as_default():
     81       v = variables.Variable([1.0, 2.0, 3.0], name="v")
     82       sm = session_manager.SessionManager(
     83           ready_op=variables.report_uninitialized_variables())
     84       saver = saver_lib.Saver({"v": v})
     85       sess = sm.prepare_session(
     86           "",
     87           init_op=variables.global_variables_initializer(),
     88           saver=saver,
     89           checkpoint_dir=checkpoint_dir)
     90       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
     91       checkpoint_filename = os.path.join(checkpoint_dir,
     92                                          "prepare_session_checkpoint")
     93       saver.save(sess, checkpoint_filename)
     94     # Create a new Graph and SessionManager and recover.
     95     with ops.Graph().as_default():
     96       # Renames the checkpoint directory.
     97       os.rename(checkpoint_dir, checkpoint_dir2)
     98       gfile.MakeDirs(checkpoint_dir)
     99       v = variables.Variable([6.0, 7.0, 8.0], name="v")
    100       with self.test_session():
    101         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    102       session_manager.SessionManager(
    103           ready_op=variables.report_uninitialized_variables())
    104       saver = saver_lib.Saver({"v": v})
    105       # This should fail as there's no checkpoint within 2 seconds.
    106       with self.assertRaisesRegexp(
    107           RuntimeError, "no init_op or init_fn or local_init_op was given"):
    108         sess = sm.prepare_session(
    109             "",
    110             init_op=None,
    111             saver=saver,
    112             checkpoint_dir=checkpoint_dir,
    113             wait_for_checkpoint=True,
    114             max_wait_secs=2)
    115       # Rename the checkpoint directory back.
    116       gfile.DeleteRecursively(checkpoint_dir)
    117       os.rename(checkpoint_dir2, checkpoint_dir)
    118       # This should succeed as there's checkpoint.
    119       sess = sm.prepare_session(
    120           "",
    121           init_op=None,
    122           saver=saver,
    123           checkpoint_dir=checkpoint_dir,
    124           wait_for_checkpoint=True,
    125           max_wait_secs=2)
    126       self.assertEqual(
    127           True,
    128           variables.is_variable_initialized(
    129               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    130 
    131   def _test_recovered_variable(self,
    132                                checkpoint_dir=None,
    133                                checkpoint_filename_with_path=None):
    134     # Create a new Graph and SessionManager and recover from a checkpoint.
    135     with ops.Graph().as_default():
    136       v = variables.Variable(2, name="v")
    137       with session_lib.Session():
    138         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    139       sm2 = session_manager.SessionManager(
    140           ready_op=variables.report_uninitialized_variables())
    141       saver = saver_lib.Saver({"v": v})
    142       sess, initialized = sm2.recover_session(
    143           "",
    144           saver=saver,
    145           checkpoint_dir=checkpoint_dir,
    146           checkpoint_filename_with_path=checkpoint_filename_with_path)
    147       self.assertTrue(initialized)
    148       self.assertEqual(
    149           True,
    150           variables.is_variable_initialized(
    151               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    152       self.assertEquals(1, sess.run(v))
    153 
    154   def testRecoverSession(self):
    155     # Create a checkpoint.
    156     checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    157     try:
    158       gfile.DeleteRecursively(checkpoint_dir)
    159     except errors.OpError:
    160       pass  # Ignore
    161     gfile.MakeDirs(checkpoint_dir)
    162 
    163     with ops.Graph().as_default():
    164       v = variables.Variable(1, name="v")
    165       sm = session_manager.SessionManager(
    166           ready_op=variables.report_uninitialized_variables())
    167       saver = saver_lib.Saver({"v": v})
    168       sess, initialized = sm.recover_session(
    169           "", saver=saver, checkpoint_dir=checkpoint_dir)
    170       self.assertFalse(initialized)
    171       sess.run(v.initializer)
    172       self.assertEquals(1, sess.run(v))
    173       saver.save(sess,
    174                  os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    175     self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
    176     self._test_recovered_variable(
    177         checkpoint_filename_with_path=saver_lib.latest_checkpoint(
    178             checkpoint_dir))
    179     # Cannot set both checkpoint_dir and checkpoint_filename_with_path.
    180     with self.assertRaises(ValueError):
    181       self._test_recovered_variable(
    182           checkpoint_dir=checkpoint_dir,
    183           checkpoint_filename_with_path=saver_lib.latest_checkpoint(
    184               checkpoint_dir))
    185 
    186   def testWaitForSessionReturnsNoneAfterTimeout(self):
    187     with ops.Graph().as_default():
    188       variables.Variable(1, name="v")
    189       sm = session_manager.SessionManager(
    190           ready_op=variables.report_uninitialized_variables(),
    191           recovery_wait_secs=1)
    192 
    193       # Set max_wait_secs to allow us to try a few times.
    194       with self.assertRaises(errors.DeadlineExceededError):
    195         sm.wait_for_session(master="", max_wait_secs=3)
    196 
    197   def testInitWithNoneLocalInitOpError(self):
    198     # Creating a SessionManager with a None local_init_op but
    199     # non-None ready_for_local_init_op raises ValueError
    200     with self.assertRaisesRegexp(ValueError,
    201                                  "If you pass a ready_for_local_init_op "
    202                                  "you must also pass a local_init_op "):
    203       session_manager.SessionManager(
    204           ready_for_local_init_op=variables.report_uninitialized_variables(
    205               variables.global_variables()),
    206           local_init_op=None)
    207 
    208   def testRecoverSessionWithReadyForLocalInitOp(self):
    209     # Create a checkpoint.
    210     checkpoint_dir = os.path.join(self.get_temp_dir(),
    211                                   "recover_session_ready_for_local_init")
    212     try:
    213       gfile.DeleteRecursively(checkpoint_dir)
    214     except errors.OpError:
    215       pass  # Ignore
    216     gfile.MakeDirs(checkpoint_dir)
    217 
    218     with ops.Graph().as_default():
    219       v = variables.Variable(1, name="v")
    220       sm = session_manager.SessionManager(
    221           ready_op=variables.report_uninitialized_variables())
    222       saver = saver_lib.Saver({"v": v})
    223       sess, initialized = sm.recover_session(
    224           "", saver=saver, checkpoint_dir=checkpoint_dir)
    225       self.assertFalse(initialized)
    226       sess.run(v.initializer)
    227       self.assertEquals(1, sess.run(v))
    228       saver.save(sess,
    229                  os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    230     # Create a new Graph and SessionManager and recover.
    231     with ops.Graph().as_default():
    232       v = variables.Variable(2, name="v")
    233       w = variables.Variable(
    234           v,
    235           trainable=False,
    236           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    237           name="w")
    238       with self.test_session():
    239         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    240         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    241       sm2 = session_manager.SessionManager(
    242           ready_op=variables.report_uninitialized_variables(),
    243           ready_for_local_init_op=variables.report_uninitialized_variables(
    244               variables.global_variables()),
    245           local_init_op=w.initializer)
    246       saver = saver_lib.Saver({"v": v})
    247       sess, initialized = sm2.recover_session(
    248           "", saver=saver, checkpoint_dir=checkpoint_dir)
    249       self.assertTrue(initialized)
    250       self.assertEqual(
    251           True,
    252           variables.is_variable_initialized(
    253               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    254       self.assertEqual(
    255           True,
    256           variables.is_variable_initialized(
    257               sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
    258       self.assertEquals(1, sess.run(v))
    259       self.assertEquals(1, sess.run(w))
    260 
    261   def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
    262     # We use ready_for_local_init_op=tf.report_uninitialized_variables(),
    263     # which causes recover_session to not run local_init_op, and to return
    264     # initialized=False
    265 
    266     # Create a checkpoint.
    267     checkpoint_dir = os.path.join(
    268         self.get_temp_dir(),
    269         "recover_session_ready_for_local_init_fails_to_ready_local")
    270     try:
    271       gfile.DeleteRecursively(checkpoint_dir)
    272     except errors.OpError:
    273       pass  # Ignore
    274     gfile.MakeDirs(checkpoint_dir)
    275 
    276     with ops.Graph().as_default():
    277       v = variables.Variable(1, name="v")
    278       sm = session_manager.SessionManager(
    279           ready_op=variables.report_uninitialized_variables())
    280       saver = saver_lib.Saver({"v": v})
    281       sess, initialized = sm.recover_session(
    282           "", saver=saver, checkpoint_dir=checkpoint_dir)
    283       self.assertFalse(initialized)
    284       sess.run(v.initializer)
    285       self.assertEquals(1, sess.run(v))
    286       saver.save(sess,
    287                  os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    288     # Create a new Graph and SessionManager and recover.
    289     with ops.Graph().as_default():
    290       v = variables.Variable(2, name="v")
    291       w = variables.Variable(
    292           v,
    293           trainable=False,
    294           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    295           name="w")
    296       with self.test_session():
    297         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    298         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    299       sm2 = session_manager.SessionManager(
    300           ready_op=variables.report_uninitialized_variables(),
    301           ready_for_local_init_op=variables.report_uninitialized_variables(),
    302           local_init_op=w.initializer)
    303       saver = saver_lib.Saver({"v": v})
    304       sess, initialized = sm2.recover_session(
    305           "", saver=saver, checkpoint_dir=checkpoint_dir)
    306       self.assertFalse(initialized)
    307       self.assertEqual(
    308           True,
    309           variables.is_variable_initialized(
    310               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    311       self.assertEqual(
    312           False,
    313           variables.is_variable_initialized(
    314               sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
    315       self.assertEquals(1, sess.run(v))
    316 
    317   def testRecoverSessionNoChkptStillRunsLocalInitOp(self):
    318     # This test checks for backwards compatibility.
    319     # In particular, we continue to ensure that recover_session will execute
    320     # local_init_op exactly once, regardless of whether the session was
    321     # successfully recovered.
    322     with ops.Graph().as_default():
    323       w = variables.Variable(
    324           1,
    325           trainable=False,
    326           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    327           name="w")
    328       with self.test_session():
    329         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    330       sm2 = session_manager.SessionManager(
    331           ready_op=variables.report_uninitialized_variables(),
    332           ready_for_local_init_op=None,
    333           local_init_op=w.initializer)
    334       # Try to recover session from None
    335       sess, initialized = sm2.recover_session(
    336           "", saver=None, checkpoint_dir=None)
    337       # Succeeds because recover_session still run local_init_op
    338       self.assertFalse(initialized)
    339       self.assertEqual(
    340           True,
    341           variables.is_variable_initialized(
    342               sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
    343       self.assertEquals(1, sess.run(w))
    344 
    345   def testRecoverSessionFailsStillRunsLocalInitOp(self):
    346     # Create a checkpoint.
    347     checkpoint_dir = os.path.join(
    348         self.get_temp_dir(),
    349         "recover_session_ready_for_local_init_fails_stil_run")
    350     try:
    351       gfile.DeleteRecursively(checkpoint_dir)
    352     except errors.OpError:
    353       pass  # Ignore
    354     gfile.MakeDirs(checkpoint_dir)
    355 
    356     # Create a new Graph and SessionManager and recover.
    357     with ops.Graph().as_default():
    358       v = variables.Variable(2, name="v")
    359       w = variables.Variable(
    360           1,
    361           trainable=False,
    362           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    363           name="w")
    364       with self.test_session():
    365         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    366         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    367       sm2 = session_manager.SessionManager(
    368           ready_op=variables.report_uninitialized_variables(),
    369           ready_for_local_init_op=None,
    370           local_init_op=w.initializer)
    371       saver = saver_lib.Saver({"v": v})
    372       sess, initialized = sm2.recover_session(
    373           "",
    374           saver=saver,
    375           checkpoint_dir=checkpoint_dir,
    376           wait_for_checkpoint=False)
    377       self.assertFalse(initialized)
    378       self.assertEqual(
    379           False,
    380           variables.is_variable_initialized(
    381               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    382       self.assertEqual(
    383           True,
    384           variables.is_variable_initialized(
    385               sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
    386       self.assertEquals(1, sess.run(w))
    387 
    388   def testWaitForSessionLocalInit(self):
    389     server = server_lib.Server.create_local_server()
    390     with ops.Graph().as_default() as graph:
    391       v = variables.Variable(1, name="v")
    392       w = variables.Variable(
    393           v,
    394           trainable=False,
    395           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    396           name="w")
    397       sm = session_manager.SessionManager(
    398           graph=graph,
    399           ready_op=variables.report_uninitialized_variables(),
    400           ready_for_local_init_op=variables.report_uninitialized_variables(
    401               variables.global_variables()),
    402           local_init_op=w.initializer)
    403 
    404       # Initialize v but not w
    405       s = session_lib.Session(server.target, graph=graph)
    406       s.run(v.initializer)
    407 
    408       sess = sm.wait_for_session(server.target, max_wait_secs=3)
    409       self.assertEqual(
    410           True,
    411           variables.is_variable_initialized(
    412               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    413       self.assertEqual(
    414           True,
    415           variables.is_variable_initialized(
    416               sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
    417       self.assertEquals(1, sess.run(v))
    418       self.assertEquals(1, sess.run(w))
    419 
    420   def testWaitForSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
    421     with ops.Graph().as_default() as graph:
    422       v = variables.Variable(1, name="v")
    423       w = variables.Variable(
    424           v,
    425           trainable=False,
    426           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    427           name="w")
    428       sm = session_manager.SessionManager(
    429           graph=graph,
    430           ready_op=variables.report_uninitialized_variables(),
    431           ready_for_local_init_op=variables.report_uninitialized_variables(),
    432           local_init_op=w.initializer)
    433 
    434       with self.assertRaises(errors_impl.DeadlineExceededError):
    435         # Time-out because w fails to be initialized,
    436         # because of overly restrictive ready_for_local_init_op
    437         sm.wait_for_session("", max_wait_secs=3)
    438 
    439   def testWaitForSessionInsufficientReadyForLocalInitCheck(self):
    440     with ops.Graph().as_default() as graph:
    441       v = variables.Variable(1, name="v")
    442       w = variables.Variable(
    443           v,
    444           trainable=False,
    445           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    446           name="w")
    447       sm = session_manager.SessionManager(
    448           graph=graph,
    449           ready_op=variables.report_uninitialized_variables(),
    450           ready_for_local_init_op=None,
    451           local_init_op=w.initializer)
    452     with self.assertRaisesRegexp(errors_impl.DeadlineExceededError,
    453                                  "Session was not ready after waiting.*"):
    454       sm.wait_for_session("", max_wait_secs=3)
    455 
    456   def testPrepareSessionWithReadyForLocalInitOp(self):
    457     with ops.Graph().as_default():
    458       v = variables.Variable(1, name="v")
    459       w = variables.Variable(
    460           v,
    461           trainable=False,
    462           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    463           name="w")
    464       x = variables.Variable(
    465           3 * v,
    466           trainable=False,
    467           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    468           name="x")
    469       with self.test_session():
    470         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    471         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    472         self.assertEqual(False, variables.is_variable_initialized(x).eval())
    473       sm2 = session_manager.SessionManager(
    474           ready_op=variables.report_uninitialized_variables(),
    475           ready_for_local_init_op=variables.report_uninitialized_variables(
    476               variables.global_variables()),
    477           local_init_op=[w.initializer, x.initializer])
    478       sess = sm2.prepare_session("", init_op=v.initializer)
    479       self.assertEqual(
    480           True,
    481           variables.is_variable_initialized(
    482               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    483       self.assertEqual(
    484           True,
    485           variables.is_variable_initialized(
    486               sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
    487       self.assertEqual(
    488           True,
    489           variables.is_variable_initialized(
    490               sess.graph.get_tensor_by_name("x:0")).eval(session=sess))
    491       self.assertEquals(1, sess.run(v))
    492       self.assertEquals(1, sess.run(w))
    493       self.assertEquals(3, sess.run(x))
    494 
    495   def testPrepareSessionWithPartialInitOp(self):
    496     with ops.Graph().as_default():
    497       v = variables.Variable(1, name="v")
    498       w = variables.Variable(
    499           v,
    500           trainable=False,
    501           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    502           name="w")
    503       x = variables.Variable(
    504           3 * v,
    505           trainable=False,
    506           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    507           name="x")
    508       # TODO(b/70206927): Use ResourceVariables once they are handled properly.
    509       v_res = variables.Variable(1, name="v_res")
    510       w_res = variables.Variable(
    511           v_res,
    512           trainable=False,
    513           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    514           name="w_res")
    515       x_res = variables.Variable(
    516           3 * v_res,
    517           trainable=False,
    518           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    519           name="x_res")
    520 
    521       with self.test_session():
    522         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    523         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    524         self.assertEqual(False, variables.is_variable_initialized(x).eval())
    525         self.assertEqual(False, variables.is_variable_initialized(v_res).eval())
    526         self.assertEqual(False, variables.is_variable_initialized(w_res).eval())
    527         self.assertEqual(False, variables.is_variable_initialized(x_res).eval())
    528       sm2 = session_manager.SessionManager(local_init_op=[
    529           w.initializer, x.initializer, w_res.initializer, x_res.initializer
    530       ])
    531       sess = sm2.prepare_session("", init_op=None)
    532       self.assertEqual(
    533           False,
    534           variables.is_variable_initialized(
    535               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    536       self.assertEqual(
    537           True,
    538           variables.is_variable_initialized(
    539               sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
    540       self.assertEqual(
    541           True,
    542           variables.is_variable_initialized(
    543               sess.graph.get_tensor_by_name("x:0")).eval(session=sess))
    544       self.assertEquals(1, sess.run(w))
    545       self.assertEquals(3, sess.run(x))
    546       self.assertEqual(
    547           False,
    548           variables.is_variable_initialized(
    549               sess.graph.get_tensor_by_name("v_res:0")).eval(session=sess))
    550       self.assertEqual(
    551           True,
    552           variables.is_variable_initialized(
    553               sess.graph.get_tensor_by_name("w_res:0")).eval(session=sess))
    554       self.assertEqual(
    555           True,
    556           variables.is_variable_initialized(
    557               sess.graph.get_tensor_by_name("x_res:0")).eval(session=sess))
    558       self.assertEquals(1, sess.run(w_res))
    559       self.assertEquals(3, sess.run(x_res))
    560 
    561   def testPrepareSessionWithCyclicInitializer(self):
    562     # Regression test. Previously Variable._build_initializer_expr would enter
    563     # into an infinite recursion when the variable's initial_value involved
    564     # cyclic dependencies.
    565     with ops.Graph().as_default():
    566       i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
    567       v = variables.Variable(array_ops.identity(i), name="v")
    568       with self.test_session():
    569         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    570       sm = session_manager.SessionManager(
    571           ready_op=variables.report_uninitialized_variables())
    572       sess = sm.prepare_session("", init_op=v.initializer)
    573       self.assertEqual(1, sess.run(v))
    574       self.assertEqual(
    575           True,
    576           variables.is_variable_initialized(
    577               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    578 
    579   def testPrepareSessionDidNotInitLocalVariable(self):
    580     with ops.Graph().as_default():
    581       v = variables.Variable(1, name="v")
    582       w = variables.Variable(
    583           v,
    584           trainable=False,
    585           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    586           name="w")
    587       with self.test_session():
    588         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    589         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    590       sm2 = session_manager.SessionManager(
    591           ready_op=variables.report_uninitialized_variables())
    592       with self.assertRaisesRegexp(
    593           RuntimeError, "Init operations did not make model ready.*"):
    594         sm2.prepare_session("", init_op=v.initializer)
    595 
    596   def testPrepareSessionDidNotInitLocalVariableList(self):
    597     with ops.Graph().as_default():
    598       v = variables.Variable(1, name="v")
    599       w = variables.Variable(
    600           v,
    601           trainable=False,
    602           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    603           name="w")
    604       with self.test_session():
    605         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    606         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    607       sm2 = session_manager.SessionManager(
    608           ready_op=variables.report_uninitialized_variables())
    609       with self.assertRaisesRegexp(RuntimeError,
    610                                    "Init operations did not make model ready"):
    611         sm2.prepare_session("", init_op=[v.initializer])
    612 
    613   def testPrepareSessionWithReadyNotReadyForLocal(self):
    614     with ops.Graph().as_default():
    615       v = variables.Variable(1, name="v")
    616       w = variables.Variable(
    617           v,
    618           trainable=False,
    619           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    620           name="w")
    621       with self.test_session():
    622         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    623         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    624       sm2 = session_manager.SessionManager(
    625           ready_op=variables.report_uninitialized_variables(),
    626           ready_for_local_init_op=variables.report_uninitialized_variables(
    627               variables.global_variables()),
    628           local_init_op=w.initializer)
    629       with self.assertRaisesRegexp(
    630           RuntimeError,
    631           "Init operations did not make model ready for local_init"):
    632         sm2.prepare_session("", init_op=None)
    633 
    634   def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
    635     with ops.Graph().as_default():
    636       v = variables.Variable(1, name="v")
    637       w = variables.Variable(
    638           v,
    639           trainable=False,
    640           collections=[ops.GraphKeys.LOCAL_VARIABLES],
    641           name="w")
    642       with self.test_session():
    643         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    644         self.assertEqual(False, variables.is_variable_initialized(w).eval())
    645       sm2 = session_manager.SessionManager(
    646           ready_op=variables.report_uninitialized_variables(),
    647           ready_for_local_init_op=None,
    648           local_init_op=w.initializer)
    649     with self.assertRaisesRegexp(RuntimeError,
    650                                  "Init operations did not make model ready.*"):
    651       sm2.prepare_session("", init_op=None)
    652 
    653 
    654 class ObsoleteSessionManagerTest(test.TestCase):
    655 
    656   def testPrepareSessionSucceeds(self):
    657     with ops.Graph().as_default():
    658       v = variables.Variable([1.0, 2.0, 3.0], name="v")
    659       sm = session_manager.SessionManager(
    660           ready_op=variables.assert_variables_initialized())
    661       sess = sm.prepare_session(
    662           "", init_op=variables.global_variables_initializer())
    663       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    664 
    665   def testPrepareSessionSucceedsWithInitFeedDict(self):
    666     with ops.Graph().as_default():
    667       p = array_ops.placeholder(dtypes.float32, shape=(3,))
    668       v = variables.Variable(p, name="v")
    669       sm = session_manager.SessionManager(
    670           ready_op=variables.assert_variables_initialized())
    671       sess = sm.prepare_session(
    672           "",
    673           init_op=variables.global_variables_initializer(),
    674           init_feed_dict={p: [1.0, 2.0, 3.0]})
    675       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    676 
    677   def testPrepareSessionSucceedsWithInitFn(self):
    678     with ops.Graph().as_default():
    679       v = variables.Variable([125], name="v")
    680       sm = session_manager.SessionManager(
    681           ready_op=variables.assert_variables_initialized())
    682       sess = sm.prepare_session(
    683           "", init_fn=lambda sess: sess.run(v.initializer))
    684       self.assertAllClose([125], sess.run(v))
    685 
    686   def testPrepareSessionFails(self):
    687     checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
    688     checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
    689     try:
    690       gfile.DeleteRecursively(checkpoint_dir)
    691       gfile.DeleteRecursively(checkpoint_dir2)
    692     except errors.OpError:
    693       pass  # Ignore
    694     gfile.MakeDirs(checkpoint_dir)
    695 
    696     with ops.Graph().as_default():
    697       v = variables.Variable([1.0, 2.0, 3.0], name="v")
    698       sm = session_manager.SessionManager(
    699           ready_op=variables.assert_variables_initialized())
    700       saver = saver_lib.Saver({"v": v})
    701       sess = sm.prepare_session(
    702           "",
    703           init_op=variables.global_variables_initializer(),
    704           saver=saver,
    705           checkpoint_dir=checkpoint_dir)
    706       self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
    707       checkpoint_filename = os.path.join(checkpoint_dir,
    708                                          "prepare_session_checkpoint")
    709       saver.save(sess, checkpoint_filename)
    710     # Create a new Graph and SessionManager and recover.
    711     with ops.Graph().as_default():
    712       # Renames the checkpoint directory.
    713       os.rename(checkpoint_dir, checkpoint_dir2)
    714       gfile.MakeDirs(checkpoint_dir)
    715       v = variables.Variable([6.0, 7.0, 8.0], name="v")
    716       with self.test_session():
    717         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    718       session_manager.SessionManager(
    719           ready_op=variables.assert_variables_initialized())
    720       saver = saver_lib.Saver({"v": v})
    721       # This should fail as there's no checkpoint within 2 seconds.
    722       with self.assertRaisesRegexp(
    723           RuntimeError, "no init_op or init_fn or local_init_op was given"):
    724         sess = sm.prepare_session(
    725             "",
    726             init_op=None,
    727             saver=saver,
    728             checkpoint_dir=checkpoint_dir,
    729             wait_for_checkpoint=True,
    730             max_wait_secs=2)
    731       # Rename the checkpoint directory back.
    732       gfile.DeleteRecursively(checkpoint_dir)
    733       os.rename(checkpoint_dir2, checkpoint_dir)
    734       # This should succeed as there's checkpoint.
    735       sess = sm.prepare_session(
    736           "",
    737           init_op=None,
    738           saver=saver,
    739           checkpoint_dir=checkpoint_dir,
    740           wait_for_checkpoint=True,
    741           max_wait_secs=2)
    742       self.assertEqual(
    743           True,
    744           variables.is_variable_initialized(
    745               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    746 
    747   def testRecoverSession(self):
    748     # Create a checkpoint.
    749     checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    750     try:
    751       gfile.DeleteRecursively(checkpoint_dir)
    752     except errors.OpError:
    753       pass  # Ignore
    754     gfile.MakeDirs(checkpoint_dir)
    755 
    756     with ops.Graph().as_default():
    757       v = variables.Variable(1, name="v")
    758       sm = session_manager.SessionManager(
    759           ready_op=variables.assert_variables_initialized())
    760       saver = saver_lib.Saver({"v": v})
    761       sess, initialized = sm.recover_session(
    762           "", saver=saver, checkpoint_dir=checkpoint_dir)
    763       self.assertFalse(initialized)
    764       sess.run(v.initializer)
    765       self.assertEquals(1, sess.run(v))
    766       saver.save(sess,
    767                  os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    768     # Create a new Graph and SessionManager and recover.
    769     with ops.Graph().as_default():
    770       v = variables.Variable(2, name="v")
    771       with self.test_session():
    772         self.assertEqual(False, variables.is_variable_initialized(v).eval())
    773       sm2 = session_manager.SessionManager(
    774           ready_op=variables.assert_variables_initialized())
    775       saver = saver_lib.Saver({"v": v})
    776       sess, initialized = sm2.recover_session(
    777           "", saver=saver, checkpoint_dir=checkpoint_dir)
    778       self.assertTrue(initialized)
    779       self.assertEqual(
    780           True,
    781           variables.is_variable_initialized(
    782               sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
    783       self.assertEquals(1, sess.run(v))
    784 
    785   def testWaitForSessionReturnsNoneAfterTimeout(self):
    786     with ops.Graph().as_default():
    787       variables.Variable(1, name="v")
    788       sm = session_manager.SessionManager(
    789           ready_op=variables.assert_variables_initialized(),
    790           recovery_wait_secs=1)
    791 
    792       # Set max_wait_secs to allow us to try a few times.
    793       with self.assertRaises(errors.DeadlineExceededError):
    794         sm.wait_for_session(master="", max_wait_secs=3)
    795 
    796 
    797 if __name__ == "__main__":
    798   test.main()
    799