Home | History | Annotate | Download | only in python
      1 """Tests for eager mode Saver."""
      2 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import os
     21 
     22 from tensorflow.contrib.eager.python import saver as _saver
     23 from tensorflow.python.eager import context
     24 from tensorflow.python.eager import graph_callable
     25 from tensorflow.python.eager import test
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import errors
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import init_ops
     31 from tensorflow.python.ops import resource_variable_ops
     32 from tensorflow.python.ops import variable_scope
     33 from tensorflow.python.training import adam
     34 from tensorflow.python.training import gradient_descent
     35 from tensorflow.python.training import momentum
     36 from tensorflow.python.training import rmsprop
     37 
     38 
     39 class SaverTest(test.TestCase):
     40 
     41   def _dev(self):
     42     return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0'
     43 
     44   def testBasics(self):
     45     with ops.device(self._dev()):
     46       v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
     47       def model():
     48         return array_ops.constant(2.0) * v1
     49 
     50       ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
     51 
     52       _ = model()
     53       saver = _saver.Saver([v1])
     54       saver.save(ckpt_prefix)
     55       v1.assign(2.0)
     56       self.assertEqual(v1.read_value().numpy(), 2.0)
     57 
     58       saver.restore(ckpt_prefix)
     59       self.assertEqual(v1.read_value().numpy(), 1.0)
     60 
     61   def testSameNameNoClobbering(self):
     62     with ops.device(self._dev()):
     63       # Note that this test purposefully uses Graphs rather than
     64       # IsolateTest. Users are more likely to accidentally create the same
     65       # variable name this way.
     66       first_graph = ops.Graph()
     67       with first_graph.as_default():
     68         v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1')
     69       with ops.Graph().as_default():
     70         v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1')
     71         saver = _saver.Saver([v1_first_graph, v1_second_graph])
     72       ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
     73       with self.assertRaisesRegexp(ValueError, 'v1'):
     74         saver.save(ckpt_prefix)
     75 
     76   def testSameObjectOK(self):
     77     with ops.device(self._dev()):
     78       v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
     79       # While different objects with the same shared_name are not good, passing
     80       # in the same object multiple times is fine.
     81       saver = _saver.Saver([v1, v1])
     82       ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
     83       saver.save(ckpt_prefix)
     84 
     85   def testSaveByDict(self):
     86     with ops.device(self._dev()):
     87       v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
     88       v2 = resource_variable_ops.ResourceVariable(1.0, name='v2')
     89       def model():
     90         return array_ops.constant(2.0) * v1 * v2
     91 
     92       ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
     93 
     94       # Save the variables under different names.
     95       _ = model()
     96       saver = _saver.Saver({'ckpt/v1': v1, 'ckpt/v2': v2})
     97       saver.save(ckpt_prefix)
     98       v1.assign(2.0)
     99       v2.assign(2.0)
    100       self.assertEqual(v1.read_value().numpy(), 2.0)
    101       self.assertEqual(v2.read_value().numpy(), 2.0)
    102       # Can still restore it.
    103       saver.restore(ckpt_prefix)
    104       self.assertEqual(v1.read_value().numpy(), 1.0)
    105       self.assertEqual(v1.read_value().numpy(), 1.0)
    106       # However, cannot restore it with default name.
    107       with self.assertRaisesOpError('not found in checkpoint'):
    108         saver = _saver.Saver([v1, v2]).restore(ckpt_prefix)
    109 
    110       # Can specify which variable in ckpt to restore to which variable.
    111       def map_func(x):
    112         return {'v3': 'ckpt/v1', 'v4': 'ckpt/v2'}.get(x, x)
    113       with _saver.restore_variables_on_create(ckpt_prefix, map_func):
    114         v3 = resource_variable_ops.ResourceVariable(2.0, name='v3')
    115         v4 = resource_variable_ops.ResourceVariable(2.0, name='v4')
    116       self.assertEqual(v3.read_value().numpy(), 1.0)
    117       self.assertEqual(v4.read_value().numpy(), 1.0)
    118 
    119   def testRestoreOnCreate(self):
    120     with ops.device(self._dev()):
    121       def model(init_val):
    122         v1 = resource_variable_ops.ResourceVariable(init_val, name='v1')
    123         return array_ops.constant(1.0) * v1, v1
    124 
    125       ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
    126       _, v1 = model(1.0)
    127       saver = _saver.Saver([v1])
    128       saver.save(ckpt_prefix)
    129 
    130       with ops.Graph().as_default():
    131         saver = _saver.Saver([v1])
    132         with _saver.restore_variables_on_create(ckpt_prefix):
    133           # Value is from checkpoint, but not from argument.
    134           ret, _ = model(2.0)
    135           self.assertEqual(ret.numpy(), 1.0)
    136 
    137   def testRestoreNotFound(self):
    138     with ops.device(self._dev()):
    139       def model(v):
    140         return array_ops.constant(1.0) * v
    141 
    142       ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
    143       v = resource_variable_ops.ResourceVariable(1.0, name='v1')
    144       _ = model(v)
    145       saver = _saver.Saver([v])
    146       saver.save(ckpt_prefix)
    147 
    148       with self.assertRaisesRegexp(errors.NotFoundError,
    149                                    'v2 not found in checkpoint'):
    150         with _saver.restore_variables_on_create(ckpt_prefix):
    151           _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
    152 
    153   def testSaveRestoreGraphCallable(self):
    154     with ops.device(self._dev()):
    155       @graph_callable.graph_callable(
    156           [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
    157       def model(x):
    158         v = variable_scope.get_variable(
    159             'v', initializer=init_ops.zeros_initializer(), shape=())
    160         return v + x
    161 
    162       # Default 2 + 0 = 2
    163       self.assertEqual(
    164           2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
    165 
    166       # Save the variable value 0.
    167       ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
    168       _saver.Saver(model.variables).save(ckpt_prefix)
    169 
    170       # update variable to 1, so that 2 + 1 = 3
    171       model.variables[0].assign(1.)
    172       self.assertEqual(
    173           3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
    174 
    175       # load the variable value 0, so that 2 + 0 = 2
    176       _saver.Saver(model.variables).restore(ckpt_prefix)
    177       self.assertEqual(
    178           2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
    179 
    180       # update checkpoint variable to 1 and memory value to 2.
    181       model.variables[0].assign(1.)
    182       _saver.Saver(model.variables).save(ckpt_prefix)
    183       model.variables[0].assign(2.)
    184       self.assertEqual(
    185           4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy())
    186 
    187       # reset the graph and reload on create, so that 1 + 2 = 3
    188       with ops.Graph().as_default():
    189         with _saver.restore_variables_on_create(ckpt_prefix):
    190           @graph_callable.graph_callable(
    191               [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
    192           def model2(x):
    193             v = variable_scope.get_variable(
    194                 'v', initializer=init_ops.zeros_initializer(), shape=())
    195             return v + x
    196 
    197           self.assertEqual(
    198               3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy())
    199 
    200 
    201 class GetOptimizerTests(test.TestCase):
    202 
    203   def _optimizer_test_template(self, optimizer):
    204     """Checks save and restore. Returns the optimizer variables."""
    205     v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v')
    206     loss_fn = lambda: v[0, 0] ** 2 + v[0, 1] ** 2
    207     optimizer.minimize(loss_fn)
    208     optimizer_variables = _saver.get_optimizer_variables(optimizer)
    209     saver = _saver.Saver(optimizer_variables + [v])
    210     checkpoint_path = saver.save(self.get_temp_dir())
    211     optimizer.minimize(loss_fn)
    212     after_first_minimize = v.numpy()
    213     # After we restore, the next step should be exactly the same as the one we
    214     # just did.
    215     saver.restore(checkpoint_path)
    216     optimizer.minimize(loss_fn)
    217     self.assertAllEqual(after_first_minimize, v.numpy())
    218     return optimizer_variables
    219 
    220   def testAdam(self):
    221     optimizer = adam.AdamOptimizer(0.1)
    222     self._optimizer_test_template(optimizer)
    223 
    224   def testGradientDescent(self):
    225     optimizer = gradient_descent.GradientDescentOptimizer(0.02)
    226     self.assertEqual(0, len(self._optimizer_test_template(optimizer)))
    227 
    228   def testMomentum(self):
    229     optimizer = momentum.MomentumOptimizer(
    230         learning_rate=0.03,
    231         momentum=0.5)
    232     self._optimizer_test_template(optimizer)
    233 
    234   def testRMSProp(self):
    235     optimizer = rmsprop.RMSPropOptimizer(0.01)
    236     self._optimizer_test_template(optimizer)
    237 
    238 if __name__ == '__main__':
    239   test.main()
    240