1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for checkpoints tools.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 23 from tensorflow.contrib.framework.python.framework import checkpoint_utils 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import errors_impl 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import init_ops 28 from tensorflow.python.ops import partitioned_variables 29 from tensorflow.python.ops import variable_scope 30 from tensorflow.python.ops import variables 31 from tensorflow.python.platform import test 32 from tensorflow.python.training import saver as saver_lib 33 34 35 def _create_checkpoints(sess, checkpoint_dir): 36 checkpoint_prefix = os.path.join(checkpoint_dir, "model") 37 checkpoint_state_name = "checkpoint" 38 v1 = variable_scope.get_variable("var1", [1, 10]) 39 v2 = variable_scope.get_variable("var2", [10, 10]) 40 v3 = variable_scope.get_variable("var3", [100, 100]) 41 with variable_scope.variable_scope("useful_scope"): 42 v4 = variable_scope.get_variable("var4", [9, 9]) 43 sess.run(variables.global_variables_initializer()) 44 v1_value, v2_value, v3_value, v4_value = sess.run([v1, v2, v3, v4]) 45 saver = saver_lib.Saver() 46 saver.save( 47 sess, 48 checkpoint_prefix, 49 global_step=0, 50 latest_filename=checkpoint_state_name) 51 return v1_value, v2_value, v3_value, v4_value 52 53 54 def _create_partition_checkpoints(sess, checkpoint_dir): 55 checkpoint_prefix = os.path.join(checkpoint_dir, "model") 56 checkpoint_state_name = "checkpoint" 57 with variable_scope.variable_scope("scope"): 58 v1 = variable_scope.get_variable( 59 name="var1", 60 shape=[100, 100], 61 initializer=init_ops.truncated_normal_initializer(0.5), 62 partitioner=partitioned_variables.min_max_variable_partitioner( 63 max_partitions=5, axis=0, min_slice_size=8 << 10)) 64 sess.run(variables.global_variables_initializer()) 65 v1_value = sess.run(v1._get_variable_list()) 66 saver = saver_lib.Saver() 67 saver.save( 68 sess, 69 checkpoint_prefix, 70 global_step=0, 71 latest_filename=checkpoint_state_name) 72 return v1_value 73 74 75 class CheckpointsTest(test.TestCase): 76 77 def testNoCheckpoints(self): 78 checkpoint_dir = self.get_temp_dir() + "/no_checkpoints" 79 with self.assertRaises(errors_impl.OpError): 80 self.assertAllEqual( 81 checkpoint_utils.load_variable(checkpoint_dir, "var1"), []) 82 83 def testNoTensor(self): 84 checkpoint_dir = self.get_temp_dir() 85 with self.test_session() as session: 86 _, _, _, _ = _create_checkpoints(session, checkpoint_dir) 87 with self.assertRaises(errors_impl.OpError): 88 self.assertAllEqual( 89 checkpoint_utils.load_variable(checkpoint_dir, "var5"), []) 90 91 def testGetTensor(self): 92 checkpoint_dir = self.get_temp_dir() 93 with self.test_session() as session: 94 v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) 95 self.assertAllEqual( 96 checkpoint_utils.load_variable(checkpoint_dir, "var1"), v1) 97 self.assertAllEqual( 98 checkpoint_utils.load_variable(checkpoint_dir, "var2"), v2) 99 self.assertAllEqual( 100 checkpoint_utils.load_variable(checkpoint_dir, "var3"), v3) 101 self.assertAllEqual( 102 checkpoint_utils.load_variable(checkpoint_dir, "useful_scope/var4"), v4) 103 104 def testGetAllVariables(self): 105 checkpoint_dir = self.get_temp_dir() 106 with self.test_session() as session: 107 _create_checkpoints(session, checkpoint_dir) 108 self.assertEqual( 109 checkpoint_utils.list_variables(checkpoint_dir), 110 [("useful_scope/var4", [9, 9]), ("var1", [1, 10]), ("var2", [10, 10]), 111 ("var3", [100, 100])]) 112 113 def testInitFromCheckpoint(self): 114 checkpoint_dir = self.get_temp_dir() 115 with self.test_session() as session: 116 v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) 117 118 # New graph and session. 119 with ops.Graph().as_default() as g: 120 with self.test_session(graph=g) as session: 121 with variable_scope.variable_scope("some_scope"): 122 my1 = variable_scope.get_variable("my1", [1, 10]) 123 with variable_scope.variable_scope("some_other_scope"): 124 my2 = variable_scope.get_variable("my2", [10, 10]) 125 with variable_scope.variable_scope("other_useful_scope"): 126 my4 = variable_scope.get_variable("var4", [9, 9]) 127 my3 = variable_scope.get_variable("my3", [100, 100]) 128 129 checkpoint_utils.init_from_checkpoint(checkpoint_dir, { 130 "var1": "some_scope/my1", 131 "useful_scope/": "some_scope/some_other_scope/other_useful_scope/", 132 }) 133 checkpoint_utils.init_from_checkpoint(checkpoint_dir, { 134 "var2": "some_scope/some_other_scope/my2", 135 "var3": my3, 136 }) 137 138 session.run(variables.global_variables_initializer()) 139 self.assertAllEqual(my1.eval(session), v1) 140 self.assertAllEqual(my2.eval(session), v2) 141 self.assertAllEqual(my3.eval(session), v3) 142 self.assertAllEqual(my4.eval(session), v4) 143 144 # Check that tensors are not explicitly in the graph. 145 self.assertLess(len(str(session.graph.as_graph_def())), 27000) 146 147 def testInitWithScopeDoesNotCaptureSuffixes(self): 148 checkpoint_dir = self.get_temp_dir() 149 with self.test_session() as session: 150 _, _, _, v4 = _create_checkpoints(session, checkpoint_dir) 151 152 with ops.Graph().as_default() as g: 153 with variable_scope.variable_scope("useful_scope"): 154 my4 = variable_scope.get_variable("var4", [9, 9]) 155 with variable_scope.variable_scope("useful_scope_1"): 156 my5_init = [[1.0, 2.0], [3.0, 4.0]] 157 my5 = variable_scope.get_variable("var5", initializer=my5_init) 158 159 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 160 {"useful_scope/": "useful_scope/"}) 161 with self.test_session(graph=g) as session: 162 session.run(variables.global_variables_initializer()) 163 self.assertAllEqual(my4.eval(session), v4) 164 self.assertAllEqual(my5.eval(session), my5_init) 165 166 def testInitFromRootCheckpoint(self): 167 checkpoint_dir = self.get_temp_dir() 168 with self.test_session() as session: 169 v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) 170 171 # New graph and session. 172 with ops.Graph().as_default() as g: 173 with self.test_session(graph=g) as session: 174 with variable_scope.variable_scope("some_scope"): 175 my1 = variable_scope.get_variable("var1", [1, 10]) 176 my2 = variable_scope.get_variable("var2", [10, 10]) 177 my3 = variable_scope.get_variable("var3", [100, 100]) 178 with variable_scope.variable_scope("useful_scope"): 179 my4 = variable_scope.get_variable("var4", [9, 9]) 180 181 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 182 {"/": "some_scope/",}) 183 184 session.run(variables.global_variables_initializer()) 185 self.assertAllEqual(my1.eval(session), v1) 186 self.assertAllEqual(my2.eval(session), v2) 187 self.assertAllEqual(my3.eval(session), v3) 188 self.assertAllEqual(my4.eval(session), v4) 189 190 def testInitToRootCheckpoint(self): 191 checkpoint_dir = self.get_temp_dir() 192 with self.test_session() as session: 193 v1, v2, v3, v4 = _create_checkpoints(session, checkpoint_dir) 194 195 # New graph and session. 196 with ops.Graph().as_default() as g: 197 with self.test_session(graph=g) as session: 198 my1 = variable_scope.get_variable("var1", [1, 10]) 199 my2 = variable_scope.get_variable("var2", [10, 10]) 200 my3 = variable_scope.get_variable("var3", [100, 100]) 201 with variable_scope.variable_scope("useful_scope"): 202 my4 = variable_scope.get_variable("var4", [9, 9]) 203 204 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 205 {"/": "/",}) 206 207 session.run(variables.global_variables_initializer()) 208 self.assertAllEqual(my1.eval(session), v1) 209 self.assertAllEqual(my2.eval(session), v2) 210 self.assertAllEqual(my3.eval(session), v3) 211 self.assertAllEqual(my4.eval(session), v4) 212 213 def testInitFromPartitionVar(self): 214 checkpoint_dir = self.get_temp_dir() 215 with self.test_session() as session: 216 v1 = _create_partition_checkpoints(session, checkpoint_dir) 217 218 # New graph and session. 219 with ops.Graph().as_default() as g: 220 with self.test_session(graph=g) as session: 221 with variable_scope.variable_scope("some_scope"): 222 my1 = variable_scope.get_variable( 223 name="my1", 224 shape=[100, 100], 225 initializer=init_ops.truncated_normal_initializer(0.5), 226 partitioner=partitioned_variables.min_max_variable_partitioner( 227 max_partitions=5, axis=0, min_slice_size=8 << 10)) 228 my1_var_list = my1._get_variable_list() 229 with variable_scope.variable_scope("some_other_scope"): 230 my2 = variable_scope.get_variable( 231 name="var1", 232 shape=[100, 100], 233 initializer=init_ops.truncated_normal_initializer(0.5), 234 partitioner=partitioned_variables.min_max_variable_partitioner( 235 max_partitions=5, axis=0, min_slice_size=8 << 10)) 236 my2_var_list = my2._get_variable_list() 237 238 checkpoint_utils.init_from_checkpoint(checkpoint_dir, { 239 "scope/var1": "some_scope/my1", 240 "scope/": "some_other_scope/"}) 241 242 session.run(variables.global_variables_initializer()) 243 my1_values = session.run(my1_var_list) 244 self.assertAllEqual(my1_values, v1) 245 my2_values = session.run(my2_var_list) 246 self.assertAllEqual(my2_values, v1) 247 248 # New graph and session. 249 with ops.Graph().as_default() as g: 250 with self.test_session(graph=g) as session: 251 with variable_scope.variable_scope("some_scope"): 252 my1 = variable_scope.get_variable( 253 name="my1", 254 shape=[100, 100], 255 initializer=init_ops.truncated_normal_initializer(0.5), 256 partitioner=partitioned_variables.min_max_variable_partitioner( 257 max_partitions=5, axis=0, min_slice_size=8 << 10)) 258 my1_var_list = my1._get_variable_list() 259 260 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 261 {"scope/var1": my1_var_list,}) 262 263 session.run(variables.global_variables_initializer()) 264 my1_values = session.run(my1_var_list) 265 self.assertAllEqual(my1_values, v1) 266 267 def testInitFromCheckpointMissing(self): 268 checkpoint_dir = self.get_temp_dir() 269 with self.test_session() as session: 270 _, _, _, _ = _create_checkpoints(session, checkpoint_dir) 271 272 # New graph and session. 273 with ops.Graph().as_default() as g: 274 with self.test_session(graph=g) as session: 275 with variable_scope.variable_scope("some_scope"): 276 _ = variable_scope.get_variable("my1", [10, 10]) 277 _ = variable_scope.get_variable( 278 "my2", [1, 10], 279 dtype=dtypes.int64, 280 initializer=init_ops.zeros_initializer()) 281 282 # No directory. 283 with self.assertRaises(errors_impl.OpError): 284 checkpoint_utils.init_from_checkpoint("no_dir", 285 {"var1": "some_scope/my1"}) 286 287 # No variable in checkpoint. 288 with self.assertRaises(ValueError): 289 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 290 {"no_var": "some_scope/my1"}) 291 292 # No variable in the graph. 293 with self.assertRaises(ValueError): 294 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 295 {"var3": "some_scope/no_var"}) 296 297 # Shape mismatch. 298 with self.assertRaises(ValueError): 299 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 300 {"var1": "some_scope/my1"}) 301 302 # Variable 'my1' and 'my2' are missing in given checkpoint scope. 303 with self.assertRaises(ValueError): 304 checkpoint_utils.init_from_checkpoint( 305 checkpoint_dir, {"useful_scope/": "some_scope/"}) 306 307 # Mapping is not to scope name. 308 with self.assertRaises(ValueError): 309 checkpoint_utils.init_from_checkpoint(checkpoint_dir, 310 {"useful_scope": "some_scope/"}) 311 312 313 if __name__ == "__main__": 314 test.main() 315