Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 # pylint: disable=invalid-name
     17 """Save and restore variables.
     18 
     19 Symbols in this file are deprecated. See replacements in
     20 tensorflow/python/training/trackable and tensorflow/python/training/saving.
     21 """
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import collections
     27 import os.path
     28 import time
     29 import uuid
     30 
     31 import numpy as np
     32 from tensorflow.core.protobuf import meta_graph_pb2
     33 from tensorflow.core.protobuf import saver_pb2
     34 from tensorflow.core.protobuf import trackable_object_graph_pb2
     35 from tensorflow.python import pywrap_tensorflow
     36 from tensorflow.python.client import session
     37 from tensorflow.python.eager import context
     38 from tensorflow.python.framework import constant_op
     39 from tensorflow.python.framework import device as pydev
     40 from tensorflow.python.framework import errors
     41 from tensorflow.python.framework import meta_graph
     42 from tensorflow.python.framework import ops
     43 from tensorflow.python.ops import array_ops
     44 from tensorflow.python.ops import control_flow_ops
     45 from tensorflow.python.ops import gen_io_ops
     46 from tensorflow.python.ops import io_ops
     47 from tensorflow.python.ops import string_ops
     48 from tensorflow.python.ops import variables
     49 from tensorflow.python.platform import gfile
     50 from tensorflow.python.platform import tf_logging as logging
     51 from tensorflow.python.training import checkpoint_management
     52 from tensorflow.python.training import training_util
     53 from tensorflow.python.training.saving import saveable_object
     54 from tensorflow.python.training.saving import saveable_object_util
     55 from tensorflow.python.training.tracking import base as trackable
     56 from tensorflow.python.util import compat
     57 from tensorflow.python.util.tf_export import tf_export
     58 
     59 
     60 # TODO(allenl): Remove these aliases once all users are migrated off.
     61 get_checkpoint_state = checkpoint_management.get_checkpoint_state
     62 update_checkpoint_state = checkpoint_management.update_checkpoint_state
     63 generate_checkpoint_state_proto = (
     64     checkpoint_management.generate_checkpoint_state_proto)
     65 latest_checkpoint = checkpoint_management.latest_checkpoint
     66 checkpoint_exists = checkpoint_management.checkpoint_exists
     67 get_checkpoint_mtimes = checkpoint_management.get_checkpoint_mtimes
     68 remove_checkpoint = checkpoint_management.remove_checkpoint
     69 
     70 
     71 class BaseSaverBuilder(object):
     72   """Base class for Savers.
     73 
     74   Can be extended to create different Ops.
     75   """
     76 
     77   SaveSpec = saveable_object.SaveSpec
     78   SaveableObject = saveable_object.SaveableObject
     79 
     80   # Aliases for code which was moved but still has lots of users.
     81   VariableSaveable = saveable_object_util.ReferenceVariableSaveable
     82   ResourceVariableSaveable = saveable_object_util.ResourceVariableSaveable
     83 
     84   def __init__(self, write_version=saver_pb2.SaverDef.V2):
     85     self._write_version = write_version
     86 
     87   def save_op(self, filename_tensor, saveables):
     88     """Create an Op to save 'saveables'.
     89 
     90     This is intended to be overridden by subclasses that want to generate
     91     different Ops.
     92 
     93     Args:
     94       filename_tensor: String Tensor.
     95       saveables: A list of BaseSaverBuilder.SaveableObject objects.
     96 
     97     Returns:
     98       An Operation that save the variables.
     99 
    100     Raises:
    101       RuntimeError: (implementation detail) if "self._write_version" is an
    102         unexpected value.
    103     """
    104     # pylint: disable=protected-access
    105     tensor_names = []
    106     tensors = []
    107     tensor_slices = []
    108     for saveable in saveables:
    109       for spec in saveable.specs:
    110         tensor_names.append(spec.name)
    111         tensors.append(spec.tensor)
    112         tensor_slices.append(spec.slice_spec)
    113     if self._write_version == saver_pb2.SaverDef.V1:
    114       return io_ops._save(
    115           filename=filename_tensor,
    116           tensor_names=tensor_names,
    117           tensors=tensors,
    118           tensor_slices=tensor_slices)
    119     elif self._write_version == saver_pb2.SaverDef.V2:
    120       # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
    121       # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
    122       return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
    123                             tensors)
    124     else:
    125       raise RuntimeError("Unexpected write_version: " + self._write_version)
    126 
    127   def bulk_restore(self, filename_tensor, saveables, preferred_shard,
    128                    restore_sequentially):
    129     """Restore all tensors contained in saveables.
    130 
    131     By default, this issues separate calls to `restore_op` for each saveable.
    132     Subclasses may override to load multiple saveables in a single call.
    133 
    134     Args:
    135       filename_tensor: String Tensor.
    136       saveables: List of BaseSaverBuilder.SaveableObject objects.
    137       preferred_shard: Int.  Shard to open first when loading a sharded file.
    138       restore_sequentially: Unused.  Bool.  If true, each restore is sequential.
    139 
    140     Returns:
    141       A list of Tensors resulting from reading 'saveable' from
    142         'filename'.
    143 
    144     """
    145     del restore_sequentially
    146     all_tensors = []
    147     for saveable in saveables:
    148       if saveable.device:
    149         device = saveable_object_util.set_cpu0(saveable.device)
    150       else:
    151         device = None
    152       with ops.device(device):
    153         all_tensors.extend(
    154             self.restore_op(filename_tensor, saveable, preferred_shard))
    155     return all_tensors
    156 
    157   # pylint: disable=unused-argument
    158   def restore_op(self, filename_tensor, saveable, preferred_shard):
    159     """Create ops to restore 'saveable'.
    160 
    161     This is intended to be overridden by subclasses that want to generate
    162     different Ops.
    163 
    164     Args:
    165       filename_tensor: String Tensor.
    166       saveable: A BaseSaverBuilder.SaveableObject object.
    167       preferred_shard: Int.  Shard to open first when loading a sharded file.
    168 
    169     Returns:
    170       A list of Tensors resulting from reading 'saveable' from
    171         'filename'.
    172     """
    173     # pylint: disable=protected-access
    174     tensors = []
    175     for spec in saveable.specs:
    176       tensors.append(
    177           io_ops.restore_v2(
    178               filename_tensor,
    179               [spec.name],
    180               [spec.slice_spec],
    181               [spec.dtype])[0])
    182 
    183     return tensors
    184   # pylint: enable=unused-argument
    185 
    186   def sharded_filename(self, filename_tensor, shard, num_shards):
    187     """Append sharding information to a filename.
    188 
    189     Args:
    190       filename_tensor: A string tensor.
    191       shard: Integer.  The shard for the filename.
    192       num_shards: An int Tensor for the number of shards.
    193 
    194     Returns:
    195       A string tensor.
    196     """
    197     return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards)
    198 
    199   def _AddSaveOps(self, filename_tensor, saveables):
    200     """Add ops to save variables that are on the same shard.
    201 
    202     Args:
    203       filename_tensor: String Tensor.
    204       saveables: A list of SaveableObject objects.
    205 
    206     Returns:
    207       A tensor with the filename used to save.
    208     """
    209     save = self.save_op(filename_tensor, saveables)
    210     return control_flow_ops.with_dependencies([save], filename_tensor)
    211 
    212   def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
    213     """Add ops to save the params per shard, for the V2 format.
    214 
    215     Note that the sharded save procedure for the V2 format is different from
    216     V1: there is a special "merge" step that merges the small metadata produced
    217     from each device.
    218 
    219     Args:
    220       checkpoint_prefix: scalar String Tensor.  Interpreted *NOT AS A
    221         FILENAME*, but as a prefix of a V2 checkpoint;
    222       per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
    223         returned by _GroupByDevices().
    224 
    225     Returns:
    226       An op to save the variables, which, when evaluated, returns the prefix
    227         "<user-fed prefix>" only and does not include the sharded spec suffix.
    228     """
    229     # IMPLEMENTATION DETAILS: most clients should skip.
    230     #
    231     # Suffix for any well-formed "checkpoint_prefix", when sharded.
    232     # Transformations:
    233     # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    234     # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
    235     #
    236     # Example:
    237     #   During runtime, a temporary directory is first created, which contains
    238     #   files
    239     #
    240     #     <train dir>/myckpt_temp/
    241     #        part-?????-of-?????{.index, .data-00000-of-00001}
    242     #
    243     #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    244     #
    245     #     <train dir>/
    246     #        myckpt{.index, .data-?????-of-?????}
    247     #
    248     # Users only need to interact with the user-specified prefix, which is
    249     # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    250     # prefix directly, instead of any physical pathname.  (On failure and
    251     # subsequent restore, an outdated and orphaned temporary directory can be
    252     # safely removed.)
    253     _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex
    254     tmp_checkpoint_prefix = string_ops.string_join(
    255         [checkpoint_prefix, _SHARDED_SUFFIX])
    256 
    257     num_shards = len(per_device)
    258     sharded_saves = []
    259     sharded_prefixes = []
    260     num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
    261     last_device = None
    262     for shard, (device, saveables) in enumerate(per_device):
    263       last_device = device
    264       with ops.device(saveable_object_util.set_cpu0(device)):
    265         sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
    266                                                  num_shards_tensor)
    267         sharded_prefixes.append(sharded_filename)
    268         sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
    269 
    270     with ops.control_dependencies([x.op for x in sharded_saves]):
    271       # Co-locates the merge step with the last device.
    272       with ops.device(saveable_object_util.set_cpu0(last_device)):
    273         # V2 format write path consists of a metadata merge step.  Once merged,
    274         # attempts to delete the temporary directory, "<user-fed prefix>_temp".
    275         merge_step = gen_io_ops.merge_v2_checkpoints(
    276             sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
    277         with ops.control_dependencies([merge_step]):
    278           # Returns the prefix "<user-fed prefix>" only.  DOES NOT include the
    279           # sharded spec suffix.
    280           return array_ops.identity(checkpoint_prefix)
    281 
    282   def _AddShardedSaveOps(self, filename_tensor, per_device):
    283     """Add ops to save the params per shard.
    284 
    285     Args:
    286       filename_tensor: a scalar String Tensor.
    287       per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
    288         returned by _GroupByDevices().
    289 
    290     Returns:
    291       An op to save the variables.
    292     """
    293     if self._write_version == saver_pb2.SaverDef.V2:
    294       return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
    295 
    296     num_shards = len(per_device)
    297     sharded_saves = []
    298     num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
    299     for shard, (device, saveables) in enumerate(per_device):
    300       with ops.device(device):
    301         sharded_filename = self.sharded_filename(filename_tensor, shard,
    302                                                  num_shards_tensor)
    303         sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
    304     # Return the sharded name for the save path.
    305     with ops.control_dependencies([x.op for x in sharded_saves]):
    306       return gen_io_ops.sharded_filespec(filename_tensor, num_shards_tensor)
    307 
    308   def _AddRestoreOps(self,
    309                      filename_tensor,
    310                      saveables,
    311                      restore_sequentially,
    312                      reshape,
    313                      preferred_shard=-1,
    314                      name="restore_all"):
    315     """Add operations to restore saveables.
    316 
    317     Args:
    318       filename_tensor: Tensor for the path of the file to load.
    319       saveables: A list of SaveableObject objects.
    320       restore_sequentially: True if we want to restore variables sequentially
    321         within a shard.
    322       reshape: True if we want to reshape loaded tensors to the shape of
    323         the corresponding variable.
    324       preferred_shard: Shard to open first when loading a sharded file.
    325       name: Name for the returned op.
    326 
    327     Returns:
    328       An Operation that restores the variables.
    329     """
    330     all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
    331                                     restore_sequentially)
    332 
    333     assign_ops = []
    334     idx = 0
    335     # Load and optionally reshape on the CPU, as string tensors are not
    336     # available on the GPU.
    337     # TODO(touts): Re-enable restore on GPU when we can support annotating
    338     # string tensors as "HostMemory" inputs.
    339     for saveable in saveables:
    340       shapes = None
    341       if reshape:
    342         # Compute the shapes, let the restore op decide if and how to do
    343         # the reshape.
    344         shapes = []
    345         for spec in saveable.specs:
    346           v = spec.tensor
    347           shape = v.get_shape()
    348           if not shape.is_fully_defined():
    349             shape = array_ops.shape(v)
    350           shapes.append(shape)
    351       saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
    352       idx += len(saveable.specs)
    353       assign_ops.append(saveable.restore(saveable_tensors, shapes))
    354 
    355     # Create a Noop that has control dependencies from all the updates.
    356     return control_flow_ops.group(*assign_ops, name=name)
    357 
    358   def _AddShardedRestoreOps(self, filename_tensor, per_device,
    359                             restore_sequentially, reshape):
    360     """Add Ops to restore variables from multiple devices.
    361 
    362     Args:
    363       filename_tensor: Tensor for the path of the file to load.
    364       per_device: A list of (device, SaveableObject) pairs, as
    365         returned by _GroupByDevices().
    366       restore_sequentially: True if we want to restore variables sequentially
    367         within a shard.
    368       reshape: True if we want to reshape loaded tensors to the shape of
    369         the corresponding variable.
    370 
    371     Returns:
    372       An Operation that restores the variables.
    373     """
    374     sharded_restores = []
    375     for shard, (device, saveables) in enumerate(per_device):
    376       with ops.device(device):
    377         sharded_restores.append(
    378             self._AddRestoreOps(
    379                 filename_tensor,
    380                 saveables,
    381                 restore_sequentially,
    382                 reshape,
    383                 preferred_shard=shard,
    384                 name="restore_shard"))
    385     return control_flow_ops.group(*sharded_restores, name="restore_all")
    386 
    387   def _GroupByDevices(self, saveables):
    388     """Group Variable tensor slices per device.
    389 
    390     TODO(touts): Make sure that all the devices found are on different
    391     job/replica/task/cpu|gpu.  It would be bad if 2 were on the same device.
    392     It can happen if the devices are unspecified.
    393 
    394     Args:
    395       saveables: A list of BaseSaverBuilder.SaveableObject objects.
    396 
    397     Returns:
    398       A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
    399       The list is sorted by ascending device_name.
    400 
    401     Raises:
    402       ValueError: If the tensors of a saveable are on different devices.
    403     """
    404     per_device = collections.defaultdict(lambda: [])
    405     for saveable in saveables:
    406       canonical_device = set(
    407           pydev.canonical_name(spec.tensor.device) for spec in saveable.specs)
    408       if len(canonical_device) != 1:
    409         raise ValueError("All tensors of a saveable object must be "
    410                          "on the same device: %s" % saveable.name)
    411       per_device[canonical_device.pop()].append(saveable)
    412     return sorted(per_device.items(), key=lambda t: t[0])
    413 
    414   def build(self,
    415             names_to_saveables,
    416             reshape=False,
    417             sharded=False,
    418             max_to_keep=5,
    419             keep_checkpoint_every_n_hours=10000.0,
    420             name=None,
    421             restore_sequentially=False,
    422             filename="model"):
    423     """Builds save/restore graph nodes or runs save/restore in eager mode.
    424 
    425     Args:
    426       names_to_saveables: A dictionary mapping name to a Variable or
    427         SaveableObject. Each name will be associated with the
    428         corresponding variable in the checkpoint.
    429       reshape: If True, allow restoring parameters from a checkpoint
    430         that where the parameters have a different shape.  This is
    431         only needed when you try to restore from a Dist-Belief checkpoint,
    432         and only some times.
    433       sharded: If True, shard the checkpoints, one per device that has
    434         Variable nodes.
    435       max_to_keep: Maximum number of checkpoints to keep.  As new checkpoints
    436         are created, old ones are deleted.  If None or 0, no checkpoints are
    437         deleted from the filesystem but only the last one is kept in the
    438         `checkpoint` file.  Presently the number is only roughly enforced.  For
    439         example in case of restarts more than max_to_keep checkpoints may be
    440         kept.
    441       keep_checkpoint_every_n_hours: How often checkpoints should be kept.
    442         Defaults to 10,000 hours.
    443       name: String.  Optional name to use as a prefix when adding operations.
    444       restore_sequentially: A Bool, which if true, causes restore of different
    445         variables to happen sequentially within each device.
    446       filename: If known at graph construction time, filename used for variable
    447         loading/saving. If None, then the default name "model" will be used.
    448 
    449     Returns:
    450       A SaverDef proto.
    451 
    452     Raises:
    453       TypeError: If 'names_to_saveables' is not a dictionary mapping string
    454         keys to variable Tensors.
    455       ValueError: If any of the keys or values in 'names_to_saveables' is not
    456         unique.
    457     """
    458     return self._build_internal(
    459         names_to_saveables=names_to_saveables,
    460         reshape=reshape,
    461         sharded=sharded,
    462         max_to_keep=max_to_keep,
    463         keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
    464         name=name,
    465         restore_sequentially=restore_sequentially,
    466         filename=filename)
    467 
    468   def _build_internal(self,
    469                       names_to_saveables,
    470                       reshape=False,
    471                       sharded=False,
    472                       max_to_keep=5,
    473                       keep_checkpoint_every_n_hours=10000.0,
    474                       name=None,
    475                       restore_sequentially=False,
    476                       filename="model",
    477                       build_save=True,
    478                       build_restore=True):
    479     """build() with option to only perform save and restore."""
    480     if not context.executing_eagerly() and (not build_save or
    481                                             not build_restore):
    482       raise ValueError("save and restore operations need to be built together "
    483                        " when eager execution is not enabled.")
    484 
    485     saveables = saveable_object_util.validate_and_slice_inputs(
    486         names_to_saveables)
    487     if max_to_keep is None:
    488       max_to_keep = 0
    489 
    490     with ops.name_scope(name, "save",
    491                         [saveable.op for saveable in saveables]) as name:
    492       # Add a placeholder string tensor for the filename.
    493       filename_tensor = array_ops.placeholder_with_default(
    494           filename or "model", shape=(), name="filename")
    495       # Keep the name "Const" for backwards compatibility.
    496       filename_tensor = array_ops.placeholder_with_default(
    497           filename_tensor, shape=(), name="Const")
    498 
    499       # Add the save ops.
    500       if sharded:
    501         per_device = self._GroupByDevices(saveables)
    502         if build_save:
    503           save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
    504         if build_restore:
    505           restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
    506                                                   restore_sequentially, reshape)
    507       else:
    508         if build_save:
    509           save_tensor = self._AddSaveOps(filename_tensor, saveables)
    510         if build_restore:
    511           restore_op = self._AddRestoreOps(filename_tensor, saveables,
    512                                            restore_sequentially, reshape)
    513 
    514     # In the following use case, it's possible to have restore_ops be called
    515     # something else:
    516     # - Build inference graph and export a meta_graph.
    517     # - Import the inference meta_graph
    518     # - Extend the inference graph to a train graph.
    519     # - Export a new meta_graph.
    520     # Now the second restore_op will be called "restore_all_1".
    521     # As such, comment out the assert for now until we know whether supporting
    522     # such usage model makes sense.
    523     #
    524     # assert restore_op.name.endswith("restore_all"), restore_op.name
    525     if context.executing_eagerly():
    526       # Store the tensor values to the tensor_names.
    527       save_tensor_name = save_tensor.numpy() if build_save else ""
    528       return saver_pb2.SaverDef(
    529           filename_tensor_name=filename_tensor.numpy(),
    530           save_tensor_name=save_tensor_name,
    531           restore_op_name="",
    532           max_to_keep=max_to_keep,
    533           sharded=sharded,
    534           keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
    535           version=self._write_version)
    536     else:
    537       graph = ops.get_default_graph()
    538       # Do some sanity checking on collections containing
    539       # PartitionedVariables. If a saved collection has a PartitionedVariable,
    540       # the GraphDef needs to include concat ops to get the value (or there'll
    541       # be a lookup error on load).
    542       check_collection_list = graph.get_all_collection_keys()
    543       for collection_type in check_collection_list:
    544         for element in graph.get_collection(collection_type):
    545           if isinstance(element, variables.PartitionedVariable):
    546             try:
    547               graph.get_operation_by_name(element.name)
    548             except KeyError:
    549               # Create a concat op for this PartitionedVariable. The user may
    550               # not need it, but we'll try looking it up on MetaGraph restore
    551               # since it's in a collection.
    552               element.as_tensor()
    553       return saver_pb2.SaverDef(
    554           filename_tensor_name=filename_tensor.name,
    555           save_tensor_name=save_tensor.name,
    556           restore_op_name=restore_op.name,
    557           max_to_keep=max_to_keep,
    558           sharded=sharded,
    559           keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
    560           version=self._write_version)
    561 
    562 
    563 class BulkSaverBuilder(BaseSaverBuilder):
    564   """SaverBuilder with support for bulk restoring multiple saveables."""
    565 
    566   def bulk_restore(self, filename_tensor, saveables, preferred_shard,
    567                    restore_sequentially):
    568 
    569     # Ignored: bulk restore is internally sequential.
    570     del restore_sequentially
    571     restore_specs = []
    572     for saveable in saveables:
    573       for spec in saveable.specs:
    574         restore_specs.append((spec.name, spec.slice_spec, spec.dtype))
    575 
    576     names, slices, dtypes = zip(*restore_specs)
    577     # Load all tensors onto CPU 0 for compatibility with existing code.
    578     with ops.device("cpu:0"):
    579       return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
    580 
    581 
    582 def _get_saver_or_default():
    583   """Returns the saver from SAVERS collection, or creates a default one.
    584 
    585   This method is used by other members of the training module, such as
    586   `Scaffold`, or `CheckpointSaverHook`.
    587 
    588   Returns:
    589     `Saver`.
    590 
    591   Raises:
    592     RuntimeError: If the SAVERS collection already has more than one items.
    593   """
    594   collection_key = ops.GraphKeys.SAVERS
    595   savers = ops.get_collection(collection_key)
    596   if savers:
    597     if len(savers) > 1:
    598       raise RuntimeError(
    599           "More than one item in collection {}. "
    600           "Please indicate which one to use by passing it to the constructor.".
    601           format(collection_key))
    602     return savers[0]
    603   saver = Saver(sharded=True, allow_empty=True)
    604   if saver is not None:
    605     ops.add_to_collection(collection_key, saver)
    606   return saver
    607 
    608 
    609 @tf_export(v1=["train.Saver"])
    610 class Saver(object):
    611   """Saves and restores variables.
    612 
    613   See [Variables](https://tensorflow.org/guide/variables)
    614   for an overview of variables, saving and restoring.
    615 
    616   The `Saver` class adds ops to save and restore variables to and from
    617   *checkpoints*.  It also provides convenience methods to run these ops.
    618 
    619   Checkpoints are binary files in a proprietary format which map variable names
    620   to tensor values.  The best way to examine the contents of a checkpoint is to
    621   load it using a `Saver`.
    622 
    623   Savers can automatically number checkpoint filenames with a provided counter.
    624   This lets you keep multiple checkpoints at different steps while training a
    625   model.  For example you can number the checkpoint filenames with the training
    626   step number.  To avoid filling up disks, savers manage checkpoint files
    627   automatically. For example, they can keep only the N most recent files, or
    628   one checkpoint for every N hours of training.
    629 
    630   You number checkpoint filenames by passing a value to the optional
    631   `global_step` argument to `save()`:
    632 
    633   ```python
    634   saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
    635   ...
    636   saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
    637   ```
    638 
    639   Additionally, optional arguments to the `Saver()` constructor let you control
    640   the proliferation of checkpoint files on disk:
    641 
    642   * `max_to_keep` indicates the maximum number of recent checkpoint files to
    643     keep.  As new files are created, older files are deleted.   If None or 0,
    644     no checkpoints are deleted from the filesystem but only the last one is
    645     kept in the `checkpoint` file.  Defaults to 5 (that is, the 5 most recent
    646     checkpoint files are kept.)
    647 
    648   * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
    649     `max_to_keep` checkpoint files, you might want to keep one checkpoint file
    650     for every N hours of training.  This can be useful if you want to later
    651     analyze how a model progressed during a long training session.  For
    652     example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
    653     one checkpoint file for every 2 hours of training.  The default value of
    654     10,000 hours effectively disables the feature.
    655 
    656   Note that you still have to call the `save()` method to save the model.
    657   Passing these arguments to the constructor will not save variables
    658   automatically for you.
    659 
    660   A training program that saves regularly looks like:
    661 
    662   ```python
    663   ...
    664   # Create a saver.
    665   saver = tf.train.Saver(...variables...)
    666   # Launch the graph and train, saving the model every 1,000 steps.
    667   sess = tf.Session()
    668   for step in xrange(1000000):
    669       sess.run(..training_op..)
    670       if step % 1000 == 0:
    671           # Append the step number to the checkpoint name:
    672           saver.save(sess, 'my-model', global_step=step)
    673   ```
    674 
    675   In addition to checkpoint files, savers keep a protocol buffer on disk with
    676   the list of recent checkpoints. This is used to manage numbered checkpoint
    677   files and by `latest_checkpoint()`, which makes it easy to discover the path
    678   to the most recent checkpoint. That protocol buffer is stored in a file named
    679   'checkpoint' next to the checkpoint files.
    680 
    681   If you create several savers, you can specify a different filename for the
    682   protocol buffer file in the call to `save()`.
    683   """
    684 
    685   def __init__(self,
    686                var_list=None,
    687                reshape=False,
    688                sharded=False,
    689                max_to_keep=5,
    690                keep_checkpoint_every_n_hours=10000.0,
    691                name=None,
    692                restore_sequentially=False,
    693                saver_def=None,
    694                builder=None,
    695                defer_build=False,
    696                allow_empty=False,
    697                write_version=saver_pb2.SaverDef.V2,
    698                pad_step_number=False,
    699                save_relative_paths=False,
    700                filename=None):
    701     """Creates a `Saver`.
    702 
    703     The constructor adds ops to save and restore variables.
    704 
    705     `var_list` specifies the variables that will be saved and restored. It can
    706     be passed as a `dict` or a list:
    707 
    708     * A `dict` of names to variables: The keys are the names that will be
    709       used to save or restore the variables in the checkpoint files.
    710     * A list of variables: The variables will be keyed with their op name in
    711       the checkpoint files.
    712 
    713     For example:
    714 
    715     ```python
    716     v1 = tf.Variable(..., name='v1')
    717     v2 = tf.Variable(..., name='v2')
    718 
    719     # Pass the variables as a dict:
    720     saver = tf.train.Saver({'v1': v1, 'v2': v2})
    721 
    722     # Or pass them as a list.
    723     saver = tf.train.Saver([v1, v2])
    724     # Passing a list is equivalent to passing a dict with the variable op names
    725     # as keys:
    726     saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
    727     ```
    728 
    729     The optional `reshape` argument, if `True`, allows restoring a variable from
    730     a save file where the variable had a different shape, but the same number
    731     of elements and type.  This is useful if you have reshaped a variable and
    732     want to reload it from an older checkpoint.
    733 
    734     The optional `sharded` argument, if `True`, instructs the saver to shard
    735     checkpoints per device.
    736 
    737     Args:
    738       var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
    739         names to `SaveableObject`s. If `None`, defaults to the list of all
    740         saveable objects.
    741       reshape: If `True`, allows restoring parameters from a checkpoint
    742         where the variables have a different shape.
    743       sharded: If `True`, shard the checkpoints, one per device.
    744       max_to_keep: Maximum number of recent checkpoints to keep.
    745         Defaults to 5.
    746       keep_checkpoint_every_n_hours: How often to keep checkpoints.
    747         Defaults to 10,000 hours.
    748       name: String.  Optional name to use as a prefix when adding operations.
    749       restore_sequentially: A `Bool`, which if true, causes restore of different
    750         variables to happen sequentially within each device.  This can lower
    751         memory usage when restoring very large models.
    752       saver_def: Optional `SaverDef` proto to use instead of running the
    753         builder. This is only useful for specialty code that wants to recreate
    754         a `Saver` object for a previously built `Graph` that had a `Saver`.
    755         The `saver_def` proto should be the one returned by the
    756         `as_saver_def()` call of the `Saver` that was created for that `Graph`.
    757       builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
    758         Defaults to `BulkSaverBuilder()`.
    759       defer_build: If `True`, defer adding the save and restore ops to the
    760         `build()` call. In that case `build()` should be called before
    761         finalizing the graph or using the saver.
    762       allow_empty: If `False` (default) raise an error if there are no
    763         variables in the graph. Otherwise, construct the saver anyway and make
    764         it a no-op.
    765       write_version: controls what format to use when saving checkpoints.  It
    766         also affects certain filepath matching logic.  The V2 format is the
    767         recommended choice: it is much more optimized than V1 in terms of
    768         memory required and latency incurred during restore.  Regardless of
    769         this flag, the Saver is able to restore from both V2 and V1 checkpoints.
    770       pad_step_number: if True, pads the global step number in the checkpoint
    771         filepaths to some fixed width (8 by default).  This is turned off by
    772         default.
    773       save_relative_paths: If `True`, will write relative paths to the
    774         checkpoint state file. This is needed if the user wants to copy the
    775         checkpoint directory and reload from the copied directory.
    776       filename: If known at graph construction time, filename used for variable
    777         loading/saving.
    778 
    779     Raises:
    780       TypeError: If `var_list` is invalid.
    781       ValueError: If any of the keys or values in `var_list` are not unique.
    782       RuntimeError: If eager execution is enabled and`var_list` does not specify
    783         a list of varialbes to save.
    784 
    785     @compatibility(eager)
    786     When eager execution is enabled, `var_list` must specify a `list` or `dict`
    787     of variables to save. Otherwise, a `RuntimeError` will be raised.
    788 
    789     Although Saver works in some cases when executing eagerly, it is
    790     fragile. Please switch to `tf.train.Checkpoint` or
    791     `tf.keras.Model.save_weights`, which perform a more robust object-based
    792     saving. These APIs will load checkpoints written by `Saver`.
    793     @end_compatibility
    794     """
    795     if defer_build and var_list:
    796       raise ValueError(
    797           "If `var_list` is provided then build cannot be deferred. "
    798           "Either set defer_build=False or var_list=None.")
    799     if context.executing_eagerly():
    800       logging.warning(
    801           "Saver is deprecated, please switch to tf.train.Checkpoint or "
    802           "tf.keras.Model.save_weights for training checkpoints. When "
    803           "executing eagerly variables do not necessarily have unique names, "
    804           "and so the variable.name-based lookups Saver performs are "
    805           "error-prone.")
    806       if var_list is None:
    807         raise RuntimeError(
    808             "When eager execution is enabled, `var_list` must specify a list "
    809             "or dict of variables to save")
    810     self._var_list = var_list
    811     self._reshape = reshape
    812     self._sharded = sharded
    813     self._max_to_keep = max_to_keep
    814     self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
    815     self._name = name
    816     self._restore_sequentially = restore_sequentially
    817     self.saver_def = saver_def
    818     self._builder = builder
    819     self._is_built = False
    820     self._allow_empty = allow_empty
    821     self._is_empty = None
    822     self._write_version = write_version
    823     self._pad_step_number = pad_step_number
    824     self._filename = filename
    825     self._last_checkpoints = []
    826     self._checkpoints_to_be_deleted = []
    827     if context.executing_eagerly():
    828       self._next_checkpoint_time = (
    829           time.time() + self._keep_checkpoint_every_n_hours * 3600)
    830     elif not defer_build:
    831       self.build()
    832     if self.saver_def:
    833       self._check_saver_def()
    834       self._write_version = self.saver_def.version
    835     self._save_relative_paths = save_relative_paths
    836     # For compatibility with object-based checkpoints, we may build a second
    837     # Saver to read the renamed keys.
    838     self._object_restore_saver = None
    839 
    840   def build(self):
    841     if context.executing_eagerly():
    842       raise RuntimeError("Use save/restore instead of build in eager mode.")
    843     self._build(self._filename, build_save=True, build_restore=True)
    844 
    845   def _build_eager(self, checkpoint_path, build_save, build_restore):
    846     self._build(
    847         checkpoint_path, build_save=build_save, build_restore=build_restore)
    848 
    849   def _build(self, checkpoint_path, build_save, build_restore):
    850     """Builds saver_def."""
    851     if not context.executing_eagerly():
    852       if self._is_built:
    853         return
    854       self._is_built = True
    855 
    856     if not self.saver_def or context.executing_eagerly():
    857       if self._builder is None:
    858         self._builder = BulkSaverBuilder(self._write_version)
    859 
    860       if self._var_list is None:
    861         # pylint: disable=protected-access
    862         self._var_list = variables._all_saveable_objects()
    863       if not self._var_list:
    864         if self._allow_empty:
    865           self._is_empty = True
    866           return
    867         else:
    868           raise ValueError("No variables to save")
    869       self._is_empty = False
    870 
    871       self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
    872           self._var_list,
    873           reshape=self._reshape,
    874           sharded=self._sharded,
    875           max_to_keep=self._max_to_keep,
    876           keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
    877           name=self._name,
    878           restore_sequentially=self._restore_sequentially,
    879           filename=checkpoint_path,
    880           build_save=build_save, build_restore=build_restore)
    881     elif self.saver_def and self._name:
    882       # Since self._name is used as a name_scope by builder(), we are
    883       # overloading the use of this field to represent the "import_scope" as
    884       # well.
    885       self.saver_def.filename_tensor_name = ops.prepend_name_scope(
    886           self.saver_def.filename_tensor_name, self._name)
    887       self.saver_def.save_tensor_name = ops.prepend_name_scope(
    888           self.saver_def.save_tensor_name, self._name)
    889       self.saver_def.restore_op_name = ops.prepend_name_scope(
    890           self.saver_def.restore_op_name, self._name)
    891 
    892     self._check_saver_def()
    893     if not context.executing_eagerly():
    894       # Updates next checkpoint time.
    895       # Set in __init__ when executing eagerly.
    896       self._next_checkpoint_time = (
    897           time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
    898 
    899   def _check_saver_def(self):
    900     if not isinstance(self.saver_def, saver_pb2.SaverDef):
    901       raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
    902                        self.saver_def)
    903     if not context.executing_eagerly():
    904       if not self.saver_def.save_tensor_name:
    905         raise ValueError("saver_def must specify the save_tensor_name: %s" %
    906                          str(self.saver_def))
    907       if not self.saver_def.restore_op_name:
    908         raise ValueError("saver_def must specify the restore_op_name: %s" %
    909                          str(self.saver_def))
    910 
    911   def _CheckpointFilename(self, p):
    912     """Returns the checkpoint filename given a `(filename, time)` pair.
    913 
    914     Args:
    915       p: (filename, time) pair.
    916 
    917     Returns:
    918       Checkpoint file name.
    919     """
    920     name, _ = p
    921     return name
    922 
    923   def _RecordLastCheckpoint(self, latest_save_path):
    924     """Manages the list of the latest checkpoints."""
    925     if not self.saver_def.max_to_keep:
    926       return
    927     # Remove first from list if the same name was used before.
    928     for p in self._last_checkpoints:
    929       if latest_save_path == self._CheckpointFilename(p):
    930         self._last_checkpoints.remove(p)
    931     # Append new path to list
    932     self._last_checkpoints.append((latest_save_path, time.time()))
    933 
    934     # If more than max_to_keep, remove oldest.
    935     if len(self._last_checkpoints) > self.saver_def.max_to_keep:
    936       self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
    937 
    938   def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
    939     """Deletes old checkpoints if necessary.
    940 
    941     `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
    942     over `max_to_keep`.  They are going to be deleted.  If
    943     `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
    944     every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
    945     kept for every 0.5 hours of training; if `N` is 10, an additional
    946     checkpoint is kept for every 10 hours of training.
    947 
    948     Args:
    949       meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
    950     """
    951     if self._checkpoints_to_be_deleted:
    952       p = self._checkpoints_to_be_deleted.pop(0)
    953       # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
    954       # have reached N hours of training.
    955       should_keep = p[1] > self._next_checkpoint_time
    956       if should_keep:
    957         self._next_checkpoint_time += (
    958             self.saver_def.keep_checkpoint_every_n_hours * 3600)
    959         return
    960 
    961       # Otherwise delete the files.
    962       try:
    963         checkpoint_management.remove_checkpoint(
    964             self._CheckpointFilename(p), self.saver_def.version,
    965             meta_graph_suffix)
    966       except Exception as e:  # pylint: disable=broad-except
    967         logging.warning("Ignoring: %s", str(e))
    968 
    969   def as_saver_def(self):
    970     """Generates a `SaverDef` representation of this saver.
    971 
    972     Returns:
    973       A `SaverDef` proto.
    974     """
    975     return self.saver_def
    976 
    977   def to_proto(self, export_scope=None):
    978     """Converts this `Saver` to a `SaverDef` protocol buffer.
    979 
    980     Args:
    981       export_scope: Optional `string`. Name scope to remove.
    982 
    983     Returns:
    984       A `SaverDef` protocol buffer.
    985     """
    986     if export_scope is None:
    987       return self.saver_def
    988 
    989     if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
    990             self.saver_def.save_tensor_name.startswith(export_scope) and
    991             self.saver_def.restore_op_name.startswith(export_scope)):
    992       return None
    993 
    994     saver_def = saver_pb2.SaverDef()
    995     saver_def.CopyFrom(self.saver_def)
    996     saver_def.filename_tensor_name = ops.strip_name_scope(
    997         saver_def.filename_tensor_name, export_scope)
    998     saver_def.save_tensor_name = ops.strip_name_scope(
    999         saver_def.save_tensor_name, export_scope)
   1000     saver_def.restore_op_name = ops.strip_name_scope(
   1001         saver_def.restore_op_name, export_scope)
   1002     return saver_def
   1003 
   1004   @staticmethod
   1005   def from_proto(saver_def, import_scope=None):
   1006     """Returns a `Saver` object created from `saver_def`.
   1007 
   1008     Args:
   1009       saver_def: a `SaverDef` protocol buffer.
   1010       import_scope: Optional `string`. Name scope to use.
   1011 
   1012     Returns:
   1013       A `Saver` built from saver_def.
   1014     """
   1015     return Saver(saver_def=saver_def, name=import_scope)
   1016 
   1017   @property
   1018   def last_checkpoints(self):
   1019     """List of not-yet-deleted checkpoint filenames.
   1020 
   1021     You can pass any of the returned values to `restore()`.
   1022 
   1023     Returns:
   1024       A list of checkpoint filenames, sorted from oldest to newest.
   1025     """
   1026     return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
   1027 
   1028   def set_last_checkpoints(self, last_checkpoints):
   1029     """DEPRECATED: Use set_last_checkpoints_with_time.
   1030 
   1031     Sets the list of old checkpoint filenames.
   1032 
   1033     Args:
   1034       last_checkpoints: A list of checkpoint filenames.
   1035 
   1036     Raises:
   1037       AssertionError: If last_checkpoints is not a list.
   1038     """
   1039     assert isinstance(last_checkpoints, list)
   1040     # We use a timestamp of +inf so that this checkpoint will never be
   1041     # deleted.  This is both safe and backwards compatible to a previous
   1042     # version of the code which used s[1] as the "timestamp".
   1043     self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
   1044 
   1045   def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
   1046     """Sets the list of old checkpoint filenames and timestamps.
   1047 
   1048     Args:
   1049       last_checkpoints_with_time: A list of tuples of checkpoint filenames and
   1050         timestamps.
   1051 
   1052     Raises:
   1053       AssertionError: If last_checkpoints_with_time is not a list.
   1054     """
   1055     assert isinstance(last_checkpoints_with_time, list)
   1056     self._last_checkpoints = last_checkpoints_with_time
   1057 
   1058   def recover_last_checkpoints(self, checkpoint_paths):
   1059     """Recovers the internal saver state after a crash.
   1060 
   1061     This method is useful for recovering the "self._last_checkpoints" state.
   1062 
   1063     Globs for the checkpoints pointed to by `checkpoint_paths`.  If the files
   1064     exist, use their mtime as the checkpoint timestamp.
   1065 
   1066     Args:
   1067       checkpoint_paths: a list of checkpoint paths.
   1068     """
   1069     mtimes = checkpoint_management.get_checkpoint_mtimes(checkpoint_paths)
   1070     self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
   1071 
   1072   def save(self,
   1073            sess,
   1074            save_path,
   1075            global_step=None,
   1076            latest_filename=None,
   1077            meta_graph_suffix="meta",
   1078            write_meta_graph=True,
   1079            write_state=True,
   1080            strip_default_attrs=False,
   1081            save_debug_info=False):
   1082     # pylint: disable=line-too-long
   1083     """Saves variables.
   1084 
   1085     This method runs the ops added by the constructor for saving variables.
   1086     It requires a session in which the graph was launched.  The variables to
   1087     save must also have been initialized.
   1088 
   1089     The method returns the path prefix of the newly created checkpoint files.
   1090     This string can be passed directly to a call to `restore()`.
   1091 
   1092     Args:
   1093       sess: A Session to use to save the variables.
   1094       save_path: String.  Prefix of filenames created for the checkpoint.
   1095       global_step: If provided the global step number is appended to
   1096         `save_path` to create the checkpoint filenames. The optional argument
   1097         can be a `Tensor`, a `Tensor` name or an integer.
   1098       latest_filename: Optional name for the protocol buffer file that will
   1099         contains the list of most recent checkpoints.  That file,
   1100         kept in the same directory as the checkpoint files, is automatically
   1101         managed by the saver to keep track of recent checkpoints.  Defaults to
   1102         'checkpoint'.
   1103       meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
   1104       write_meta_graph: `Boolean` indicating whether or not to write the meta
   1105         graph file.
   1106       write_state: `Boolean` indicating whether or not to write the
   1107         `CheckpointStateProto`.
   1108       strip_default_attrs: Boolean. If `True`, default-valued attributes will be
   1109         removed from the NodeDefs. For a detailed guide, see
   1110         [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
   1111       save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
   1112         which in the same directory of save_path and with `_debug` added before
   1113         the file extension. This is only enabled when `write_meta_graph` is
   1114         `True`
   1115 
   1116     Returns:
   1117       A string: path prefix used for the checkpoint files.  If the saver is
   1118         sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
   1119         is the number of shards created.
   1120       If the saver is empty, returns None.
   1121 
   1122     Raises:
   1123       TypeError: If `sess` is not a `Session`.
   1124       ValueError: If `latest_filename` contains path components, or if it
   1125         collides with `save_path`.
   1126       RuntimeError: If save and restore ops weren't built.
   1127     """
   1128     # pylint: enable=line-too-long
   1129     if not self._is_built and not context.executing_eagerly():
   1130       raise RuntimeError(
   1131           "`build()` should be called before save if defer_build==True")
   1132     if latest_filename is None:
   1133       latest_filename = "checkpoint"
   1134     if self._write_version != saver_pb2.SaverDef.V2:
   1135       logging.warning("*******************************************************")
   1136       logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
   1137       logging.warning("Consider switching to the more efficient V2 format:")
   1138       logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
   1139       logging.warning("now on by default.")
   1140       logging.warning("*******************************************************")
   1141 
   1142     if os.path.split(latest_filename)[0]:
   1143       raise ValueError("'latest_filename' must not contain path components")
   1144 
   1145     if global_step is not None:
   1146       if not isinstance(global_step, compat.integral_types):
   1147         global_step = training_util.global_step(sess, global_step)
   1148       checkpoint_file = "%s-%d" % (save_path, global_step)
   1149       if self._pad_step_number:
   1150         # Zero-pads the step numbers, so that they are sorted when listed.
   1151         checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
   1152     else:
   1153       checkpoint_file = save_path
   1154       if os.path.basename(
   1155           save_path) == latest_filename and not self._sharded:
   1156         # Guard against collision between data file and checkpoint state file.
   1157         raise ValueError(
   1158             "'latest_filename' collides with 'save_path': '%s' and '%s'" %
   1159             (latest_filename, save_path))
   1160 
   1161     if (not context.executing_eagerly() and
   1162         not isinstance(sess, session.SessionInterface)):
   1163       raise TypeError("'sess' must be a Session; %s" % sess)
   1164 
   1165     save_path_parent = os.path.dirname(save_path)
   1166     if not self._is_empty:
   1167       try:
   1168         if context.executing_eagerly():
   1169           self._build_eager(
   1170               checkpoint_file, build_save=True, build_restore=False)
   1171           model_checkpoint_path = self.saver_def.save_tensor_name
   1172         else:
   1173           model_checkpoint_path = sess.run(
   1174               self.saver_def.save_tensor_name,
   1175               {self.saver_def.filename_tensor_name: checkpoint_file})
   1176 
   1177         model_checkpoint_path = compat.as_str(model_checkpoint_path)
   1178         if write_state:
   1179           self._RecordLastCheckpoint(model_checkpoint_path)
   1180           checkpoint_management.update_checkpoint_state_internal(
   1181               save_dir=save_path_parent,
   1182               model_checkpoint_path=model_checkpoint_path,
   1183               all_model_checkpoint_paths=self.last_checkpoints,
   1184               latest_filename=latest_filename,
   1185               save_relative_paths=self._save_relative_paths)
   1186           self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
   1187       except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
   1188         if not gfile.IsDirectory(save_path_parent):
   1189           exc = ValueError(
   1190               "Parent directory of {} doesn't exist, can't save.".format(
   1191                   save_path))
   1192         raise exc
   1193 
   1194     if write_meta_graph:
   1195       meta_graph_filename = checkpoint_management.meta_graph_filename(
   1196           checkpoint_file, meta_graph_suffix=meta_graph_suffix)
   1197       if not context.executing_eagerly():
   1198         with sess.graph.as_default():
   1199           self.export_meta_graph(
   1200               meta_graph_filename, strip_default_attrs=strip_default_attrs,
   1201               save_debug_info=save_debug_info)
   1202 
   1203     if self._is_empty:
   1204       return None
   1205     else:
   1206       return model_checkpoint_path
   1207 
   1208   def export_meta_graph(self,
   1209                         filename=None,
   1210                         collection_list=None,
   1211                         as_text=False,
   1212                         export_scope=None,
   1213                         clear_devices=False,
   1214                         clear_extraneous_savers=False,
   1215                         strip_default_attrs=False,
   1216                         save_debug_info=False):
   1217     # pylint: disable=line-too-long
   1218     """Writes `MetaGraphDef` to save_path/filename.
   1219 
   1220     Args:
   1221       filename: Optional meta_graph filename including the path.
   1222       collection_list: List of string keys to collect.
   1223       as_text: If `True`, writes the meta_graph as an ASCII proto.
   1224       export_scope: Optional `string`. Name scope to remove.
   1225       clear_devices: Whether or not to clear the device field for an `Operation`
   1226         or `Tensor` during export.
   1227       clear_extraneous_savers: Remove any Saver-related information from the
   1228         graph (both Save/Restore ops and SaverDefs) that are not associated
   1229         with this Saver.
   1230       strip_default_attrs: Boolean. If `True`, default-valued attributes will be
   1231         removed from the NodeDefs. For a detailed guide, see
   1232         [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
   1233       save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
   1234         which in the same directory of filename and with `_debug` added before
   1235         the file extension.
   1236 
   1237     Returns:
   1238       A `MetaGraphDef` proto.
   1239     """
   1240     # pylint: enable=line-too-long
   1241     return export_meta_graph(
   1242         filename=filename,
   1243         graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
   1244         saver_def=self.saver_def,
   1245         collection_list=collection_list,
   1246         as_text=as_text,
   1247         export_scope=export_scope,
   1248         clear_devices=clear_devices,
   1249         clear_extraneous_savers=clear_extraneous_savers,
   1250         strip_default_attrs=strip_default_attrs,
   1251         save_debug_info=save_debug_info)
   1252 
   1253   def restore(self, sess, save_path):
   1254     """Restores previously saved variables.
   1255 
   1256     This method runs the ops added by the constructor for restoring variables.
   1257     It requires a session in which the graph was launched.  The variables to
   1258     restore do not have to have been initialized, as restoring is itself a way
   1259     to initialize variables.
   1260 
   1261     The `save_path` argument is typically a value previously returned from a
   1262     `save()` call, or a call to `latest_checkpoint()`.
   1263 
   1264     Args:
   1265       sess: A `Session` to use to restore the parameters. None in eager mode.
   1266       save_path: Path where parameters were previously saved.
   1267 
   1268     Raises:
   1269       ValueError: If save_path is None or not a valid checkpoint.
   1270     """
   1271     if self._is_empty:
   1272       return
   1273     if save_path is None:
   1274       raise ValueError("Can't load save_path when it is None.")
   1275 
   1276     if not checkpoint_management.checkpoint_exists(compat.as_text(save_path)):
   1277       raise ValueError("The passed save_path is not a valid checkpoint: "
   1278                        + compat.as_text(save_path))
   1279 
   1280     logging.info("Restoring parameters from %s", compat.as_text(save_path))
   1281     try:
   1282       if context.executing_eagerly():
   1283         self._build_eager(save_path, build_save=False, build_restore=True)
   1284       else:
   1285         sess.run(self.saver_def.restore_op_name,
   1286                  {self.saver_def.filename_tensor_name: save_path})
   1287     except errors.NotFoundError as err:
   1288       # There are three common conditions that might cause this error:
   1289       # 0. The file is missing. We ignore here, as this is checked above.
   1290       # 1. This is an object-based checkpoint trying name-based loading.
   1291       # 2. The graph has been altered and a variable or other name is missing.
   1292 
   1293       # 1. The checkpoint would not be loaded successfully as is. Try to parse
   1294       # it as an object-based checkpoint.
   1295       try:
   1296         names_to_keys = object_graph_key_mapping(save_path)
   1297       except errors.NotFoundError:
   1298         # 2. This is not an object-based checkpoint, which likely means there
   1299         # is a graph mismatch. Re-raise the original error with
   1300         # a helpful message (b/110263146)
   1301         raise _wrap_restore_error_with_msg(
   1302             err, "a Variable name or other graph key that is missing")
   1303 
   1304       # This is an object-based checkpoint. We'll print a warning and then do
   1305       # the restore.
   1306       logging.warning(
   1307           "Restoring an object-based checkpoint using a name-based saver. This "
   1308           "may be somewhat fragile, and will re-build the Saver. Instead, "
   1309           "consider loading object-based checkpoints using "
   1310           "tf.train.Checkpoint().")
   1311       self._object_restore_saver = saver_from_object_based_checkpoint(
   1312           checkpoint_path=save_path,
   1313           var_list=self._var_list,
   1314           builder=self._builder,
   1315           names_to_keys=names_to_keys,
   1316           cached_saver=self._object_restore_saver)
   1317       self._object_restore_saver.restore(sess=sess, save_path=save_path)
   1318     except errors.InvalidArgumentError as err:
   1319       # There is a mismatch between the graph and the checkpoint being loaded.
   1320       # We add a more reasonable error message here to help users (b/110263146)
   1321       raise _wrap_restore_error_with_msg(
   1322           err, "a mismatch between the current graph and the graph")
   1323 
   1324   @staticmethod
   1325   def _add_collection_def(meta_graph_def, key, export_scope=None):
   1326     """Adds a collection to MetaGraphDef protocol buffer.
   1327 
   1328     Args:
   1329       meta_graph_def: MetaGraphDef protocol buffer.
   1330       key: One of the GraphKeys or user-defined string.
   1331       export_scope: Optional `string`. Name scope to remove.
   1332     """
   1333     meta_graph.add_collection_def(meta_graph_def, key,
   1334                                   export_scope=export_scope)
   1335 
   1336 
   1337 @tf_export(v1=["train.import_meta_graph"])
   1338 def import_meta_graph(meta_graph_or_file, clear_devices=False,
   1339                       import_scope=None, **kwargs):
   1340   """Recreates a Graph saved in a `MetaGraphDef` proto.
   1341 
   1342   This function takes a `MetaGraphDef` protocol buffer as input. If
   1343   the argument is a file containing a `MetaGraphDef` protocol buffer ,
   1344   it constructs a protocol buffer from the file content. The function
   1345   then adds all the nodes from the `graph_def` field to the
   1346   current graph, recreates all the collections, and returns a saver
   1347   constructed from the `saver_def` field.
   1348 
   1349   In combination with `export_meta_graph()`, this function can be used to
   1350 
   1351   * Serialize a graph along with other Python objects such as `QueueRunner`,
   1352     `Variable` into a `MetaGraphDef`.
   1353 
   1354   * Restart training from a saved graph and checkpoints.
   1355 
   1356   * Run inference from a saved graph and checkpoints.
   1357 
   1358   ```Python
   1359   ...
   1360   # Create a saver.
   1361   saver = tf.train.Saver(...variables...)
   1362   # Remember the training_op we want to run by adding it to a collection.
   1363   tf.add_to_collection('train_op', train_op)
   1364   sess = tf.Session()
   1365   for step in xrange(1000000):
   1366       sess.run(train_op)
   1367       if step % 1000 == 0:
   1368           # Saves checkpoint, which by default also exports a meta_graph
   1369           # named 'my-model-global_step.meta'.
   1370           saver.save(sess, 'my-model', global_step=step)
   1371   ```
   1372 
   1373   Later we can continue training from this saved `meta_graph` without building
   1374   the model from scratch.
   1375 
   1376   ```Python
   1377   with tf.Session() as sess:
   1378     new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
   1379     new_saver.restore(sess, 'my-save-dir/my-model-10000')
   1380     # tf.get_collection() returns a list. In this example we only want the
   1381     # first one.
   1382     train_op = tf.get_collection('train_op')[0]
   1383     for step in xrange(1000000):
   1384       sess.run(train_op)
   1385   ```
   1386 
   1387   NOTE: Restarting training from saved `meta_graph` only works if the
   1388   device assignments have not changed.
   1389 
   1390   Example 2:
   1391   Variables, placeholders, and independent operations can also be stored, as
   1392   shown in the following example.
   1393 
   1394   ```Python
   1395   # Saving contents and operations.
   1396   v1 = tf.placeholder(tf.float32, name="v1")
   1397   v2 = tf.placeholder(tf.float32, name="v2")
   1398   v3 = tf.mul(v1, v2)
   1399   vx = tf.Variable(10.0, name="vx")
   1400   v4 = tf.add(v3, vx, name="v4")
   1401   saver = tf.train.Saver([vx])
   1402   sess = tf.Session()
   1403   sess.run(tf.initialize_all_variables())
   1404   sess.run(vx.assign(tf.add(vx, vx)))
   1405   result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
   1406   print(result)
   1407   saver.save(sess, "./model_ex1")
   1408   ```
   1409 
   1410   Later this model can be restored and contents loaded.
   1411 
   1412   ```Python
   1413   # Restoring variables and running operations.
   1414   saver = tf.train.import_meta_graph("./model_ex1.meta")
   1415   sess = tf.Session()
   1416   saver.restore(sess, "./model_ex1")
   1417   result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 3.3})
   1418   print(result)
   1419   ```
   1420 
   1421   Args:
   1422     meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
   1423       the path) containing a `MetaGraphDef`.
   1424     clear_devices: Whether or not to clear the device field for an `Operation`
   1425       or `Tensor` during import.
   1426     import_scope: Optional `string`. Name scope to add. Only used when
   1427       initializing from protocol buffer.
   1428     **kwargs: Optional keyed arguments.
   1429 
   1430   Returns:
   1431     A saver constructed from `saver_def` in `MetaGraphDef` or None.
   1432 
   1433     A None value is returned if no variables exist in the `MetaGraphDef`
   1434     (i.e., there are no variables to restore).
   1435 
   1436   Raises:
   1437     RuntimeError: If called with eager execution enabled.
   1438 
   1439   @compatibility(eager)
   1440   Exporting/importing meta graphs is not supported. No graph exists when eager
   1441   execution is enabled.
   1442   @end_compatibility
   1443   """  # pylint: disable=g-doc-exception
   1444   return _import_meta_graph_with_return_elements(
   1445       meta_graph_or_file, clear_devices, import_scope, **kwargs)[0]
   1446 
   1447 
   1448 def _import_meta_graph_with_return_elements(
   1449     meta_graph_or_file, clear_devices=False, import_scope=None,
   1450     return_elements=None, **kwargs):
   1451   """Import MetaGraph, and return both a saver and returned elements."""
   1452   if context.executing_eagerly():
   1453     raise RuntimeError("Exporting/importing meta graphs is not supported when "
   1454                        "eager execution is enabled. No graph exists when eager "
   1455                        "execution is enabled.")
   1456   if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
   1457     meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
   1458   else:
   1459     meta_graph_def = meta_graph_or_file
   1460 
   1461   imported_vars, imported_return_elements = (
   1462       meta_graph.import_scoped_meta_graph_with_return_elements(
   1463           meta_graph_def,
   1464           clear_devices=clear_devices,
   1465           import_scope=import_scope,
   1466           return_elements=return_elements,
   1467           **kwargs))
   1468 
   1469   saver = _create_saver_from_imported_meta_graph(
   1470       meta_graph_def, import_scope, imported_vars)
   1471   return saver, imported_return_elements
   1472 
   1473 
   1474 def _create_saver_from_imported_meta_graph(
   1475     meta_graph_def, import_scope, imported_vars):
   1476   """Return a saver for restoring variable values to an imported MetaGraph."""
   1477   if meta_graph_def.HasField("saver_def"):
   1478     # Infer the scope that is prepended by `import_scoped_meta_graph`.
   1479     scope = import_scope
   1480     var_names = list(imported_vars.keys())
   1481     if var_names:
   1482       sample_key = var_names[0]
   1483       sample_var = imported_vars[sample_key]
   1484       scope = sample_var.name[:-len(sample_key)]
   1485 
   1486     return Saver(saver_def=meta_graph_def.saver_def, name=scope)
   1487   else:
   1488     if variables._all_saveable_objects(scope=import_scope):  # pylint: disable=protected-access
   1489       # Return the default saver instance for all graph variables.
   1490       return Saver()
   1491     else:
   1492       # If no graph variables exist, then a Saver cannot be constructed.
   1493       logging.info("Saver not created because there are no variables in the"
   1494                    " graph to restore")
   1495       return None
   1496 
   1497 
   1498 @tf_export(v1=["train.export_meta_graph"])
   1499 def export_meta_graph(filename=None,
   1500                       meta_info_def=None,
   1501                       graph_def=None,
   1502                       saver_def=None,
   1503                       collection_list=None,
   1504                       as_text=False,
   1505                       graph=None,
   1506                       export_scope=None,
   1507                       clear_devices=False,
   1508                       clear_extraneous_savers=False,
   1509                       strip_default_attrs=False,
   1510                       save_debug_info=False,
   1511                       **kwargs):
   1512   # pylint: disable=line-too-long
   1513   """Returns `MetaGraphDef` proto. Optionally writes it to filename.
   1514 
   1515   This function exports the graph, saver, and collection objects into
   1516   `MetaGraphDef` protocol buffer with the intention of it being imported
   1517   at a later time or location to restart training, run inference, or be
   1518   a subgraph.
   1519 
   1520   Args:
   1521     filename: Optional filename including the path for writing the
   1522       generated `MetaGraphDef` protocol buffer.
   1523     meta_info_def: `MetaInfoDef` protocol buffer.
   1524     graph_def: `GraphDef` protocol buffer.
   1525     saver_def: `SaverDef` protocol buffer.
   1526     collection_list: List of string keys to collect.
   1527     as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
   1528     graph: The `Graph` to export. If `None`, use the default graph.
   1529     export_scope: Optional `string`. Name scope under which to extract
   1530       the subgraph. The scope name will be striped from the node definitions
   1531       for easy import later into new name scopes. If `None`, the whole graph
   1532       is exported. graph_def and export_scope cannot both be specified.
   1533     clear_devices: Whether or not to clear the device field for an `Operation`
   1534       or `Tensor` during export.
   1535     clear_extraneous_savers: Remove any Saver-related information from the
   1536         graph (both Save/Restore ops and SaverDefs) that are not associated
   1537         with the provided SaverDef.
   1538     strip_default_attrs: Boolean. If `True`, default-valued attributes will be
   1539       removed from the NodeDefs. For a detailed guide, see
   1540       [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
   1541     save_debug_info: If `True`, save the GraphDebugInfo to a separate file,
   1542       which in the same directory of filename and with `_debug` added before
   1543       the file extend.
   1544     **kwargs: Optional keyed arguments.
   1545 
   1546   Returns:
   1547     A `MetaGraphDef` proto.
   1548 
   1549   Raises:
   1550     ValueError: When the `GraphDef` is larger than 2GB.
   1551     RuntimeError: If called with eager execution enabled.
   1552 
   1553   @compatibility(eager)
   1554   Exporting/importing meta graphs is not supported unless both `graph_def` and
   1555   `graph` are provided. No graph exists when eager execution is enabled.
   1556   @end_compatibility
   1557   """
   1558   # pylint: enable=line-too-long
   1559   if context.executing_eagerly() and not (graph_def is not None and
   1560                                           graph is not None):
   1561     raise RuntimeError("Exporting/importing meta graphs is not supported when "
   1562                        "eager execution is enabled. No graph exists when eager "
   1563                        "execution is enabled.")
   1564   meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
   1565       filename=filename,
   1566       meta_info_def=meta_info_def,
   1567       graph_def=graph_def,
   1568       saver_def=saver_def,
   1569       collection_list=collection_list,
   1570       as_text=as_text,
   1571       graph=graph,
   1572       export_scope=export_scope,
   1573       clear_devices=clear_devices,
   1574       clear_extraneous_savers=clear_extraneous_savers,
   1575       strip_default_attrs=strip_default_attrs,
   1576       save_debug_info=save_debug_info,
   1577       **kwargs)
   1578   return meta_graph_def
   1579 
   1580 
   1581 def _wrap_restore_error_with_msg(err, extra_verbiage):
   1582   err_msg = ("Restoring from checkpoint failed. This is most likely "
   1583              "due to {} from the checkpoint. Please ensure that you "
   1584              "have not altered the graph expected based on the checkpoint. "
   1585              "Original error:\n\n{}").format(extra_verbiage, err.message)
   1586   return err.__class__(err.node_def, err.op, err_msg)
   1587 
   1588 
   1589 ops.register_proto_function(
   1590     ops.GraphKeys.SAVERS,
   1591     proto_type=saver_pb2.SaverDef,
   1592     to_proto=Saver.to_proto,
   1593     from_proto=Saver.from_proto)
   1594 
   1595 
   1596 def object_graph_key_mapping(checkpoint_path):
   1597   """Return name to key mappings from the checkpoint.
   1598 
   1599   Args:
   1600     checkpoint_path: string, path to object-based checkpoint
   1601 
   1602   Returns:
   1603     Dictionary mapping tensor names to checkpoint keys.
   1604   """
   1605   reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
   1606   object_graph_string = reader.get_tensor(
   1607       trackable.OBJECT_GRAPH_PROTO_KEY)
   1608   object_graph_proto = (
   1609       trackable_object_graph_pb2.TrackableObjectGraph())
   1610   object_graph_proto.ParseFromString(object_graph_string)
   1611   names_to_keys = {}
   1612   for node in object_graph_proto.nodes:
   1613     for attribute in node.attributes:
   1614       names_to_keys[attribute.full_name] = attribute.checkpoint_key
   1615   return names_to_keys
   1616 
   1617 
   1618 def saver_from_object_based_checkpoint(
   1619     checkpoint_path, var_list=None, builder=None, names_to_keys=None,
   1620     cached_saver=None):
   1621   """Return a `Saver` which reads from an object-based checkpoint.
   1622 
   1623   This function validates that all variables in the variables list are remapped
   1624   in the object-based checkpoint (or `names_to_keys` dict if provided). A
   1625   saver will be created with the list of remapped variables.
   1626 
   1627   The `cached_saver` argument allows the user to pass in a previously created
   1628   saver, so multiple `saver.restore()` calls don't pollute the graph when graph
   1629   building. This assumes that keys are consistent, meaning that the
   1630     1) `checkpoint_path` checkpoint, and
   1631     2) checkpoint used to create the `cached_saver`
   1632   are the same type of object-based checkpoint. If this argument is set, this
   1633   function will simply validate that all variables have been remapped by the
   1634   checkpoint at `checkpoint_path`.
   1635 
   1636   Note that in general, `tf.train.Checkpoint` should be used to restore/save an
   1637   object-based checkpoint.
   1638 
   1639   Args:
   1640     checkpoint_path: string, path to object-based checkpoint
   1641     var_list: list of `Variables` that appear in the checkpoint. If `None`,
   1642       `var_list` will be set to all saveable objects.
   1643     builder: a `BaseSaverBuilder` instance. If `None`, a new `BulkSaverBuilder`
   1644       will be created.
   1645     names_to_keys: dict mapping string tensor names to checkpooint keys. If
   1646       `None`, this dict will be generated from the checkpoint file.
   1647     cached_saver: Cached `Saver` object with remapped variables.
   1648 
   1649   Returns:
   1650     `Saver` with remapped variables for reading from an object-based checkpoint.
   1651 
   1652   Raises:
   1653     ValueError if the checkpoint provided is not an object-based checkpoint.
   1654     NotFoundError: If one of the variables in `var_list` can not be found in the
   1655       checkpoint. This could mean the checkpoint or `names_to_keys` mapping is
   1656       missing the variable.
   1657   """
   1658   if names_to_keys is None:
   1659     try:
   1660       names_to_keys = object_graph_key_mapping(checkpoint_path)
   1661     except errors.NotFoundError:
   1662       raise ValueError("Checkpoint in %s not an object-based checkpoint."
   1663                        % checkpoint_path)
   1664   if var_list is None:
   1665     var_list = variables._all_saveable_objects()  # pylint: disable=protected-access
   1666   if builder is None:
   1667     builder = BulkSaverBuilder()
   1668 
   1669   saveables = saveable_object_util.validate_and_slice_inputs(var_list)
   1670   current_names = set()
   1671   for saveable in saveables:
   1672     for spec in saveable.specs:
   1673       current_names.add(spec.name)
   1674   previous_names = set(names_to_keys.keys())
   1675   missing_names = current_names - previous_names
   1676   if missing_names:
   1677     extra_names = previous_names - current_names
   1678     intersecting_names = previous_names.intersection(current_names)
   1679     raise errors.NotFoundError(
   1680         None, None,
   1681         message=(
   1682             "\n\nExisting variables not in the checkpoint: %s\n\n"
   1683             "Variables names when this checkpoint was written which don't "
   1684             "exist now: %s\n\n"
   1685             "(%d variable name(s) did match)\n\n"
   1686             "Could not find some variables in the checkpoint (see names "
   1687             "above). Saver was attempting to load an object-based checkpoint "
   1688             "(saved using tf.train.Checkpoint or tf.keras.Model.save_weights) "
   1689             "using variable names. If the checkpoint was written with eager "
   1690             "execution enabled, it's possible that variable names have "
   1691             "changed (for example missing a '_1' suffix). It's also "
   1692             "possible that there are new variables which did not exist "
   1693             "when the checkpoint was written. You can construct a "
   1694             "Saver(var_list=...) with only the variables which previously "
   1695             "existed, and if variable names have changed you may need to "
   1696             "make this a dictionary with the old names as keys. If you're "
   1697             "using an Estimator, you'll need to return a tf.train.Saver "
   1698             "inside a tf.train.Scaffold from your model_fn.")
   1699         % (", ".join(sorted(missing_names)), ", ".join(sorted(extra_names)),
   1700            len(intersecting_names)))
   1701   for saveable in saveables:
   1702     for spec in saveable.specs:
   1703       spec.name = names_to_keys[spec.name]
   1704   if cached_saver is None:
   1705     return Saver(saveables)
   1706   return cached_saver
   1707