Home | History | Annotate | Download | only in training
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tools to work with checkpoints."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import six
     22 
     23 from tensorflow.python import pywrap_tensorflow
     24 from tensorflow.python.distribute import distribution_strategy_context
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import io_ops
     27 from tensorflow.python.ops import resource_variable_ops
     28 from tensorflow.python.ops import variable_scope as vs
     29 from tensorflow.python.ops import variables
     30 from tensorflow.python.platform import gfile
     31 from tensorflow.python.platform import tf_logging as logging
     32 from tensorflow.python.training import checkpoint_management
     33 from tensorflow.python.training.saving import saveable_object_util
     34 from tensorflow.python.util.tf_export import tf_export
     35 
     36 
     37 __all__ = [
     38     "load_checkpoint", "load_variable", "list_variables", "init_from_checkpoint"
     39 ]
     40 
     41 
     42 @tf_export("train.load_checkpoint")
     43 def load_checkpoint(ckpt_dir_or_file):
     44   """Returns `CheckpointReader` for checkpoint found in `ckpt_dir_or_file`.
     45 
     46   If `ckpt_dir_or_file` resolves to a directory with multiple checkpoints,
     47   reader for the latest checkpoint is returned.
     48 
     49   Args:
     50     ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint
     51       file.
     52 
     53   Returns:
     54     `CheckpointReader` object.
     55 
     56   Raises:
     57     ValueError: If `ckpt_dir_or_file` resolves to a directory with no
     58       checkpoints.
     59   """
     60   filename = _get_checkpoint_filename(ckpt_dir_or_file)
     61   if filename is None:
     62     raise ValueError("Couldn't find 'checkpoint' file or checkpoints in "
     63                      "given directory %s" % ckpt_dir_or_file)
     64   return pywrap_tensorflow.NewCheckpointReader(filename)
     65 
     66 
     67 @tf_export("train.load_variable")
     68 def load_variable(ckpt_dir_or_file, name):
     69   """Returns the tensor value of the given variable in the checkpoint.
     70 
     71   Args:
     72     ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
     73     name: Name of the variable to return.
     74 
     75   Returns:
     76     A numpy `ndarray` with a copy of the value of this variable.
     77   """
     78   # TODO(b/29227106): Fix this in the right place and remove this.
     79   if name.endswith(":0"):
     80     name = name[:-2]
     81   reader = load_checkpoint(ckpt_dir_or_file)
     82   return reader.get_tensor(name)
     83 
     84 
     85 @tf_export("train.list_variables")
     86 def list_variables(ckpt_dir_or_file):
     87   """Returns list of all variables in the checkpoint.
     88 
     89   Args:
     90     ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
     91 
     92   Returns:
     93     List of tuples `(name, shape)`.
     94   """
     95   reader = load_checkpoint(ckpt_dir_or_file)
     96   variable_map = reader.get_variable_to_shape_map()
     97   names = sorted(variable_map.keys())
     98   result = []
     99   for name in names:
    100     result.append((name, variable_map[name]))
    101   return result
    102 
    103 
    104 @tf_export(v1=["train.init_from_checkpoint"])
    105 def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
    106   """Replaces `tf.Variable` initializers so they load from a checkpoint file.
    107 
    108   Values are not loaded immediately, but when the initializer is run
    109   (typically by running a `tf.global_variables_initializer` op).
    110 
    111   Note: This overrides default initialization ops of specified variables and
    112   redefines dtype.
    113 
    114   Assignment map supports following syntax:
    115 
    116   * `'checkpoint_scope_name/': 'scope_name/'` - will load all variables in
    117     current `scope_name` from `checkpoint_scope_name` with matching tensor
    118     names.
    119   * `'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'` -
    120     will initialize `scope_name/variable_name` variable
    121     from `checkpoint_scope_name/some_other_variable`.
    122   * `'scope_variable_name': variable` - will initialize given `tf.Variable`
    123     object with tensor 'scope_variable_name' from the checkpoint.
    124   * `'scope_variable_name': list(variable)` - will initialize list of
    125     partitioned variables with tensor 'scope_variable_name' from the checkpoint.
    126   * `'/': 'scope_name/'` - will load all variables in current `scope_name` from
    127     checkpoint's root (e.g. no scope).
    128 
    129   Supports loading into partitioned variables, which are represented as
    130   `'<variable>/part_<part #>'`.
    131 
    132   Example:
    133 
    134   ```python
    135 
    136   # Say, '/tmp/model.ckpt' has the following tensors:
    137   #  -- name='old_scope_1/var1', shape=[20, 2]
    138   #  -- name='old_scope_1/var2', shape=[50, 4]
    139   #  -- name='old_scope_2/var3', shape=[100, 100]
    140 
    141   # Create new model's variables
    142   with tf.variable_scope('new_scope_1'):
    143     var1 = tf.get_variable('var1', shape=[20, 2],
    144                            initializer=tf.zeros_initializer())
    145   with tf.variable_scope('new_scope_2'):
    146     var2 = tf.get_variable('var2', shape=[50, 4],
    147                            initializer=tf.zeros_initializer())
    148     # Partition into 5 variables along the first axis.
    149     var3 = tf.get_variable(name='var3', shape=[100, 100],
    150                            initializer=tf.zeros_initializer(),
    151                            partitioner=lambda shape, dtype: [5, 1])
    152 
    153   # Initialize all variables in `new_scope_1` from `old_scope_1`.
    154   init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1'})
    155 
    156   # Use names to specify which variables to initialize from checkpoint.
    157   init_from_checkpoint('/tmp/model.ckpt',
    158                        {'old_scope_1/var1': 'new_scope_1/var1',
    159                         'old_scope_1/var2': 'new_scope_2/var2'})
    160 
    161   # Or use tf.Variable objects to identify what to initialize.
    162   init_from_checkpoint('/tmp/model.ckpt',
    163                        {'old_scope_1/var1': var1,
    164                         'old_scope_1/var2': var2})
    165 
    166   # Initialize partitioned variables using variable's name
    167   init_from_checkpoint('/tmp/model.ckpt',
    168                        {'old_scope_2/var3': 'new_scope_2/var3'})
    169 
    170   # Or specify the list of tf.Variable objects.
    171   init_from_checkpoint('/tmp/model.ckpt',
    172                        {'old_scope_2/var3': var3._get_variable_list()})
    173 
    174   ```
    175 
    176   Args:
    177     ckpt_dir_or_file: Directory with checkpoints file or path to checkpoint.
    178     assignment_map: Dict, where keys are names of the variables in the
    179       checkpoint and values are current variables or names of current variables
    180       (in default graph).
    181 
    182   Raises:
    183     ValueError: If missing variables in current graph, or if missing
    184       checkpoints or tensors in checkpoints.
    185   """
    186   init_from_checkpoint_fn = lambda _: _init_from_checkpoint(
    187       ckpt_dir_or_file, assignment_map)
    188   if distribution_strategy_context.get_cross_replica_context():
    189     init_from_checkpoint_fn(None)
    190   else:
    191     distribution_strategy_context.get_replica_context().merge_call(
    192         init_from_checkpoint_fn)
    193 
    194 
    195 def _init_from_checkpoint(ckpt_dir_or_file, assignment_map):
    196   """See `init_from_checkpoint` for documentation."""
    197   ckpt_file = _get_checkpoint_filename(ckpt_dir_or_file)
    198   reader = load_checkpoint(ckpt_dir_or_file)
    199   variable_map = reader.get_variable_to_shape_map()
    200   for tensor_name_in_ckpt, current_var_or_name in sorted(
    201       six.iteritems(assignment_map)):
    202     var = None
    203     # Check if this is Variable object or list of Variable objects (in case of
    204     # partitioned variables).
    205     if _is_variable(current_var_or_name) or (
    206         isinstance(current_var_or_name, list)
    207         and all(_is_variable(v) for v in current_var_or_name)):
    208       var = current_var_or_name
    209     else:
    210       store_vars = vs._get_default_variable_store()._vars  # pylint:disable=protected-access
    211       # Check if this variable is in var_store.
    212       var = store_vars.get(current_var_or_name, None)
    213       # Also check if variable is partitioned as list.
    214       if var is None:
    215         var = _collect_partitioned_variable(current_var_or_name, store_vars)
    216     if var is not None:
    217       # If 1 to 1 mapping was provided, find variable in the checkpoint.
    218       if tensor_name_in_ckpt not in variable_map:
    219         raise ValueError("Tensor %s is not found in %s checkpoint %s" % (
    220             tensor_name_in_ckpt, ckpt_dir_or_file, variable_map
    221         ))
    222       if _is_variable(var):
    223         # Additional at-call-time checks.
    224         if not var.get_shape().is_compatible_with(
    225             variable_map[tensor_name_in_ckpt]):
    226           raise ValueError(
    227               "Shape of variable %s (%s) doesn't match with shape of "
    228               "tensor %s (%s) from checkpoint reader." % (
    229                   var.name, str(var.get_shape()),
    230                   tensor_name_in_ckpt, str(variable_map[tensor_name_in_ckpt])
    231               ))
    232         var_name = var.name
    233       else:
    234         var_name = ",".join([v.name for v in var])
    235       _set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
    236       logging.debug("Initialize variable %s from checkpoint %s with %s",
    237                     var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
    238     else:
    239       scopes = ""
    240       # TODO(vihanjain): Support list of 'current_var_or_name' here.
    241       if "/" in current_var_or_name:
    242         scopes = current_var_or_name[:current_var_or_name.rindex("/")]
    243       if not tensor_name_in_ckpt.endswith("/"):
    244         raise ValueError(
    245             "Assignment map with scope only name {} should map to scope only "
    246             "{}. Should be 'scope/': 'other_scope/'.".format(
    247                 scopes, tensor_name_in_ckpt))
    248       # If scope to scope mapping was provided, find all variables in the scope
    249       # and create variable to variable mapping.
    250       scope_variables = set()
    251       for var_name in store_vars:
    252         if not scopes or var_name.startswith(scopes + "/"):
    253           # Consume /part_ if partitioned variable.
    254           if "/part_" in var_name:
    255             var_name = var_name[:var_name.index("/part_")]
    256           scope_variables.add(var_name)
    257       for var_name in sorted(scope_variables):
    258         # Lookup name with specified prefix and suffix from current variable.
    259         # If tensor_name given is '/' (root), don't use it for full name.
    260         full_tensor_name = var_name[len(scopes):]
    261         if current_var_or_name != "/":
    262           full_tensor_name = full_tensor_name[1:]
    263         if tensor_name_in_ckpt != "/":
    264           full_tensor_name = tensor_name_in_ckpt + full_tensor_name
    265         # Remove trailing '/', if any, in the full_tensor_name
    266         if full_tensor_name.endswith("/"):
    267           full_tensor_name = full_tensor_name[:-1]
    268         if full_tensor_name not in variable_map:
    269           raise ValueError(
    270               "Tensor %s (%s in %s) is not found in %s checkpoint" % (
    271                   full_tensor_name, var_name[len(scopes) + 1:],
    272                   tensor_name_in_ckpt, ckpt_dir_or_file
    273               ))
    274         var = store_vars.get(var_name, None)
    275         if var is None:
    276           var = _collect_partitioned_variable(var_name, store_vars)
    277         _set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
    278         logging.debug("Initialize variable %s from checkpoint %s with %s",
    279                       var_name, ckpt_dir_or_file, full_tensor_name)
    280 
    281 
    282 def _get_checkpoint_filename(ckpt_dir_or_file):
    283   """Returns checkpoint filename given directory or specific checkpoint file."""
    284   if gfile.IsDirectory(ckpt_dir_or_file):
    285     return checkpoint_management.latest_checkpoint(ckpt_dir_or_file)
    286   return ckpt_dir_or_file
    287 
    288 
    289 def _set_checkpoint_initializer(variable,
    290                                 ckpt_file,
    291                                 tensor_name,
    292                                 slice_spec,
    293                                 name="checkpoint_initializer"):
    294   """Overrides given variable's initialization op.
    295 
    296   Sets variable initializer to assign op that initializes variable from tensor's
    297   value in the checkpoint.
    298 
    299   Args:
    300     variable: `tf.Variable` object.
    301     ckpt_file: string, full path of the checkpoint.
    302     tensor_name: Name of the tensor to load from the checkpoint.
    303     slice_spec: Slice specification for loading partitioned tensors.
    304     name: Name of the operation.
    305   """
    306   base_type = variable.dtype.base_dtype
    307   # Do not colocate with variable since RestoreV2 op only runs on CPU and
    308   # colocation will force variable (and other ops that colocate with variable)
    309   # to be on CPU as well. It is okay to place the variable's initializer op on
    310   # CPU since it will only be run once at the start.
    311   with ops.device(variable.device), ops.device("/cpu:0"):
    312     restore_op = io_ops.restore_v2(
    313         ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0]
    314 
    315     names_to_saveables = saveable_object_util.op_list_to_dict([variable])
    316     saveable_objects = []
    317     for name, op in names_to_saveables.items():
    318       for s in saveable_object_util.saveable_objects_for_op(op, name):
    319         saveable_objects.append(s)
    320 
    321     assert len(saveable_objects) == 1  # Should be only one variable.
    322   init_op = saveable_objects[0].restore([restore_op], restored_shapes=None)
    323 
    324   # pylint:disable=protected-access
    325   variable._initializer_op = init_op
    326   restore_op.set_shape(variable.shape)
    327   variable._initial_value = restore_op
    328   # pylint:enable=protected-access
    329 
    330 
    331 def _set_variable_or_list_initializer(variable_or_list, ckpt_file,
    332                                       tensor_name):
    333   """Overrides initialization op of given variable or list of variables.
    334 
    335   Calls `_set_checkpoint_initializer` for each variable in the given list of
    336   variables.
    337 
    338   Args:
    339     variable_or_list: `tf.Variable` object or a list of `tf.Variable` objects.
    340     ckpt_file: string, full path of the checkpoint.
    341     tensor_name: Name of the tensor to load from the checkpoint.
    342 
    343   Raises:
    344     ValueError: if all objects in `variable_or_list` are not partitions of the
    345       same large variable.
    346   """
    347   if isinstance(variable_or_list, (list, tuple)):
    348     # A set of slices.
    349     slice_name = None
    350     for v in variable_or_list:
    351       slice_info = v._save_slice_info  # pylint:disable=protected-access
    352       if slice_name is None:
    353         slice_name = slice_info.full_name
    354       elif slice_name != slice_info.full_name:
    355         raise ValueError("Slices must all be from the same tensor: %s != %s" %
    356                          (slice_name, slice_info.full_name))
    357       _set_checkpoint_initializer(v, ckpt_file, tensor_name, slice_info.spec)
    358   else:
    359     _set_checkpoint_initializer(variable_or_list, ckpt_file, tensor_name, "")
    360 
    361 
    362 def _is_variable(x):
    363   return (isinstance(x, variables.Variable) or
    364           resource_variable_ops.is_resource_variable(x))
    365 
    366 def _collect_partitioned_variable(name, all_vars):
    367   """Returns list of `tf.Variable` that comprise the partitioned variable."""
    368   if name + "/part_0" in all_vars:
    369     var = []
    370     i = 0
    371     while name + "/part_%d" % i in all_vars:
    372       var.append(all_vars[name + "/part_%d" % i])
    373       i += 1
    374     return var
    375   return None
    376