1 """Tests for eager mode Saver.""" 2 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import os 21 22 from tensorflow.contrib.eager.python import saver as _saver 23 from tensorflow.python.eager import context 24 from tensorflow.python.eager import graph_callable 25 from tensorflow.python.eager import test 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import errors 28 from tensorflow.python.framework import ops 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import init_ops 31 from tensorflow.python.ops import resource_variable_ops 32 from tensorflow.python.ops import variable_scope 33 from tensorflow.python.training import adam 34 from tensorflow.python.training import gradient_descent 35 from tensorflow.python.training import momentum 36 from tensorflow.python.training import rmsprop 37 38 39 class SaverTest(test.TestCase): 40 41 def _dev(self): 42 return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0' 43 44 def testBasics(self): 45 with ops.device(self._dev()): 46 v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') 47 def model(): 48 return array_ops.constant(2.0) * v1 49 50 ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') 51 52 _ = model() 53 saver = _saver.Saver([v1]) 54 saver.save(ckpt_prefix) 55 v1.assign(2.0) 56 self.assertEqual(v1.read_value().numpy(), 2.0) 57 58 saver.restore(ckpt_prefix) 59 self.assertEqual(v1.read_value().numpy(), 1.0) 60 61 def testSameNameNoClobbering(self): 62 with ops.device(self._dev()): 63 # Note that this test purposefully uses Graphs rather than 64 # IsolateTest. Users are more likely to accidentally create the same 65 # variable name this way. 66 first_graph = ops.Graph() 67 with first_graph.as_default(): 68 v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') 69 with ops.Graph().as_default(): 70 v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') 71 saver = _saver.Saver([v1_first_graph, v1_second_graph]) 72 ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') 73 with self.assertRaisesRegexp(ValueError, 'v1'): 74 saver.save(ckpt_prefix) 75 76 def testSameObjectOK(self): 77 with ops.device(self._dev()): 78 v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') 79 # While different objects with the same shared_name are not good, passing 80 # in the same object multiple times is fine. 81 saver = _saver.Saver([v1, v1]) 82 ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') 83 saver.save(ckpt_prefix) 84 85 def testSaveByDict(self): 86 with ops.device(self._dev()): 87 v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') 88 v2 = resource_variable_ops.ResourceVariable(1.0, name='v2') 89 def model(): 90 return array_ops.constant(2.0) * v1 * v2 91 92 ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') 93 94 # Save the variables under different names. 95 _ = model() 96 saver = _saver.Saver({'ckpt/v1': v1, 'ckpt/v2': v2}) 97 saver.save(ckpt_prefix) 98 v1.assign(2.0) 99 v2.assign(2.0) 100 self.assertEqual(v1.read_value().numpy(), 2.0) 101 self.assertEqual(v2.read_value().numpy(), 2.0) 102 # Can still restore it. 103 saver.restore(ckpt_prefix) 104 self.assertEqual(v1.read_value().numpy(), 1.0) 105 self.assertEqual(v1.read_value().numpy(), 1.0) 106 # However, cannot restore it with default name. 107 with self.assertRaisesOpError('not found in checkpoint'): 108 saver = _saver.Saver([v1, v2]).restore(ckpt_prefix) 109 110 # Can specify which variable in ckpt to restore to which variable. 111 def map_func(x): 112 return {'v3': 'ckpt/v1', 'v4': 'ckpt/v2'}.get(x, x) 113 with _saver.restore_variables_on_create(ckpt_prefix, map_func): 114 v3 = resource_variable_ops.ResourceVariable(2.0, name='v3') 115 v4 = resource_variable_ops.ResourceVariable(2.0, name='v4') 116 self.assertEqual(v3.read_value().numpy(), 1.0) 117 self.assertEqual(v4.read_value().numpy(), 1.0) 118 119 def testRestoreOnCreate(self): 120 with ops.device(self._dev()): 121 def model(init_val): 122 v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') 123 return array_ops.constant(1.0) * v1, v1 124 125 ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') 126 _, v1 = model(1.0) 127 saver = _saver.Saver([v1]) 128 saver.save(ckpt_prefix) 129 130 with ops.Graph().as_default(): 131 saver = _saver.Saver([v1]) 132 with _saver.restore_variables_on_create(ckpt_prefix): 133 # Value is from checkpoint, but not from argument. 134 ret, _ = model(2.0) 135 self.assertEqual(ret.numpy(), 1.0) 136 137 def testRestoreNotFound(self): 138 with ops.device(self._dev()): 139 def model(v): 140 return array_ops.constant(1.0) * v 141 142 ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') 143 v = resource_variable_ops.ResourceVariable(1.0, name='v1') 144 _ = model(v) 145 saver = _saver.Saver([v]) 146 saver.save(ckpt_prefix) 147 148 with self.assertRaisesRegexp(errors.NotFoundError, 149 'v2 not found in checkpoint'): 150 with _saver.restore_variables_on_create(ckpt_prefix): 151 _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) 152 153 def testSaveRestoreGraphCallable(self): 154 with ops.device(self._dev()): 155 @graph_callable.graph_callable( 156 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) 157 def model(x): 158 v = variable_scope.get_variable( 159 'v', initializer=init_ops.zeros_initializer(), shape=()) 160 return v + x 161 162 # Default 2 + 0 = 2 163 self.assertEqual( 164 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) 165 166 # Save the variable value 0. 167 ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') 168 _saver.Saver(model.variables).save(ckpt_prefix) 169 170 # update variable to 1, so that 2 + 1 = 3 171 model.variables[0].assign(1.) 172 self.assertEqual( 173 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) 174 175 # load the variable value 0, so that 2 + 0 = 2 176 _saver.Saver(model.variables).restore(ckpt_prefix) 177 self.assertEqual( 178 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) 179 180 # update checkpoint variable to 1 and memory value to 2. 181 model.variables[0].assign(1.) 182 _saver.Saver(model.variables).save(ckpt_prefix) 183 model.variables[0].assign(2.) 184 self.assertEqual( 185 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) 186 187 # reset the graph and reload on create, so that 1 + 2 = 3 188 with ops.Graph().as_default(): 189 with _saver.restore_variables_on_create(ckpt_prefix): 190 @graph_callable.graph_callable( 191 [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) 192 def model2(x): 193 v = variable_scope.get_variable( 194 'v', initializer=init_ops.zeros_initializer(), shape=()) 195 return v + x 196 197 self.assertEqual( 198 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) 199 200 201 class GetOptimizerTests(test.TestCase): 202 203 def _optimizer_test_template(self, optimizer): 204 """Checks save and restore. Returns the optimizer variables.""" 205 v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v') 206 loss_fn = lambda: v[0, 0] ** 2 + v[0, 1] ** 2 207 optimizer.minimize(loss_fn) 208 optimizer_variables = _saver.get_optimizer_variables(optimizer) 209 saver = _saver.Saver(optimizer_variables + [v]) 210 checkpoint_path = saver.save(self.get_temp_dir()) 211 optimizer.minimize(loss_fn) 212 after_first_minimize = v.numpy() 213 # After we restore, the next step should be exactly the same as the one we 214 # just did. 215 saver.restore(checkpoint_path) 216 optimizer.minimize(loss_fn) 217 self.assertAllEqual(after_first_minimize, v.numpy()) 218 return optimizer_variables 219 220 def testAdam(self): 221 optimizer = adam.AdamOptimizer(0.1) 222 self._optimizer_test_template(optimizer) 223 224 def testGradientDescent(self): 225 optimizer = gradient_descent.GradientDescentOptimizer(0.02) 226 self.assertEqual(0, len(self._optimizer_test_template(optimizer))) 227 228 def testMomentum(self): 229 optimizer = momentum.MomentumOptimizer( 230 learning_rate=0.03, 231 momentum=0.5) 232 self._optimizer_test_template(optimizer) 233 234 def testRMSProp(self): 235 optimizer = rmsprop.RMSPropOptimizer(0.01) 236 self._optimizer_test_template(optimizer) 237 238 if __name__ == '__main__': 239 test.main() 240