Home | History | Annotate | Download | only in framework
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for checkpoints tools."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 
     23 from tensorflow.contrib.framework.python.framework import checkpoint_utils
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import errors_impl
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import init_ops
     28 from tensorflow.python.ops import partitioned_variables
     29 from tensorflow.python.ops import variable_scope
     30 from tensorflow.python.ops import variables
     31 from tensorflow.python.platform import test
     32 from tensorflow.python.training import saver as saver_lib
     33 
     34 
     35 def _create_checkpoints(sess, checkpoint_dir):
     36   checkpoint_prefix = os.path.join(checkpoint_dir, "model")
     37   checkpoint_state_name = "checkpoint"
     38   v1 = variable_scope.get_variable("var1", [1, 10])
     39   v2 = variable_scope.get_variable("var2", [10, 10])
     40   v3 = variable_scope.get_variable("var3", [100, 100])
     41   with variable_scope.variable_scope("useful_scope"):
     42     v4 = variable_scope.get_variable("var4", [9, 9])
     43   sess.run(variables.global_variables_initializer())
     44   v1_value, v2_value, v3_value, v4_value = sess.run([v1, v2, v3, v4])
     45   saver = saver_lib.Saver()
     46   saver.save(
     47       sess,
     48       checkpoint_prefix,
     49       global_step=0,
     50       latest_filename=checkpoint_state_name)
     51   return v1_value, v2_value, v3_value, v4_value
     52 
     53 
     54 def _create_partition_checkpoints(sess, checkpoint_dir):
     55   checkpoint_prefix = os.path.join(checkpoint_dir, "model")
     56   checkpoint_state_name = "checkpoint"
     57   with variable_scope.variable_scope("scope"):
     58     v1 = variable_scope.get_variable(
     59         name="var1",
     60         shape=[100, 100],
     61         initializer=init_ops.truncated_normal_initializer(0.5),
     62         partitioner=partitioned_variables.min_max_variable_partitioner(
     63             max_partitions=5, axis=0, min_slice_size=8 << 10))
     64   sess.run(variables.global_variables_initializer())
     65   v1_value = sess.run(v1._get_variable_list())
     66   saver = saver_lib.Saver()
     67   saver.save(
     68       sess,
     69       checkpoint_prefix,
     70       global_step=0,
     71       latest_filename=checkpoint_state_name)
     72   return v1_value
     73 
     74 
     75 class CheckpointsTest(test.TestCase):
     76 
     77   def testNoCheckpoints(self):
     78     checkpoint_dir = self.get_temp_dir() + "/no_checkpoints"
     79     with self.assertRaises(errors_impl.OpError):
     80       self.assertAllEqual(
     81           checkpoint_utils.load_variable(checkpoint_dir, "var1"), [])
     82 
     83   def testNoTensor(self):
     84     checkpoint_dir = self.get_temp_dir()
     85     with self.test_session() as session:
     86       _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
     87     with self.assertRaises(errors_impl.OpError):
     88       self.assertAllEqual(
     89           checkpoint_utils.load_variable(checkpoint_dir, "var5"), [])
     90 
     91   def testGetTensor(self):
     92     checkpoint_dir = self.get_temp_dir()
     93     with self.test_session() as session:
     94       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
     95     self.assertAllEqual(
     96         checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1)
     97     self.assertAllEqual(
     98         checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2)
     99     self.assertAllEqual(
    100         checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3)
    101     self.assertAllEqual(
    102         checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4)
    103 
    104   def testGetAllVariables(self):
    105     checkpoint_dir = self.get_temp_dir()
    106     with self.test_session() as session:
    107       _create_checkpoints(session, checkpoint_dir)
    108     self.assertEqual(
    109         checkpoint_utils.list_variables(checkpoint_dir),
    110         [("useful_scope/var4", [9, 9]), ("var1", [1, 10]), ("var2", [10, 10]),
    111          ("var3", [100, 100])])
    112 
    113   def testInitFromCheckpoint(self):
    114     checkpoint_dir = self.get_temp_dir()
    115     with self.test_session() as session:
    116       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
    117 
    118     # New graph and session.
    119     with ops.Graph().as_default() as g:
    120       with self.test_session(graph=g) as session:
    121         with variable_scope.variable_scope("some_scope"):
    122           my1 = variable_scope.get_variable("my1", [1, 10])
    123           with variable_scope.variable_scope("some_other_scope"):
    124             my2 = variable_scope.get_variable("my2", [10, 10])
    125             with variable_scope.variable_scope("other_useful_scope"):
    126               my4 = variable_scope.get_variable("var4", [9, 9])
    127         my3 = variable_scope.get_variable("my3", [100, 100])
    128 
    129         checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
    130             "var1": "some_scope/my1",
    131             "useful_scope/": "some_scope/some_other_scope/other_useful_scope/",
    132         })
    133         checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
    134             "var2": "some_scope/some_other_scope/my2",
    135             "var3": my3,
    136         })
    137 
    138         session.run(variables.global_variables_initializer())
    139         self.assertAllEqual(my1.eval(session), v1)
    140         self.assertAllEqual(my2.eval(session), v2)
    141         self.assertAllEqual(my3.eval(session), v3)
    142         self.assertAllEqual(my4.eval(session), v4)
    143 
    144         # Check that tensors are not explicitly in the graph.
    145         self.assertLess(len(str(session.graph.as_graph_def())), 27000)
    146 
    147   def testInitWithScopeDoesNotCaptureSuffixes(self):
    148     checkpoint_dir = self.get_temp_dir()
    149     with self.test_session() as session:
    150       _, _, _, v4 = _create_checkpoints(session, checkpoint_dir)
    151 
    152     with ops.Graph().as_default() as g:
    153       with variable_scope.variable_scope("useful_scope"):
    154         my4 = variable_scope.get_variable("var4", [9, 9])
    155       with variable_scope.variable_scope("useful_scope_1"):
    156         my5_init = [[1.0, 2.0], [3.0, 4.0]]
    157         my5 = variable_scope.get_variable("var5", initializer=my5_init)
    158 
    159       checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    160                                             {"useful_scope/": "useful_scope/"})
    161       with self.test_session(graph=g) as session:
    162         session.run(variables.global_variables_initializer())
    163         self.assertAllEqual(my4.eval(session), v4)
    164         self.assertAllEqual(my5.eval(session), my5_init)
    165 
    166   def testInitFromRootCheckpoint(self):
    167     checkpoint_dir = self.get_temp_dir()
    168     with self.test_session() as session:
    169       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
    170 
    171     # New graph and session.
    172     with ops.Graph().as_default() as g:
    173       with self.test_session(graph=g) as session:
    174         with variable_scope.variable_scope("some_scope"):
    175           my1 = variable_scope.get_variable("var1", [1, 10])
    176           my2 = variable_scope.get_variable("var2", [10, 10])
    177           my3 = variable_scope.get_variable("var3", [100, 100])
    178           with variable_scope.variable_scope("useful_scope"):
    179             my4 = variable_scope.get_variable("var4", [9, 9])
    180 
    181         checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    182                                               {"/": "some_scope/",})
    183 
    184         session.run(variables.global_variables_initializer())
    185         self.assertAllEqual(my1.eval(session), v1)
    186         self.assertAllEqual(my2.eval(session), v2)
    187         self.assertAllEqual(my3.eval(session), v3)
    188         self.assertAllEqual(my4.eval(session), v4)
    189 
    190   def testInitToRootCheckpoint(self):
    191     checkpoint_dir = self.get_temp_dir()
    192     with self.test_session() as session:
    193       v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir)
    194 
    195     # New graph and session.
    196     with ops.Graph().as_default() as g:
    197       with self.test_session(graph=g) as session:
    198         my1 = variable_scope.get_variable("var1", [1, 10])
    199         my2 = variable_scope.get_variable("var2", [10, 10])
    200         my3 = variable_scope.get_variable("var3", [100, 100])
    201         with variable_scope.variable_scope("useful_scope"):
    202           my4 = variable_scope.get_variable("var4", [9, 9])
    203 
    204         checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    205                                               {"/": "/",})
    206 
    207         session.run(variables.global_variables_initializer())
    208         self.assertAllEqual(my1.eval(session), v1)
    209         self.assertAllEqual(my2.eval(session), v2)
    210         self.assertAllEqual(my3.eval(session), v3)
    211         self.assertAllEqual(my4.eval(session), v4)
    212 
    213   def testInitFromPartitionVar(self):
    214     checkpoint_dir = self.get_temp_dir()
    215     with self.test_session() as session:
    216       v1 = _create_partition_checkpoints(session, checkpoint_dir)
    217 
    218     # New graph and session.
    219     with ops.Graph().as_default() as g:
    220       with self.test_session(graph=g) as session:
    221         with variable_scope.variable_scope("some_scope"):
    222           my1 = variable_scope.get_variable(
    223               name="my1",
    224               shape=[100, 100],
    225               initializer=init_ops.truncated_normal_initializer(0.5),
    226               partitioner=partitioned_variables.min_max_variable_partitioner(
    227                   max_partitions=5, axis=0, min_slice_size=8 << 10))
    228           my1_var_list = my1._get_variable_list()
    229         with variable_scope.variable_scope("some_other_scope"):
    230           my2 = variable_scope.get_variable(
    231               name="var1",
    232               shape=[100, 100],
    233               initializer=init_ops.truncated_normal_initializer(0.5),
    234               partitioner=partitioned_variables.min_max_variable_partitioner(
    235                   max_partitions=5, axis=0, min_slice_size=8 << 10))
    236           my2_var_list = my2._get_variable_list()
    237 
    238         checkpoint_utils.init_from_checkpoint(checkpoint_dir, {
    239             "scope/var1": "some_scope/my1",
    240             "scope/": "some_other_scope/"})
    241 
    242         session.run(variables.global_variables_initializer())
    243         my1_values = session.run(my1_var_list)
    244         self.assertAllEqual(my1_values, v1)
    245         my2_values = session.run(my2_var_list)
    246         self.assertAllEqual(my2_values, v1)
    247 
    248     # New graph and session.
    249     with ops.Graph().as_default() as g:
    250       with self.test_session(graph=g) as session:
    251         with variable_scope.variable_scope("some_scope"):
    252           my1 = variable_scope.get_variable(
    253               name="my1",
    254               shape=[100, 100],
    255               initializer=init_ops.truncated_normal_initializer(0.5),
    256               partitioner=partitioned_variables.min_max_variable_partitioner(
    257                   max_partitions=5, axis=0, min_slice_size=8 << 10))
    258           my1_var_list = my1._get_variable_list()
    259 
    260         checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    261                                               {"scope/var1": my1_var_list,})
    262 
    263         session.run(variables.global_variables_initializer())
    264         my1_values = session.run(my1_var_list)
    265         self.assertAllEqual(my1_values, v1)
    266 
    267   def testInitFromCheckpointMissing(self):
    268     checkpoint_dir = self.get_temp_dir()
    269     with self.test_session() as session:
    270       _, _, _, _ = _create_checkpoints(session, checkpoint_dir)
    271 
    272     # New graph and session.
    273     with ops.Graph().as_default() as g:
    274       with self.test_session(graph=g) as session:
    275         with variable_scope.variable_scope("some_scope"):
    276           _ = variable_scope.get_variable("my1", [10, 10])
    277           _ = variable_scope.get_variable(
    278               "my2", [1, 10],
    279               dtype=dtypes.int64,
    280               initializer=init_ops.zeros_initializer())
    281 
    282         # No directory.
    283         with self.assertRaises(errors_impl.OpError):
    284           checkpoint_utils.init_from_checkpoint("no_dir",
    285                                                 {"var1": "some_scope/my1"})
    286 
    287         # No variable in checkpoint.
    288         with self.assertRaises(ValueError):
    289           checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    290                                                 {"no_var": "some_scope/my1"})
    291 
    292         # No variable in the graph.
    293         with self.assertRaises(ValueError):
    294           checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    295                                                 {"var3": "some_scope/no_var"})
    296 
    297         # Shape mismatch.
    298         with self.assertRaises(ValueError):
    299           checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    300                                                 {"var1": "some_scope/my1"})
    301 
    302         # Variable 'my1' and 'my2' are missing in given checkpoint scope.
    303         with self.assertRaises(ValueError):
    304           checkpoint_utils.init_from_checkpoint(
    305               checkpoint_dir, {"useful_scope/": "some_scope/"})
    306 
    307         # Mapping is not to scope name.
    308         with self.assertRaises(ValueError):
    309           checkpoint_utils.init_from_checkpoint(checkpoint_dir,
    310                                                 {"useful_scope": "some_scope/"})
    311 
    312 
    313 if __name__ == "__main__":
    314   test.main()
    315