Home | History | Annotate | Download | only in training
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 # pylint: disable=invalid-name
     17 """Save and restore variables."""
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import collections
     23 import os.path
     24 import re
     25 import time
     26 import uuid
     27 
     28 import numpy as np
     29 import six
     30 
     31 from google.protobuf import text_format
     32 
     33 from tensorflow.core.protobuf import meta_graph_pb2
     34 from tensorflow.core.protobuf import saver_pb2
     35 from tensorflow.python.client import session
     36 from tensorflow.python.eager import context
     37 from tensorflow.python.framework import constant_op
     38 from tensorflow.python.framework import device as pydev
     39 from tensorflow.python.framework import errors
     40 from tensorflow.python.framework import meta_graph
     41 from tensorflow.python.framework import ops
     42 from tensorflow.python.lib.io import file_io
     43 from tensorflow.python.ops import array_ops
     44 from tensorflow.python.ops import control_flow_ops
     45 from tensorflow.python.ops import gen_io_ops
     46 from tensorflow.python.ops import io_ops
     47 from tensorflow.python.ops import resource_variable_ops
     48 from tensorflow.python.ops import state_ops
     49 from tensorflow.python.ops import string_ops
     50 from tensorflow.python.ops import variables
     51 from tensorflow.python.platform import gfile
     52 from tensorflow.python.platform import tf_logging as logging
     53 from tensorflow.python.training import training_util
     54 from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
     55 from tensorflow.python.util import compat
     56 from tensorflow.python.util.tf_export import tf_export
     57 
     58 
     59 # Op names which identify variable reads which should be saved.
     60 _VARIABLE_OPS = set(["Variable",
     61                      "VariableV2",
     62                      "AutoReloadVariable",
     63                      "VarHandleOp",
     64                      "ReadVariableOp"])
     65 
     66 
     67 def _set_cpu0(device_string):
     68   """Creates a new device string based on `device_string` but using /CPU:0.
     69 
     70   If the device is already on /CPU:0, this is a no-op.
     71 
     72   Args:
     73     device_string: A device string.
     74 
     75   Returns:
     76     A device string.
     77   """
     78   parsed_device = pydev.DeviceSpec.from_string(device_string)
     79   parsed_device.device_type = "CPU"
     80   parsed_device.device_index = 0
     81   return parsed_device.to_string()
     82 
     83 
     84 class BaseSaverBuilder(object):
     85   """Base class for Savers.
     86 
     87   Can be extended to create different Ops.
     88   """
     89 
     90   class SaveSpec(object):
     91     """Class used to describe tensor slices that need to be saved."""
     92 
     93     def __init__(self, tensor, slice_spec, name):
     94       """Creates a `SaveSpec` object.
     95 
     96       Args:
     97         tensor: the tensor to save or callable that produces a tensor to save.
     98         slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
     99         name: the name to save the tensor under.
    100       """
    101       self._tensor = tensor
    102       self.slice_spec = slice_spec
    103       self.name = name
    104 
    105     @property
    106     def tensor(self):
    107       return self._tensor() if callable(self._tensor) else self._tensor
    108 
    109   class SaveableObject(object):
    110     """Base class for saving and restoring saveable objects."""
    111 
    112     def __init__(self, op, specs, name):
    113       """Creates a `SaveableObject` object.
    114 
    115       Args:
    116         op: the "producer" object that this class wraps; it produces a list of
    117           tensors to save.  E.g., a "Variable" object saving its backing tensor.
    118         specs: a list of SaveSpec, each element of which describes one tensor to
    119           save under this object.
    120         name: the name to save the object under.
    121       """
    122       self.op = op
    123       self.specs = specs
    124       self.name = name
    125       # The device of this saveable. All tensors must be on the same device.
    126       self.device = specs[0].tensor.device
    127 
    128     def restore(self, restored_tensors, restored_shapes):
    129       """Restores this object from 'restored_tensors'.
    130 
    131       Args:
    132         restored_tensors: the tensors that were loaded from a checkpoint
    133         restored_shapes: the shapes this object should conform to after
    134           restore, or None.
    135 
    136       Returns:
    137         An operation that restores the state of the object.
    138 
    139       Raises:
    140         ValueError: If the object cannot be restored using the provided
    141           parameters.
    142       """
    143       # pylint: disable=unused-argument
    144       raise ValueError("Calling an abstract method.")
    145 
    146   class VariableSaveable(SaveableObject):
    147     """SaveableObject implementation that handles Variables."""
    148 
    149     def __init__(self, var, slice_spec, name):
    150       spec = BaseSaverBuilder.SaveSpec(var, slice_spec, name)
    151       super(BaseSaverBuilder.VariableSaveable, self).__init__(var, [spec], name)
    152 
    153     def restore(self, restored_tensors, restored_shapes):
    154       restored_tensor = restored_tensors[0]
    155       if restored_shapes is not None:
    156         restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
    157       return state_ops.assign(
    158           self.op,
    159           restored_tensor,
    160           validate_shape=restored_shapes is None and
    161           self.op.get_shape().is_fully_defined())
    162 
    163   class ResourceVariableSaveable(SaveableObject):
    164     """SaveableObject implementation that handles ResourceVariables."""
    165 
    166     def __init__(self, var, slice_spec, name):
    167       self._var_device = var.device
    168       self._var_shape = var.shape
    169       if isinstance(var, ops.Tensor):
    170         self.handle_op = var.op.inputs[0]
    171         tensor = var
    172       elif isinstance(var, resource_variable_ops.ResourceVariable):
    173 
    174         def _read_variable_closure(v):
    175           def f():
    176             with ops.device(v.device):
    177               x = v.read_value()
    178             with ops.device("/device:CPU:0"):
    179               return array_ops.identity(x)
    180           return f
    181 
    182         self.handle_op = var.handle
    183         tensor = _read_variable_closure(var)
    184       else:
    185         raise ValueError(
    186             "Saveable is neither a resource variable nor a read operation."
    187             " Got: %s" % repr(var))
    188       spec = BaseSaverBuilder.SaveSpec(tensor, slice_spec, name)
    189       super(BaseSaverBuilder.ResourceVariableSaveable, self).__init__(
    190           var, [spec], name)
    191 
    192     def restore(self, restored_tensors, restored_shapes):
    193       restored_tensor = restored_tensors[0]
    194       if restored_shapes is not None:
    195         restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0])
    196       # Copy the restored tensor to the variable's device.
    197       with ops.device(self._var_device):
    198         restored_tensor = array_ops.identity(restored_tensor)
    199       return resource_variable_ops.shape_safe_assign_variable_handle(
    200           self.handle_op, self._var_shape, restored_tensor)
    201 
    202   def __init__(self, write_version=saver_pb2.SaverDef.V2):
    203     self._write_version = write_version
    204 
    205   def save_op(self, filename_tensor, saveables):
    206     """Create an Op to save 'saveables'.
    207 
    208     This is intended to be overridden by subclasses that want to generate
    209     different Ops.
    210 
    211     Args:
    212       filename_tensor: String Tensor.
    213       saveables: A list of BaseSaverBuilder.SaveableObject objects.
    214 
    215     Returns:
    216       An Operation that save the variables.
    217 
    218     Raises:
    219       RuntimeError: (implementation detail) if "self._write_version" is an
    220         unexpected value.
    221     """
    222     # pylint: disable=protected-access
    223     tensor_names = []
    224     tensors = []
    225     tensor_slices = []
    226     for saveable in saveables:
    227       for spec in saveable.specs:
    228         tensor_names.append(spec.name)
    229         tensors.append(spec.tensor)
    230         tensor_slices.append(spec.slice_spec)
    231     if self._write_version == saver_pb2.SaverDef.V1:
    232       return io_ops._save(
    233           filename=filename_tensor,
    234           tensor_names=tensor_names,
    235           tensors=tensors,
    236           tensor_slices=tensor_slices)
    237     elif self._write_version == saver_pb2.SaverDef.V2:
    238       # "filename_tensor" is interpreted *NOT AS A FILENAME*, but as a prefix
    239       # of a V2 checkpoint: e.g. "/fs/train/ckpt-<step>/tmp/worker<i>-<step>".
    240       return io_ops.save_v2(filename_tensor, tensor_names, tensor_slices,
    241                             tensors)
    242     else:
    243       raise RuntimeError("Unexpected write_version: " + self._write_version)
    244 
    245   def bulk_restore(self, filename_tensor, saveables, preferred_shard,
    246                    restore_sequentially):
    247     """Restore all tensors contained in saveables.
    248 
    249     By default, this issues separate calls to `restore_op` for each saveable.
    250     Subclasses may override to load multiple saveables in a single call.
    251 
    252     Args:
    253       filename_tensor: String Tensor.
    254       saveables: List of BaseSaverBuilder.SaveableObject objects.
    255       preferred_shard: Int.  Shard to open first when loading a sharded file.
    256       restore_sequentially: Bool.  If true, each restore is sequential.
    257 
    258     Returns:
    259       A list of Tensors resulting from reading 'saveable' from
    260         'filename'.
    261 
    262     """
    263     all_tensors = []
    264     assign_ops = []
    265     for saveable in saveables:
    266       restore_control_inputs = assign_ops[-1:] if restore_sequentially else []
    267       with ops.device(_set_cpu0(saveable.device) if saveable.device else None):
    268         with ops.control_dependencies(restore_control_inputs):
    269           all_tensors.extend(
    270               self.restore_op(filename_tensor, saveable, preferred_shard))
    271     return all_tensors
    272 
    273   # pylint: disable=unused-argument
    274   def restore_op(self, filename_tensor, saveable, preferred_shard):
    275     """Create ops to restore 'saveable'.
    276 
    277     This is intended to be overridden by subclasses that want to generate
    278     different Ops.
    279 
    280     Args:
    281       filename_tensor: String Tensor.
    282       saveable: A BaseSaverBuilder.SaveableObject object.
    283       preferred_shard: Int.  Shard to open first when loading a sharded file.
    284 
    285     Returns:
    286       A list of Tensors resulting from reading 'saveable' from
    287         'filename'.
    288     """
    289     # pylint: disable=protected-access
    290     tensors = []
    291     for spec in saveable.specs:
    292       tensors.append(
    293           io_ops.restore_v2(
    294               filename_tensor,
    295               [spec.name],
    296               [spec.slice_spec],
    297               [spec.tensor.dtype])[0])
    298 
    299     return tensors
    300   # pylint: enable=unused-argument
    301 
    302   def sharded_filename(self, filename_tensor, shard, num_shards):
    303     """Append sharding information to a filename.
    304 
    305     Args:
    306       filename_tensor: A string tensor.
    307       shard: Integer.  The shard for the filename.
    308       num_shards: An int Tensor for the number of shards.
    309 
    310     Returns:
    311       A string tensor.
    312     """
    313     # pylint: disable=protected-access
    314     return gen_io_ops._sharded_filename(filename_tensor, shard, num_shards)
    315 
    316   def _AddSaveOps(self, filename_tensor, saveables):
    317     """Add ops to save variables that are on the same shard.
    318 
    319     Args:
    320       filename_tensor: String Tensor.
    321       saveables: A list of SaveableObject objects.
    322 
    323     Returns:
    324       A tensor with the filename used to save.
    325     """
    326     save = self.save_op(filename_tensor, saveables)
    327     return control_flow_ops.with_dependencies([save], filename_tensor)
    328 
    329   def _AddShardedSaveOpsForV2(self, checkpoint_prefix, per_device):
    330     """Add ops to save the params per shard, for the V2 format.
    331 
    332     Note that the sharded save procedure for the V2 format is different from
    333     V1: there is a special "merge" step that merges the small metadata produced
    334     from each device.
    335 
    336     Args:
    337       checkpoint_prefix: scalar String Tensor.  Interpreted *NOT AS A
    338         FILENAME*, but as a prefix of a V2 checkpoint;
    339       per_device: A list of (device, BaseSaverBuilder.VarToSave) pairs, as
    340         returned by _GroupByDevices().
    341 
    342     Returns:
    343       An op to save the variables, which, when evaluated, returns the prefix
    344         "<user-fed prefix>" only and does not include the sharded spec suffix.
    345     """
    346     # IMPLEMENTATION DETAILS: most clients should skip.
    347     #
    348     # Suffix for any well-formed "checkpoint_prefix", when sharded.
    349     # Transformations:
    350     # * Users pass in "save_path" in save() and restore().  Say "myckpt".
    351     # * checkpoint_prefix gets fed <save_path><_SHARDED_SUFFIX>.
    352     #
    353     # Example:
    354     #   During runtime, a temporary directory is first created, which contains
    355     #   files
    356     #
    357     #     <train dir>/myckpt_temp/
    358     #        part-?????-of-?????{.index, .data-00000-of-00001}
    359     #
    360     #   Before .save() finishes, they will be (hopefully, atomically) renamed to
    361     #
    362     #     <train dir>/
    363     #        myckpt{.index, .data-?????-of-?????}
    364     #
    365     # Users only need to interact with the user-specified prefix, which is
    366     # "<train dir>/myckpt" in this case.  Save() and Restore() work with the
    367     # prefix directly, instead of any physical pathname.  (On failure and
    368     # subsequent restore, an outdated and orphaned temporary directory can be
    369     # safely removed.)
    370     _SHARDED_SUFFIX = "_temp_%s/part" % uuid.uuid4().hex
    371     tmp_checkpoint_prefix = string_ops.string_join(
    372         [checkpoint_prefix, _SHARDED_SUFFIX])
    373 
    374     num_shards = len(per_device)
    375     sharded_saves = []
    376     sharded_prefixes = []
    377     num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
    378     last_device = None
    379     for shard, (device, saveables) in enumerate(per_device):
    380       last_device = device
    381       with ops.device(_set_cpu0(device)):
    382         sharded_filename = self.sharded_filename(tmp_checkpoint_prefix, shard,
    383                                                  num_shards_tensor)
    384         sharded_prefixes.append(sharded_filename)
    385         sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
    386 
    387     with ops.control_dependencies([x.op for x in sharded_saves]):
    388       # Co-locates the merge step with the last device.
    389       with ops.device(_set_cpu0(last_device)):
    390         # V2 format write path consists of a metadata merge step.  Once merged,
    391         # attempts to delete the temporary directory, "<user-fed prefix>_temp".
    392         merge_step = gen_io_ops.merge_v2_checkpoints(
    393             sharded_prefixes, checkpoint_prefix, delete_old_dirs=True)
    394         with ops.control_dependencies([merge_step]):
    395           # Returns the prefix "<user-fed prefix>" only.  DOES NOT include the
    396           # sharded spec suffix.
    397           return array_ops.identity(checkpoint_prefix)
    398 
    399   def _AddShardedSaveOps(self, filename_tensor, per_device):
    400     """Add ops to save the params per shard.
    401 
    402     Args:
    403       filename_tensor: a scalar String Tensor.
    404       per_device: A list of (device, BaseSaverBuilder.SaveableObject) pairs, as
    405         returned by _GroupByDevices().
    406 
    407     Returns:
    408       An op to save the variables.
    409     """
    410     if self._write_version == saver_pb2.SaverDef.V2:
    411       return self._AddShardedSaveOpsForV2(filename_tensor, per_device)
    412 
    413     num_shards = len(per_device)
    414     sharded_saves = []
    415     num_shards_tensor = constant_op.constant(num_shards, name="num_shards")
    416     for shard, (device, saveables) in enumerate(per_device):
    417       with ops.device(device):
    418         sharded_filename = self.sharded_filename(filename_tensor, shard,
    419                                                  num_shards_tensor)
    420         sharded_saves.append(self._AddSaveOps(sharded_filename, saveables))
    421     # Return the sharded name for the save path.
    422     with ops.control_dependencies([x.op for x in sharded_saves]):
    423       # pylint: disable=protected-access
    424       return gen_io_ops._sharded_filespec(filename_tensor, num_shards_tensor)
    425 
    426   def _AddRestoreOps(self,
    427                      filename_tensor,
    428                      saveables,
    429                      restore_sequentially,
    430                      reshape,
    431                      preferred_shard=-1,
    432                      name="restore_all"):
    433     """Add operations to restore saveables.
    434 
    435     Args:
    436       filename_tensor: Tensor for the path of the file to load.
    437       saveables: A list of SaveableObject objects.
    438       restore_sequentially: True if we want to restore variables sequentially
    439         within a shard.
    440       reshape: True if we want to reshape loaded tensors to the shape of
    441         the corresponding variable.
    442       preferred_shard: Shard to open first when loading a sharded file.
    443       name: Name for the returned op.
    444 
    445     Returns:
    446       An Operation that restores the variables.
    447     """
    448     all_tensors = self.bulk_restore(filename_tensor, saveables, preferred_shard,
    449                                     restore_sequentially)
    450 
    451     assign_ops = []
    452     idx = 0
    453     # Load and optionally reshape on the CPU, as string tensors are not
    454     # available on the GPU.
    455     # TODO(touts): Re-enable restore on GPU when we can support annotating
    456     # string tensors as "HostMemory" inputs.
    457     for saveable in saveables:
    458       shapes = None
    459       if reshape:
    460         # Compute the shapes, let the restore op decide if and how to do
    461         # the reshape.
    462         shapes = []
    463         for spec in saveable.specs:
    464           v = spec.tensor
    465           shape = v.get_shape()
    466           if not shape.is_fully_defined():
    467             shape = array_ops.shape(v)
    468           shapes.append(shape)
    469       saveable_tensors = all_tensors[idx:idx + len(saveable.specs)]
    470       idx += len(saveable.specs)
    471       assign_ops.append(saveable.restore(saveable_tensors, shapes))
    472 
    473     # Create a Noop that has control dependencies from all the updates.
    474     return control_flow_ops.group(*assign_ops, name=name)
    475 
    476   def _AddShardedRestoreOps(self, filename_tensor, per_device,
    477                             restore_sequentially, reshape):
    478     """Add Ops to restore variables from multiple devices.
    479 
    480     Args:
    481       filename_tensor: Tensor for the path of the file to load.
    482       per_device: A list of (device, SaveableObject) pairs, as
    483         returned by _GroupByDevices().
    484       restore_sequentially: True if we want to restore variables sequentially
    485         within a shard.
    486       reshape: True if we want to reshape loaded tensors to the shape of
    487         the corresponding variable.
    488 
    489     Returns:
    490       An Operation that restores the variables.
    491     """
    492     sharded_restores = []
    493     for shard, (device, saveables) in enumerate(per_device):
    494       with ops.device(device):
    495         sharded_restores.append(
    496             self._AddRestoreOps(
    497                 filename_tensor,
    498                 saveables,
    499                 restore_sequentially,
    500                 reshape,
    501                 preferred_shard=shard,
    502                 name="restore_shard"))
    503     return control_flow_ops.group(*sharded_restores, name="restore_all")
    504 
    505   @staticmethod
    506   def _IsVariable(v):
    507     return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS
    508 
    509   def _GroupByDevices(self, saveables):
    510     """Group Variable tensor slices per device.
    511 
    512     TODO(touts): Make sure that all the devices found are on different
    513     job/replica/task/cpu|gpu.  It would be bad if 2 were on the same device.
    514     It can happen if the devices are unspecified.
    515 
    516     Args:
    517       saveables: A list of BaseSaverBuilder.SaveableObject objects.
    518 
    519     Returns:
    520       A list of tuples: (device_name, BaseSaverBuilder.SaveableObject) tuples.
    521       The list is sorted by ascending device_name.
    522 
    523     Raises:
    524       ValueError: If the tensors of a saveable are on different devices.
    525     """
    526     per_device = collections.defaultdict(lambda: [])
    527     for saveable in saveables:
    528       canonical_device = set(
    529           pydev.canonical_name(spec.tensor.device) for spec in saveable.specs)
    530       if len(canonical_device) != 1:
    531         raise ValueError("All tensors of a saveable object must be "
    532                          "on the same device: %s" % saveable.name)
    533       per_device[canonical_device.pop()].append(saveable)
    534     return sorted(per_device.items(), key=lambda t: t[0])
    535 
    536   @staticmethod
    537   def OpListToDict(op_list, convert_variable_to_tensor=True):
    538     """Create a dictionary of names to operation lists.
    539 
    540     Args:
    541       op_list: A list, tuple, or set of Variables or SaveableObjects.
    542       convert_variable_to_tensor: Whether or not to convert single Variables
    543         with no slice info into Tensors.
    544 
    545     Returns:
    546       A dictionary of names to the operations that must be saved under
    547       that name.  Variables with save_slice_info are grouped together under the
    548       same key in no particular order.
    549 
    550     Raises:
    551       TypeError: If the type of op_list or its elements is not supported.
    552       ValueError: If at least two saveables share the same name.
    553     """
    554     if not isinstance(op_list, (list, tuple, set)):
    555       raise TypeError("Variables to save should be passed in a dict or a "
    556                       "list: %s" % op_list)
    557     # When ResourceVariables are converted to Tensors, read ops are added to the
    558     # graph. Sorting the op_list ensures that the resulting graph is always
    559     # constructed in a deterministic way:
    560     op_list = sorted(op_list, key=lambda x: x.name)
    561     names_to_saveables = {}
    562     # pylint: disable=protected-access
    563     for var in op_list:
    564       if isinstance(var, BaseSaverBuilder.SaveableObject):
    565         names_to_saveables[var.name] = var
    566       elif isinstance(var, variables.PartitionedVariable):
    567         if var.name in names_to_saveables:
    568           raise ValueError("At least two variables have the same name: %s" %
    569                            var.name)
    570         names_to_saveables[var.name] = var
    571       elif isinstance(var, variables.Variable) and var._save_slice_info:
    572         name = var._save_slice_info.full_name
    573         if name in names_to_saveables:
    574           if not isinstance(names_to_saveables[name], list):
    575             raise ValueError("Mixing slices and non-slices with the same name: "
    576                              "%s" % name)
    577           names_to_saveables[name].append(var)
    578         else:
    579           names_to_saveables[name] = [var]
    580       else:
    581         if context.in_graph_mode():
    582           if convert_variable_to_tensor:
    583             var = ops.internal_convert_to_tensor(var, as_ref=True)
    584             if not BaseSaverBuilder._IsVariable(var):
    585               raise TypeError("Variable to save is not a Variable: %s" % var)
    586           if var.op.type == "ReadVariableOp":
    587             name = var.op.inputs[0].op.name
    588           else:
    589             name = var.op.name
    590           if name in names_to_saveables:
    591             raise ValueError("At least two variables have the same name: %s" %
    592                              name)
    593           names_to_saveables[name] = var
    594         else:
    595           if not isinstance(var, resource_variable_ops.ResourceVariable):
    596             raise ValueError("Can only save/restore ResourceVariable eager "
    597                              "mode is enabled, type: %s." % type(var))
    598           set_var = names_to_saveables.setdefault(var._shared_name, var)
    599           if set_var is not var:
    600             raise ValueError(
    601                 ("Two different ResourceVariable objects with the same "
    602                  "shared_name '%s' were passed to the Saver. This likely means "
    603                  "that they were created in different Graphs or isolation "
    604                  "contexts, and may not be checkpointed together.") % (
    605                      var._shared_name,))
    606 
    607       # pylint: enable=protected-access
    608     return names_to_saveables
    609 
    610   def _ValidateAndSliceInputs(self, names_to_saveables):
    611     """Returns the variables and names that will be used for a Saver.
    612 
    613     Args:
    614       names_to_saveables: A dict (k, v) where k is the name of an operation and
    615          v is an operation to save or a BaseSaverBuilder.Saver.
    616 
    617     Returns:
    618       A list of BaseSaverBuilder.SaveableObject objects.
    619 
    620     Raises:
    621       TypeError: If any of the keys are not strings or any of the
    622         values are not one of Tensor or Variable or a checkpointable operation.
    623       ValueError: If the same operation is given in more than one value
    624         (this also applies to slices of SlicedVariables).
    625     """
    626     if not isinstance(names_to_saveables, dict):
    627       names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)
    628 
    629     saveables = []
    630     seen_ops = set()
    631     for name in sorted(names_to_saveables.keys()):
    632       if not isinstance(name, six.string_types):
    633         raise TypeError(
    634             "names_to_saveables must be a dict mapping string names to "
    635             "checkpointable operations. Name is not a string: %s" % name)
    636       op = names_to_saveables[name]
    637       if isinstance(op, BaseSaverBuilder.SaveableObject):
    638         self._AddSaveable(saveables, seen_ops, op)
    639       elif isinstance(op, (list, tuple, variables.PartitionedVariable)):
    640         if isinstance(op, variables.PartitionedVariable):
    641           op = list(op)
    642         # A set of slices.
    643         slice_name = None
    644         # pylint: disable=protected-access
    645         for variable in op:
    646           if not isinstance(variable, variables.Variable):
    647             raise ValueError("Slices must all be Variables: %s" % variable)
    648           if not variable._save_slice_info:
    649             raise ValueError("Slices must all be slices: %s" % variable)
    650           if slice_name is None:
    651             slice_name = variable._save_slice_info.full_name
    652           elif slice_name != variable._save_slice_info.full_name:
    653             raise ValueError(
    654                 "Slices must all be from the same tensor: %s != %s" %
    655                 (slice_name, variable._save_slice_info.full_name))
    656           if variable.op.type in ["Variable", "VariableV2",
    657                                   "AutoReloadVariable"]:
    658             saveable = BaseSaverBuilder.VariableSaveable(
    659                 variable, variable._save_slice_info.spec, name)
    660           else:
    661             saveable = BaseSaverBuilder.ResourceVariableSaveable(
    662                 variable, variable._save_slice_info.spec, name)
    663           self._AddSaveable(saveables, seen_ops, saveable)
    664         # pylint: enable=protected-access
    665       else:
    666         # A variable or tensor.
    667         if context.in_eager_mode():
    668           if not isinstance(op, resource_variable_ops.ResourceVariable):
    669             raise ValueError("Can only save/restore ResourceVariable eager "
    670                              "mode is enabled, type: %s." % type(op))
    671           saveable = BaseSaverBuilder.ResourceVariableSaveable(op, "", name)
    672         else:
    673           variable = ops.internal_convert_to_tensor(op, as_ref=True)
    674           if not BaseSaverBuilder._IsVariable(variable):
    675             raise TypeError("names_to_saveables must be a dict mapping string "
    676                             "names to Tensors/Variables. Not a variable: %s" %
    677                             variable)
    678           if variable.op.type in ["Variable", "VariableV2",
    679                                   "AutoReloadVariable"]:
    680             saveable = BaseSaverBuilder.VariableSaveable(variable, "", name)
    681           else:
    682             saveable = BaseSaverBuilder.ResourceVariableSaveable(
    683                 variable, "", name)
    684         self._AddSaveable(saveables, seen_ops, saveable)
    685     return saveables
    686 
    687   def _AddSaveable(self, saveables, seen_ops, saveable):
    688     """Adds the saveable to the saveables list.
    689 
    690     Args:
    691       saveables: List to append the SaveableObject to.
    692       seen_ops: Set of the ops of the saveables already processed.  Used to
    693         check that each saveable is only saved once.
    694       saveable: The saveable.
    695 
    696     Raises:
    697       ValueError: If the saveable has already been processed.
    698     """
    699     if saveable.op in seen_ops:
    700       raise ValueError("The same saveable will be restored with two names: %s" %
    701                        saveable.name)
    702     saveables.append(saveable)
    703     seen_ops.add(saveable.op)
    704 
    705   def build(self,
    706             names_to_saveables,
    707             reshape=False,
    708             sharded=False,
    709             max_to_keep=5,
    710             keep_checkpoint_every_n_hours=10000.0,
    711             name=None,
    712             restore_sequentially=False,
    713             filename="model"):
    714     """Builds save/restore graph nodes or runs save/restore in eager mode.
    715 
    716     Args:
    717       names_to_saveables: A dictionary mapping name to a Variable or
    718         SaveableObject. Each name will be associated with the
    719         corresponding variable in the checkpoint.
    720       reshape: If True, allow restoring parameters from a checkpoint
    721         that where the parameters have a different shape.  This is
    722         only needed when you try to restore from a Dist-Belief checkpoint,
    723         and only some times.
    724       sharded: If True, shard the checkpoints, one per device that has
    725         Variable nodes.
    726       max_to_keep: Maximum number of checkpoints to keep.  As new checkpoints
    727         are created, old ones are deleted.  If None or 0, no checkpoints are
    728         deleted from the filesystem but only the last one is kept in the
    729         `checkpoint` file.  Presently the number is only roughly enforced.  For
    730         example in case of restarts more than max_to_keep checkpoints may be
    731         kept.
    732       keep_checkpoint_every_n_hours: How often checkpoints should be kept.
    733         Defaults to 10,000 hours.
    734       name: String.  Optional name to use as a prefix when adding operations.
    735       restore_sequentially: A Bool, which if true, causes restore of different
    736         variables to happen sequentially within each device.
    737       filename: If known at graph construction time, filename used for variable
    738         loading/saving. If None, then the default name "model" will be used.
    739 
    740     Returns:
    741       A SaverDef proto.
    742 
    743     Raises:
    744       TypeError: If 'names_to_saveables' is not a dictionary mapping string
    745         keys to variable Tensors.
    746       ValueError: If any of the keys or values in 'names_to_saveables' is not
    747         unique.
    748     """
    749     return self._build_internal(
    750         names_to_saveables=names_to_saveables,
    751         reshape=reshape,
    752         sharded=sharded,
    753         max_to_keep=max_to_keep,
    754         keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
    755         name=name,
    756         restore_sequentially=restore_sequentially,
    757         filename=filename)
    758 
    759   def _build_internal(self,
    760                       names_to_saveables,
    761                       reshape=False,
    762                       sharded=False,
    763                       max_to_keep=5,
    764                       keep_checkpoint_every_n_hours=10000.0,
    765                       name=None,
    766                       restore_sequentially=False,
    767                       filename="model",
    768                       build_save=True,
    769                       build_restore=True):
    770     """build() with option to only perform save and restore."""
    771     if context.in_graph_mode() and (not build_save or not build_restore):
    772       raise ValueError("Graph mode needs to build save and restore together.")
    773 
    774     saveables = self._ValidateAndSliceInputs(names_to_saveables)
    775     if max_to_keep is None:
    776       max_to_keep = 0
    777 
    778     with ops.name_scope(name, "save",
    779                         [saveable.op for saveable in saveables]) as name:
    780       # Add the Constant string tensor for the filename.
    781       filename_tensor = constant_op.constant(filename or "model")
    782 
    783       # Add the save ops.
    784       if sharded:
    785         per_device = self._GroupByDevices(saveables)
    786         if build_save:
    787           save_tensor = self._AddShardedSaveOps(filename_tensor, per_device)
    788         if build_restore:
    789           restore_op = self._AddShardedRestoreOps(filename_tensor, per_device,
    790                                                   restore_sequentially, reshape)
    791       else:
    792         if build_save:
    793           save_tensor = self._AddSaveOps(filename_tensor, saveables)
    794         if build_restore:
    795           restore_op = self._AddRestoreOps(filename_tensor, saveables,
    796                                            restore_sequentially, reshape)
    797 
    798     # In the following use case, it's possible to have restore_ops be called
    799     # something else:
    800     # - Build inference graph and export a meta_graph.
    801     # - Import the inference meta_graph
    802     # - Extend the inference graph to a train graph.
    803     # - Export a new meta_graph.
    804     # Now the second restore_op will be called "restore_all_1".
    805     # As such, comment out the assert for now until we know whether supporting
    806     # such usage model makes sense.
    807     #
    808     # assert restore_op.name.endswith("restore_all"), restore_op.name
    809     if context.in_graph_mode():
    810       return saver_pb2.SaverDef(
    811           filename_tensor_name=filename_tensor.name,
    812           save_tensor_name=save_tensor.name,
    813           restore_op_name=restore_op.name,
    814           max_to_keep=max_to_keep,
    815           sharded=sharded,
    816           keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
    817           version=self._write_version)
    818     else:
    819       # Store the tensor values to the tensor_names.
    820       save_tensor_name = save_tensor.numpy() if build_save else ""
    821       return saver_pb2.SaverDef(
    822           filename_tensor_name=filename_tensor.numpy(),
    823           save_tensor_name=save_tensor_name,
    824           restore_op_name="",
    825           max_to_keep=max_to_keep,
    826           sharded=sharded,
    827           keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
    828           version=self._write_version)
    829 
    830 
    831 class BulkSaverBuilder(BaseSaverBuilder):
    832   """SaverBuilder with support for bulk restoring multiple saveables."""
    833 
    834   def bulk_restore(self, filename_tensor, saveables, preferred_shard,
    835                    restore_sequentially):
    836 
    837     # Ignored: bulk restore is internally sequential.
    838     del restore_sequentially
    839     restore_specs = []
    840     for saveable in saveables:
    841       for spec in saveable.specs:
    842         restore_specs.append((spec.name, spec.slice_spec, spec.tensor.dtype))
    843 
    844     names, slices, dtypes = zip(*restore_specs)
    845     # Load all tensors onto CPU 0 for compatibility with existing code.
    846     with ops.device("cpu:0"):
    847       return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
    848 
    849 
    850 def _get_saver_or_default():
    851   """Returns the saver from SAVERS collection, or creates a default one.
    852 
    853   This method is used by other members of the training module, such as
    854   `Scaffold`, or `CheckpointSaverHook`.
    855 
    856   Returns:
    857     `Saver`.
    858 
    859   Raises:
    860     RuntimeError: If the SAVERS collection already has more than one items.
    861   """
    862   collection_key = ops.GraphKeys.SAVERS
    863   savers = ops.get_collection(collection_key)
    864   if savers:
    865     if len(savers) > 1:
    866       raise RuntimeError(
    867           "More than one item in collection {}. "
    868           "Please indicate which one to use by passing it to the constructor.".
    869           format(collection_key))
    870     return savers[0]
    871   saver = Saver(sharded=True, allow_empty=True)
    872   if saver is not None:
    873     ops.add_to_collection(collection_key, saver)
    874   return saver
    875 
    876 
    877 def _GetCheckpointFilename(save_dir, latest_filename):
    878   """Returns a filename for storing the CheckpointState.
    879 
    880   Args:
    881     save_dir: The directory for saving and restoring checkpoints.
    882     latest_filename: Name of the file in 'save_dir' that is used
    883       to store the CheckpointState.
    884 
    885   Returns:
    886     The path of the file that contains the CheckpointState proto.
    887   """
    888   if latest_filename is None:
    889     latest_filename = "checkpoint"
    890   return os.path.join(save_dir, latest_filename)
    891 
    892 
    893 @tf_export("train.generate_checkpoint_state_proto")
    894 def generate_checkpoint_state_proto(save_dir,
    895                                     model_checkpoint_path,
    896                                     all_model_checkpoint_paths=None):
    897   """Generates a checkpoint state proto.
    898 
    899   Args:
    900     save_dir: Directory where the model was saved.
    901     model_checkpoint_path: The checkpoint file.
    902     all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
    903       checkpoints, sorted from oldest to newest.  If this is a non-empty list,
    904       the last element must be equal to model_checkpoint_path.  These paths
    905       are also saved in the CheckpointState proto.
    906 
    907   Returns:
    908     CheckpointState proto with model_checkpoint_path and
    909     all_model_checkpoint_paths updated to either absolute paths or
    910     relative paths to the current save_dir.
    911   """
    912   if all_model_checkpoint_paths is None:
    913     all_model_checkpoint_paths = []
    914 
    915   if (not all_model_checkpoint_paths or
    916       all_model_checkpoint_paths[-1] != model_checkpoint_path):
    917     logging.info("%s is not in all_model_checkpoint_paths. Manually adding it.",
    918                  model_checkpoint_path)
    919     all_model_checkpoint_paths.append(model_checkpoint_path)
    920 
    921   # Relative paths need to be rewritten to be relative to the "save_dir"
    922   # if model_checkpoint_path already contains "save_dir".
    923   if not os.path.isabs(save_dir):
    924     if not os.path.isabs(model_checkpoint_path):
    925       model_checkpoint_path = os.path.relpath(model_checkpoint_path, save_dir)
    926     for i in range(len(all_model_checkpoint_paths)):
    927       p = all_model_checkpoint_paths[i]
    928       if not os.path.isabs(p):
    929         all_model_checkpoint_paths[i] = os.path.relpath(p, save_dir)
    930 
    931   coord_checkpoint_proto = CheckpointState(
    932       model_checkpoint_path=model_checkpoint_path,
    933       all_model_checkpoint_paths=all_model_checkpoint_paths)
    934 
    935   return coord_checkpoint_proto
    936 
    937 
    938 @tf_export("train.update_checkpoint_state")
    939 def update_checkpoint_state(save_dir,
    940                             model_checkpoint_path,
    941                             all_model_checkpoint_paths=None,
    942                             latest_filename=None):
    943   """Updates the content of the 'checkpoint' file.
    944 
    945   This updates the checkpoint file containing a CheckpointState
    946   proto.
    947 
    948   Args:
    949     save_dir: Directory where the model was saved.
    950     model_checkpoint_path: The checkpoint file.
    951     all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
    952       checkpoints, sorted from oldest to newest.  If this is a non-empty list,
    953       the last element must be equal to model_checkpoint_path.  These paths
    954       are also saved in the CheckpointState proto.
    955     latest_filename: Optional name of the checkpoint file.  Default to
    956       'checkpoint'.
    957 
    958   Raises:
    959     RuntimeError: If any of the model checkpoint paths conflict with the file
    960       containing CheckpointSate.
    961   """
    962   _update_checkpoint_state(
    963       save_dir=save_dir,
    964       model_checkpoint_path=model_checkpoint_path,
    965       all_model_checkpoint_paths=all_model_checkpoint_paths,
    966       latest_filename=latest_filename,
    967       save_relative_paths=False)
    968 
    969 
    970 def _update_checkpoint_state(save_dir,
    971                              model_checkpoint_path,
    972                              all_model_checkpoint_paths=None,
    973                              latest_filename=None,
    974                              save_relative_paths=False):
    975   """Updates the content of the 'checkpoint' file.
    976 
    977   This updates the checkpoint file containing a CheckpointState
    978   proto.
    979 
    980   Args:
    981     save_dir: Directory where the model was saved.
    982     model_checkpoint_path: The checkpoint file.
    983     all_model_checkpoint_paths: List of strings.  Paths to all not-yet-deleted
    984       checkpoints, sorted from oldest to newest.  If this is a non-empty list,
    985       the last element must be equal to model_checkpoint_path.  These paths
    986       are also saved in the CheckpointState proto.
    987     latest_filename: Optional name of the checkpoint file.  Default to
    988       'checkpoint'.
    989     save_relative_paths: If `True`, will write relative paths to the checkpoint
    990       state file.
    991 
    992   Raises:
    993     RuntimeError: If any of the model checkpoint paths conflict with the file
    994       containing CheckpointSate.
    995   """
    996   # Writes the "checkpoint" file for the coordinator for later restoration.
    997   coord_checkpoint_filename = _GetCheckpointFilename(save_dir, latest_filename)
    998   if save_relative_paths:
    999     if os.path.isabs(model_checkpoint_path):
   1000       rel_model_checkpoint_path = os.path.relpath(
   1001           model_checkpoint_path, save_dir)
   1002     else:
   1003       rel_model_checkpoint_path = model_checkpoint_path
   1004     rel_all_model_checkpoint_paths = []
   1005     for p in all_model_checkpoint_paths:
   1006       if os.path.isabs(p):
   1007         rel_all_model_checkpoint_paths.append(os.path.relpath(p, save_dir))
   1008       else:
   1009         rel_all_model_checkpoint_paths.append(p)
   1010     ckpt = generate_checkpoint_state_proto(
   1011         save_dir,
   1012         rel_model_checkpoint_path,
   1013         all_model_checkpoint_paths=rel_all_model_checkpoint_paths)
   1014   else:
   1015     ckpt = generate_checkpoint_state_proto(
   1016         save_dir,
   1017         model_checkpoint_path,
   1018         all_model_checkpoint_paths=all_model_checkpoint_paths)
   1019 
   1020   if coord_checkpoint_filename == ckpt.model_checkpoint_path:
   1021     raise RuntimeError("Save path '%s' conflicts with path used for "
   1022                        "checkpoint state.  Please use a different save path." %
   1023                        model_checkpoint_path)
   1024 
   1025   # Preventing potential read/write race condition by *atomically* writing to a
   1026   # file.
   1027   file_io.atomic_write_string_to_file(coord_checkpoint_filename,
   1028                                       text_format.MessageToString(ckpt))
   1029 
   1030 
   1031 @tf_export("train.get_checkpoint_state")
   1032 def get_checkpoint_state(checkpoint_dir, latest_filename=None):
   1033   """Returns CheckpointState proto from the "checkpoint" file.
   1034 
   1035   If the "checkpoint" file contains a valid CheckpointState
   1036   proto, returns it.
   1037 
   1038   Args:
   1039     checkpoint_dir: The directory of checkpoints.
   1040     latest_filename: Optional name of the checkpoint file.  Default to
   1041       'checkpoint'.
   1042 
   1043   Returns:
   1044     A CheckpointState if the state was available, None
   1045     otherwise.
   1046 
   1047   Raises:
   1048     ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
   1049   """
   1050   ckpt = None
   1051   coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
   1052                                                      latest_filename)
   1053   f = None
   1054   try:
   1055     # Check that the file exists before opening it to avoid
   1056     # many lines of errors from colossus in the logs.
   1057     if file_io.file_exists(coord_checkpoint_filename):
   1058       file_content = file_io.read_file_to_string(
   1059           coord_checkpoint_filename)
   1060       ckpt = CheckpointState()
   1061       text_format.Merge(file_content, ckpt)
   1062       if not ckpt.model_checkpoint_path:
   1063         raise ValueError("Invalid checkpoint state loaded from %s",
   1064                          checkpoint_dir)
   1065       # For relative model_checkpoint_path and all_model_checkpoint_paths,
   1066       # prepend checkpoint_dir.
   1067       if not os.path.isabs(ckpt.model_checkpoint_path):
   1068         ckpt.model_checkpoint_path = os.path.join(checkpoint_dir,
   1069                                                   ckpt.model_checkpoint_path)
   1070       for i in range(len(ckpt.all_model_checkpoint_paths)):
   1071         p = ckpt.all_model_checkpoint_paths[i]
   1072         if not os.path.isabs(p):
   1073           ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
   1074   except errors.OpError as e:
   1075     # It's ok if the file cannot be read
   1076     logging.warning("%s: %s", type(e).__name__, e)
   1077     logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
   1078     return None
   1079   except text_format.ParseError as e:
   1080     logging.warning("%s: %s", type(e).__name__, e)
   1081     logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
   1082     return None
   1083   finally:
   1084     if f:
   1085       f.close()
   1086   return ckpt
   1087 
   1088 
   1089 @tf_export("train.Saver")
   1090 class Saver(object):
   1091   """Saves and restores variables.
   1092 
   1093   See @{$variables$Variables}
   1094   for an overview of variables, saving and restoring.
   1095 
   1096   The `Saver` class adds ops to save and restore variables to and from
   1097   *checkpoints*.  It also provides convenience methods to run these ops.
   1098 
   1099   Checkpoints are binary files in a proprietary format which map variable names
   1100   to tensor values.  The best way to examine the contents of a checkpoint is to
   1101   load it using a `Saver`.
   1102 
   1103   Savers can automatically number checkpoint filenames with a provided counter.
   1104   This lets you keep multiple checkpoints at different steps while training a
   1105   model.  For example you can number the checkpoint filenames with the training
   1106   step number.  To avoid filling up disks, savers manage checkpoint files
   1107   automatically. For example, they can keep only the N most recent files, or
   1108   one checkpoint for every N hours of training.
   1109 
   1110   You number checkpoint filenames by passing a value to the optional
   1111   `global_step` argument to `save()`:
   1112 
   1113   ```python
   1114   saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
   1115   ...
   1116   saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
   1117   ```
   1118 
   1119   Additionally, optional arguments to the `Saver()` constructor let you control
   1120   the proliferation of checkpoint files on disk:
   1121 
   1122   * `max_to_keep` indicates the maximum number of recent checkpoint files to
   1123     keep.  As new files are created, older files are deleted.  If None or 0,
   1124     all checkpoint files are kept.  Defaults to 5 (that is, the 5 most recent
   1125     checkpoint files are kept.)
   1126 
   1127   * `keep_checkpoint_every_n_hours`: In addition to keeping the most recent
   1128     `max_to_keep` checkpoint files, you might want to keep one checkpoint file
   1129     for every N hours of training.  This can be useful if you want to later
   1130     analyze how a model progressed during a long training session.  For
   1131     example, passing `keep_checkpoint_every_n_hours=2` ensures that you keep
   1132     one checkpoint file for every 2 hours of training.  The default value of
   1133     10,000 hours effectively disables the feature.
   1134 
   1135   Note that you still have to call the `save()` method to save the model.
   1136   Passing these arguments to the constructor will not save variables
   1137   automatically for you.
   1138 
   1139   A training program that saves regularly looks like:
   1140 
   1141   ```python
   1142   ...
   1143   # Create a saver.
   1144   saver = tf.train.Saver(...variables...)
   1145   # Launch the graph and train, saving the model every 1,000 steps.
   1146   sess = tf.Session()
   1147   for step in xrange(1000000):
   1148       sess.run(..training_op..)
   1149       if step % 1000 == 0:
   1150           # Append the step number to the checkpoint name:
   1151           saver.save(sess, 'my-model', global_step=step)
   1152   ```
   1153 
   1154   In addition to checkpoint files, savers keep a protocol buffer on disk with
   1155   the list of recent checkpoints. This is used to manage numbered checkpoint
   1156   files and by `latest_checkpoint()`, which makes it easy to discover the path
   1157   to the most recent checkpoint. That protocol buffer is stored in a file named
   1158   'checkpoint' next to the checkpoint files.
   1159 
   1160   If you create several savers, you can specify a different filename for the
   1161   protocol buffer file in the call to `save()`.
   1162   """
   1163 
   1164   def __init__(self,
   1165                var_list=None,
   1166                reshape=False,
   1167                sharded=False,
   1168                max_to_keep=5,
   1169                keep_checkpoint_every_n_hours=10000.0,
   1170                name=None,
   1171                restore_sequentially=False,
   1172                saver_def=None,
   1173                builder=None,
   1174                defer_build=False,
   1175                allow_empty=False,
   1176                write_version=saver_pb2.SaverDef.V2,
   1177                pad_step_number=False,
   1178                save_relative_paths=False,
   1179                filename=None):
   1180     """Creates a `Saver`.
   1181 
   1182     The constructor adds ops to save and restore variables.
   1183 
   1184     `var_list` specifies the variables that will be saved and restored. It can
   1185     be passed as a `dict` or a list:
   1186 
   1187     * A `dict` of names to variables: The keys are the names that will be
   1188       used to save or restore the variables in the checkpoint files.
   1189     * A list of variables: The variables will be keyed with their op name in
   1190       the checkpoint files.
   1191 
   1192     For example:
   1193 
   1194     ```python
   1195     v1 = tf.Variable(..., name='v1')
   1196     v2 = tf.Variable(..., name='v2')
   1197 
   1198     # Pass the variables as a dict:
   1199     saver = tf.train.Saver({'v1': v1, 'v2': v2})
   1200 
   1201     # Or pass them as a list.
   1202     saver = tf.train.Saver([v1, v2])
   1203     # Passing a list is equivalent to passing a dict with the variable op names
   1204     # as keys:
   1205     saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
   1206     ```
   1207 
   1208     The optional `reshape` argument, if `True`, allows restoring a variable from
   1209     a save file where the variable had a different shape, but the same number
   1210     of elements and type.  This is useful if you have reshaped a variable and
   1211     want to reload it from an older checkpoint.
   1212 
   1213     The optional `sharded` argument, if `True`, instructs the saver to shard
   1214     checkpoints per device.
   1215 
   1216     Args:
   1217       var_list: A list of `Variable`/`SaveableObject`, or a dictionary mapping
   1218         names to `SaveableObject`s. If `None`, defaults to the list of all
   1219         saveable objects.
   1220       reshape: If `True`, allows restoring parameters from a checkpoint
   1221         where the variables have a different shape.
   1222       sharded: If `True`, shard the checkpoints, one per device.
   1223       max_to_keep: Maximum number of recent checkpoints to keep.
   1224         Defaults to 5.
   1225       keep_checkpoint_every_n_hours: How often to keep checkpoints.
   1226         Defaults to 10,000 hours.
   1227       name: String.  Optional name to use as a prefix when adding operations.
   1228       restore_sequentially: A `Bool`, which if true, causes restore of different
   1229         variables to happen sequentially within each device.  This can lower
   1230         memory usage when restoring very large models.
   1231       saver_def: Optional `SaverDef` proto to use instead of running the
   1232         builder. This is only useful for specialty code that wants to recreate
   1233         a `Saver` object for a previously built `Graph` that had a `Saver`.
   1234         The `saver_def` proto should be the one returned by the
   1235         `as_saver_def()` call of the `Saver` that was created for that `Graph`.
   1236       builder: Optional `SaverBuilder` to use if a `saver_def` was not provided.
   1237         Defaults to `BulkSaverBuilder()`.
   1238       defer_build: If `True`, defer adding the save and restore ops to the
   1239         `build()` call. In that case `build()` should be called before
   1240         finalizing the graph or using the saver.
   1241       allow_empty: If `False` (default) raise an error if there are no
   1242         variables in the graph. Otherwise, construct the saver anyway and make
   1243         it a no-op.
   1244       write_version: controls what format to use when saving checkpoints.  It
   1245         also affects certain filepath matching logic.  The V2 format is the
   1246         recommended choice: it is much more optimized than V1 in terms of
   1247         memory required and latency incurred during restore.  Regardless of
   1248         this flag, the Saver is able to restore from both V2 and V1 checkpoints.
   1249       pad_step_number: if True, pads the global step number in the checkpoint
   1250         filepaths to some fixed width (8 by default).  This is turned off by
   1251         default.
   1252       save_relative_paths: If `True`, will write relative paths to the
   1253         checkpoint state file. This is needed if the user wants to copy the
   1254         checkpoint directory and reload from the copied directory.
   1255       filename: If known at graph construction time, filename used for variable
   1256         loading/saving.
   1257 
   1258     Raises:
   1259       TypeError: If `var_list` is invalid.
   1260       ValueError: If any of the keys or values in `var_list` are not unique.
   1261       RuntimeError: If eager execution is enabled and`var_list` does not specify
   1262         a list of varialbes to save.
   1263 
   1264     @compatibility(eager)
   1265     When eager execution is enabled, `var_list` must specify a `list` or `dict`
   1266     of variables to save. Otherwise, a `RuntimeError` will be raised.
   1267     @end_compatibility
   1268     """
   1269     if defer_build and var_list:
   1270       raise ValueError(
   1271           "If `var_list` is provided then build cannot be deferred. "
   1272           "Either set defer_build=False or var_list=None.")
   1273     if context.in_eager_mode() and var_list is None:
   1274       raise RuntimeError(
   1275           "When eager execution is enabled, `var_list` must specify a list or "
   1276           "dict of variables to save")
   1277     self._var_list = var_list
   1278     self._reshape = reshape
   1279     self._sharded = sharded
   1280     self._max_to_keep = max_to_keep
   1281     self._keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours
   1282     self._name = name
   1283     self._restore_sequentially = restore_sequentially
   1284     self.saver_def = saver_def
   1285     self._builder = builder
   1286     self._is_built = False
   1287     self._allow_empty = allow_empty
   1288     self._is_empty = None
   1289     self._write_version = write_version
   1290     self._pad_step_number = pad_step_number
   1291     self._filename = filename
   1292     if not defer_build and context.in_graph_mode():
   1293       self.build()
   1294     if self.saver_def:
   1295       self._check_saver_def()
   1296       self._write_version = self.saver_def.version
   1297     self._save_relative_paths = save_relative_paths
   1298 
   1299   def build(self):
   1300     if context.in_eager_mode():
   1301       raise RuntimeError("Use save/restore instead of build in eager mode.")
   1302     self._build(self._filename, build_save=True, build_restore=True)
   1303 
   1304   def _build_eager(self, checkpoint_path, build_save, build_restore):
   1305     self._build(
   1306         checkpoint_path, build_save=build_save, build_restore=build_restore)
   1307 
   1308   def _build(self, checkpoint_path, build_save, build_restore):
   1309     """Builds saver_def."""
   1310     if context.in_graph_mode():
   1311       if self._is_built:
   1312         return
   1313       self._is_built = True
   1314 
   1315     if not self.saver_def or context.in_eager_mode():
   1316       if self._builder is None:
   1317         self._builder = BulkSaverBuilder(self._write_version)
   1318 
   1319       if self._var_list is None:
   1320         # pylint: disable=protected-access
   1321         self._var_list = variables._all_saveable_objects()
   1322       if not self._var_list:
   1323         if self._allow_empty:
   1324           self._is_empty = True
   1325           return
   1326         else:
   1327           raise ValueError("No variables to save")
   1328       self._is_empty = False
   1329 
   1330       self.saver_def = self._builder._build_internal(  # pylint: disable=protected-access
   1331           self._var_list,
   1332           reshape=self._reshape,
   1333           sharded=self._sharded,
   1334           max_to_keep=self._max_to_keep,
   1335           keep_checkpoint_every_n_hours=self._keep_checkpoint_every_n_hours,
   1336           name=self._name,
   1337           restore_sequentially=self._restore_sequentially,
   1338           filename=checkpoint_path,
   1339           build_save=build_save, build_restore=build_restore)
   1340     elif self.saver_def and self._name:
   1341       # Since self._name is used as a name_scope by builder(), we are
   1342       # overloading the use of this field to represent the "import_scope" as
   1343       # well.
   1344       self.saver_def.filename_tensor_name = ops.prepend_name_scope(
   1345           self.saver_def.filename_tensor_name, self._name)
   1346       self.saver_def.save_tensor_name = ops.prepend_name_scope(
   1347           self.saver_def.save_tensor_name, self._name)
   1348       self.saver_def.restore_op_name = ops.prepend_name_scope(
   1349           self.saver_def.restore_op_name, self._name)
   1350 
   1351     self._check_saver_def()
   1352     # Updates next checkpoint time.
   1353     self._next_checkpoint_time = (
   1354         time.time() + self.saver_def.keep_checkpoint_every_n_hours * 3600)
   1355     self._last_checkpoints = []
   1356     self._checkpoints_to_be_deleted = []
   1357 
   1358   def _check_saver_def(self):
   1359     if not isinstance(self.saver_def, saver_pb2.SaverDef):
   1360       raise ValueError("saver_def must be a saver_pb2.SaverDef: %s" %
   1361                        self.saver_def)
   1362     if context.in_graph_mode():
   1363       if not self.saver_def.save_tensor_name:
   1364         raise ValueError("saver_def must specify the save_tensor_name: %s" %
   1365                          str(self.saver_def))
   1366       if not self.saver_def.restore_op_name:
   1367         raise ValueError("saver_def must specify the restore_op_name: %s" %
   1368                          str(self.saver_def))
   1369 
   1370   def _CheckpointFilename(self, p):
   1371     """Returns the checkpoint filename given a `(filename, time)` pair.
   1372 
   1373     Args:
   1374       p: (filename, time) pair.
   1375 
   1376     Returns:
   1377       Checkpoint file name.
   1378     """
   1379     name, _ = p
   1380     return name
   1381 
   1382   def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
   1383     """Returns the meta graph filename.
   1384 
   1385     Args:
   1386       checkpoint_filename: Name of the checkpoint file.
   1387       meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
   1388 
   1389     Returns:
   1390       MetaGraph file name.
   1391     """
   1392     # If the checkpoint_filename is sharded, the checkpoint_filename could
   1393     # be of format model.ckpt-step#-?????-of-shard#. For example,
   1394     # model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
   1395     basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
   1396     meta_graph_filename = ".".join([basename, meta_graph_suffix])
   1397     return meta_graph_filename
   1398 
   1399   def _RecordLastCheckpoint(self, latest_save_path):
   1400     """Manages the list of the latest checkpoints."""
   1401     if not self.saver_def.max_to_keep:
   1402       return
   1403     # Remove first from list if the same name was used before.
   1404     for p in self._last_checkpoints:
   1405       if latest_save_path == self._CheckpointFilename(p):
   1406         self._last_checkpoints.remove(p)
   1407     # Append new path to list
   1408     self._last_checkpoints.append((latest_save_path, time.time()))
   1409 
   1410     # If more than max_to_keep, remove oldest.
   1411     if len(self._last_checkpoints) > self.saver_def.max_to_keep:
   1412       self._checkpoints_to_be_deleted.append(self._last_checkpoints.pop(0))
   1413 
   1414   def _MaybeDeleteOldCheckpoints(self, meta_graph_suffix="meta"):
   1415     """Deletes old checkpoints if necessary.
   1416 
   1417     `self._checkpoints_to_be_deleted` is going to contain checkpoints that are
   1418     over `max_to_keep`.  They are going to be deleted.  If
   1419     `keep_checkpoint_every_n_hours` was specified, keep an additional checkpoint
   1420     every `N` hours. For example, if `N` is 0.5, an additional checkpoint is
   1421     kept for every 0.5 hours of training; if `N` is 10, an additional
   1422     checkpoint is kept for every 10 hours of training.
   1423 
   1424     Args:
   1425       meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
   1426     """
   1427     if self._checkpoints_to_be_deleted:
   1428       p = self._checkpoints_to_be_deleted.pop(0)
   1429       # Do not delete the file if we keep_checkpoint_every_n_hours is set and we
   1430       # have reached N hours of training.
   1431       should_keep = p[1] > self._next_checkpoint_time
   1432       if should_keep:
   1433         self._next_checkpoint_time += (
   1434             self.saver_def.keep_checkpoint_every_n_hours * 3600)
   1435         return
   1436 
   1437       # Otherwise delete the files.
   1438       try:
   1439         checkpoint_prefix = self._CheckpointFilename(p)
   1440         self._delete_file_if_exists(
   1441             self._MetaGraphFilename(checkpoint_prefix, meta_graph_suffix))
   1442         if self.saver_def.version == saver_pb2.SaverDef.V2:
   1443           # V2 has a metadata file and some data files.
   1444           self._delete_file_if_exists(checkpoint_prefix + ".index")
   1445           self._delete_file_if_exists(checkpoint_prefix +
   1446                                       ".data-?????-of-?????")
   1447         else:
   1448           # V1, Legacy.  Exact match on the data file.
   1449           self._delete_file_if_exists(checkpoint_prefix)
   1450       except Exception as e:  # pylint: disable=broad-except
   1451         logging.warning("Ignoring: %s", str(e))
   1452 
   1453   def _delete_file_if_exists(self, filespec):
   1454     for pathname in file_io.get_matching_files(filespec):
   1455       file_io.delete_file(pathname)
   1456 
   1457   def as_saver_def(self):
   1458     """Generates a `SaverDef` representation of this saver.
   1459 
   1460     Returns:
   1461       A `SaverDef` proto.
   1462     """
   1463     return self.saver_def
   1464 
   1465   def to_proto(self, export_scope=None):
   1466     """Converts this `Saver` to a `SaverDef` protocol buffer.
   1467 
   1468     Args:
   1469       export_scope: Optional `string`. Name scope to remove.
   1470 
   1471     Returns:
   1472       A `SaverDef` protocol buffer.
   1473     """
   1474     if export_scope is None:
   1475       return self.saver_def
   1476 
   1477     if not (self.saver_def.filename_tensor_name.startswith(export_scope) and
   1478             self.saver_def.save_tensor_name.startswith(export_scope) and
   1479             self.saver_def.restore_op_name.startswith(export_scope)):
   1480       return None
   1481 
   1482     saver_def = saver_pb2.SaverDef()
   1483     saver_def.CopyFrom(self.saver_def)
   1484     saver_def.filename_tensor_name = ops.strip_name_scope(
   1485         saver_def.filename_tensor_name, export_scope)
   1486     saver_def.save_tensor_name = ops.strip_name_scope(
   1487         saver_def.save_tensor_name, export_scope)
   1488     saver_def.restore_op_name = ops.strip_name_scope(
   1489         saver_def.restore_op_name, export_scope)
   1490     return saver_def
   1491 
   1492   @staticmethod
   1493   def from_proto(saver_def, import_scope=None):
   1494     """Returns a `Saver` object created from `saver_def`.
   1495 
   1496     Args:
   1497       saver_def: a `SaverDef` protocol buffer.
   1498       import_scope: Optional `string`. Name scope to use.
   1499 
   1500     Returns:
   1501       A `Saver` built from saver_def.
   1502     """
   1503     return Saver(saver_def=saver_def, name=import_scope)
   1504 
   1505   @property
   1506   def last_checkpoints(self):
   1507     """List of not-yet-deleted checkpoint filenames.
   1508 
   1509     You can pass any of the returned values to `restore()`.
   1510 
   1511     Returns:
   1512       A list of checkpoint filenames, sorted from oldest to newest.
   1513     """
   1514     return list(self._CheckpointFilename(p) for p in self._last_checkpoints)
   1515 
   1516   def set_last_checkpoints(self, last_checkpoints):
   1517     """DEPRECATED: Use set_last_checkpoints_with_time.
   1518 
   1519     Sets the list of old checkpoint filenames.
   1520 
   1521     Args:
   1522       last_checkpoints: A list of checkpoint filenames.
   1523 
   1524     Raises:
   1525       AssertionError: If last_checkpoints is not a list.
   1526     """
   1527     assert isinstance(last_checkpoints, list)
   1528     # We use a timestamp of +inf so that this checkpoint will never be
   1529     # deleted.  This is both safe and backwards compatible to a previous
   1530     # version of the code which used s[1] as the "timestamp".
   1531     self._last_checkpoints = [(s, np.inf) for s in last_checkpoints]
   1532 
   1533   def set_last_checkpoints_with_time(self, last_checkpoints_with_time):
   1534     """Sets the list of old checkpoint filenames and timestamps.
   1535 
   1536     Args:
   1537       last_checkpoints_with_time: A list of tuples of checkpoint filenames and
   1538         timestamps.
   1539 
   1540     Raises:
   1541       AssertionError: If last_checkpoints_with_time is not a list.
   1542     """
   1543     assert isinstance(last_checkpoints_with_time, list)
   1544     self._last_checkpoints = last_checkpoints_with_time
   1545 
   1546   def recover_last_checkpoints(self, checkpoint_paths):
   1547     """Recovers the internal saver state after a crash.
   1548 
   1549     This method is useful for recovering the "self._last_checkpoints" state.
   1550 
   1551     Globs for the checkpoints pointed to by `checkpoint_paths`.  If the files
   1552     exist, use their mtime as the checkpoint timestamp.
   1553 
   1554     Args:
   1555       checkpoint_paths: a list of checkpoint paths.
   1556     """
   1557     mtimes = get_checkpoint_mtimes(checkpoint_paths)
   1558     self.set_last_checkpoints_with_time(list(zip(checkpoint_paths, mtimes)))
   1559 
   1560   def save(self,
   1561            sess,
   1562            save_path,
   1563            global_step=None,
   1564            latest_filename=None,
   1565            meta_graph_suffix="meta",
   1566            write_meta_graph=True,
   1567            write_state=True,
   1568            strip_default_attrs=False):
   1569     # pylint: disable=line-too-long
   1570     """Saves variables.
   1571 
   1572     This method runs the ops added by the constructor for saving variables.
   1573     It requires a session in which the graph was launched.  The variables to
   1574     save must also have been initialized.
   1575 
   1576     The method returns the path prefix of the newly created checkpoint files.
   1577     This string can be passed directly to a call to `restore()`.
   1578 
   1579     Args:
   1580       sess: A Session to use to save the variables.
   1581       save_path: String.  Prefix of filenames created for the checkpoint.
   1582       global_step: If provided the global step number is appended to
   1583         `save_path` to create the checkpoint filenames. The optional argument
   1584         can be a `Tensor`, a `Tensor` name or an integer.
   1585       latest_filename: Optional name for the protocol buffer file that will
   1586         contains the list of most recent checkpoints.  That file,
   1587         kept in the same directory as the checkpoint files, is automatically
   1588         managed by the saver to keep track of recent checkpoints.  Defaults to
   1589         'checkpoint'.
   1590       meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
   1591       write_meta_graph: `Boolean` indicating whether or not to write the meta
   1592         graph file.
   1593       write_state: `Boolean` indicating whether or not to write the
   1594         `CheckpointStateProto`.
   1595       strip_default_attrs: Boolean. If `True`, default-valued attributes will be
   1596         removed from the NodeDefs. For a detailed guide, see
   1597         [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
   1598 
   1599     Returns:
   1600       A string: path prefix used for the checkpoint files.  If the saver is
   1601         sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn'
   1602         is the number of shards created.
   1603       If the saver is empty, returns None.
   1604 
   1605     Raises:
   1606       TypeError: If `sess` is not a `Session`.
   1607       ValueError: If `latest_filename` contains path components, or if it
   1608         collides with `save_path`.
   1609       RuntimeError: If save and restore ops weren't built.
   1610     """
   1611     # pylint: enable=line-too-long
   1612     if not self._is_built and context.in_graph_mode():
   1613       raise RuntimeError(
   1614           "`build()` should be called before save if defer_build==True")
   1615     if latest_filename is None:
   1616       latest_filename = "checkpoint"
   1617     if self._write_version != saver_pb2.SaverDef.V2:
   1618       logging.warning("*******************************************************")
   1619       logging.warning("TensorFlow's V1 checkpoint format has been deprecated.")
   1620       logging.warning("Consider switching to the more efficient V2 format:")
   1621       logging.warning("   `tf.train.Saver(write_version=tf.train.SaverDef.V2)`")
   1622       logging.warning("now on by default.")
   1623       logging.warning("*******************************************************")
   1624 
   1625     if os.path.split(latest_filename)[0]:
   1626       raise ValueError("'latest_filename' must not contain path components")
   1627 
   1628     if global_step is not None:
   1629       if not isinstance(global_step, compat.integral_types):
   1630         global_step = training_util.global_step(sess, global_step)
   1631       checkpoint_file = "%s-%d" % (save_path, global_step)
   1632       if self._pad_step_number:
   1633         # Zero-pads the step numbers, so that they are sorted when listed.
   1634         checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))
   1635     else:
   1636       checkpoint_file = save_path
   1637       if os.path.basename(
   1638           save_path) == latest_filename and not self._sharded:
   1639         # Guard against collision between data file and checkpoint state file.
   1640         raise ValueError(
   1641             "'latest_filename' collides with 'save_path': '%s' and '%s'" %
   1642             (latest_filename, save_path))
   1643 
   1644     if (context.in_graph_mode() and
   1645         not isinstance(sess, session.SessionInterface)):
   1646       raise TypeError("'sess' must be a Session; %s" % sess)
   1647 
   1648     save_path_parent = os.path.dirname(save_path)
   1649     if not self._is_empty:
   1650       try:
   1651         if context.in_graph_mode():
   1652           model_checkpoint_path = sess.run(
   1653               self.saver_def.save_tensor_name,
   1654               {self.saver_def.filename_tensor_name: checkpoint_file})
   1655         else:
   1656           self._build_eager(
   1657               checkpoint_file, build_save=True, build_restore=False)
   1658           model_checkpoint_path = self.saver_def.save_tensor_name
   1659 
   1660         model_checkpoint_path = compat.as_str(model_checkpoint_path)
   1661         if write_state:
   1662           self._RecordLastCheckpoint(model_checkpoint_path)
   1663           _update_checkpoint_state(
   1664               save_dir=save_path_parent,
   1665               model_checkpoint_path=model_checkpoint_path,
   1666               all_model_checkpoint_paths=self.last_checkpoints,
   1667               latest_filename=latest_filename,
   1668               save_relative_paths=self._save_relative_paths)
   1669           self._MaybeDeleteOldCheckpoints(meta_graph_suffix=meta_graph_suffix)
   1670       except (errors.FailedPreconditionError, errors.NotFoundError) as exc:
   1671         if not gfile.IsDirectory(save_path_parent):
   1672           exc = ValueError(
   1673               "Parent directory of {} doesn't exist, can't save.".format(
   1674                   save_path))
   1675         raise exc
   1676 
   1677     if write_meta_graph:
   1678       meta_graph_filename = self._MetaGraphFilename(
   1679           checkpoint_file, meta_graph_suffix=meta_graph_suffix)
   1680       if context.in_graph_mode():
   1681         with sess.graph.as_default():
   1682           self.export_meta_graph(
   1683               meta_graph_filename, strip_default_attrs=strip_default_attrs)
   1684 
   1685     if self._is_empty:
   1686       return None
   1687     else:
   1688       return model_checkpoint_path
   1689 
   1690   def export_meta_graph(self,
   1691                         filename=None,
   1692                         collection_list=None,
   1693                         as_text=False,
   1694                         export_scope=None,
   1695                         clear_devices=False,
   1696                         clear_extraneous_savers=False,
   1697                         strip_default_attrs=False):
   1698     # pylint: disable=line-too-long
   1699     """Writes `MetaGraphDef` to save_path/filename.
   1700 
   1701     Args:
   1702       filename: Optional meta_graph filename including the path.
   1703       collection_list: List of string keys to collect.
   1704       as_text: If `True`, writes the meta_graph as an ASCII proto.
   1705       export_scope: Optional `string`. Name scope to remove.
   1706       clear_devices: Whether or not to clear the device field for an `Operation`
   1707         or `Tensor` during export.
   1708       clear_extraneous_savers: Remove any Saver-related information from the
   1709         graph (both Save/Restore ops and SaverDefs) that are not associated
   1710         with this Saver.
   1711       strip_default_attrs: Boolean. If `True`, default-valued attributes will be
   1712         removed from the NodeDefs. For a detailed guide, see
   1713         [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
   1714 
   1715     Returns:
   1716       A `MetaGraphDef` proto.
   1717     """
   1718     # pylint: enable=line-too-long
   1719     return export_meta_graph(
   1720         filename=filename,
   1721         graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
   1722         saver_def=self.saver_def,
   1723         collection_list=collection_list,
   1724         as_text=as_text,
   1725         export_scope=export_scope,
   1726         clear_devices=clear_devices,
   1727         clear_extraneous_savers=clear_extraneous_savers,
   1728         strip_default_attrs=strip_default_attrs)
   1729 
   1730   def restore(self, sess, save_path):
   1731     """Restores previously saved variables.
   1732 
   1733     This method runs the ops added by the constructor for restoring variables.
   1734     It requires a session in which the graph was launched.  The variables to
   1735     restore do not have to have been initialized, as restoring is itself a way
   1736     to initialize variables.
   1737 
   1738     The `save_path` argument is typically a value previously returned from a
   1739     `save()` call, or a call to `latest_checkpoint()`.
   1740 
   1741     Args:
   1742       sess: A `Session` to use to restore the parameters. None in eager mode.
   1743       save_path: Path where parameters were previously saved.
   1744 
   1745     Raises:
   1746       ValueError: If save_path is None.
   1747     """
   1748     if self._is_empty:
   1749       return
   1750     if save_path is None:
   1751       raise ValueError("Can't load save_path when it is None.")
   1752     logging.info("Restoring parameters from %s", save_path)
   1753     if context.in_graph_mode():
   1754       sess.run(self.saver_def.restore_op_name,
   1755                {self.saver_def.filename_tensor_name: save_path})
   1756     else:
   1757       self._build_eager(save_path, build_save=False, build_restore=True)
   1758 
   1759   @staticmethod
   1760   def _add_collection_def(meta_graph_def, key, export_scope=None):
   1761     """Adds a collection to MetaGraphDef protocol buffer.
   1762 
   1763     Args:
   1764       meta_graph_def: MetaGraphDef protocol buffer.
   1765       key: One of the GraphKeys or user-defined string.
   1766       export_scope: Optional `string`. Name scope to remove.
   1767     """
   1768     meta_graph.add_collection_def(meta_graph_def, key,
   1769                                   export_scope=export_scope)
   1770 
   1771 
   1772 def _prefix_to_checkpoint_path(prefix, format_version):
   1773   """Returns the pathname of a checkpoint file, given the checkpoint prefix.
   1774 
   1775   For V1 checkpoint, simply returns the prefix itself (the data file).  For V2,
   1776   returns the pathname to the index file.
   1777 
   1778   Args:
   1779     prefix: a string, the prefix of a checkpoint.
   1780     format_version: the checkpoint format version that corresponds to the
   1781       prefix.
   1782   Returns:
   1783     The pathname of a checkpoint file, taking into account the checkpoint
   1784       format version.
   1785   """
   1786   if format_version == saver_pb2.SaverDef.V2:
   1787     return prefix + ".index"  # The index file identifies a checkpoint.
   1788   return prefix  # Just the data file.
   1789 
   1790 
   1791 @tf_export("train.latest_checkpoint")
   1792 def latest_checkpoint(checkpoint_dir, latest_filename=None):
   1793   """Finds the filename of latest saved checkpoint file.
   1794 
   1795   Args:
   1796     checkpoint_dir: Directory where the variables were saved.
   1797     latest_filename: Optional name for the protocol buffer file that
   1798       contains the list of most recent checkpoint filenames.
   1799       See the corresponding argument to `Saver.save()`.
   1800 
   1801   Returns:
   1802     The full path to the latest checkpoint or `None` if no checkpoint was found.
   1803   """
   1804   # Pick the latest checkpoint based on checkpoint state.
   1805   ckpt = get_checkpoint_state(checkpoint_dir, latest_filename)
   1806   if ckpt and ckpt.model_checkpoint_path:
   1807     # Look for either a V2 path or a V1 path, with priority for V2.
   1808     v2_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
   1809                                          saver_pb2.SaverDef.V2)
   1810     v1_path = _prefix_to_checkpoint_path(ckpt.model_checkpoint_path,
   1811                                          saver_pb2.SaverDef.V1)
   1812     if file_io.get_matching_files(v2_path) or file_io.get_matching_files(
   1813         v1_path):
   1814       return ckpt.model_checkpoint_path
   1815     else:
   1816       logging.error("Couldn't match files for checkpoint %s",
   1817                     ckpt.model_checkpoint_path)
   1818   return None
   1819 
   1820 
   1821 @tf_export("train.import_meta_graph")
   1822 def import_meta_graph(meta_graph_or_file, clear_devices=False,
   1823                       import_scope=None, **kwargs):
   1824   """Recreates a Graph saved in a `MetaGraphDef` proto.
   1825 
   1826   This function takes a `MetaGraphDef` protocol buffer as input. If
   1827   the argument is a file containing a `MetaGraphDef` protocol buffer ,
   1828   it constructs a protocol buffer from the file content. The function
   1829   then adds all the nodes from the `graph_def` field to the
   1830   current graph, recreates all the collections, and returns a saver
   1831   constructed from the `saver_def` field.
   1832 
   1833   In combination with `export_meta_graph()`, this function can be used to
   1834 
   1835   * Serialize a graph along with other Python objects such as `QueueRunner`,
   1836     `Variable` into a `MetaGraphDef`.
   1837 
   1838   * Restart training from a saved graph and checkpoints.
   1839 
   1840   * Run inference from a saved graph and checkpoints.
   1841 
   1842   ```Python
   1843   ...
   1844   # Create a saver.
   1845   saver = tf.train.Saver(...variables...)
   1846   # Remember the training_op we want to run by adding it to a collection.
   1847   tf.add_to_collection('train_op', train_op)
   1848   sess = tf.Session()
   1849   for step in xrange(1000000):
   1850       sess.run(train_op)
   1851       if step % 1000 == 0:
   1852           # Saves checkpoint, which by default also exports a meta_graph
   1853           # named 'my-model-global_step.meta'.
   1854           saver.save(sess, 'my-model', global_step=step)
   1855   ```
   1856 
   1857   Later we can continue training from this saved `meta_graph` without building
   1858   the model from scratch.
   1859 
   1860   ```Python
   1861   with tf.Session() as sess:
   1862     new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
   1863     new_saver.restore(sess, 'my-save-dir/my-model-10000')
   1864     # tf.get_collection() returns a list. In this example we only want the
   1865     # first one.
   1866     train_op = tf.get_collection('train_op')[0]
   1867     for step in xrange(1000000):
   1868       sess.run(train_op)
   1869   ```
   1870 
   1871   NOTE: Restarting training from saved `meta_graph` only works if the
   1872   device assignments have not changed.
   1873 
   1874   Args:
   1875     meta_graph_or_file: `MetaGraphDef` protocol buffer or filename (including
   1876       the path) containing a `MetaGraphDef`.
   1877     clear_devices: Whether or not to clear the device field for an `Operation`
   1878       or `Tensor` during import.
   1879     import_scope: Optional `string`. Name scope to add. Only used when
   1880       initializing from protocol buffer.
   1881     **kwargs: Optional keyed arguments.
   1882 
   1883   Returns:
   1884     A saver constructed from `saver_def` in `MetaGraphDef` or None.
   1885 
   1886     A None value is returned if no variables exist in the `MetaGraphDef`
   1887     (i.e., there are no variables to restore).
   1888 
   1889   Raises:
   1890     RuntimeError: If called with eager execution enabled.
   1891 
   1892   @compatibility(eager)
   1893   Exporting/importing meta graphs is not supported. No graph exists when eager
   1894   execution is enabled.
   1895   @end_compatibility
   1896   """  # pylint: disable=g-doc-exception
   1897   if context.in_eager_mode():
   1898     raise RuntimeError("Exporting/importing meta graphs is not supported when "
   1899                        "eager execution is enabled. No graph exists when eager "
   1900                        "execution is enabled.")
   1901   if not isinstance(meta_graph_or_file, meta_graph_pb2.MetaGraphDef):
   1902     meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file)
   1903   else:
   1904     meta_graph_def = meta_graph_or_file
   1905 
   1906   meta_graph.import_scoped_meta_graph(meta_graph_def,
   1907                                       clear_devices=clear_devices,
   1908                                       import_scope=import_scope,
   1909                                       **kwargs)
   1910   if meta_graph_def.HasField("saver_def"):
   1911     return Saver(saver_def=meta_graph_def.saver_def, name=import_scope)
   1912   else:
   1913     if variables._all_saveable_objects():  # pylint: disable=protected-access
   1914       # Return the default saver instance for all graph variables.
   1915       return Saver()
   1916     else:
   1917       # If no graph variables exist, then a Saver cannot be constructed.
   1918       logging.info("Saver not created because there are no variables in the"
   1919                    " graph to restore")
   1920       return None
   1921 
   1922 
   1923 @tf_export("train.export_meta_graph")
   1924 def export_meta_graph(filename=None,
   1925                       meta_info_def=None,
   1926                       graph_def=None,
   1927                       saver_def=None,
   1928                       collection_list=None,
   1929                       as_text=False,
   1930                       graph=None,
   1931                       export_scope=None,
   1932                       clear_devices=False,
   1933                       clear_extraneous_savers=False,
   1934                       strip_default_attrs=False,
   1935                       **kwargs):
   1936   # pylint: disable=line-too-long
   1937   """Returns `MetaGraphDef` proto. Optionally writes it to filename.
   1938 
   1939   This function exports the graph, saver, and collection objects into
   1940   `MetaGraphDef` protocol buffer with the intention of it being imported
   1941   at a later time or location to restart training, run inference, or be
   1942   a subgraph.
   1943 
   1944   Args:
   1945     filename: Optional filename including the path for writing the
   1946       generated `MetaGraphDef` protocol buffer.
   1947     meta_info_def: `MetaInfoDef` protocol buffer.
   1948     graph_def: `GraphDef` protocol buffer.
   1949     saver_def: `SaverDef` protocol buffer.
   1950     collection_list: List of string keys to collect.
   1951     as_text: If `True`, writes the `MetaGraphDef` as an ASCII proto.
   1952     graph: The `Graph` to import into. If `None`, use the default graph.
   1953     export_scope: Optional `string`. Name scope under which to extract
   1954       the subgraph. The scope name will be striped from the node definitions
   1955       for easy import later into new name scopes. If `None`, the whole graph
   1956       is exported. graph_def and export_scope cannot both be specified.
   1957     clear_devices: Whether or not to clear the device field for an `Operation`
   1958       or `Tensor` during export.
   1959     clear_extraneous_savers: Remove any Saver-related information from the
   1960         graph (both Save/Restore ops and SaverDefs) that are not associated
   1961         with the provided SaverDef.
   1962     strip_default_attrs: Boolean. If `True`, default-valued attributes will be
   1963       removed from the NodeDefs. For a detailed guide, see
   1964       [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
   1965     **kwargs: Optional keyed arguments.
   1966 
   1967   Returns:
   1968     A `MetaGraphDef` proto.
   1969 
   1970   Raises:
   1971     ValueError: When the `GraphDef` is larger than 2GB.
   1972     RuntimeError: If called with eager execution enabled.
   1973 
   1974   @compatibility(eager)
   1975   Exporting/importing meta graphs is not supported. No graph exists when eager
   1976   execution is enabled.
   1977   @end_compatibility
   1978   """
   1979   # pylint: enable=line-too-long
   1980   if context.in_eager_mode():
   1981     raise RuntimeError("Exporting/importing meta graphs is not supported when "
   1982                        "eager execution is enabled. No graph exists when eager "
   1983                        "execution is enabled.")
   1984   meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
   1985       filename=filename,
   1986       meta_info_def=meta_info_def,
   1987       graph_def=graph_def,
   1988       saver_def=saver_def,
   1989       collection_list=collection_list,
   1990       as_text=as_text,
   1991       graph=graph,
   1992       export_scope=export_scope,
   1993       clear_devices=clear_devices,
   1994       clear_extraneous_savers=clear_extraneous_savers,
   1995       strip_default_attrs=strip_default_attrs,
   1996       **kwargs)
   1997   return meta_graph_def
   1998 
   1999 
   2000 @tf_export("train.checkpoint_exists")
   2001 def checkpoint_exists(checkpoint_prefix):
   2002   """Checks whether a V1 or V2 checkpoint exists with the specified prefix.
   2003 
   2004   This is the recommended way to check if a checkpoint exists, since it takes
   2005   into account the naming difference between V1 and V2 formats.
   2006 
   2007   Args:
   2008     checkpoint_prefix: the prefix of a V1 or V2 checkpoint, with V2 taking
   2009       priority.  Typically the result of `Saver.save()` or that of
   2010       `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
   2011       V1/V2.
   2012   Returns:
   2013     A bool, true iff a checkpoint referred to by `checkpoint_prefix` exists.
   2014   """
   2015   pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
   2016                                         saver_pb2.SaverDef.V2)
   2017   if file_io.get_matching_files(pathname):
   2018     return True
   2019   elif file_io.get_matching_files(checkpoint_prefix):
   2020     return True
   2021   else:
   2022     return False
   2023 
   2024 
   2025 @tf_export("train.get_checkpoint_mtimes")
   2026 def get_checkpoint_mtimes(checkpoint_prefixes):
   2027   """Returns the mtimes (modification timestamps) of the checkpoints.
   2028 
   2029   Globs for the checkpoints pointed to by `checkpoint_prefixes`.  If the files
   2030   exist, collect their mtime.  Both V2 and V1 checkpoints are considered, in
   2031   that priority.
   2032 
   2033   This is the recommended way to get the mtimes, since it takes into account
   2034   the naming difference between V1 and V2 formats.
   2035 
   2036   Args:
   2037     checkpoint_prefixes: a list of checkpoint paths, typically the results of
   2038       `Saver.save()` or those of `tf.train.latest_checkpoint()`, regardless of
   2039       sharded/non-sharded or V1/V2.
   2040   Returns:
   2041     A list of mtimes (in microseconds) of the found checkpoints.
   2042   """
   2043   mtimes = []
   2044 
   2045   def match_maybe_append(pathname):
   2046     fnames = file_io.get_matching_files(pathname)
   2047     if fnames:
   2048       mtimes.append(file_io.stat(fnames[0]).mtime_nsec / 1e9)
   2049       return True
   2050     return False
   2051 
   2052   for checkpoint_prefix in checkpoint_prefixes:
   2053     # Tries V2's metadata file first.
   2054     pathname = _prefix_to_checkpoint_path(checkpoint_prefix,
   2055                                           saver_pb2.SaverDef.V2)
   2056     if match_maybe_append(pathname):
   2057       continue
   2058     # Otherwise, tries V1, where the prefix is the complete pathname.
   2059     match_maybe_append(checkpoint_prefix)
   2060 
   2061   return mtimes
   2062 
   2063 
   2064 ops.register_proto_function(
   2065     ops.GraphKeys.SAVERS,
   2066     proto_type=saver_pb2.SaverDef,
   2067     to_proto=Saver.to_proto,
   2068     from_proto=Saver.from_proto)
   2069