Home | History | Annotate | Download | only in session_bundle
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for bundle_shim.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os.path
     22 
     23 from tensorflow.contrib.session_bundle import bundle_shim
     24 from tensorflow.contrib.session_bundle import constants
     25 from tensorflow.contrib.session_bundle import manifest_pb2
     26 from tensorflow.core.protobuf import config_pb2
     27 from tensorflow.core.protobuf import meta_graph_pb2
     28 from tensorflow.python.framework import meta_graph
     29 from tensorflow.python.framework import ops
     30 import tensorflow.python.ops.parsing_ops  # pylint: disable=unused-import
     31 from tensorflow.python.platform import test
     32 from tensorflow.python.saved_model import signature_constants
     33 from tensorflow.python.saved_model import tag_constants
     34 from tensorflow.python.util import compat
     35 
     36 SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123")
     37 SESSION_BUNDLE_PATH = "contrib/session_bundle/testdata/half_plus_two/00000123"
     38 
     39 
     40 class BundleShimTest(test.TestCase):
     41 
     42   def testBadPath(self):
     43     base_path = test.test_src_dir_path("/no/such/a/dir")
     44     ops.reset_default_graph()
     45     with self.assertRaises(RuntimeError):
     46       _, _ = bundle_shim.load_session_bundle_or_saved_model_bundle_from_path(
     47           base_path)
     48 
     49   def testAddInputToSignatureDef(self):
     50     signature_def = meta_graph_pb2.SignatureDef()
     51     signature_def_compare = meta_graph_pb2.SignatureDef()
     52 
     53     # Add input to signature-def corresponding to `foo_key`.
     54     bundle_shim._add_input_to_signature_def("foo-name", "foo-key",
     55                                             signature_def)
     56     self.assertEqual(len(signature_def.inputs), 1)
     57     self.assertEqual(len(signature_def.outputs), 0)
     58     self.assertProtoEquals(
     59         signature_def.inputs["foo-key"],
     60         meta_graph_pb2.TensorInfo(name="foo-name"))
     61 
     62     # Attempt to add another input to the signature-def with the same tensor
     63     # name and key.
     64     bundle_shim._add_input_to_signature_def("foo-name", "foo-key",
     65                                             signature_def)
     66     self.assertEqual(len(signature_def.inputs), 1)
     67     self.assertEqual(len(signature_def.outputs), 0)
     68     self.assertProtoEquals(
     69         signature_def.inputs["foo-key"],
     70         meta_graph_pb2.TensorInfo(name="foo-name"))
     71 
     72     # Add another input to the signature-def corresponding to `bar-key`.
     73     bundle_shim._add_input_to_signature_def("bar-name", "bar-key",
     74                                             signature_def)
     75     self.assertEqual(len(signature_def.inputs), 2)
     76     self.assertEqual(len(signature_def.outputs), 0)
     77     self.assertProtoEquals(
     78         signature_def.inputs["bar-key"],
     79         meta_graph_pb2.TensorInfo(name="bar-name"))
     80 
     81     # Add an input to the signature-def corresponding to `foo-key` with an
     82     # updated tensor name.
     83     bundle_shim._add_input_to_signature_def("bar-name", "foo-key",
     84                                             signature_def)
     85     self.assertEqual(len(signature_def.inputs), 2)
     86     self.assertEqual(len(signature_def.outputs), 0)
     87     self.assertProtoEquals(
     88         signature_def.inputs["foo-key"],
     89         meta_graph_pb2.TensorInfo(name="bar-name"))
     90 
     91     # Test that there are no other side-effects.
     92     del signature_def.inputs["foo-key"]
     93     del signature_def.inputs["bar-key"]
     94     self.assertProtoEquals(signature_def, signature_def_compare)
     95 
     96   def testAddOutputToSignatureDef(self):
     97     signature_def = meta_graph_pb2.SignatureDef()
     98     signature_def_compare = meta_graph_pb2.SignatureDef()
     99 
    100     # Add output to signature-def corresponding to `foo_key`.
    101     bundle_shim._add_output_to_signature_def("foo-name", "foo-key",
    102                                              signature_def)
    103     self.assertEqual(len(signature_def.outputs), 1)
    104     self.assertEqual(len(signature_def.inputs), 0)
    105     self.assertProtoEquals(
    106         signature_def.outputs["foo-key"],
    107         meta_graph_pb2.TensorInfo(name="foo-name"))
    108 
    109     # Attempt to add another output to the signature-def with the same tensor
    110     # name and key.
    111     bundle_shim._add_output_to_signature_def("foo-name", "foo-key",
    112                                              signature_def)
    113     self.assertEqual(len(signature_def.outputs), 1)
    114     self.assertEqual(len(signature_def.inputs), 0)
    115     self.assertProtoEquals(
    116         signature_def.outputs["foo-key"],
    117         meta_graph_pb2.TensorInfo(name="foo-name"))
    118 
    119     # Add another output to the signature-def corresponding to `bar-key`.
    120     bundle_shim._add_output_to_signature_def("bar-name", "bar-key",
    121                                              signature_def)
    122     self.assertEqual(len(signature_def.outputs), 2)
    123     self.assertEqual(len(signature_def.inputs), 0)
    124     self.assertProtoEquals(
    125         signature_def.outputs["bar-key"],
    126         meta_graph_pb2.TensorInfo(name="bar-name"))
    127 
    128     # Add an output to the signature-def corresponding to `foo-key` with an
    129     # updated tensor name.
    130     bundle_shim._add_output_to_signature_def("bar-name", "foo-key",
    131                                              signature_def)
    132     self.assertEqual(len(signature_def.outputs), 2)
    133     self.assertEqual(len(signature_def.inputs), 0)
    134     self.assertProtoEquals(
    135         signature_def.outputs["foo-key"],
    136         meta_graph_pb2.TensorInfo(name="bar-name"))
    137 
    138     # Test that there are no other sideeffects.
    139     del signature_def.outputs["foo-key"]
    140     del signature_def.outputs["bar-key"]
    141     self.assertProtoEquals(signature_def, signature_def_compare)
    142 
    143   def testConvertDefaultSignatureGenericToSignatureDef(self):
    144     signatures_proto = manifest_pb2.Signatures()
    145     generic_signature = manifest_pb2.GenericSignature()
    146     signatures_proto.default_signature.generic_signature.CopyFrom(
    147         generic_signature)
    148     signature_def = bundle_shim._convert_default_signature_to_signature_def(
    149         signatures_proto)
    150     self.assertEquals(signature_def, None)
    151 
    152   def testConvertDefaultSignatureRegressionToSignatureDef(self):
    153     signatures_proto = manifest_pb2.Signatures()
    154     regression_signature = manifest_pb2.RegressionSignature()
    155     regression_signature.input.CopyFrom(
    156         manifest_pb2.TensorBinding(
    157             tensor_name=signature_constants.REGRESS_INPUTS))
    158     regression_signature.output.CopyFrom(
    159         manifest_pb2.TensorBinding(
    160             tensor_name=signature_constants.REGRESS_OUTPUTS))
    161     signatures_proto.default_signature.regression_signature.CopyFrom(
    162         regression_signature)
    163     signature_def = bundle_shim._convert_default_signature_to_signature_def(
    164         signatures_proto)
    165 
    166     # Validate regression signature correctly copied over.
    167     self.assertEqual(signature_def.method_name,
    168                      signature_constants.REGRESS_METHOD_NAME)
    169     self.assertEqual(len(signature_def.inputs), 1)
    170     self.assertEqual(len(signature_def.outputs), 1)
    171     self.assertProtoEquals(
    172         signature_def.inputs[signature_constants.REGRESS_INPUTS],
    173         meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_INPUTS))
    174     self.assertProtoEquals(
    175         signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
    176         meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_OUTPUTS))
    177 
    178   def testConvertDefaultSignatureClassificationToSignatureDef(self):
    179     signatures_proto = manifest_pb2.Signatures()
    180     classification_signature = manifest_pb2.ClassificationSignature()
    181     classification_signature.input.CopyFrom(
    182         manifest_pb2.TensorBinding(
    183             tensor_name=signature_constants.CLASSIFY_INPUTS))
    184     classification_signature.classes.CopyFrom(
    185         manifest_pb2.TensorBinding(
    186             tensor_name=signature_constants.CLASSIFY_OUTPUT_CLASSES))
    187     classification_signature.scores.CopyFrom(
    188         manifest_pb2.TensorBinding(
    189             tensor_name=signature_constants.CLASSIFY_OUTPUT_SCORES))
    190     signatures_proto.default_signature.classification_signature.CopyFrom(
    191         classification_signature)
    192 
    193     signatures_proto.default_signature.classification_signature.CopyFrom(
    194         classification_signature)
    195     signature_def = bundle_shim._convert_default_signature_to_signature_def(
    196         signatures_proto)
    197 
    198     # Validate classification signature correctly copied over.
    199     self.assertEqual(signature_def.method_name,
    200                      signature_constants.CLASSIFY_METHOD_NAME)
    201     self.assertEqual(len(signature_def.inputs), 1)
    202     self.assertEqual(len(signature_def.outputs), 2)
    203     self.assertProtoEquals(
    204         signature_def.inputs[signature_constants.CLASSIFY_INPUTS],
    205         meta_graph_pb2.TensorInfo(name=signature_constants.CLASSIFY_INPUTS))
    206     self.assertProtoEquals(
    207         signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES],
    208         meta_graph_pb2.TensorInfo(
    209             name=signature_constants.CLASSIFY_OUTPUT_SCORES))
    210     self.assertProtoEquals(
    211         signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES],
    212         meta_graph_pb2.TensorInfo(
    213             name=signature_constants.CLASSIFY_OUTPUT_CLASSES))
    214 
    215   def testConvertNamedSignatureNonGenericToSignatureDef(self):
    216     signatures_proto = manifest_pb2.Signatures()
    217     regression_signature = manifest_pb2.RegressionSignature()
    218     signatures_proto.named_signatures[
    219         signature_constants.PREDICT_INPUTS].regression_signature.CopyFrom(
    220             regression_signature)
    221     with self.assertRaises(RuntimeError):
    222       _ = bundle_shim._convert_named_signatures_to_signature_def(
    223           signatures_proto)
    224     signatures_proto = manifest_pb2.Signatures()
    225     classification_signature = manifest_pb2.ClassificationSignature()
    226     signatures_proto.named_signatures[
    227         signature_constants.PREDICT_INPUTS].classification_signature.CopyFrom(
    228             classification_signature)
    229     with self.assertRaises(RuntimeError):
    230       _ = bundle_shim._convert_named_signatures_to_signature_def(
    231           signatures_proto)
    232 
    233   def testConvertNamedSignatureToSignatureDef(self):
    234     signatures_proto = manifest_pb2.Signatures()
    235     generic_signature = manifest_pb2.GenericSignature()
    236     generic_signature.map["input_key"].CopyFrom(
    237         manifest_pb2.TensorBinding(tensor_name="input"))
    238     signatures_proto.named_signatures[
    239         signature_constants.PREDICT_INPUTS].generic_signature.CopyFrom(
    240             generic_signature)
    241 
    242     generic_signature = manifest_pb2.GenericSignature()
    243     generic_signature.map["output_key"].CopyFrom(
    244         manifest_pb2.TensorBinding(tensor_name="output"))
    245     signatures_proto.named_signatures[
    246         signature_constants.PREDICT_OUTPUTS].generic_signature.CopyFrom(
    247             generic_signature)
    248     signature_def = bundle_shim._convert_named_signatures_to_signature_def(
    249         signatures_proto)
    250     self.assertEqual(signature_def.method_name,
    251                      signature_constants.PREDICT_METHOD_NAME)
    252     self.assertEqual(len(signature_def.inputs), 1)
    253     self.assertEqual(len(signature_def.outputs), 1)
    254     self.assertProtoEquals(
    255         signature_def.inputs["input_key"],
    256         meta_graph_pb2.TensorInfo(name="input"))
    257     self.assertProtoEquals(
    258         signature_def.outputs["output_key"],
    259         meta_graph_pb2.TensorInfo(name="output"))
    260 
    261   def testConvertSignaturesToSignatureDefs(self):
    262     base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH)
    263     meta_graph_filename = os.path.join(base_path,
    264                                        constants.META_GRAPH_DEF_FILENAME)
    265     metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename)
    266     default_signature_def, named_signature_def = (
    267         bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    268     self.assertEqual(default_signature_def.method_name,
    269                      signature_constants.REGRESS_METHOD_NAME)
    270     self.assertEqual(len(default_signature_def.inputs), 1)
    271     self.assertEqual(len(default_signature_def.outputs), 1)
    272     self.assertProtoEquals(
    273         default_signature_def.inputs[signature_constants.REGRESS_INPUTS],
    274         meta_graph_pb2.TensorInfo(name="tf_example:0"))
    275     self.assertProtoEquals(
    276         default_signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
    277         meta_graph_pb2.TensorInfo(name="Identity:0"))
    278     self.assertEqual(named_signature_def.method_name,
    279                      signature_constants.PREDICT_METHOD_NAME)
    280     self.assertEqual(len(named_signature_def.inputs), 1)
    281     self.assertEqual(len(named_signature_def.outputs), 1)
    282     self.assertProtoEquals(
    283         named_signature_def.inputs["x"], meta_graph_pb2.TensorInfo(name="x:0"))
    284     self.assertProtoEquals(
    285         named_signature_def.outputs["y"], meta_graph_pb2.TensorInfo(name="y:0"))
    286 
    287     # Now try default signature only
    288     collection_def = metagraph_def.collection_def
    289     signatures_proto = manifest_pb2.Signatures()
    290     signatures = collection_def[constants.SIGNATURES_KEY].any_list.value[0]
    291     signatures.Unpack(signatures_proto)
    292     named_only_signatures_proto = manifest_pb2.Signatures()
    293     named_only_signatures_proto.CopyFrom(signatures_proto)
    294 
    295     default_only_signatures_proto = manifest_pb2.Signatures()
    296     default_only_signatures_proto.CopyFrom(signatures_proto)
    297     default_only_signatures_proto.named_signatures.clear()
    298     default_only_signatures_proto.ClearField("named_signatures")
    299     metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
    300         0].Pack(default_only_signatures_proto)
    301     default_signature_def, named_signature_def = (
    302         bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    303     self.assertEqual(default_signature_def.method_name,
    304                      signature_constants.REGRESS_METHOD_NAME)
    305     self.assertEqual(named_signature_def, None)
    306 
    307     named_only_signatures_proto.ClearField("default_signature")
    308     metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
    309         0].Pack(named_only_signatures_proto)
    310     default_signature_def, named_signature_def = (
    311         bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    312     self.assertEqual(named_signature_def.method_name,
    313                      signature_constants.PREDICT_METHOD_NAME)
    314     self.assertEqual(default_signature_def, None)
    315 
    316   def testLegacyBasic(self):
    317     base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH)
    318     ops.reset_default_graph()
    319     sess, meta_graph_def = (
    320         bundle_shim.load_session_bundle_or_saved_model_bundle_from_path(
    321             base_path,
    322             tags=[""],
    323             target="",
    324             config=config_pb2.ConfigProto(device_count={"CPU": 2})))
    325 
    326     self.assertTrue(sess)
    327     asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY)
    328     with sess.as_default():
    329       path1, path2 = sess.run(["filename1:0", "filename2:0"])
    330       self.assertEqual(
    331           compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1)
    332       self.assertEqual(
    333           compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2)
    334 
    335       collection_def = meta_graph_def.collection_def
    336 
    337       signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value
    338       self.assertEqual(len(signatures_any), 1)
    339 
    340   def testSavedModelBasic(self):
    341     base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
    342     ops.reset_default_graph()
    343     sess, meta_graph_def = (
    344         bundle_shim.load_session_bundle_or_saved_model_bundle_from_path(
    345             base_path,
    346             tags=[tag_constants.SERVING],
    347             target="",
    348             config=config_pb2.ConfigProto(device_count={"CPU": 2})))
    349 
    350     self.assertTrue(sess)
    351 
    352     # Check basic signature def property.
    353     signature_def = meta_graph_def.signature_def
    354     self.assertEqual(signature_def["regress_x_to_y"].method_name,
    355                      signature_constants.REGRESS_METHOD_NAME)
    356     with sess.as_default():
    357       output1 = sess.run(["filename_tensor:0"])
    358       self.assertEqual([compat.as_bytes("foo.txt")], output1)
    359 
    360 
    361 if __name__ == "__main__":
    362   test.main()
    363