1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================= 15 """Utility classes for testing checkpointing.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.eager import context 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import ops as ops_lib 24 from tensorflow.python.ops import gen_lookup_ops 25 from tensorflow.python.training import saver as saver_module 26 27 28 class CheckpointedOp(object): 29 """Op with a custom checkpointing implementation. 30 31 Defined as part of the test because the MutableHashTable Python code is 32 currently in contrib. 33 """ 34 35 # pylint: disable=protected-access 36 def __init__(self, name, table_ref=None): 37 if table_ref is None: 38 self.table_ref = gen_lookup_ops._mutable_hash_table_v2( 39 key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name) 40 else: 41 self.table_ref = table_ref 42 self._name = name 43 if context.in_graph_mode(): 44 self._saveable = CheckpointedOp.CustomSaveable(self, name) 45 ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS, 46 self._saveable) 47 48 @property 49 def name(self): 50 return self._name 51 52 @property 53 def saveable(self): 54 if context.in_graph_mode(): 55 return self._saveable 56 else: 57 return CheckpointedOp.CustomSaveable(self, self.name) 58 59 def insert(self, keys, values): 60 return gen_lookup_ops._lookup_table_insert_v2(self.table_ref, keys, values) 61 62 def lookup(self, keys, default): 63 return gen_lookup_ops._lookup_table_find_v2(self.table_ref, keys, default) 64 65 def keys(self): 66 return self._export()[0] 67 68 def values(self): 69 return self._export()[1] 70 71 def _export(self): 72 return gen_lookup_ops._lookup_table_export_v2(self.table_ref, dtypes.string, 73 dtypes.float32) 74 75 class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject): 76 """A custom saveable for CheckpointedOp.""" 77 78 def __init__(self, table, name): 79 tensors = table._export() 80 specs = [ 81 saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "", 82 name + "-keys"), 83 saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "", 84 name + "-values") 85 ] 86 super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name) 87 88 def restore(self, restore_tensors, shapes): 89 return gen_lookup_ops._lookup_table_import_v2( 90 self.op.table_ref, restore_tensors[0], restore_tensors[1]) 91 # pylint: enable=protected-access 92