1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 """Tools to work with checkpoints.""" 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import six 23 24 from tensorflow.python.ops import io_ops 25 from tensorflow.python.ops import state_ops 26 from tensorflow.python.ops import variable_scope as vs 27 from tensorflow.python.ops import variables 28 from tensorflow.python.platform import gfile 29 from tensorflow.python.platform import tf_logging as logging 30 from tensorflow.python.training import saver 31 from tensorflow.python.training import training as train 32 33 __all__ = [ 34 "load_checkpoint", 35 "load_variable", 36 "list_variables", 37 "init_from_checkpoint"] 38 39 40 def _get_checkpoint_filename(filepattern): 41 """Returns checkpoint filename given directory or specific filepattern.""" 42 if gfile.IsDirectory(filepattern): 43 return saver.latest_checkpoint(filepattern) 44 return filepattern 45 46 47 def load_checkpoint(filepattern): 48 """Returns CheckpointReader for latest checkpoint. 49 50 Args: 51 filepattern: Directory with checkpoints file or path to checkpoint. 52 53 Returns: 54 `CheckpointReader` object. 55 56 Raises: 57 ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints. 58 """ 59 filename = _get_checkpoint_filename(filepattern) 60 if filename is None: 61 raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " 62 "given directory %s" % filepattern) 63 return train.NewCheckpointReader(filename) 64 65 66 def load_variable(checkpoint_dir, name): 67 """Returns a Tensor with the contents of the given variable in the checkpoint. 68 69 Args: 70 checkpoint_dir: Directory with checkpoints file or path to checkpoint. 71 name: Name of the tensor to return. 72 73 Returns: 74 `Tensor` object. 75 """ 76 # TODO(b/29227106): Fix this in the right place and remove this. 77 if name.endswith(":0"): 78 name = name[:-2] 79 reader = load_checkpoint(checkpoint_dir) 80 return reader.get_tensor(name) 81 82 83 def list_variables(checkpoint_dir): 84 """Returns list of all variables in the latest checkpoint. 85 86 Args: 87 checkpoint_dir: Directory with checkpoints file or path to checkpoint. 88 89 Returns: 90 List of tuples `(name, shape)`. 91 """ 92 reader = load_checkpoint(checkpoint_dir) 93 variable_map = reader.get_variable_to_shape_map() 94 names = sorted(variable_map.keys()) 95 result = [] 96 for name in names: 97 result.append((name, variable_map[name])) 98 return result 99 100 101 # pylint: disable=protected-access 102 # Currently variable_scope doesn't provide very good APIs to access 103 # all variables under scope and retrieve and check existing scopes. 104 # TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs. 105 106 107 def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec, 108 name="checkpoint_initializer"): 109 """Sets variable initializer to assign op form value in checkpoint's tensor. 110 111 Args: 112 variable: `Variable` object. 113 file_pattern: string, where to load checkpoints from. 114 tensor_name: Name of the `Tensor` to load from checkpoint reader. 115 slice_spec: Slice specification for loading partitioned variables. 116 name: Name of the operation. 117 """ 118 base_type = variable.dtype.base_dtype 119 restore_op = io_ops.restore_v2( 120 file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0] 121 variable._initializer_op = state_ops.assign(variable, restore_op) 122 123 124 def _set_variable_or_list_initializer(variable_or_list, file_pattern, 125 tensor_name): 126 if isinstance(variable_or_list, (list, tuple)): 127 # A set of slices. 128 slice_name = None 129 for v in variable_or_list: 130 if slice_name is None: 131 slice_name = v._save_slice_info.full_name 132 elif slice_name != v._save_slice_info.full_name: 133 raise ValueError("Slices must all be from the same tensor: %s != %s" % 134 (slice_name, v._save_slice_info.full_name)) 135 _set_checkpoint_initializer(v, file_pattern, tensor_name, 136 v._save_slice_info.spec) 137 else: 138 _set_checkpoint_initializer(variable_or_list, file_pattern, tensor_name, "") 139 140 141 def _collect_partitioned_variable(name, var_scope): 142 if name + "/part_0" in var_scope._vars: 143 var = [] 144 i = 0 145 while name + "/part_%d" % i in var_scope._vars: 146 var.append(var_scope._vars[name + "/part_%d" % i]) 147 i += 1 148 return var 149 return None 150 151 152 def init_from_checkpoint(checkpoint_dir, assignment_map): 153 """Using assignment map initializes current variables with loaded tensors. 154 155 Note: This overrides default initialization ops of specified variables and 156 redefines dtype. 157 158 Assignment map supports following syntax: 159 160 * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in 161 current `scope_name` from `checkpoint_scope_name` with matching variable 162 names. 163 * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` - 164 will initialize `scope_name/variable_name` variable 165 from `checkpoint_scope_name/some_other_variable`. 166 * `'scope_variable_name': variable` - will initialize given `tf.Variable` 167 object with variable from the checkpoint. 168 * `'scope_variable_name': list(variable)` - will initialize list of 169 partitioned variables with variable from the checkpoint. 170 * `'/': 'scope_name/'` - will load all variables in current `scope_name` from 171 checkpoint's root (e.g. no scope). 172 173 Supports loading into partitioned variables, which are represented as 174 `'<variable>/part_<part #>'`. 175 176 Example: 177 178 ```python 179 # Create variables. 180 with tf.variable_scope('test'): 181 m = tf.get_variable('my_var') 182 with tf.variable_scope('test2'): 183 var2 = tf.get_variable('my_var') 184 var3 = tf.get_variable(name="my1", shape=[100, 100], 185 partitioner=lambda shape, dtype: [5, 1]) 186 ... 187 # Specify which variables to initialize from checkpoint. 188 init_from_checkpoint(checkpoint_dir, { 189 'some_var': 'test/my_var', 190 'some_scope/': 'test2/'}) 191 ... 192 # Or use `Variable` objects to identify what to initialize. 193 init_from_checkpoint(checkpoint_dir, { 194 'some_scope/var2': var2, 195 }) 196 # Initialize partitioned variables 197 init_from_checkpoint(checkpoint_dir, { 198 'some_var_from_ckpt': 'part_var', 199 }) 200 # Or specifying the list of `Variable` objects. 201 init_from_checkpoint(checkpoint_dir, { 202 'some_var_from_ckpt': var3._get_variable_list(), 203 }) 204 ... 205 # Initialize variables as usual. 206 session.run(tf.get_all_variables()) 207 ``` 208 209 Args: 210 checkpoint_dir: Directory with checkpoints file or path to checkpoint. 211 assignment_map: Dict, where keys are names of the variables in the 212 checkpoint and values are current variables or names of current variables 213 (in default graph). 214 215 Raises: 216 tf.errors.OpError: If missing checkpoints or tensors in checkpoints. 217 ValueError: If missing variables in current graph. 218 """ 219 filepattern = _get_checkpoint_filename(checkpoint_dir) 220 reader = load_checkpoint(checkpoint_dir) 221 variable_map = reader.get_variable_to_shape_map() 222 for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map): 223 var = None 224 # Check if this is Variable object or list of Variable objects (in case of 225 # partitioned variables). 226 is_var = lambda x: isinstance(x, variables.Variable) 227 if is_var(current_var_or_name) or ( 228 isinstance(current_var_or_name, list) 229 and all(is_var(v) for v in current_var_or_name)): 230 var = current_var_or_name 231 else: 232 var_scope = vs._get_default_variable_store() 233 # Check if this variable is in var_store. 234 var = var_scope._vars.get(current_var_or_name, None) 235 # Also check if variable is partitioned as list. 236 if var is None: 237 var = _collect_partitioned_variable(current_var_or_name, var_scope) 238 if var is not None: 239 # If 1 to 1 mapping was provided, find variable in the checkpoint. 240 if tensor_name_in_ckpt not in variable_map: 241 raise ValueError("Tensor %s is not found in %s checkpoint %s" % ( 242 tensor_name_in_ckpt, checkpoint_dir, variable_map 243 )) 244 if is_var(var): 245 # Additional at-call-time checks. 246 if not var.get_shape().is_compatible_with( 247 variable_map[tensor_name_in_ckpt]): 248 raise ValueError( 249 "Shape of variable %s (%s) doesn't match with shape of " 250 "tensor %s (%s) from checkpoint reader." % ( 251 var.name, str(var.get_shape()), 252 tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt]) 253 )) 254 var_name = var.name 255 else: 256 var_name = ",".join([v.name for v in var]) 257 _set_variable_or_list_initializer(var, filepattern, tensor_name_in_ckpt) 258 logging.info("Initialize variable %s from checkpoint %s with %s" % ( 259 var_name, checkpoint_dir, tensor_name_in_ckpt 260 )) 261 else: 262 scopes = "" 263 # TODO(vihanjain): Support list of 'current_var_or_name' here. 264 if "/" in current_var_or_name: 265 scopes = current_var_or_name[:current_var_or_name.rindex("/")] 266 if not tensor_name_in_ckpt.endswith("/"): 267 raise ValueError( 268 "Assignment map with scope only name {} should map to scope only " 269 "{}. Should be 'scope/': 'other_scope/'.".format( 270 scopes, tensor_name_in_ckpt)) 271 # If scope to scope mapping was provided, find all variables in the scope 272 # and create variable to variable mapping. 273 scope_variables = set() 274 for var_name in var_scope._vars: 275 if not scopes or var_name.startswith(scopes + "/"): 276 # Consume /part_ if partitioned variable. 277 if "/part_" in var_name: 278 var_name = var_name[:var_name.index("/part_")] 279 scope_variables.add(var_name) 280 for var_name in scope_variables: 281 # Lookup name with specified prefix and suffix from current variable. 282 # If tensor_name given is '/' (root), don't use it for full name. 283 full_tensor_name = var_name[len(scopes):] 284 if current_var_or_name != "/": 285 full_tensor_name = full_tensor_name[1:] 286 if tensor_name_in_ckpt != "/": 287 full_tensor_name = tensor_name_in_ckpt + full_tensor_name 288 if full_tensor_name not in variable_map: 289 raise ValueError( 290 "Tensor %s (%s in %s) is not found in %s checkpoint" % ( 291 full_tensor_name, var_name[len(scopes) + 1:], 292 tensor_name_in_ckpt, checkpoint_dir 293 )) 294 var = var_scope._vars.get(var_name, None) 295 if var is None: 296 var = _collect_partitioned_variable(var_name, var_scope) 297 _set_variable_or_list_initializer(var, filepattern, full_tensor_name) 298 logging.info("Initialize variable %s from checkpoint %s with %s" % ( 299 var_name, checkpoint_dir, full_tensor_name 300 )) 301 # pylint: enable=protected-access 302