Home | History | Annotate | Download | only in saving
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Types for specifying saving and loading behavior."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 
     21 class SaveSpec(object):
     22   """Class used to describe tensor slices that need to be saved."""
     23 
     24   def __init__(self, tensor, slice_spec, name, dtype=None):
     25     """Creates a `SaveSpec` object.
     26 
     27     Args:
     28       tensor: the tensor to save or callable that produces a tensor to save.
     29       slice_spec: the slice to be saved. See `Variable.SaveSliceInfo`.
     30       name: the name to save the tensor under.
     31       dtype: The data type of the Tensor. Required if `tensor` is callable.
     32         Used for error checking in the restore op.
     33     """
     34     self._tensor = tensor
     35     self.slice_spec = slice_spec
     36     self.name = name
     37     if callable(self._tensor):
     38       if dtype is None:
     39         raise AssertionError(
     40             "When passing a callable `tensor` to a SaveSpec, an explicit "
     41             "dtype must be provided.")
     42       self.dtype = dtype
     43     else:
     44       self.dtype = tensor.dtype
     45 
     46   @property
     47   def tensor(self):
     48     return self._tensor() if callable(self._tensor) else self._tensor
     49 
     50 
     51 class SaveableObject(object):
     52   """Base class for saving and restoring saveable objects."""
     53 
     54   def __init__(self, op, specs, name):
     55     """Creates a `SaveableObject` object.
     56 
     57     Args:
     58       op: the "producer" object that this class wraps; it produces a list of
     59         tensors to save.  E.g., a "Variable" object saving its backing tensor.
     60       specs: a list of SaveSpec, each element of which describes one tensor to
     61         save under this object. All Tensors must be on the same device.
     62       name: the name to save the object under.
     63     """
     64     self.op = op
     65     self.specs = specs
     66     self.name = name
     67     self._device = None
     68 
     69   @property
     70   def optional_restore(self):
     71     """A hint to restore assertions that this object is optional."""
     72     return False  # Default to required
     73 
     74   @property
     75   def device(self):
     76     """The device for SaveSpec Tensors."""
     77     # Note that SaveSpec.tensor runs Tensor-gathering ops when executing
     78     # eagerly, making this call potentially very expensive.
     79     #
     80     # TODO(allenl): Consider another way to gather device information. Lower
     81     # priority since this property isn't part of the normal save()/restore()
     82     # workflow, but does come up when some alternative builders are passed to
     83     # the Saver.
     84     if self._device is None:
     85       self._device = self.specs[0].tensor.device
     86     return self._device
     87 
     88   def restore(self, restored_tensors, restored_shapes):
     89     """Restores this object from 'restored_tensors'.
     90 
     91     Args:
     92       restored_tensors: the tensors that were loaded from a checkpoint
     93       restored_shapes: the shapes this object should conform to after
     94         restore, or None.
     95 
     96     Returns:
     97       An operation that restores the state of the object.
     98 
     99     Raises:
    100       ValueError: If the object cannot be restored using the provided
    101         parameters.
    102     """
    103     # pylint: disable=unused-argument
    104     raise ValueError("Calling an abstract method.")
    105