Home | History | Annotate | Download | only in saved_model
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for SavedModel."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 
     23 from tensorflow.core.framework import types_pb2
     24 from tensorflow.core.protobuf import config_pb2
     25 from tensorflow.core.protobuf import meta_graph_pb2
     26 from tensorflow.python.client import session
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import errors
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import test_ops
     32 from tensorflow.python.framework import test_util
     33 from tensorflow.python.lib.io import file_io
     34 from tensorflow.python.ops import control_flow_ops
     35 from tensorflow.python.ops import math_ops
     36 from tensorflow.python.ops import state_ops
     37 from tensorflow.python.ops import variables
     38 from tensorflow.python.platform import test
     39 from tensorflow.python.saved_model import builder as saved_model_builder
     40 from tensorflow.python.saved_model import constants
     41 from tensorflow.python.saved_model import loader
     42 from tensorflow.python.saved_model import loader_impl
     43 from tensorflow.python.saved_model import main_op
     44 from tensorflow.python.saved_model import signature_def_utils
     45 from tensorflow.python.saved_model import tag_constants
     46 from tensorflow.python.training import saver_test_utils
     47 from tensorflow.python.util import compat
     48 
     49 SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123")
     50 
     51 
     52 def tearDownModule():
     53   file_io.delete_recursively(test.get_temp_dir())
     54 
     55 
     56 @test_util.with_c_api
     57 class SavedModelTest(test.TestCase):
     58 
     59   def _get_export_dir(self, label):
     60     if ops._USE_C_API:
     61       label += "_c_api"
     62     return os.path.join(test.get_temp_dir(), label)
     63 
     64   def _init_and_validate_variable(self, sess, variable_name, variable_value):
     65     v = variables.Variable(variable_value, name=variable_name)
     66     sess.run(variables.global_variables_initializer())
     67     self.assertEqual(variable_value, v.eval())
     68 
     69   def _build_asset_collection(self, asset_file_name, asset_file_contents,
     70                               asset_file_tensor_name):
     71     asset_filepath = os.path.join(
     72         compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name))
     73     file_io.write_string_to_file(asset_filepath, asset_file_contents)
     74     asset_file_tensor = constant_op.constant(
     75         asset_filepath, name=asset_file_tensor_name)
     76     ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file_tensor)
     77     asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
     78     return asset_collection
     79 
     80   def _validate_asset_collection(self, export_dir, graph_collection_def,
     81                                  expected_asset_file_name,
     82                                  expected_asset_file_contents,
     83                                  expected_asset_tensor_name):
     84     assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
     85     asset = meta_graph_pb2.AssetFileDef()
     86     assets_any[0].Unpack(asset)
     87     assets_path = os.path.join(
     88         compat.as_bytes(export_dir),
     89         compat.as_bytes(constants.ASSETS_DIRECTORY),
     90         compat.as_bytes(expected_asset_file_name))
     91     actual_asset_contents = file_io.read_file_to_string(assets_path)
     92     self.assertEqual(expected_asset_file_contents,
     93                      compat.as_text(actual_asset_contents))
     94     self.assertEqual(expected_asset_file_name, asset.filename)
     95     self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name)
     96 
     97   def _validate_inputs_tensor_info(self, builder, tensor_info):
     98     with self.test_session(graph=ops.Graph()) as sess:
     99       self._init_and_validate_variable(sess, "v", 42)
    100 
    101       foo_signature = signature_def_utils.build_signature_def({
    102           "foo_inputs": tensor_info
    103       }, dict(), "foo")
    104       self.assertRaises(
    105           AssertionError,
    106           builder.add_meta_graph_and_variables,
    107           sess, ["foo"],
    108           signature_def_map={"foo_key": foo_signature})
    109 
    110   def _validate_outputs_tensor_info(self, builder, tensor_info):
    111     with self.test_session(graph=ops.Graph()) as sess:
    112       self._init_and_validate_variable(sess, "v", 42)
    113 
    114       foo_signature = signature_def_utils.build_signature_def(
    115           dict(), {"foo_outputs": tensor_info}, "foo")
    116       self.assertRaises(
    117           AssertionError,
    118           builder.add_meta_graph_and_variables,
    119           sess, ["foo"],
    120           signature_def_map={"foo_key": foo_signature})
    121 
    122   def testMaybeSavedModelDir(self):
    123     base_path = test.test_src_dir_path("/python/saved_model")
    124     self.assertFalse(loader.maybe_saved_model_directory(base_path))
    125     base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
    126     self.assertTrue(loader.maybe_saved_model_directory(base_path))
    127     base_path = "complete_garbage"
    128     self.assertFalse(loader.maybe_saved_model_directory(base_path))
    129 
    130   def testBadSavedModelFileFormat(self):
    131     export_dir = self._get_export_dir("test_bad_saved_model_file_format")
    132     # Attempt to load a SavedModel from an export directory that does not exist.
    133     with self.test_session(graph=ops.Graph()) as sess:
    134       with self.assertRaisesRegexp(IOError,
    135                                    "SavedModel file does not exist at: %s" %
    136                                    export_dir):
    137         loader.load(sess, ["foo"], export_dir)
    138 
    139     os.makedirs(export_dir)
    140     # Write an invalid binary proto to saved_model.pb.
    141     path_to_pb = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
    142     with open(path_to_pb, "w") as f:
    143       f.write("invalid content")
    144     with self.test_session(graph=ops.Graph()) as sess:
    145       with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" %
    146                                    constants.SAVED_MODEL_FILENAME_PB):
    147         loader.load(sess, ["foo"], export_dir)
    148 
    149     # Cleanup the directory and start again.
    150     file_io.delete_recursively(export_dir)
    151 
    152     os.makedirs(export_dir)
    153     # Write an invalid text proto to saved_model.pbtxt
    154     path_to_pbtxt = os.path.join(export_dir,
    155                                  constants.SAVED_MODEL_FILENAME_PBTXT)
    156     with open(path_to_pbtxt, "w") as f:
    157       f.write("invalid content")
    158     with self.test_session(graph=ops.Graph()) as sess:
    159       with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" %
    160                                    constants.SAVED_MODEL_FILENAME_PBTXT):
    161         loader.load(sess, ["foo"], export_dir)
    162 
    163   def testVerifySessionGraphUsage(self):
    164     export_dir = self._get_export_dir("test_verify_session_graph_usage")
    165     builder = saved_model_builder.SavedModelBuilder(export_dir)
    166 
    167     with self.test_session(graph=ops.Graph()) as sess:
    168       self._init_and_validate_variable(sess, "v", 42)
    169       builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
    170 
    171     # Save the SavedModel to disk.
    172     builder.save()
    173 
    174     # Build a session and supply it to the load operation.
    175     sess = session.Session(graph=ops.Graph())
    176     loader.load(sess, [tag_constants.TRAINING], export_dir)
    177 
    178     # Check the variable within the scope of the session and its graph.
    179     with sess:
    180       self.assertEqual(
    181           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    182 
    183   def testSequence(self):
    184     export_dir = self._get_export_dir("test_sequence")
    185     builder = saved_model_builder.SavedModelBuilder(export_dir)
    186 
    187     # Expect an assertion error since add_meta_graph_and_variables() should be
    188     # invoked before any add_meta_graph() calls.
    189     with self.test_session(graph=ops.Graph()) as sess:
    190       self.assertRaises(AssertionError, builder.add_meta_graph, ["foo"])
    191 
    192     # Expect an assertion error for multiple calls of
    193     # add_meta_graph_and_variables() since weights should be saved exactly once.
    194     with self.test_session(graph=ops.Graph()) as sess:
    195       self._init_and_validate_variable(sess, "v", 42)
    196       builder.add_meta_graph_and_variables(sess, ["bar"])
    197       self.assertRaises(AssertionError, builder.add_meta_graph_and_variables,
    198                         sess, ["baz"])
    199 
    200   def testTags(self):
    201     export_dir = self._get_export_dir("test_tags")
    202     builder = saved_model_builder.SavedModelBuilder(export_dir)
    203 
    204     # Graph with a single variable. SavedModel invoked to:
    205     # - add with weights.
    206     # - a single tag (from predefined constants).
    207     with self.test_session(graph=ops.Graph()) as sess:
    208       self._init_and_validate_variable(sess, "v", 42)
    209       builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
    210 
    211     # Graph that updates the single variable. SavedModel invoked to:
    212     # - simply add the model (weights are not updated).
    213     # - a single tag (from predefined constants).
    214     with self.test_session(graph=ops.Graph()) as sess:
    215       self._init_and_validate_variable(sess, "v", 43)
    216       builder.add_meta_graph([tag_constants.SERVING])
    217 
    218     # Graph that updates the single variable. SavedModel invoked to:
    219     # - simply add the model (weights are not updated).
    220     # - multiple tags (from predefined constants).
    221     with self.test_session(graph=ops.Graph()) as sess:
    222       self._init_and_validate_variable(sess, "v", 45)
    223       builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU])
    224 
    225     # Graph that updates the single variable. SavedModel invoked to:
    226     # - simply add the model (weights are not updated).
    227     # - multiple tags (from predefined constants for serving on TPU).
    228     with self.test_session(graph=ops.Graph()) as sess:
    229       self._init_and_validate_variable(sess, "v", 45)
    230       builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU])
    231 
    232     # Graph that updates the single variable. SavedModel is invoked:
    233     # - to add the model (weights are not updated).
    234     # - multiple custom tags.
    235     with self.test_session(graph=ops.Graph()) as sess:
    236       self._init_and_validate_variable(sess, "v", 44)
    237       builder.add_meta_graph(["foo", "bar"])
    238 
    239     # Save the SavedModel to disk.
    240     builder.save()
    241 
    242     # Restore the graph with a single predefined tag whose variables were saved.
    243     with self.test_session(graph=ops.Graph()) as sess:
    244       loader.load(sess, [tag_constants.TRAINING], export_dir)
    245       self.assertEqual(
    246           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    247 
    248     # Restore the graph with a single predefined tag whose variables were not
    249     # saved.
    250     with self.test_session(graph=ops.Graph()) as sess:
    251       loader.load(sess, [tag_constants.SERVING], export_dir)
    252       self.assertEqual(
    253           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    254 
    255     # Restore the graph with multiple predefined tags whose variables were not
    256     # saved.
    257     with self.test_session(graph=ops.Graph()) as sess:
    258       loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir)
    259       self.assertEqual(
    260           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    261 
    262     # Restore the graph with multiple predefined tags (for serving on TPU)
    263     # whose variables were not saved.
    264     with self.test_session(graph=ops.Graph()) as sess:
    265       loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir)
    266       self.assertEqual(
    267           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    268 
    269     # Restore the graph with multiple tags. Provide duplicate tags to test set
    270     # semantics.
    271     with self.test_session(graph=ops.Graph()) as sess:
    272       loader.load(sess, ["foo", "bar", "foo"], export_dir)
    273       self.assertEqual(
    274           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    275 
    276     # Try restoring a graph with a non-existent tag. This should yield a runtime
    277     # error.
    278     with self.test_session(graph=ops.Graph()) as sess:
    279       self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"],
    280                         export_dir)
    281 
    282     # Try restoring a graph where a subset of the tags match. Since tag matching
    283     # for meta graph defs follows "all" semantics, this should yield a runtime
    284     # error.
    285     with self.test_session(graph=ops.Graph()) as sess:
    286       self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"],
    287                         export_dir)
    288 
    289   def testVariables(self):
    290     export_dir = self._get_export_dir("test_variables")
    291     builder = saved_model_builder.SavedModelBuilder(export_dir)
    292 
    293     # Graph with two variables. SavedModel invoked to:
    294     # - add with weights.
    295     with self.test_session(graph=ops.Graph()) as sess:
    296       self._init_and_validate_variable(sess, "v1", 1)
    297       self._init_and_validate_variable(sess, "v2", 2)
    298       builder.add_meta_graph_and_variables(sess, ["foo"])
    299 
    300     # Graph with a single variable (subset of the variables from the previous
    301     # graph whose weights were saved). SavedModel invoked to:
    302     # - simply add the model (weights are not updated).
    303     with self.test_session(graph=ops.Graph()) as sess:
    304       self._init_and_validate_variable(sess, "v2", 3)
    305       builder.add_meta_graph(["bar"])
    306 
    307     # Graph with a single variable (disjoint set of variables from the previous
    308     # graph whose weights were saved). SavedModel invoked to:
    309     # - simply add the model (weights are not updated).
    310     with self.test_session(graph=ops.Graph()) as sess:
    311       self._init_and_validate_variable(sess, "v3", 4)
    312       builder.add_meta_graph(["baz"])
    313 
    314     # Save the SavedModel to disk.
    315     builder.save()
    316 
    317     # Restore the graph with tag "foo", whose variables were saved.
    318     with self.test_session(graph=ops.Graph()) as sess:
    319       loader.load(sess, ["foo"], export_dir)
    320       collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    321       self.assertEqual(len(collection_vars), 2)
    322       self.assertEqual(1, collection_vars[0].eval())
    323       self.assertEqual(2, collection_vars[1].eval())
    324 
    325     # Restore the graph with tag "bar", whose variables were not saved. Only the
    326     # subset of the variables added to the graph will be restored with the
    327     # checkpointed value.
    328     with self.test_session(graph=ops.Graph()) as sess:
    329       loader.load(sess, ["bar"], export_dir)
    330       collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    331       self.assertEqual(len(collection_vars), 1)
    332       self.assertEqual(2, collection_vars[0].eval())
    333 
    334     # Try restoring the graph with tag "baz", whose variables were not saved.
    335     # Since this graph has a disjoint set of variables from the set that was
    336     # saved, this should raise an error.
    337     with self.test_session(graph=ops.Graph()) as sess:
    338       self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"],
    339                         export_dir)
    340 
    341   def testGraphWithoutVariables(self):
    342     export_dir = self._get_export_dir("test_graph_has_variables")
    343     builder = saved_model_builder.SavedModelBuilder(export_dir)
    344 
    345     # Graph with no variables.
    346     with self.test_session(graph=ops.Graph()) as sess:
    347       constant_5_name = constant_op.constant(5.0).name
    348       builder.add_meta_graph_and_variables(sess, ["foo"])
    349 
    350     # Second graph with no variables
    351     with self.test_session(graph=ops.Graph()) as sess:
    352       constant_6_name = constant_op.constant(6.0).name
    353       builder.add_meta_graph(["bar"])
    354 
    355     # Save the SavedModel to disk.
    356     builder.save()
    357 
    358     # Restore the graph with tag "foo".
    359     with self.test_session(graph=ops.Graph()) as sess:
    360       loader.load(sess, ["foo"], export_dir)
    361       # Read the constant a from the graph.
    362       a = ops.get_default_graph().get_tensor_by_name(constant_5_name)
    363       b = constant_op.constant(6.0)
    364       c = a * b
    365       self.assertEqual(30.0, sess.run(c))
    366 
    367     # Restore the graph with tag "bar".
    368     with self.test_session(graph=ops.Graph()) as sess:
    369       loader.load(sess, ["bar"], export_dir)
    370       # Read the constant a from the graph.
    371       a = ops.get_default_graph().get_tensor_by_name(constant_6_name)
    372       b = constant_op.constant(5.0)
    373       c = a * b
    374       self.assertEqual(30.0, sess.run(c))
    375 
    376   def testNoOverwrite(self):
    377     export_dir = self._get_export_dir("test_no_overwrite")
    378     builder = saved_model_builder.SavedModelBuilder(export_dir)
    379 
    380     # Graph with a single variable. SavedModel invoked to:
    381     # - add with weights.
    382     with self.test_session(graph=ops.Graph()) as sess:
    383       self._init_and_validate_variable(sess, "v", 42)
    384       builder.add_meta_graph_and_variables(sess, ["foo"])
    385 
    386     # Save the SavedModel to disk in text format.
    387     builder.save(as_text=True)
    388 
    389     # Restore the graph with tag "foo", whose variables were saved.
    390     with self.test_session(graph=ops.Graph()) as sess:
    391       loader.load(sess, ["foo"], export_dir)
    392       self.assertEqual(
    393           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    394 
    395     # An attempt to create another builder with the same export directory should
    396     # result in an assertion error.
    397     self.assertRaises(AssertionError, saved_model_builder.SavedModelBuilder,
    398                       export_dir)
    399 
    400   def testSaveAsText(self):
    401     export_dir = self._get_export_dir("test_astext")
    402     builder = saved_model_builder.SavedModelBuilder(export_dir)
    403 
    404     # Graph with a single variable. SavedModel invoked to:
    405     # - add with weights.
    406     with self.test_session(graph=ops.Graph()) as sess:
    407       self._init_and_validate_variable(sess, "v", 42)
    408       builder.add_meta_graph_and_variables(sess, ["foo"])
    409 
    410     # Graph with the same single variable. SavedModel invoked to:
    411     # - simply add the model (weights are not updated).
    412     with self.test_session(graph=ops.Graph()) as sess:
    413       self._init_and_validate_variable(sess, "v", 43)
    414       builder.add_meta_graph(["bar"])
    415 
    416     # Save the SavedModel to disk in text format.
    417     builder.save(as_text=True)
    418 
    419     # Restore the graph with tag "foo", whose variables were saved.
    420     with self.test_session(graph=ops.Graph()) as sess:
    421       loader.load(sess, ["foo"], export_dir)
    422       self.assertEqual(
    423           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    424 
    425     # Restore the graph with tag "bar", whose variables were not saved.
    426     with self.test_session(graph=ops.Graph()) as sess:
    427       loader.load(sess, ["bar"], export_dir)
    428       self.assertEqual(
    429           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    430 
    431   def testCollections(self):
    432     export_dir = self._get_export_dir("test_collections")
    433     builder = saved_model_builder.SavedModelBuilder(export_dir)
    434 
    435     # Graph with a single variable added to a collection. SavedModel invoked to:
    436     # - add with weights.
    437     with self.test_session(graph=ops.Graph()) as sess:
    438       v = variables.Variable(42, name="v")
    439       ops.add_to_collection("foo_vars", v)
    440       sess.run(variables.global_variables_initializer())
    441       self.assertEqual(42, v.eval())
    442       builder.add_meta_graph_and_variables(sess, ["foo"])
    443 
    444     # Graph with the same single variable added to a different collection.
    445     # SavedModel invoked to:
    446     # - simply add the model (weights are not updated).
    447     with self.test_session(graph=ops.Graph()) as sess:
    448       v = variables.Variable(43, name="v")
    449       ops.add_to_collection("bar_vars", v)
    450       sess.run(variables.global_variables_initializer())
    451       self.assertEqual(43, v.eval())
    452       builder.add_meta_graph(["bar"])
    453 
    454     # Save the SavedModel to disk.
    455     builder.save()
    456 
    457     # Restore the graph with tag "foo", whose variables were saved. The
    458     # collection 'foo_vars' should contain a single element. The collection
    459     # 'bar_vars' should not be found.
    460     with self.test_session(graph=ops.Graph()) as sess:
    461       loader.load(sess, ["foo"], export_dir)
    462       collection_foo_vars = ops.get_collection("foo_vars")
    463       self.assertEqual(len(collection_foo_vars), 1)
    464       self.assertEqual(42, collection_foo_vars[0].eval())
    465 
    466       self.assertEqual(len(ops.get_collection("bar_vars")), 0)
    467 
    468     # Restore the graph with tag "bar", whose variables were not saved. The
    469     # collection-def exported as part of the meta graph def is updated to
    470     # reflect the new collection. The value of the variable in the
    471     # collection-def corresponds to the saved value (from the previous graph
    472     # with tag "foo").
    473     with self.test_session(graph=ops.Graph()) as sess:
    474       loader.load(sess, ["bar"], export_dir)
    475       collection_bar_vars = ops.get_collection("bar_vars")
    476       self.assertEqual(len(collection_bar_vars), 1)
    477       self.assertEqual(42, collection_bar_vars[0].eval())
    478 
    479       self.assertEqual(len(ops.get_collection("foo_vars")), 0)
    480 
    481   def testSignatureDefs(self):
    482     export_dir = self._get_export_dir("test_signature_defs")
    483     builder = saved_model_builder.SavedModelBuilder(export_dir)
    484 
    485     # Graph with a single variable and a single entry in the signature def map.
    486     # SavedModel is invoked to add with weights.
    487     with self.test_session(graph=ops.Graph()) as sess:
    488       self._init_and_validate_variable(sess, "v", 42)
    489       # Build and populate an empty SignatureDef for testing.
    490       foo_signature = signature_def_utils.build_signature_def(dict(),
    491                                                               dict(), "foo")
    492       builder.add_meta_graph_and_variables(
    493           sess, ["foo"], signature_def_map={"foo_key": foo_signature})
    494 
    495     # Graph with the same single variable and multiple entries in the signature
    496     # def map. No weights are saved by SavedModel.
    497     with self.test_session(graph=ops.Graph()) as sess:
    498       self._init_and_validate_variable(sess, "v", 43)
    499       # Build and populate a different SignatureDef for testing.
    500       bar_signature = signature_def_utils.build_signature_def(dict(),
    501                                                               dict(), "bar")
    502       # Also, build a different SignatureDef corresponding to "foo_key" defined
    503       # in the previous graph.
    504       foo_new_signature = signature_def_utils.build_signature_def(dict(),
    505                                                                   dict(),
    506                                                                   "foo_new")
    507       builder.add_meta_graph(
    508           ["bar"],
    509           signature_def_map={
    510               "bar_key": bar_signature,
    511               "foo_key": foo_new_signature
    512           })
    513 
    514     # Save the SavedModel to disk.
    515     builder.save()
    516 
    517     # Restore the graph with tag "foo". The single entry in the SignatureDef map
    518     # corresponding to "foo_key" should exist.
    519     with self.test_session(graph=ops.Graph()) as sess:
    520       foo_graph = loader.load(sess, ["foo"], export_dir)
    521       self.assertEqual(
    522           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    523 
    524       foo_signature = foo_graph.signature_def
    525       self.assertEqual(len(foo_signature), 1)
    526       self.assertEqual("foo", foo_signature["foo_key"].method_name)
    527 
    528     # Restore the graph with tag "bar". The SignatureDef map should have two
    529     # entries. One corresponding to "bar_key" and another corresponding to the
    530     # new value of "foo_key".
    531     with self.test_session(graph=ops.Graph()) as sess:
    532       bar_graph = loader.load(sess, ["bar"], export_dir)
    533       self.assertEqual(
    534           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    535 
    536       bar_signature = bar_graph.signature_def
    537       self.assertEqual(len(bar_signature), 2)
    538       self.assertEqual("bar", bar_signature["bar_key"].method_name)
    539       self.assertEqual("foo_new", bar_signature["foo_key"].method_name)
    540 
    541   def testSignatureDefValidation(self):
    542     export_dir = self._get_export_dir("test_signature_def_validation")
    543     builder = saved_model_builder.SavedModelBuilder(export_dir)
    544 
    545     tensor_without_name = meta_graph_pb2.TensorInfo()
    546     tensor_without_name.dtype = types_pb2.DT_FLOAT
    547     self._validate_inputs_tensor_info(builder, tensor_without_name)
    548     self._validate_outputs_tensor_info(builder, tensor_without_name)
    549 
    550     tensor_without_dtype = meta_graph_pb2.TensorInfo()
    551     tensor_without_dtype.name = "x"
    552     self._validate_inputs_tensor_info(builder, tensor_without_dtype)
    553     self._validate_outputs_tensor_info(builder, tensor_without_dtype)
    554 
    555     tensor_empty = meta_graph_pb2.TensorInfo()
    556     self._validate_inputs_tensor_info(builder, tensor_empty)
    557     self._validate_outputs_tensor_info(builder, tensor_empty)
    558 
    559   def testAssets(self):
    560     export_dir = self._get_export_dir("test_assets")
    561     builder = saved_model_builder.SavedModelBuilder(export_dir)
    562 
    563     with self.test_session(graph=ops.Graph()) as sess:
    564       self._init_and_validate_variable(sess, "v", 42)
    565 
    566       # Build an asset collection.
    567       ignored_filepath = os.path.join(
    568           compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt"))
    569       file_io.write_string_to_file(ignored_filepath, "will be ignored")
    570 
    571       asset_collection = self._build_asset_collection("hello42.txt",
    572                                                       "foo bar baz",
    573                                                       "asset_file_tensor")
    574 
    575       builder.add_meta_graph_and_variables(
    576           sess, ["foo"], assets_collection=asset_collection)
    577 
    578     # Save the SavedModel to disk.
    579     builder.save()
    580 
    581     with self.test_session(graph=ops.Graph()) as sess:
    582       foo_graph = loader.load(sess, ["foo"], export_dir)
    583       self._validate_asset_collection(export_dir, foo_graph.collection_def,
    584                                       "hello42.txt", "foo bar baz",
    585                                       "asset_file_tensor:0")
    586       ignored_asset_path = os.path.join(
    587           compat.as_bytes(export_dir),
    588           compat.as_bytes(constants.ASSETS_DIRECTORY),
    589           compat.as_bytes("ignored.txt"))
    590       self.assertFalse(file_io.file_exists(ignored_asset_path))
    591 
    592   def testCustomMainOp(self):
    593     export_dir = self._get_export_dir("test_main_op")
    594     builder = saved_model_builder.SavedModelBuilder(export_dir)
    595 
    596     with self.test_session(graph=ops.Graph()) as sess:
    597       # Add `v1` and `v2` variables to the graph.
    598       v1 = variables.Variable(1, name="v1")
    599       ops.add_to_collection("v", v1)
    600       v2 = variables.Variable(2, name="v2")
    601       ops.add_to_collection("v", v2)
    602 
    603       # Initialize another variable `v3` to 42.
    604       v3 = variables.Variable(42, name="v3")
    605       ops.add_to_collection("v", v3)
    606 
    607       # Set up an assignment op to be run as part of the main_op.
    608       with ops.control_dependencies([main_op.main_op()]):
    609         add_v1_v2 = math_ops.add(v1._ref(), v2._ref())
    610         custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2))
    611 
    612       sess.run(custom_main_op)
    613       builder.add_meta_graph_and_variables(
    614           sess, ["foo"], main_op=custom_main_op)
    615 
    616     # Save the SavedModel to disk.
    617     builder.save()
    618 
    619     with self.test_session(graph=ops.Graph()) as sess:
    620       loader.load(sess, ["foo"], export_dir)
    621       self.assertEqual(1, ops.get_collection("v")[0].eval())
    622       self.assertEqual(2, ops.get_collection("v")[1].eval())
    623       # Evaluates to the sum of the first two variables and assigned as part of
    624       # the main_op, following a restore.
    625       self.assertEqual(3, ops.get_collection("v")[2].eval())
    626 
    627   def testLegacyInitOp(self):
    628     export_dir = self._get_export_dir("test_legacy_init_op")
    629     builder = saved_model_builder.SavedModelBuilder(export_dir)
    630 
    631     with self.test_session(graph=ops.Graph()) as sess:
    632       # Add `v1` and `v2` variables to the graph.
    633       v1 = variables.Variable(1, name="v1")
    634       ops.add_to_collection("v", v1)
    635       v2 = variables.Variable(2, name="v2")
    636       ops.add_to_collection("v", v2)
    637 
    638       # Initialize another variable `v3` to 42.
    639       v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
    640       ops.add_to_collection("v", v3)
    641 
    642       # Set up an assignment op to be run as part of the legacy_init_op.
    643       assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
    644       legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op")
    645 
    646       sess.run(variables.global_variables_initializer())
    647       builder.add_meta_graph_and_variables(
    648           sess, ["foo"], legacy_init_op=legacy_init_op)
    649 
    650     # Save the SavedModel to disk.
    651     builder.save()
    652 
    653     with self.test_session(graph=ops.Graph()) as sess:
    654       loader.load(sess, ["foo"], export_dir)
    655       self.assertEqual(1, ops.get_collection("v")[0].eval())
    656       self.assertEqual(2, ops.get_collection("v")[1].eval())
    657       # Evaluates to the sum of the first two variables and assigned as part of
    658       # the legacy_init_op, following a restore.
    659       self.assertEqual(3, ops.get_collection("v")[2].eval())
    660 
    661   def testLegacyInitOpWithNonEmptyCollection(self):
    662     export_dir = self._get_export_dir(
    663         "test_legacy_init_op_with_non_empty_collection")
    664     builder = saved_model_builder.SavedModelBuilder(export_dir)
    665 
    666     with self.test_session(graph=ops.Graph()) as sess:
    667       # Initialize variable `v1` to 1.
    668       v1 = variables.Variable(1, name="v1")
    669       ops.add_to_collection("v", v1)
    670 
    671       # Initialize another variable `v2` to 42.
    672       v2 = variables.Variable(42, name="v2", trainable=False, collections=[])
    673       ops.add_to_collection("v", v2)
    674 
    675       # Set up an assignment op to be run as part of the legacy_init_op.
    676       assign_v2 = state_ops.assign(v2, v1)
    677       legacy_init_op = control_flow_ops.group(assign_v2, name="legacy_init_op")
    678 
    679       sess.run(variables.global_variables_initializer())
    680 
    681       ops.add_to_collection(constants.LEGACY_INIT_OP_KEY,
    682                             control_flow_ops.no_op())
    683       # AssertionError should be raised since the LEGACY_INIT_OP_KEY collection
    684       # is not empty and we don't support multiple init ops.
    685       with self.assertRaises(AssertionError):
    686         builder.add_meta_graph_and_variables(
    687             sess, ["foo"], legacy_init_op=legacy_init_op)
    688 
    689   def testMultipleAssets(self):
    690     export_dir = self._get_export_dir("test_multiple_assets")
    691     builder = saved_model_builder.SavedModelBuilder(export_dir)
    692 
    693     with self.test_session(graph=ops.Graph()) as sess:
    694       self._init_and_validate_variable(sess, "v", 42)
    695 
    696       # Build an asset collection specific to `foo` graph.
    697       asset_collection = self._build_asset_collection("foo.txt", "content_foo",
    698                                                       "asset_file_tensor")
    699 
    700       # Add the asset collection as part of the graph with tag "foo".
    701       builder.add_meta_graph_and_variables(
    702           sess, ["foo"], assets_collection=asset_collection)
    703 
    704     with self.test_session(graph=ops.Graph()) as sess:
    705       self._init_and_validate_variable(sess, "v", 42)
    706 
    707       # Build an asset collection specific to `bar` graph.
    708       asset_collection = self._build_asset_collection("bar.txt", "content_bar",
    709                                                       "asset_file_tensor")
    710 
    711       # Add the asset collection as part of the graph with tag "bar".
    712       builder.add_meta_graph(["bar"], assets_collection=asset_collection)
    713 
    714     # Save the SavedModel to disk.
    715     builder.save()
    716 
    717     # Check assets restored for graph with tag "foo".
    718     with self.test_session(graph=ops.Graph()) as sess:
    719       foo_graph = loader.load(sess, ["foo"], export_dir)
    720       self._validate_asset_collection(export_dir, foo_graph.collection_def,
    721                                       "foo.txt", "content_foo",
    722                                       "asset_file_tensor:0")
    723 
    724     # Check assets restored for graph with tag "bar".
    725     with self.test_session(graph=ops.Graph()) as sess:
    726       bar_graph = loader.load(sess, ["bar"], export_dir)
    727       self._validate_asset_collection(export_dir, bar_graph.collection_def,
    728                                       "bar.txt", "content_bar",
    729                                       "asset_file_tensor:0")
    730 
    731   def testDuplicateAssets(self):
    732     export_dir = self._get_export_dir("test_duplicate_assets")
    733     builder = saved_model_builder.SavedModelBuilder(export_dir)
    734 
    735     with self.test_session(graph=ops.Graph()) as sess:
    736       self._init_and_validate_variable(sess, "v", 42)
    737 
    738       # Build an asset collection with `foo.txt` that has `foo` specific
    739       # content.
    740       asset_collection = self._build_asset_collection("foo.txt", "content_foo",
    741                                                       "asset_file_tensor")
    742 
    743       # Add the asset collection as part of the graph with tag "foo".
    744       builder.add_meta_graph_and_variables(
    745           sess, ["foo"], assets_collection=asset_collection)
    746 
    747     with self.test_session(graph=ops.Graph()) as sess:
    748       self._init_and_validate_variable(sess, "v", 42)
    749 
    750       # Build an asset collection with `foo.txt` that has `bar` specific
    751       # content.
    752       asset_collection = self._build_asset_collection("foo.txt", "content_bar",
    753                                                       "asset_file_tensor")
    754 
    755       # Add the asset collection as part of the graph with tag "bar".
    756       builder.add_meta_graph(["bar"], assets_collection=asset_collection)
    757 
    758     # Save the SavedModel to disk.
    759     builder.save()
    760 
    761     # Check assets restored for graph with tag "foo".
    762     with self.test_session(graph=ops.Graph()) as sess:
    763       foo_graph = loader.load(sess, ["foo"], export_dir)
    764       self._validate_asset_collection(export_dir, foo_graph.collection_def,
    765                                       "foo.txt", "content_foo",
    766                                       "asset_file_tensor:0")
    767 
    768     # Check assets restored for graph with tag "bar".
    769     with self.test_session(graph=ops.Graph()) as sess:
    770       bar_graph = loader.load(sess, ["bar"], export_dir)
    771 
    772       # Validate the assets for `bar` graph. `foo.txt` should contain the
    773       # original contents corresponding to `foo` graph since an asset with the
    774       # same name across multiple graphs is only stored the first time
    775       self._validate_asset_collection(export_dir, bar_graph.collection_def,
    776                                       "foo.txt", "content_foo",
    777                                       "asset_file_tensor:0")
    778 
    779   def testOp(self):
    780     export_dir = self._get_export_dir("test_op")
    781     builder = saved_model_builder.SavedModelBuilder(export_dir)
    782 
    783     with session.Session(
    784         graph=ops.Graph(),
    785         config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
    786       with sess.graph.device("/cpu:0"):
    787         v1 = variables.Variable(1, name="v1")
    788       with sess.graph.device("/cpu:1"):
    789         v2 = variables.Variable(2, name="v2")
    790 
    791       # v3 is an unsaved variable derived from v1 and v2.  It is used to
    792       # exercise the ability to run an init op when restoring a graph.
    793       v3 = variables.Variable(1, name="v3", trainable=False, collections=[])
    794       assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
    795       init_op = control_flow_ops.group(assign_v3, name="init_op")
    796 
    797       ops.add_to_collection("v", v1)
    798       ops.add_to_collection("v", v2)
    799       ops.add_to_collection("v", v3)
    800       ops.add_to_collection("init_op", init_op)
    801 
    802       sess.run(variables.global_variables_initializer())
    803       self.assertEqual(1, ops.get_collection("v")[0].eval())
    804       self.assertEqual(2, ops.get_collection("v")[1].eval())
    805 
    806       builder.add_meta_graph_and_variables(sess, ["foo"])
    807 
    808     # Save the SavedModel to disk.
    809     builder.save()
    810 
    811     with session.Session(
    812         graph=ops.Graph(),
    813         config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
    814       loader.load(sess, ["foo"], export_dir)
    815 
    816       # Validate variables, run the init op and verify result.
    817       self.assertEqual(1, ops.get_collection("v")[0].eval())
    818       self.assertEqual(2, ops.get_collection("v")[1].eval())
    819       ops.get_collection("init_op")[0].run()
    820       self.assertEqual(3, ops.get_collection("v")[2].eval())
    821 
    822   def testCustomSaveable(self):
    823     export_dir = self._get_export_dir("custom_saveable")
    824     builder = saved_model_builder.SavedModelBuilder(export_dir)
    825 
    826     with session.Session(
    827         graph=ops.Graph(),
    828         config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
    829       # CheckpointedOp is a key-value table that can be saved across sessions.
    830       # The table register itself in SAVEABLE_OBJECTS collection.
    831       v1 = saver_test_utils.CheckpointedOp(name="v1")
    832       variables.global_variables_initializer().run()
    833       v1.insert("k1", 3.0).run()
    834       # Once the table is restored, we can access it through this reference.
    835       ops.add_to_collection("table_ref", v1.table_ref)
    836       builder.add_meta_graph_and_variables(sess, ["foo"])
    837 
    838     # Save the SavedModel to disk.
    839     builder.save()
    840 
    841     with session.Session(
    842         graph=ops.Graph(),
    843         config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
    844       loader.load(sess, ["foo"], export_dir)
    845       # Instantiate a wrapper object from the checkpointed reference.
    846       v1 = saver_test_utils.CheckpointedOp(
    847           name="v1", table_ref=ops.get_collection("table_ref")[0])
    848       self.assertEqual(b"k1", v1.keys().eval())
    849       self.assertEqual(3.0, v1.values().eval())
    850 
    851   def testClearDevices(self):
    852     export_dir = self._get_export_dir("test_clear_devices")
    853     builder = saved_model_builder.SavedModelBuilder(export_dir)
    854 
    855     # Specify a device and save a variable.
    856     ops.reset_default_graph()
    857     with session.Session(
    858         target="",
    859         config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
    860       with sess.graph.device("/cpu:0"):
    861         self._init_and_validate_variable(sess, "v", 42)
    862         builder.add_meta_graph_and_variables(
    863             sess, [tag_constants.TRAINING], clear_devices=True)
    864 
    865     # Save the SavedModel to disk.
    866     builder.save()
    867 
    868     # Restore the graph with a single predefined tag whose variables were saved
    869     # without any device information.
    870     with self.test_session(graph=ops.Graph()) as sess:
    871       loader.load(sess, [tag_constants.TRAINING], export_dir)
    872       self.assertEqual(
    873           42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
    874 
    875   def testStripDefaultAttrs(self):
    876     export_dir = self._get_export_dir("test_strip_default_attrs")
    877     builder = saved_model_builder.SavedModelBuilder(export_dir)
    878 
    879     # Add a graph with two float32 variables and a Complex Op composing them
    880     # with strip_default_attrs enabled.
    881     with session.Session(graph=ops.Graph()) as sess:
    882       real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
    883       imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
    884       math_ops.complex(real_num, imag_num, name="complex")
    885       sess.run(variables.global_variables_initializer())
    886       builder.add_meta_graph_and_variables(
    887           sess, ["foo"], strip_default_attrs=True)
    888 
    889     # Add a graph with the same float32 variables and a Complex Op composing
    890     # them with strip_default_attrs disabled.
    891     with session.Session(graph=ops.Graph()) as sess:
    892       real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
    893       imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
    894       math_ops.complex(real_num, imag_num, name="complex")
    895       sess.run(variables.global_variables_initializer())
    896       builder.add_meta_graph(["bar"], strip_default_attrs=False)
    897 
    898     # Save the SavedModel to disk in text format.
    899     builder.save(as_text=True)
    900 
    901     # Loading graph "foo" via the loader must restore the defaults for the
    902     # "Complex" node based on the "Complex" OpDef in the Op registry.
    903     sess = session.Session(graph=ops.Graph())
    904     meta_graph_def = loader.load(sess, ["foo"], export_dir)
    905     complex_node = test_util.get_node_def_from_graph("complex",
    906                                                      meta_graph_def.graph_def)
    907     self.assertIn("T", complex_node.attr)
    908     self.assertIn("Tout", complex_node.attr)
    909 
    910     # Load graph "foo" from disk as-is to verify default attrs are stripped.
    911     # pylint: disable=protected-access
    912     saved_model_pb = loader_impl._parse_saved_model(export_dir)
    913     self.assertIsNotNone(saved_model_pb)
    914     # pylint: enable=protected-access
    915 
    916     meta_graph_foo_def = None
    917     meta_graph_bar_def = None
    918     for meta_graph_def in saved_model_pb.meta_graphs:
    919       if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
    920         meta_graph_foo_def = meta_graph_def
    921       elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
    922         meta_graph_bar_def = meta_graph_def
    923 
    924     self.assertIsNotNone(meta_graph_foo_def)
    925     self.assertIsNotNone(meta_graph_bar_def)
    926 
    927     # "Complex" Op has 2 attributes with defaults:
    928     #   o "T"    : float32.   (input type)
    929     #   o "Tout" : complex64. (output type)
    930 
    931     # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
    932     # Graph "foo" was saved with strip_default_attrs set to True.
    933     node_def = test_util.get_node_def_from_graph("complex",
    934                                                  meta_graph_foo_def.graph_def)
    935     self.assertNotIn("T", node_def.attr)
    936     self.assertNotIn("Tout", node_def.attr)
    937 
    938     # "Complex" Op in graph "bar" must have attributes "T" and "Tout".
    939     # Graph "bar" was saved with strip_default_attrs set to False.
    940     node_def = test_util.get_node_def_from_graph("complex",
    941                                                  meta_graph_bar_def.graph_def)
    942     self.assertIn("T", node_def.attr)
    943     self.assertIn("Tout", node_def.attr)
    944 
    945   # Tests the behavior of loading SavedModels that having missing attrs or attrs
    946   # with incorrect types.
    947   def testInconsistentConsumerDefaultAttrs(self):
    948     export_dir = self._get_export_dir(
    949         "test_strip_default_attrs_no_consumer_defaults")
    950     builder = saved_model_builder.SavedModelBuilder(export_dir)
    951 
    952     # Add a graph with a single variable and a test op with a defaultless
    953     # float32 attr, "test_attr".
    954     with session.Session(graph=ops.Graph()) as sess:
    955       variables.Variable(1.0, dtype=dtypes.float64, name="var")
    956       test_ops.test_attr(T=dtypes.float32, name="test_attr")
    957       sess.run(variables.global_variables_initializer())
    958       builder.add_meta_graph_and_variables(sess, ["foo"])
    959 
    960     # Save the SavedModel to disk in text format.
    961     builder.save(as_text=True)
    962 
    963     # Rewrite the SavedModel to remove the T attr from "test_attr".
    964     saved_model_file = os.path.join(
    965         export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
    966     with open(saved_model_file) as f:
    967       original_saved_model = f.read()
    968 
    969     no_attr_saved_model = original_saved_model.replace("""
    970       attr {
    971         key: "T"
    972         value {
    973           type: DT_FLOAT
    974         }
    975       }""", "")
    976     with open(saved_model_file, "w") as f:
    977       f.write(no_attr_saved_model)
    978 
    979     # Loading the SavedModel via the loader must fail because the SavedModel
    980     # does not have any attr values for the "TestAttr" node, and there is no
    981     # default specified in the TestAttr OpDef.
    982     sess = session.Session(graph=ops.Graph())
    983     if ops._USE_C_API:
    984       error_message = "NodeDef missing attr 'T' from Op<name=TestAttr"
    985     else:
    986       error_message = ("Expected one attr with name .*T(out)?.* in name: "
    987                        "\"test_attr\".*")
    988     with self.assertRaisesRegexp(ValueError, error_message):
    989       loader.load(sess, ["foo"], export_dir)
    990 
    991     # Rewrite the SavedModel to change the type of the T attr in "test_attr"
    992     bad_type_saved_model = original_saved_model.replace("""
    993       attr {
    994         key: "T"
    995         value {
    996           type: DT_FLOAT
    997         }
    998       }""", """
    999       attr {
   1000         key: "T"
   1001         value {
   1002           type: DT_DOUBLE
   1003         }
   1004       }""")
   1005     with open(saved_model_file, "w") as f:
   1006       f.write(bad_type_saved_model)
   1007 
   1008     # Loading the SavedModel via the loader must fail because there is no
   1009     # OpKernel registered to handle T = double.
   1010     sess = session.Session(graph=ops.Graph())
   1011     with self.assertRaisesRegexp(
   1012         errors.InvalidArgumentError,
   1013         ".*No OpKernel was registered to support Op \'TestAttr\' with these "
   1014         "attrs..*"):
   1015       loader.load(sess, ["foo"], export_dir)
   1016 
   1017 
   1018 if __name__ == "__main__":
   1019   test.main()
   1020