Home | History | Annotate | Download | only in tracking
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 from __future__ import absolute_import
     16 from __future__ import division
     17 from __future__ import print_function
     18 
     19 import io
     20 import os
     21 
     22 import numpy
     23 
     24 from tensorflow.python.client import session
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import test_util
     27 from tensorflow.python.module import module
     28 from tensorflow.python.platform import test
     29 from tensorflow.python.training.tracking import python_state
     30 from tensorflow.python.training.tracking import util
     31 
     32 
     33 class _NumpyState(module.Module):
     34   """A checkpointable object whose NumPy array attributes are saved/restored.
     35 
     36   Example usage:
     37 
     38   ```python
     39   arrays = _NumpyState()
     40   checkpoint = tf.train.Checkpoint(numpy_arrays=arrays)
     41   arrays.x = numpy.zeros([3, 4])
     42   save_path = checkpoint.save("/tmp/ckpt")
     43   arrays.x[1, 1] = 4.
     44   checkpoint.restore(save_path)
     45   assert (arrays.x == numpy.zeros([3, 4])).all()
     46 
     47   second_checkpoint = tf.train.Checkpoint(
     48       numpy_arrays=_NumpyState())
     49   # Attributes of NumpyState objects are created automatically by restore()
     50   second_checkpoint.restore(save_path)
     51   assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all()
     52   ```
     53 
     54   Note that `NumpyState` objects re-create the attributes of the previously
     55   saved object on `restore()`. This is in contrast to TensorFlow variables, for
     56   which a `Variable` object must be created and assigned to an attribute.
     57 
     58   This snippet works both when graph building and when executing eagerly. On
     59   save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via
     60   a placeholder when graph building, or as a string constant when executing
     61   eagerly). When restoring they skip the TensorFlow graph entirely, and so no
     62   restore ops need be run. This means that restoration always happens eagerly,
     63   rather than waiting for `checkpoint.restore(...).run_restore_ops()` like
     64   TensorFlow variables when graph building.
     65   """
     66 
     67   def __init__(self):
     68     super(_NumpyState, self).__setattr__("_arrays", module.Module())
     69 
     70   def __getattribute__(self, name):
     71     """Un-wrap `_NumpyWrapper` objects when accessing attributes."""
     72     try:
     73       arrays = super(_NumpyState, self).__getattribute__("_arrays")
     74     except AttributeError:
     75       # _arrays hasn't been assigned yet
     76       return super(_NumpyState, self).__getattribute__(name)
     77     try:
     78       value = getattr(arrays, name)
     79     except AttributeError:
     80       dummy_array = numpy.array([])
     81       setattr(arrays, name, _NumpyWrapper(dummy_array))
     82       value = getattr(arrays, name)
     83       if value.array is dummy_array:
     84         # No set or restored attribute with this name
     85         delattr(arrays, name)
     86         return super(_NumpyState, self).__getattribute__(name)
     87 
     88     if isinstance(value, _NumpyWrapper):
     89       return value.array
     90     return super(_NumpyState, self).__getattribute__(name)
     91 
     92   def __setattr__(self, name, value):
     93     """Automatically wrap NumPy arrays assigned to attributes."""
     94     if isinstance(value, (numpy.ndarray, numpy.generic)):
     95       try:
     96         existing = getattr(self._arrays, name)
     97         existing.array = value
     98         return
     99       except AttributeError:
    100         value = _NumpyWrapper(value)
    101       setattr(self._arrays, name, value)
    102       return
    103     super(_NumpyState, self).__setattr__(name, value)
    104 
    105 
    106 class _NumpyWrapper(python_state.PythonState):
    107   """Wraps a NumPy array for storage in an object-based checkpoint."""
    108 
    109   def __init__(self, array):
    110     """Specify a NumPy array to wrap.
    111 
    112     Args:
    113       array: The NumPy array to save and restore (may be overwritten).
    114     """
    115     self.array = array
    116 
    117   def serialize(self):
    118     """Callback to serialize the array."""
    119     string_file = io.BytesIO()
    120     try:
    121       numpy.save(string_file, self.array, allow_pickle=False)
    122       serialized = string_file.getvalue()
    123     finally:
    124       string_file.close()
    125     return serialized
    126 
    127   def deserialize(self, string_value):
    128     """Callback to deserialize the array."""
    129     string_file = io.BytesIO(string_value)
    130     try:
    131       self.array = numpy.load(string_file, allow_pickle=False)
    132     finally:
    133       string_file.close()
    134 
    135 
    136 class NumpyStateTests(test.TestCase):
    137 
    138   def testWrapper(self):
    139     directory = self.get_temp_dir()
    140     prefix = os.path.join(directory, "ckpt")
    141     root = util.Checkpoint(numpy=_NumpyWrapper(numpy.array([1.])))
    142     save_path = root.save(prefix)
    143     root.numpy.array *= 2.
    144     self.assertEqual([2.], root.numpy.array)
    145     root.restore(save_path)
    146     self.assertEqual([1.], root.numpy.array)
    147 
    148   @test_util.run_in_graph_and_eager_modes
    149   def testSaveRestoreNumpyState(self):
    150     directory = self.get_temp_dir()
    151     prefix = os.path.join(directory, "ckpt")
    152     save_state = _NumpyState()
    153     saver = util.Checkpoint(numpy=save_state)
    154     save_state.a = numpy.ones([2, 2])
    155     save_state.b = numpy.ones([2, 2])
    156     save_state.b = numpy.zeros([2, 2])
    157     save_state.c = numpy.int64(3)
    158     self.assertAllEqual(numpy.ones([2, 2]), save_state.a)
    159     self.assertAllEqual(numpy.zeros([2, 2]), save_state.b)
    160     self.assertEqual(3, save_state.c)
    161     first_save_path = saver.save(prefix)
    162     save_state.a[1, 1] = 2.
    163     save_state.c = numpy.int64(4)
    164     second_save_path = saver.save(prefix)
    165 
    166     load_state = _NumpyState()
    167     loader = util.Checkpoint(numpy=load_state)
    168     loader.restore(first_save_path).initialize_or_restore()
    169     self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
    170     self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
    171     self.assertEqual(3, load_state.c)
    172     load_state.a[0, 0] = 42.
    173     self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a)
    174     loader.restore(first_save_path).run_restore_ops()
    175     self.assertAllEqual(numpy.ones([2, 2]), load_state.a)
    176     loader.restore(second_save_path).run_restore_ops()
    177     self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a)
    178     self.assertAllEqual(numpy.zeros([2, 2]), load_state.b)
    179     self.assertEqual(4, load_state.c)
    180 
    181   def testNoGraphPollution(self):
    182     graph = ops.Graph()
    183     with graph.as_default(), session.Session():
    184       directory = self.get_temp_dir()
    185       prefix = os.path.join(directory, "ckpt")
    186       save_state = _NumpyState()
    187       saver = util.Checkpoint(numpy=save_state)
    188       save_state.a = numpy.ones([2, 2])
    189       save_path = saver.save(prefix)
    190       saver.restore(save_path)
    191       graph.finalize()
    192       saver.save(prefix)
    193       save_state.a = numpy.zeros([2, 2])
    194       saver.save(prefix)
    195       saver.restore(save_path)
    196 
    197   @test_util.run_in_graph_and_eager_modes
    198   def testDocstringExample(self):
    199     arrays = _NumpyState()
    200     checkpoint = util.Checkpoint(numpy_arrays=arrays)
    201     arrays.x = numpy.zeros([3, 4])
    202     save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
    203     arrays.x[1, 1] = 4.
    204     checkpoint.restore(save_path)
    205     self.assertAllEqual(numpy.zeros([3, 4]), arrays.x)
    206 
    207     second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState())
    208     second_checkpoint.restore(save_path)
    209     self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x)
    210 
    211 
    212 if __name__ == "__main__":
    213   ops.enable_eager_execution()
    214   test.main()
    215