Home | History | Annotate | Download | only in training
      1 """An object-local variable management scheme."""
      2 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 import weakref
     22 
     23 from tensorflow.python import pywrap_tensorflow
     24 from tensorflow.python.eager import context
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import gen_io_ops as io_ops
     28 from tensorflow.python.util import nest
     29 
     30 # A key indicating a variable's value in an object's checkpointed Tensors
     31 # (Checkpointable._gather_tensors_for_checkpoint). If this is the only key and
     32 # the object has no dependencies, then its value may be restored on object
     33 # creation (avoiding double assignment when executing eagerly).
     34 VARIABLE_VALUE_KEY = "VARIABLE_VALUE"
     35 
     36 _CheckpointableReference = collections.namedtuple(
     37     "_CheckpointableReference",
     38     [
     39         # The local name for this dependency.
     40         "name",
     41         # The Checkpointable object being referenced.
     42         "ref"
     43     ])
     44 
     45 
     46 class CheckpointInitialValue(ops.Tensor):
     47   """Tensor wrapper for managing update UIDs in `Variables`.
     48 
     49   When supplied as an initial value, objects of this type let a `Variable`
     50   (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial
     51   value came from. This allows deferred restorations to be sequenced in the
     52   order the user specified them, and lets us fall back on assignment if an
     53   initial value is not set (e.g. due to a custom getter interfering).
     54 
     55   See comments in _add_variable_with_custom_getter for more information about
     56   how `CheckpointInitialValue` is used.
     57   """
     58 
     59   def __init__(self, checkpoint_position, shape=None):
     60     self.wrapped_value = checkpoint_position.restore_ops()[
     61         VARIABLE_VALUE_KEY]
     62     if shape:
     63       # We need to set the static shape information on the initializer if
     64       # possible so we don't get a variable with an unknown shape.
     65       self.wrapped_value.set_shape(shape)
     66     self._checkpoint_position = checkpoint_position
     67 
     68   @property
     69   def __class__(self):
     70     return (self.wrapped_value.__class__, CheckpointInitialValue)
     71 
     72   def __getattr__(self, attr):
     73     try:
     74       return getattr(self.wrapped_value, attr)
     75     except AttributeError:
     76       return self.__getattribute__(attr)
     77 
     78   @property
     79   def checkpoint_position(self):
     80     return self._checkpoint_position
     81 
     82 
     83 class _CheckpointPosition(object):
     84   """Indicates a position within a `_Checkpoint`."""
     85 
     86   def __init__(self, checkpoint, proto_id):
     87     """Specify an object within a checkpoint.
     88 
     89     Args:
     90       checkpoint: A _Checkpoint object.
     91       proto_id: The index of this object in CheckpointableObjectGraph.nodes.
     92     """
     93     self._checkpoint = checkpoint
     94     self._proto_id = proto_id
     95 
     96   def restore(self, checkpointable):
     97     """Restore this value into `checkpointable`."""
     98     if self.bind_object(checkpointable):
     99       # This object's correspondence with a checkpointed object is new, so
    100       # process deferred restorations for it and its dependencies.
    101       restore_ops = checkpointable._restore_from_checkpoint_position(self)  # pylint: disable=protected-access
    102       if restore_ops:
    103         self._checkpoint.restore_ops.extend(restore_ops)
    104 
    105   def bind_object(self, checkpointable):
    106     """Set a checkpoint<->object correspondence and process slot variables.
    107 
    108     Args:
    109       checkpointable: The object to record a correspondence for.
    110     Returns:
    111       True if this is a new assignment, False if this object has already been
    112       mapped to a checkpointed `Object` proto.
    113     Raises:
    114       AssertionError: If another object is already bound to the `Object` proto.
    115     """
    116     checkpoint = self.checkpoint
    117     current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
    118     if current_assignment is None:
    119       checkpoint.object_by_proto_id[self._proto_id] = checkpointable
    120       for deferred_slot_restoration in (
    121           checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
    122         checkpointable._create_or_restore_slot_variable(  # pylint: disable=protected-access
    123             slot_variable_position=_CheckpointPosition(
    124                 checkpoint=checkpoint,
    125                 proto_id=deferred_slot_restoration.slot_variable_id),
    126             variable=deferred_slot_restoration.original_variable,
    127             slot_name=deferred_slot_restoration.slot_name)
    128       for slot_restoration in checkpoint.slot_restorations.pop(
    129           self._proto_id, ()):
    130         optimizer_object = checkpoint.object_by_proto_id.get(
    131             slot_restoration.optimizer_id, None)
    132         if optimizer_object is None:
    133           # The optimizer has not yet been created or tracked. Record in the
    134           # checkpoint that the slot variables need to be restored when it is.
    135           checkpoint.deferred_slot_restorations.setdefault(
    136               slot_restoration.optimizer_id, []).append(
    137                   _DeferredSlotVariableRestoration(
    138                       original_variable=checkpointable,
    139                       slot_variable_id=slot_restoration.slot_variable_id,
    140                       slot_name=slot_restoration.slot_name))
    141         else:
    142           optimizer_object._create_or_restore_slot_variable(  # pylint: disable=protected-access
    143               slot_variable_position=_CheckpointPosition(
    144                   checkpoint=checkpoint,
    145                   proto_id=slot_restoration.slot_variable_id),
    146               variable=checkpointable,
    147               slot_name=slot_restoration.slot_name)
    148       return True  # New assignment
    149     else:
    150       # The object was already mapped for this checkpoint load, which means
    151       # we don't need to do anything besides check that the mapping is
    152       # consistent (if the dependency DAG is not a tree then there are
    153       # multiple paths to the same object).
    154       if current_assignment is not checkpointable:
    155         raise AssertionError(
    156             ("Unable to load the checkpoint into this object graph. Either "
    157              "the Checkpointable object references in the Python program "
    158              "have changed in an incompatible way, or the checkpoint was "
    159              "generated in an incompatible program.\n\nTwo checkpoint "
    160              "references resolved to different objects (%s and %s).")
    161             % (current_assignment, checkpointable))
    162       return False  # Not a new assignment
    163 
    164   def is_simple_variable(self):
    165     """Determine whether this value is restorable with a Tensor initializer."""
    166     attributes = self.object_proto.attributes
    167     return (len(attributes) == 1
    168             and attributes[0].name == VARIABLE_VALUE_KEY
    169             and not self.object_proto.children)
    170 
    171   def restore_ops(self):
    172     """Create restore ops for this object's attributes."""
    173     restore_tensors = {}
    174     for serialized_tensor in self.object_proto.attributes:
    175       checkpoint_key = serialized_tensor.checkpoint_key
    176       dtype = self._checkpoint.dtype_map[checkpoint_key]
    177       base_type = dtype.base_dtype
    178       with ops.init_scope():
    179         restore, = io_ops.restore_v2(
    180             prefix=self._checkpoint.save_path,
    181             tensor_names=[checkpoint_key],
    182             shape_and_slices=[""],
    183             dtypes=[base_type],
    184             name="%s_checkpoint_read" % (serialized_tensor.name,))
    185         restore_tensors[serialized_tensor.name] = restore
    186       return restore_tensors
    187 
    188   @property
    189   def checkpoint(self):
    190     return self._checkpoint
    191 
    192   @property
    193   def checkpointable(self):
    194     return self._checkpoint.object_by_proto_id[self._proto_id]
    195 
    196   @property
    197   def object_proto(self):
    198     return self._checkpoint.object_graph_proto.nodes[self._proto_id]
    199 
    200   @property
    201   def restore_uid(self):
    202     return self._checkpoint.restore_uid
    203 
    204   def __repr__(self):
    205     return repr(self.object_proto)
    206 
    207 
    208 _DeferredSlotVariableRestoration = collections.namedtuple(
    209     "_DeferredSlotVariableRestoration",
    210     [
    211         "original_variable",
    212         "slot_variable_id",
    213         "slot_name",
    214     ]
    215 )
    216 
    217 _SlotVariableRestoration = collections.namedtuple(
    218     "_SlotVariableRestoration",
    219     [
    220         # The checkpoint proto id of the optimizer object.
    221         "optimizer_id",
    222         # The checkpoint proto id of the slot variable.
    223         "slot_variable_id",
    224         "slot_name",
    225     ])
    226 
    227 
    228 class _Checkpoint(object):
    229   """Holds the status of an object-based checkpoint load."""
    230 
    231   def __init__(self, object_graph_proto, save_path):
    232     """Specify the checkpoint being loaded.
    233 
    234     Args:
    235       object_graph_proto: The CheckpointableObjectGraph protocol buffer
    236         associated with this checkpoint.
    237       save_path: The path to the checkpoint, as returned by
    238         `tf.train.latest_checkpoint`.
    239     """
    240     self.object_graph_proto = object_graph_proto
    241     self.restore_uid = ops.uid()
    242     # Dictionary mapping from an id in the protocol buffer flat array to
    243     # Checkpointable Python objects. This mapping may be deferred if a
    244     # checkpoint is restored before all dependencies have been tracked. Uses
    245     # weak references so that partial restorations don't create reference cycles
    246     # (as objects with deferred dependencies will generally have references to
    247     # this object).
    248     self.object_by_proto_id = weakref.WeakValueDictionary()
    249     self.save_path = save_path
    250     reader = pywrap_tensorflow.NewCheckpointReader(save_path)
    251     self.dtype_map = reader.get_variable_to_dtype_map()
    252     # When graph building, contains a list of ops to run to restore objects from
    253     # this checkpoint.
    254     self.restore_ops = []
    255     # A mapping from optimizer proto ids to lists of slot variables to be
    256     # restored when the optimizer is tracked. Only includes slot variables whose
    257     # regular variables have already been created, and only for optimizer
    258     # objects which have not yet been created/tracked.
    259     self.deferred_slot_restorations = {}
    260     # A mapping from variable proto ids to lists of slot variables to be
    261     # restored when the variable is created/tracked. These get shifted over to
    262     # deferred_slot_restorations if the optimizer hasn't been created when that
    263     # happens.
    264     self.slot_restorations = {}
    265     for node_index, node in enumerate(self.object_graph_proto.nodes):
    266       for slot_reference in node.slot_variables:
    267         # `node` refers to an `Optimizer`, since only these have slot variables.
    268         self.slot_restorations.setdefault(
    269             slot_reference.original_variable_node_id, []).append(
    270                 _SlotVariableRestoration(
    271                     optimizer_id=node_index,
    272                     slot_variable_id=slot_reference.slot_variable_node_id,
    273                     slot_name=slot_reference.slot_name))
    274 
    275 
    276 class CheckpointableBase(object):
    277   """Base class for `Checkpointable` objects without automatic dependencies.
    278 
    279   This class has no __setattr__ override for performance reasons. Dependencies
    280   must be added explicitly. Unless attribute assignment is performance-critical,
    281   use `Checkpointable` instead. Use `CheckpointableBase` for `isinstance`
    282   checks.
    283   """
    284 
    285   def _maybe_initialize_checkpointable(self):
    286     """Initialize dependency management.
    287 
    288     Not __init__, since most objects will forget to call it.
    289     """
    290     if hasattr(self, "_checkpoint_dependencies"):
    291       # __init__ already called. This check means that we don't need
    292       # Checkpointable.__init__() in the constructor of every TensorFlow object.
    293       return
    294     # A list of _CheckpointableReference objects.
    295     self._checkpoint_dependencies = []
    296     # Maps names -> Checkpointable objects
    297     self._dependency_names = {}
    298     # Restorations for other Checkpointable objects on which this object may
    299     # eventually depend.
    300     self._deferred_dependencies = {}  # local name -> _CheckpointPosition list
    301     # The UID of the highest assignment to this object. Used to ensure that the
    302     # last requested assignment determines the final value of an object.
    303     if hasattr(self, "_update_uid"):
    304       raise AssertionError(
    305           "Internal error: the object had an update UID set before its "
    306           "initialization code was run.")
    307     self._update_uid = -1
    308 
    309   def _add_variable_with_custom_getter(
    310       self, name, shape=None, dtype=dtypes.float32,
    311       initializer=None, getter=None, **kwargs_for_getter):
    312     """Restore-on-create for a variable be saved with this `Checkpointable`.
    313 
    314     If the user has requested that this object or another `Checkpointable` which
    315     depends on this object be restored from a checkpoint (deferred loading
    316     before variable object creation), `initializer` may be ignored and the value
    317     from the checkpoint used instead.
    318 
    319     Args:
    320       name: A name for the variable. Must be unique within this object.
    321       shape: The shape of the variable.
    322       dtype: The data type of the variable.
    323 
    324       initializer: The initializer to use. Ignored if there is a deferred
    325         restoration left over from a call to
    326         `_restore_from_checkpoint_position`.
    327 
    328       getter: The getter to wrap which actually fetches the variable.
    329       **kwargs_for_getter: Passed to the getter.
    330 
    331     Returns:
    332       The new variable object.
    333 
    334     Raises:
    335       ValueError: If the variable name is not unique.
    336     """
    337     self._maybe_initialize_checkpointable()
    338     if name in self._dependency_names:
    339       raise ValueError(
    340           ("A variable named '%s' already exists in this Checkpointable, but "
    341            "Checkpointable._add_variable called to create another with "
    342            "that name. Variable names must be unique within a Checkpointable "
    343            "object.") % (name,))
    344     if context.in_eager_mode():
    345       # If this is a variable with a single Tensor stored in the checkpoint, we
    346       # can set that value as an initializer rather than initializing and then
    347       # assigning (when executing eagerly). This call returns None if there is
    348       # nothing to restore.
    349       checkpoint_initializer = self._preload_simple_restoration(
    350           name=name, shape=shape)
    351     else:
    352       checkpoint_initializer = None
    353     if (checkpoint_initializer is not None
    354         and not (
    355             isinstance(initializer, CheckpointInitialValue)
    356             and initializer.restore_uid > checkpoint_initializer.restore_uid)):
    357       # If multiple Checkpointable objects are "creating" the same variable via
    358       # the magic of custom getters, the one with the highest restore UID (the
    359       # one called last) has to make the final initializer. If another custom
    360       # getter interrupts this process by overwriting the initializer, then
    361       # we'll catch that when we call _track_checkpointable. So this is "best
    362       # effort" to set the initializer with the highest restore UID.
    363       initializer = checkpoint_initializer
    364       shape = None
    365 
    366     new_variable = getter(
    367         name=name, shape=shape, dtype=dtype, initializer=initializer,
    368         **kwargs_for_getter)
    369 
    370     # If we set an initializer and the variable processed it, tracking will not
    371     # assign again. It will add this variable to our dependencies, and if there
    372     # is a non-trivial restoration queued, it will handle that. This also
    373     # handles slot variables.
    374     return self._track_checkpointable(new_variable, name=name)
    375 
    376   def _preload_simple_restoration(self, name, shape):
    377     """Return a dependency's value for restore-on-create.
    378 
    379     Note the restoration is not deleted; if for some reason preload is called
    380     and then not assigned to the variable (for example because a custom getter
    381     overrides the initializer), the assignment will still happen once the
    382     variable is tracked (determined based on checkpoint.restore_uid).
    383 
    384     Args:
    385       name: The object-local name of the dependency holding the variable's
    386         value.
    387       shape: The shape of the variable being loaded into.
    388     Returns:
    389       An callable for use as a variable's initializer/initial_value, or None if
    390       one should not be set (either because there was no variable with this name
    391       in the checkpoint or because it needs more complex deserialization). Any
    392       non-trivial deserialization will happen when the variable object is
    393       tracked.
    394     """
    395     deferred_dependencies_list = self._deferred_dependencies.get(name, ())
    396     if not deferred_dependencies_list:
    397       # Nothing to do; we don't have a restore for this dependency queued up.
    398       return
    399     for checkpoint_position in deferred_dependencies_list:
    400       if not checkpoint_position.is_simple_variable():
    401         # If _any_ pending restoration is too complicated to fit in an
    402         # initializer (because it has dependencies, or because there are
    403         # multiple Tensors to restore), bail and let the general tracking code
    404         # handle it.
    405         return None
    406     checkpoint_position = max(
    407         deferred_dependencies_list,
    408         key=lambda restore: restore.checkpoint.restore_uid)
    409     return CheckpointInitialValue(
    410         checkpoint_position=checkpoint_position, shape=shape)
    411 
    412   def _track_checkpointable(self, checkpointable, name, overwrite=False):
    413     """Declare a dependency on another `Checkpointable` object.
    414 
    415     Indicates that checkpoints for this object should include variables from
    416     `checkpointable`.
    417 
    418     Variables in a checkpoint are mapped to `Checkpointable`s based on names if
    419     provided when the checkpoint was written, but otherwise use the order those
    420     `Checkpointable`s were declared as dependencies.
    421 
    422     To avoid breaking existing checkpoints when modifying a class, neither
    423     variable names nor dependency names (the names passed to
    424     `track_checkpointable`) may change.
    425 
    426     Args:
    427       checkpointable: A `Checkpointable` which this object depends on.
    428       name: A local name for `checkpointable`, used for loading checkpoints into
    429         the correct objects.
    430       overwrite: Boolean, whether silently replacing dependencies is OK. Used
    431         for __setattr__, where throwing an error on attribute reassignment would
    432         be inappropriate.
    433 
    434     Returns:
    435       `checkpointable`, for convenience when declaring a dependency and
    436       assigning to a member variable in one statement.
    437 
    438     Raises:
    439       TypeError: If `checkpointable` does not inherit from `Checkpointable`.
    440       ValueError: If another object is already tracked by this name.
    441     """
    442     self._maybe_initialize_checkpointable()
    443     if not isinstance(checkpointable, CheckpointableBase):
    444       raise TypeError(
    445           ("Checkpointable._track_checkpointable() passed type %s, not a "
    446            "Checkpointable.") % (type(checkpointable),))
    447     new_reference = _CheckpointableReference(name=name, ref=checkpointable)
    448     if (name in self._dependency_names
    449         and self._dependency_names[name] is not checkpointable):
    450       if not overwrite:
    451         raise ValueError(
    452             ("Called Checkpointable._track_checkpointable() with name='%s', "
    453              "but a Checkpointable with this name is already declared as a "
    454              "dependency. Names must be unique (or overwrite=True).") % (name,))
    455       # This is a weird thing to do, but we're not going to stop people from
    456       # using __setattr__.
    457       for index, (old_name, _) in enumerate(self._checkpoint_dependencies):
    458         if name == old_name:
    459           self._checkpoint_dependencies[index] = new_reference
    460     else:
    461       self._checkpoint_dependencies.append(new_reference)
    462 
    463     self._dependency_names[name] = checkpointable
    464     deferred_dependency_list = self._deferred_dependencies.pop(name, None)
    465     if deferred_dependency_list is not None:
    466       for checkpoint_position in deferred_dependency_list:
    467         checkpoint_position.restore(checkpointable=checkpointable)
    468     return checkpointable
    469 
    470   def _restore_from_checkpoint_position(self, checkpoint_position):
    471     """Restore this object and its dependencies (may be deferred)."""
    472     # Attempt a breadth-first traversal, since presumably the user has more
    473     # control over shorter paths. If we don't have all of the dependencies at
    474     # this point, the end result is not breadth-first (since other deferred
    475     # traversals will happen later).
    476     visit_queue = collections.deque([checkpoint_position])
    477     restore_ops = []
    478     while visit_queue:
    479       current_position = visit_queue.popleft()
    480       restore_ops.extend(nest.flatten(
    481           current_position.checkpointable  # pylint: disable=protected-access
    482           ._single_restoration_from_checkpoint_position(
    483               checkpoint_position=current_position,
    484               visit_queue=visit_queue)))
    485     return restore_ops
    486 
    487   def _single_restoration_from_checkpoint_position(
    488       self, checkpoint_position, visit_queue):
    489     """Restore this object, and either queue its dependencies or defer them."""
    490     self._maybe_initialize_checkpointable()
    491     checkpoint = checkpoint_position.checkpoint
    492     # If the UID of this restore is lower than our current update UID, we don't
    493     # need to actually restore the object. However, we should pass the
    494     # restoration on to our dependencies.
    495     if checkpoint.restore_uid > self._update_uid:
    496       restore_op = self._scatter_tensors_from_checkpoint(
    497           checkpoint_position.restore_ops())
    498       self._update_uid = checkpoint.restore_uid
    499     else:
    500       restore_op = ()
    501     for child in checkpoint_position.object_proto.children:
    502       child_position = _CheckpointPosition(
    503           checkpoint=checkpoint,
    504           proto_id=child.node_id)
    505       local_object = self._dependency_names.get(child.local_name, None)
    506       if local_object is None:
    507         # We don't yet have a dependency registered with this name. Save it
    508         # in case we do.
    509         self._deferred_dependencies.setdefault(child.local_name, []).append(
    510             child_position)
    511       else:
    512         if child_position.bind_object(checkpointable=local_object):
    513           # This object's correspondence is new, so dependencies need to be
    514           # visited. Delay doing it so that we get a breadth-first dependency
    515           # resolution order (shallowest paths first). The caller is responsible
    516           # for emptying visit_queue.
    517           visit_queue.append(child_position)
    518     return restore_op
    519 
    520   def _scatter_tensors_from_checkpoint(self, attributes):
    521     """Restores this object from a checkpoint.
    522 
    523     Args:
    524       attributes: A dictionary of Tensors, with key corresponding to those
    525         returned from _gather_tensors_for_checkpoint.
    526     Returns:
    527       A restore op to run (if graph building).
    528     """
    529     if attributes:
    530       raise AssertionError(
    531           ("A Checkpointable object which was not expecting any data received "
    532            "some from a checkpoint. (Got %s)") % (attributes,))
    533     return ()  # No restore ops
    534 
    535   def _gather_tensors_for_checkpoint(self):
    536     """Returns a dictionary of Tensors to save with this object."""
    537     return {}
    538 
    539 
    540 class Checkpointable(CheckpointableBase):
    541   """Manages dependencies on other objects.
    542 
    543   `Checkpointable` objects may have dependencies: other `Checkpointable` objects
    544   which should be saved if the object declaring the dependency is saved. A
    545   correctly saveable program has a dependency graph such that if changing a
    546   global variable affects an object (e.g. changes the behavior of any of its
    547   methods) then there is a chain of dependencies from the influenced object to
    548   the variable.
    549 
    550   Dependency edges have names, and are created implicitly when a
    551   `Checkpointable` object is assigned to an attribute of another
    552   `Checkpointable` object. For example:
    553 
    554   ```
    555   obj = Checkpointable()
    556   obj.v = ResourceVariable(0.)
    557   ```
    558 
    559   The `Checkpointable` object `obj` now has a dependency named "v" on a
    560   variable.
    561 
    562   `Checkpointable` objects may specify `Tensor`s to be saved and restored
    563   directly (e.g. a `Variable` indicating how to save itself) rather than through
    564   dependencies on other objects. See
    565   `Checkpointable._scatter_tensors_from_checkpoint` and
    566   `Checkpointable._gather_tensors_for_checkpoint` for details.
    567   """
    568 
    569   def __setattr__(self, name, value):
    570     """Support self.foo = checkpointable syntax."""
    571     # Perform the attribute assignment, and potentially call other __setattr__
    572     # overrides such as that for tf.keras.Model.
    573     super(Checkpointable, self).__setattr__(name, value)
    574     if isinstance(value, CheckpointableBase):
    575       self._track_checkpointable(
    576           value, name=name,
    577           # Allow the user to switch the Checkpointable which is tracked by this
    578           # name, since assigning a new variable to an attribute has
    579           # historically been fine (e.g. Adam did this).
    580           # TODO(allenl): Should this be a warning once Checkpointable save/load
    581           # is usable?
    582           overwrite=True)
    583