Home | History | Annotate | Download | only in saved_model
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Import a trackable object from a SavedModel."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 import os
     23 
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.framework import tensor_util
     27 from tensorflow.python.ops import init_ops
     28 from tensorflow.python.ops import resource_variable_ops
     29 from tensorflow.python.ops import variables
     30 from tensorflow.python.saved_model import function_deserialization
     31 from tensorflow.python.saved_model import load_v1_in_v2
     32 from tensorflow.python.saved_model import loader_impl
     33 from tensorflow.python.saved_model import nested_structure_coder
     34 from tensorflow.python.saved_model import revived_types
     35 from tensorflow.python.saved_model import utils_impl as saved_model_utils
     36 from tensorflow.python.training.tracking import base
     37 from tensorflow.python.training.tracking import graph_view
     38 from tensorflow.python.training.tracking import tracking
     39 from tensorflow.python.training.tracking import util
     40 from tensorflow.python.util import nest
     41 from tensorflow.python.util.tf_export import tf_export
     42 
     43 
     44 class _Loader(object):
     45   """Helper class to load an object-based SavedModel."""
     46 
     47   def __init__(self, object_graph_proto, saved_model_proto, export_dir):
     48     meta_graph = saved_model_proto.meta_graphs[0]
     49     self._asset_file_def = meta_graph.asset_file_def
     50     self._operation_attributes = {
     51         node.name: node.attr for node in meta_graph.graph_def.node}
     52     self._proto = object_graph_proto
     53     self._export_dir = export_dir
     54     self._concrete_functions = (
     55         function_deserialization.load_function_def_library(
     56             meta_graph.graph_def.library))
     57     self._load_all()
     58     # TODO(b/124045874): There are limitations with functions whose captures
     59     # trigger other functions to be executed. For now it is only guaranteed to
     60     # work if the captures of a function only trigger functions without
     61     # captures.
     62     self._setup_functions_structures()
     63     self._setup_functions_captures()
     64     self._restore_checkpoint()
     65 
     66     for node in self._nodes:
     67       if isinstance(node, tracking.TrackableResource):
     68         init_op = node._initialize()  # pylint: disable=protected-access
     69         ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
     70 
     71   def _setup_functions_structures(self):
     72     """Setup structure for inputs and outputs of restored functions."""
     73     coder = nested_structure_coder.StructureCoder()
     74     for name, proto in sorted(self._proto.concrete_functions.items()):
     75       concrete_function = self._concrete_functions[name]
     76       # By setting the structured_outputs directly, we can rely on this
     77       # function_lib.ConcreteFunction object to perform the output repacking
     78       # logic. The only limitation of that logic is that it only works
     79       # with output that is convertible to Tensors and the conversion
     80       # always happens. For example tf.TensorShape([2, 3]) will be
     81       # converted to Tensor representing [2, 3].
     82       original_outputs = coder.decode_proto(proto.output_signature)
     83       # The original_outputs here had Tensors converted to TensorSpecs, so
     84       # the restored function's structured_outputs field will not be
     85       # exactly the same. Fortunately the repacking logic cares only about
     86       # the structure.
     87       # TODO(vbardiovsky): Should we just replicate the structures, with
     88       # Nones instead of real objects?
     89       concrete_function._func_graph.structured_outputs = original_outputs  # pylint: disable=protected-access
     90       concrete_function._func_graph.structured_input_signature = (  # pylint: disable=protected-access
     91           coder.decode_proto(proto.canonicalized_input_signature))
     92 
     93   def _setup_functions_captures(self):
     94     """Setup captures and variables in restored functions."""
     95     concrete_functions = sorted(self._proto.concrete_functions.items())
     96     for name, proto in concrete_functions:
     97       concrete_function = self._concrete_functions[name]
     98       bound_inputs = [
     99           self._get_tensor_from_node(node_id)
    100           for node_id in proto.bound_inputs]
    101       bound_variables = [
    102           self._nodes[node_id]
    103           for node_id in proto.bound_inputs
    104           if self._proto.nodes[node_id].WhichOneof("kind") == "variable"
    105       ]
    106       # TODO(andresp): This is only injecting the captured inputs into the
    107       # concrete function, note that we did not modify the FuncGraph
    108       # itself.
    109       concrete_function._captured_inputs = bound_inputs  # pylint: disable=protected-access
    110       concrete_function._func_graph.variables = bound_variables  # pylint: disable=protected-access
    111 
    112   def _get_tensor_from_node(self, node_id):
    113     """Resolves a node id into a tensor to be captured for a function."""
    114     with ops.init_scope():
    115       obj = self._nodes[node_id]
    116       if resource_variable_ops.is_resource_variable(obj):
    117         return obj.handle
    118       elif isinstance(obj, tracking.TrackableAsset):
    119         return obj.asset_path
    120       elif tensor_util.is_tensor(obj):
    121         return obj
    122       elif isinstance(obj, tracking.TrackableResource):
    123         # Note: this executes restored functions in the TrackableResource.
    124         return obj.resource_handle
    125       raise ValueError("Can't convert node %s to tensor" % (type(obj)))
    126 
    127   def _load_all(self):
    128     """Load all saved objects and wire their properties."""
    129     # Maps from node ids to recreated objects
    130     nodes = {}
    131     # Maps from node ids to setter functions (same signature as setattr) for
    132     # setting dependencies.
    133     node_setters = {}
    134 
    135     # Figure out which objects are slot variables. These objects are created
    136     # with Optimizer.add_slot rather than _recreate_variable.
    137     slot_variable_node_ids = set()
    138     for proto in self._proto.nodes:
    139       for slot_variable_proto in proto.slot_variables:
    140         slot_variable_node_ids.add(slot_variable_proto.slot_variable_node_id)
    141 
    142     # Re-create everything except slot variables.
    143     for node_id, proto in enumerate(self._proto.nodes):
    144       if node_id in slot_variable_node_ids:
    145         # Defer recreating slot variables so we can use the public Optimizer
    146         # interface.
    147         continue
    148       node, setter = self._recreate(proto)
    149       nodes[node_id] = node
    150       node_setters[node_id] = setter
    151 
    152     # Now that we have created the variables being optimized, we have enough
    153     # information to re-create slot variables for them.
    154     for node_id, proto in enumerate(self._proto.nodes):
    155       optimizer_object = nodes[node_id]
    156       for slot_variable_proto in proto.slot_variables:
    157         optimized_variable = nodes[
    158             slot_variable_proto.original_variable_node_id]
    159         slot_variable = optimizer_object.add_slot(
    160             var=optimized_variable,
    161             slot_name=slot_variable_proto.slot_name)
    162         nodes[slot_variable_proto.slot_variable_node_id] = slot_variable
    163         node_setters[slot_variable_proto.slot_variable_node_id] = setattr
    164 
    165     self._nodes = []
    166 
    167     # After creating the objects, construct the edges between the objects.
    168     for node_id, object_proto in enumerate(self._proto.nodes):
    169       obj = nodes[node_id]
    170       setter = node_setters[node_id]
    171       self._nodes.append(obj)
    172 
    173       for reference in object_proto.children:
    174         setter(obj, reference.local_name, nodes[reference.node_id])
    175         # Note: if an object has an attribute `__call__` add a class method
    176         # that allows `obj()` syntax to work. This is done per-instance to
    177         # allow `callable` to be used to find out if an object is callable.
    178         if reference.local_name == "__call__":
    179           setattr(type(obj), "__call__", _call_attribute)
    180 
    181   def _restore_checkpoint(self):
    182     """Load state from checkpoint into the deserialized objects."""
    183     variables_path = saved_model_utils.get_variables_path(self._export_dir)
    184     # TODO(andresp): Clean use of private methods of TrackableSaver.
    185     # pylint: disable=protected-access
    186     saver = util.TrackableSaver(graph_view.ObjectGraphView(self.get(0)))
    187     saver._file_prefix_placeholder = constant_op.constant(variables_path)
    188     load_status = saver.restore(variables_path)
    189     load_status.assert_existing_objects_matched()
    190     checkpoint = load_status._checkpoint
    191 
    192     # When running in eager mode, the `restore` call above has already run and
    193     # restored the state of trackables, call `position.restore_ops()` will
    194     # return an empty list as there is nothing left to do. In graph mode, that
    195     # will return the list of ops that must run to restore the object on that
    196     # position. We have to wire them in the initializers of the objects so that
    197     # they get initialized properly when using common practices (e.g. the ones
    198     # used by ManagedSession) without further user action.
    199     for object_id, obj in dict(checkpoint.object_by_proto_id).items():
    200       position = base.CheckpointPosition(checkpoint=checkpoint,
    201                                          proto_id=object_id)
    202       restore_ops = position.restore_ops()
    203       if restore_ops:
    204         if resource_variable_ops.is_resource_variable(obj):
    205           obj._initializer_op = restore_ops
    206         else:
    207           raise NotImplementedError(
    208               ("Missing functionality to restore state of object "
    209                "%r from the checkpoint." % obj))
    210 
    211   def get(self, node_id):
    212     return self._nodes[node_id]
    213 
    214   def _recreate(self, proto):
    215     """Creates a Python object from a SavedObject protocol buffer."""
    216     factory = {
    217         "user_object": lambda: self._recreate_user_object(proto.user_object),
    218         "asset": lambda: self._recreate_asset(proto.asset),
    219         "function": lambda: self._recreate_function(proto.function),
    220         "bare_concrete_function": functools.partial(
    221             self._recreate_bare_concrete_function,
    222             proto.bare_concrete_function),
    223         "variable": lambda: self._recreate_variable(proto.variable),
    224         "constant": lambda: self._recreate_constant(proto.constant),
    225         "resource": lambda: self._recreate_resource(proto.resource),
    226     }
    227     kind = proto.WhichOneof("kind")
    228     if kind not in factory:
    229       raise ValueError("Unknown SavedObject type: %r" % kind)
    230     return factory[kind]()
    231 
    232   def _recreate_user_object(self, proto):
    233     """Instantiates a SavedUserObject."""
    234     looked_up = revived_types.deserialize(proto)
    235     if looked_up is None:
    236       # Note: each user object has its own class. This allows to make each one
    237       # individually callable by adding a `__call__` method to the classes of
    238       # the objects instances that have a `__call__` property.
    239 
    240       class _UserObject(tracking.AutoTrackable):
    241         pass
    242 
    243       return _UserObject(), setattr
    244     return looked_up
    245 
    246   def _recreate_asset(self, proto):
    247     filename = os.path.join(
    248         saved_model_utils.get_assets_dir(self._export_dir),
    249         self._asset_file_def[proto.asset_file_def_index].filename)
    250     return tracking.TrackableAsset(filename), setattr
    251 
    252   def _recreate_function(self, proto):
    253     return function_deserialization.recreate_function(
    254         proto, self._concrete_functions), setattr
    255 
    256   def _recreate_bare_concrete_function(self, proto):
    257     return function_deserialization.setup_bare_concrete_function(
    258         proto, self._concrete_functions), setattr
    259 
    260   def _recreate_variable(self, proto):
    261     # TODO(andresp): Can we use the checkpointed value as initializer?
    262     dummy_value = init_ops.Zeros(dtype=proto.dtype)(shape=proto.shape)
    263     return variables.Variable(dummy_value, trainable=proto.trainable), setattr
    264 
    265   def _recreate_constant(self, proto):
    266     tensor_proto = self._operation_attributes[proto.operation]["value"].tensor
    267     imported_constant = constant_op.constant(
    268         tensor_util.MakeNdarray(tensor_proto))
    269     return imported_constant, setattr
    270 
    271   def _recreate_resource(self, proto):
    272     del proto
    273     return _RestoredResource(), setattr
    274 
    275 
    276 # TODO(b/124205571,b/124092991): Solve destruction of resources.
    277 class _RestoredResource(tracking.TrackableResource):
    278   """Restored SavedResource."""
    279 
    280   def _create_resource(self):
    281     raise RuntimeError()
    282 
    283   def _initialize(self):
    284     raise RuntimeError()
    285 
    286   def _list_functions_for_serialization(self):
    287     # Overwrite this method to avoid the implementation of
    288     # base class to re-wrap the polymorphic functions into
    289     # another layer of `tf.function`.
    290     return {
    291         "_create_resource": self._create_resource,
    292         "_initialize": self._initialize,
    293     }
    294 
    295 
    296 def _call_attribute(instance, *args, **kwargs):
    297   return instance.__call__(*args, **kwargs)
    298 
    299 
    300 @tf_export("saved_model.load", v1=["saved_model.load_v2"])
    301 def load(export_dir, tags=None):
    302   """Load a SavedModel from `export_dir`.
    303 
    304   Signatures associated with the SavedModel are available as functions:
    305 
    306   ```python
    307   imported = tf.saved_model.load(path)
    308   f = imported.signatures["serving_default"]
    309   print(f(x=tf.constant([[1.]])))
    310   ```
    311 
    312   Objects exported with `tf.saved_model.save` additionally have trackable
    313   objects and functions assigned to attributes:
    314 
    315   ```python
    316   exported = tf.train.Checkpoint(v=tf.Variable(3.))
    317   exported.f = tf.function(
    318       lambda x: exported.v * x,
    319       input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)])
    320   tf.saved_model.save(exported, path)
    321   imported = tf.saved_model.load(path)
    322   assert 3. == imported.v.numpy()
    323   assert 6. == imported.f(x=tf.constant(2.)).numpy()
    324   ```
    325 
    326   Args:
    327     export_dir: The SavedModel directory to load from.
    328     tags: A tag or sequence of tags identifying the MetaGraph to load. Optional
    329       if the SavedModel contains a single MetaGraph, as for those exported from
    330       `tf.saved_model.load`.
    331 
    332   Returns:
    333     A trackable object with a `signatures` attribute mapping from signature
    334     keys to functions. If the SavedModel was exported by `tf.saved_model.load`,
    335     it also points to trackable objects and functions which were attached
    336     to the exported object.
    337 
    338   Raises:
    339     ValueError: If `tags` don't match a MetaGraph in the SavedModel.
    340   """
    341   if tags is not None and not isinstance(tags, set):
    342     # Supports e.g. tags=SERVING and tags=[SERVING]. Sets aren't considered
    343     # sequences for nest.flatten, so we put those through as-is.
    344     tags = nest.flatten(tags)
    345   saved_model_proto = loader_impl.parse_saved_model(export_dir)
    346   if (len(saved_model_proto.meta_graphs) == 1
    347       and saved_model_proto.meta_graphs[0].HasField("object_graph_def")):
    348     meta_graph_def = saved_model_proto.meta_graphs[0]
    349     if (tags is not None
    350         and set(tags) != set(meta_graph_def.meta_info_def.tags)):
    351       raise ValueError(
    352           ("The SavedModel at {} has one MetaGraph with tags {}, but got an "
    353            "incompatible argument tags={} to tf.saved_model.load. You may omit "
    354            "it, pass 'None', or pass matching tags.")
    355           .format(export_dir, meta_graph_def.meta_info_def.tags, tags))
    356     object_graph_proto = meta_graph_def.object_graph_def
    357     with ops.init_scope():
    358       loader = _Loader(object_graph_proto,
    359                        saved_model_proto,
    360                        export_dir)
    361       root = loader.get(0)
    362   else:
    363     with ops.init_scope():
    364       root = load_v1_in_v2.load(export_dir, tags)
    365   return root
    366