Home | History | Annotate | Download | only in framework
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 """Tools to work with checkpoints."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import six
     23 
     24 from tensorflow.python.ops import io_ops
     25 from tensorflow.python.ops import state_ops
     26 from tensorflow.python.ops import variable_scope as vs
     27 from tensorflow.python.ops import variables
     28 from tensorflow.python.platform import gfile
     29 from tensorflow.python.platform import tf_logging as logging
     30 from tensorflow.python.training import saver
     31 from tensorflow.python.training import training as train
     32 
     33 __all__ = [
     34     "load_checkpoint",
     35     "load_variable",
     36     "list_variables",
     37     "init_from_checkpoint"]
     38 
     39 
     40 def _get_checkpoint_filename(filepattern):
     41   """Returns checkpoint filename given directory or specific filepattern."""
     42   if gfile.IsDirectory(filepattern):
     43     return saver.latest_checkpoint(filepattern)
     44   return filepattern
     45 
     46 
     47 def load_checkpoint(filepattern):
     48   """Returns CheckpointReader for latest checkpoint.
     49 
     50   Args:
     51     filepattern: Directory with checkpoints file or path to checkpoint.
     52 
     53   Returns:
     54     `CheckpointReader` object.
     55 
     56   Raises:
     57     ValueError: if checkpoint_dir doesn't have 'checkpoint' file or checkpoints.
     58   """
     59   filename = _get_checkpoint_filename(filepattern)
     60   if filename is None:
     61     raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
     62                      "given directory %s" % filepattern)
     63   return train.NewCheckpointReader(filename)
     64 
     65 
     66 def load_variable(checkpoint_dir, name):
     67   """Returns a Tensor with the contents of the given variable in the checkpoint.
     68 
     69   Args:
     70     checkpoint_dir: Directory with checkpoints file or path to checkpoint.
     71     name: Name of the tensor to return.
     72 
     73   Returns:
     74     `Tensor` object.
     75   """
     76   # TODO(b/29227106): Fix this in the right place and remove this.
     77   if name.endswith(":0"):
     78     name = name[:-2]
     79   reader = load_checkpoint(checkpoint_dir)
     80   return reader.get_tensor(name)
     81 
     82 
     83 def list_variables(checkpoint_dir):
     84   """Returns list of all variables in the latest checkpoint.
     85 
     86   Args:
     87     checkpoint_dir: Directory with checkpoints file or path to checkpoint.
     88 
     89   Returns:
     90     List of tuples `(name, shape)`.
     91   """
     92   reader = load_checkpoint(checkpoint_dir)
     93   variable_map = reader.get_variable_to_shape_map()
     94   names = sorted(variable_map.keys())
     95   result = []
     96   for name in names:
     97     result.append((name, variable_map[name]))
     98   return result
     99 
    100 
    101 # pylint: disable=protected-access
    102 # Currently variable_scope doesn't provide very good APIs to access
    103 # all variables under scope and retrieve and check existing scopes.
    104 # TODO(ipolosukhin): Refactor variable_scope module to provide nicer APIs.
    105 
    106 
    107 def _set_checkpoint_initializer(variable, file_pattern, tensor_name, slice_spec,
    108                                 name="checkpoint_initializer"):
    109   """Sets variable initializer to assign op form value in checkpoint's tensor.
    110 
    111   Args:
    112     variable: `Variable` object.
    113     file_pattern: string, where to load checkpoints from.
    114     tensor_name: Name of the `Tensor` to load from checkpoint reader.
    115     slice_spec: Slice specification for loading partitioned variables.
    116     name: Name of the operation.
    117   """
    118   base_type = variable.dtype.base_dtype
    119   restore_op = io_ops.restore_v2(
    120       file_pattern, [tensor_name], [slice_spec], [base_type], name=name)[0]
    121   variable._initializer_op = state_ops.assign(variable, restore_op)
    122 
    123 
    124 def _set_variable_or_list_initializer(variable_or_list, file_pattern,
    125                                       tensor_name):
    126   if isinstance(variable_or_list, (list, tuple)):
    127     # A set of slices.
    128     slice_name = None
    129     for v in variable_or_list:
    130       if slice_name is None:
    131         slice_name = v._save_slice_info.full_name
    132       elif slice_name != v._save_slice_info.full_name:
    133         raise ValueError("Slices must all be from the same tensor: %s != %s" %
    134                          (slice_name, v._save_slice_info.full_name))
    135       _set_checkpoint_initializer(v, file_pattern, tensor_name,
    136                                   v._save_slice_info.spec)
    137   else:
    138     _set_checkpoint_initializer(variable_or_list, file_pattern, tensor_name, "")
    139 
    140 
    141 def _collect_partitioned_variable(name, var_scope):
    142   if name + "/part_0" in var_scope._vars:
    143     var = []
    144     i = 0
    145     while name + "/part_%d" % i in var_scope._vars:
    146       var.append(var_scope._vars[name + "/part_%d" % i])
    147       i += 1
    148     return var
    149   return None
    150 
    151 
    152 def init_from_checkpoint(checkpoint_dir, assignment_map):
    153   """Using assignment map initializes current variables with loaded tensors.
    154 
    155   Note: This overrides default initialization ops of specified variables and
    156   redefines dtype.
    157 
    158   Assignment map supports following syntax:
    159 
    160   * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    161     current `scope_name` from `checkpoint_scope_name` with matching variable
    162     names.
    163   * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    164     will initialize `scope_name/variable_name` variable
    165     from `checkpoint_scope_name/some_other_variable`.
    166   * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    167     object with variable from the checkpoint.
    168   * `'scope_variable_name': list(variable)` - will initialize list of
    169     partitioned variables with variable from the checkpoint.
    170   * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    171     checkpoint's root (e.g. no scope).
    172 
    173   Supports loading into partitioned variables, which are represented as
    174   `'<variable>/part_<part #>'`.
    175 
    176   Example:
    177 
    178   ```python
    179     # Create variables.
    180     with tf.variable_scope('test'):
    181       m = tf.get_variable('my_var')
    182     with tf.variable_scope('test2'):
    183       var2 = tf.get_variable('my_var')
    184     var3 = tf.get_variable(name="my1", shape=[100, 100],
    185                            partitioner=lambda shape, dtype: [5, 1])
    186     ...
    187     # Specify which variables to initialize from checkpoint.
    188     init_from_checkpoint(checkpoint_dir, {
    189       'some_var': 'test/my_var',
    190       'some_scope/': 'test2/'})
    191     ...
    192     # Or use `Variable` objects to identify what to initialize.
    193     init_from_checkpoint(checkpoint_dir, {
    194       'some_scope/var2': var2,
    195     })
    196     # Initialize partitioned variables
    197     init_from_checkpoint(checkpoint_dir, {
    198       'some_var_from_ckpt': 'part_var',
    199     })
    200     # Or specifying the list of `Variable` objects.
    201     init_from_checkpoint(checkpoint_dir, {
    202       'some_var_from_ckpt': var3._get_variable_list(),
    203     })
    204     ...
    205     # Initialize variables as usual.
    206     session.run(tf.get_all_variables())
    207   ```
    208 
    209   Args:
    210     checkpoint_dir: Directory with checkpoints file or path to checkpoint.
    211     assignment_map: Dict, where keys are names of the variables in the
    212       checkpoint and values are current variables or names of current variables
    213       (in default graph).
    214 
    215   Raises:
    216     tf.errors.OpError: If missing checkpoints or tensors in checkpoints.
    217     ValueError: If missing variables in current graph.
    218   """
    219   filepattern = _get_checkpoint_filename(checkpoint_dir)
    220   reader = load_checkpoint(checkpoint_dir)
    221   variable_map = reader.get_variable_to_shape_map()
    222   for tensor_name_in_ckpt, current_var_or_name in six.iteritems(assignment_map):
    223     var = None
    224     # Check if this is Variable object or list of Variable objects (in case of
    225     # partitioned variables).
    226     is_var = lambda x: isinstance(x, variables.Variable)
    227     if is_var(current_var_or_name) or (
    228         isinstance(current_var_or_name, list)
    229         and all(is_var(v) for v in current_var_or_name)):
    230       var = current_var_or_name
    231     else:
    232       var_scope = vs._get_default_variable_store()
    233       # Check if this variable is in var_store.
    234       var = var_scope._vars.get(current_var_or_name, None)
    235       # Also check if variable is partitioned as list.
    236       if var is None:
    237         var = _collect_partitioned_variable(current_var_or_name, var_scope)
    238     if var is not None:
    239       # If 1 to 1 mapping was provided, find variable in the checkpoint.
    240       if tensor_name_in_ckpt not in variable_map:
    241         raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
    242             tensor_name_in_ckpt, checkpoint_dir, variable_map
    243         ))
    244       if is_var(var):
    245         # Additional at-call-time checks.
    246         if not var.get_shape().is_compatible_with(
    247             variable_map[tensor_name_in_ckpt]):
    248           raise ValueError(
    249               "Shape of variable %s (%s) doesn't match with shape of "
    250               "tensor %s (%s) from checkpoint reader." % (
    251                   var.name, str(var.get_shape()),
    252                   tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
    253               ))
    254         var_name = var.name
    255       else:
    256         var_name = ",".join([v.name for v in var])
    257       _set_variable_or_list_initializer(var, filepattern, tensor_name_in_ckpt)
    258       logging.info("Initialize variable %s from checkpoint %s with %s" % (
    259           var_name, checkpoint_dir, tensor_name_in_ckpt
    260       ))
    261     else:
    262       scopes = ""
    263       # TODO(vihanjain): Support list of 'current_var_or_name' here.
    264       if "/" in current_var_or_name:
    265         scopes = current_var_or_name[:current_var_or_name.rindex("/")]
    266       if not tensor_name_in_ckpt.endswith("/"):
    267         raise ValueError(
    268             "Assignment map with scope only name {} should map to scope only "
    269             "{}. Should be 'scope/': 'other_scope/'.".format(
    270                 scopes, tensor_name_in_ckpt))
    271       # If scope to scope mapping was provided, find all variables in the scope
    272       # and create variable to variable mapping.
    273       scope_variables = set()
    274       for var_name in var_scope._vars:
    275         if not scopes or var_name.startswith(scopes + "/"):
    276           # Consume /part_ if partitioned variable.
    277           if "/part_" in var_name:
    278             var_name = var_name[:var_name.index("/part_")]
    279           scope_variables.add(var_name)
    280       for var_name in scope_variables:
    281         # Lookup name with specified prefix and suffix from current variable.
    282         # If tensor_name given is '/' (root), don't use it for full name.
    283         full_tensor_name = var_name[len(scopes):]
    284         if current_var_or_name != "/":
    285           full_tensor_name = full_tensor_name[1:]
    286         if tensor_name_in_ckpt != "/":
    287           full_tensor_name = tensor_name_in_ckpt + full_tensor_name
    288         if full_tensor_name not in variable_map:
    289           raise ValueError(
    290               "Tensor %s (%s in %s) is not found in %s checkpoint" % (
    291                   full_tensor_name, var_name[len(scopes) + 1:],
    292                   tensor_name_in_ckpt, checkpoint_dir
    293               ))
    294         var = var_scope._vars.get(var_name, None)
    295         if var is None:
    296           var = _collect_partitioned_variable(var_name, var_scope)
    297         _set_variable_or_list_initializer(var, filepattern, full_tensor_name)
    298         logging.info("Initialize variable %s from checkpoint %s with %s" % (
    299             var_name, checkpoint_dir, full_tensor_name
    300         ))
    301 # pylint: enable=protected-access
    302