1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 from __future__ import absolute_import 16 from __future__ import division 17 from __future__ import print_function 18 19 import io 20 import os 21 22 import numpy 23 24 from tensorflow.python.client import session 25 from tensorflow.python.framework import ops 26 from tensorflow.python.framework import test_util 27 from tensorflow.python.module import module 28 from tensorflow.python.platform import test 29 from tensorflow.python.training.tracking import python_state 30 from tensorflow.python.training.tracking import util 31 32 33 class _NumpyState(module.Module): 34 """A checkpointable object whose NumPy array attributes are saved/restored. 35 36 Example usage: 37 38 ```python 39 arrays = _NumpyState() 40 checkpoint = tf.train.Checkpoint(numpy_arrays=arrays) 41 arrays.x = numpy.zeros([3, 4]) 42 save_path = checkpoint.save("/tmp/ckpt") 43 arrays.x[1, 1] = 4. 44 checkpoint.restore(save_path) 45 assert (arrays.x == numpy.zeros([3, 4])).all() 46 47 second_checkpoint = tf.train.Checkpoint( 48 numpy_arrays=_NumpyState()) 49 # Attributes of NumpyState objects are created automatically by restore() 50 second_checkpoint.restore(save_path) 51 assert (second_checkpoint.numpy_arrays.x == numpy.zeros([3, 4])).all() 52 ``` 53 54 Note that `NumpyState` objects re-create the attributes of the previously 55 saved object on `restore()`. This is in contrast to TensorFlow variables, for 56 which a `Variable` object must be created and assigned to an attribute. 57 58 This snippet works both when graph building and when executing eagerly. On 59 save, the NumPy array(s) are fed as strings to be saved in the checkpoint (via 60 a placeholder when graph building, or as a string constant when executing 61 eagerly). When restoring they skip the TensorFlow graph entirely, and so no 62 restore ops need be run. This means that restoration always happens eagerly, 63 rather than waiting for `checkpoint.restore(...).run_restore_ops()` like 64 TensorFlow variables when graph building. 65 """ 66 67 def __init__(self): 68 super(_NumpyState, self).__setattr__("_arrays", module.Module()) 69 70 def __getattribute__(self, name): 71 """Un-wrap `_NumpyWrapper` objects when accessing attributes.""" 72 try: 73 arrays = super(_NumpyState, self).__getattribute__("_arrays") 74 except AttributeError: 75 # _arrays hasn't been assigned yet 76 return super(_NumpyState, self).__getattribute__(name) 77 try: 78 value = getattr(arrays, name) 79 except AttributeError: 80 dummy_array = numpy.array([]) 81 setattr(arrays, name, _NumpyWrapper(dummy_array)) 82 value = getattr(arrays, name) 83 if value.array is dummy_array: 84 # No set or restored attribute with this name 85 delattr(arrays, name) 86 return super(_NumpyState, self).__getattribute__(name) 87 88 if isinstance(value, _NumpyWrapper): 89 return value.array 90 return super(_NumpyState, self).__getattribute__(name) 91 92 def __setattr__(self, name, value): 93 """Automatically wrap NumPy arrays assigned to attributes.""" 94 if isinstance(value, (numpy.ndarray, numpy.generic)): 95 try: 96 existing = getattr(self._arrays, name) 97 existing.array = value 98 return 99 except AttributeError: 100 value = _NumpyWrapper(value) 101 setattr(self._arrays, name, value) 102 return 103 super(_NumpyState, self).__setattr__(name, value) 104 105 106 class _NumpyWrapper(python_state.PythonState): 107 """Wraps a NumPy array for storage in an object-based checkpoint.""" 108 109 def __init__(self, array): 110 """Specify a NumPy array to wrap. 111 112 Args: 113 array: The NumPy array to save and restore (may be overwritten). 114 """ 115 self.array = array 116 117 def serialize(self): 118 """Callback to serialize the array.""" 119 string_file = io.BytesIO() 120 try: 121 numpy.save(string_file, self.array, allow_pickle=False) 122 serialized = string_file.getvalue() 123 finally: 124 string_file.close() 125 return serialized 126 127 def deserialize(self, string_value): 128 """Callback to deserialize the array.""" 129 string_file = io.BytesIO(string_value) 130 try: 131 self.array = numpy.load(string_file, allow_pickle=False) 132 finally: 133 string_file.close() 134 135 136 class NumpyStateTests(test.TestCase): 137 138 def testWrapper(self): 139 directory = self.get_temp_dir() 140 prefix = os.path.join(directory, "ckpt") 141 root = util.Checkpoint(numpy=_NumpyWrapper(numpy.array([1.]))) 142 save_path = root.save(prefix) 143 root.numpy.array *= 2. 144 self.assertEqual([2.], root.numpy.array) 145 root.restore(save_path) 146 self.assertEqual([1.], root.numpy.array) 147 148 @test_util.run_in_graph_and_eager_modes 149 def testSaveRestoreNumpyState(self): 150 directory = self.get_temp_dir() 151 prefix = os.path.join(directory, "ckpt") 152 save_state = _NumpyState() 153 saver = util.Checkpoint(numpy=save_state) 154 save_state.a = numpy.ones([2, 2]) 155 save_state.b = numpy.ones([2, 2]) 156 save_state.b = numpy.zeros([2, 2]) 157 save_state.c = numpy.int64(3) 158 self.assertAllEqual(numpy.ones([2, 2]), save_state.a) 159 self.assertAllEqual(numpy.zeros([2, 2]), save_state.b) 160 self.assertEqual(3, save_state.c) 161 first_save_path = saver.save(prefix) 162 save_state.a[1, 1] = 2. 163 save_state.c = numpy.int64(4) 164 second_save_path = saver.save(prefix) 165 166 load_state = _NumpyState() 167 loader = util.Checkpoint(numpy=load_state) 168 loader.restore(first_save_path).initialize_or_restore() 169 self.assertAllEqual(numpy.ones([2, 2]), load_state.a) 170 self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) 171 self.assertEqual(3, load_state.c) 172 load_state.a[0, 0] = 42. 173 self.assertAllEqual([[42., 1.], [1., 1.]], load_state.a) 174 loader.restore(first_save_path).run_restore_ops() 175 self.assertAllEqual(numpy.ones([2, 2]), load_state.a) 176 loader.restore(second_save_path).run_restore_ops() 177 self.assertAllEqual([[1., 1.], [1., 2.]], load_state.a) 178 self.assertAllEqual(numpy.zeros([2, 2]), load_state.b) 179 self.assertEqual(4, load_state.c) 180 181 def testNoGraphPollution(self): 182 graph = ops.Graph() 183 with graph.as_default(), session.Session(): 184 directory = self.get_temp_dir() 185 prefix = os.path.join(directory, "ckpt") 186 save_state = _NumpyState() 187 saver = util.Checkpoint(numpy=save_state) 188 save_state.a = numpy.ones([2, 2]) 189 save_path = saver.save(prefix) 190 saver.restore(save_path) 191 graph.finalize() 192 saver.save(prefix) 193 save_state.a = numpy.zeros([2, 2]) 194 saver.save(prefix) 195 saver.restore(save_path) 196 197 @test_util.run_in_graph_and_eager_modes 198 def testDocstringExample(self): 199 arrays = _NumpyState() 200 checkpoint = util.Checkpoint(numpy_arrays=arrays) 201 arrays.x = numpy.zeros([3, 4]) 202 save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) 203 arrays.x[1, 1] = 4. 204 checkpoint.restore(save_path) 205 self.assertAllEqual(numpy.zeros([3, 4]), arrays.x) 206 207 second_checkpoint = util.Checkpoint(numpy_arrays=_NumpyState()) 208 second_checkpoint.restore(save_path) 209 self.assertAllEqual(numpy.zeros([3, 4]), second_checkpoint.numpy_arrays.x) 210 211 212 if __name__ == "__main__": 213 ops.enable_eager_execution() 214 test.main() 215