1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for bundle_shim.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os.path 22 23 from tensorflow.contrib.session_bundle import bundle_shim 24 from tensorflow.contrib.session_bundle import constants 25 from tensorflow.contrib.session_bundle import manifest_pb2 26 from tensorflow.core.protobuf import config_pb2 27 from tensorflow.core.protobuf import meta_graph_pb2 28 from tensorflow.python.framework import meta_graph 29 from tensorflow.python.framework import ops 30 import tensorflow.python.ops.parsing_ops # pylint: disable=unused-import 31 from tensorflow.python.platform import test 32 from tensorflow.python.saved_model import signature_constants 33 from tensorflow.python.saved_model import tag_constants 34 from tensorflow.python.util import compat 35 36 SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123") 37 SESSION_BUNDLE_PATH = "contrib/session_bundle/testdata/half_plus_two/00000123" 38 39 40 class BundleShimTest(test.TestCase): 41 42 def testBadPath(self): 43 base_path = test.test_src_dir_path("/no/such/a/dir") 44 ops.reset_default_graph() 45 with self.assertRaises(RuntimeError): 46 _, _ = bundle_shim.load_session_bundle_or_saved_model_bundle_from_path( 47 base_path) 48 49 def testAddInputToSignatureDef(self): 50 signature_def = meta_graph_pb2.SignatureDef() 51 signature_def_compare = meta_graph_pb2.SignatureDef() 52 53 # Add input to signature-def corresponding to `foo_key`. 54 bundle_shim._add_input_to_signature_def("foo-name", "foo-key", 55 signature_def) 56 self.assertEqual(len(signature_def.inputs), 1) 57 self.assertEqual(len(signature_def.outputs), 0) 58 self.assertProtoEquals( 59 signature_def.inputs["foo-key"], 60 meta_graph_pb2.TensorInfo(name="foo-name")) 61 62 # Attempt to add another input to the signature-def with the same tensor 63 # name and key. 64 bundle_shim._add_input_to_signature_def("foo-name", "foo-key", 65 signature_def) 66 self.assertEqual(len(signature_def.inputs), 1) 67 self.assertEqual(len(signature_def.outputs), 0) 68 self.assertProtoEquals( 69 signature_def.inputs["foo-key"], 70 meta_graph_pb2.TensorInfo(name="foo-name")) 71 72 # Add another input to the signature-def corresponding to `bar-key`. 73 bundle_shim._add_input_to_signature_def("bar-name", "bar-key", 74 signature_def) 75 self.assertEqual(len(signature_def.inputs), 2) 76 self.assertEqual(len(signature_def.outputs), 0) 77 self.assertProtoEquals( 78 signature_def.inputs["bar-key"], 79 meta_graph_pb2.TensorInfo(name="bar-name")) 80 81 # Add an input to the signature-def corresponding to `foo-key` with an 82 # updated tensor name. 83 bundle_shim._add_input_to_signature_def("bar-name", "foo-key", 84 signature_def) 85 self.assertEqual(len(signature_def.inputs), 2) 86 self.assertEqual(len(signature_def.outputs), 0) 87 self.assertProtoEquals( 88 signature_def.inputs["foo-key"], 89 meta_graph_pb2.TensorInfo(name="bar-name")) 90 91 # Test that there are no other side-effects. 92 del signature_def.inputs["foo-key"] 93 del signature_def.inputs["bar-key"] 94 self.assertProtoEquals(signature_def, signature_def_compare) 95 96 def testAddOutputToSignatureDef(self): 97 signature_def = meta_graph_pb2.SignatureDef() 98 signature_def_compare = meta_graph_pb2.SignatureDef() 99 100 # Add output to signature-def corresponding to `foo_key`. 101 bundle_shim._add_output_to_signature_def("foo-name", "foo-key", 102 signature_def) 103 self.assertEqual(len(signature_def.outputs), 1) 104 self.assertEqual(len(signature_def.inputs), 0) 105 self.assertProtoEquals( 106 signature_def.outputs["foo-key"], 107 meta_graph_pb2.TensorInfo(name="foo-name")) 108 109 # Attempt to add another output to the signature-def with the same tensor 110 # name and key. 111 bundle_shim._add_output_to_signature_def("foo-name", "foo-key", 112 signature_def) 113 self.assertEqual(len(signature_def.outputs), 1) 114 self.assertEqual(len(signature_def.inputs), 0) 115 self.assertProtoEquals( 116 signature_def.outputs["foo-key"], 117 meta_graph_pb2.TensorInfo(name="foo-name")) 118 119 # Add another output to the signature-def corresponding to `bar-key`. 120 bundle_shim._add_output_to_signature_def("bar-name", "bar-key", 121 signature_def) 122 self.assertEqual(len(signature_def.outputs), 2) 123 self.assertEqual(len(signature_def.inputs), 0) 124 self.assertProtoEquals( 125 signature_def.outputs["bar-key"], 126 meta_graph_pb2.TensorInfo(name="bar-name")) 127 128 # Add an output to the signature-def corresponding to `foo-key` with an 129 # updated tensor name. 130 bundle_shim._add_output_to_signature_def("bar-name", "foo-key", 131 signature_def) 132 self.assertEqual(len(signature_def.outputs), 2) 133 self.assertEqual(len(signature_def.inputs), 0) 134 self.assertProtoEquals( 135 signature_def.outputs["foo-key"], 136 meta_graph_pb2.TensorInfo(name="bar-name")) 137 138 # Test that there are no other sideeffects. 139 del signature_def.outputs["foo-key"] 140 del signature_def.outputs["bar-key"] 141 self.assertProtoEquals(signature_def, signature_def_compare) 142 143 def testConvertDefaultSignatureGenericToSignatureDef(self): 144 signatures_proto = manifest_pb2.Signatures() 145 generic_signature = manifest_pb2.GenericSignature() 146 signatures_proto.default_signature.generic_signature.CopyFrom( 147 generic_signature) 148 signature_def = bundle_shim._convert_default_signature_to_signature_def( 149 signatures_proto) 150 self.assertEquals(signature_def, None) 151 152 def testConvertDefaultSignatureRegressionToSignatureDef(self): 153 signatures_proto = manifest_pb2.Signatures() 154 regression_signature = manifest_pb2.RegressionSignature() 155 regression_signature.input.CopyFrom( 156 manifest_pb2.TensorBinding( 157 tensor_name=signature_constants.REGRESS_INPUTS)) 158 regression_signature.output.CopyFrom( 159 manifest_pb2.TensorBinding( 160 tensor_name=signature_constants.REGRESS_OUTPUTS)) 161 signatures_proto.default_signature.regression_signature.CopyFrom( 162 regression_signature) 163 signature_def = bundle_shim._convert_default_signature_to_signature_def( 164 signatures_proto) 165 166 # Validate regression signature correctly copied over. 167 self.assertEqual(signature_def.method_name, 168 signature_constants.REGRESS_METHOD_NAME) 169 self.assertEqual(len(signature_def.inputs), 1) 170 self.assertEqual(len(signature_def.outputs), 1) 171 self.assertProtoEquals( 172 signature_def.inputs[signature_constants.REGRESS_INPUTS], 173 meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_INPUTS)) 174 self.assertProtoEquals( 175 signature_def.outputs[signature_constants.REGRESS_OUTPUTS], 176 meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_OUTPUTS)) 177 178 def testConvertDefaultSignatureClassificationToSignatureDef(self): 179 signatures_proto = manifest_pb2.Signatures() 180 classification_signature = manifest_pb2.ClassificationSignature() 181 classification_signature.input.CopyFrom( 182 manifest_pb2.TensorBinding( 183 tensor_name=signature_constants.CLASSIFY_INPUTS)) 184 classification_signature.classes.CopyFrom( 185 manifest_pb2.TensorBinding( 186 tensor_name=signature_constants.CLASSIFY_OUTPUT_CLASSES)) 187 classification_signature.scores.CopyFrom( 188 manifest_pb2.TensorBinding( 189 tensor_name=signature_constants.CLASSIFY_OUTPUT_SCORES)) 190 signatures_proto.default_signature.classification_signature.CopyFrom( 191 classification_signature) 192 193 signatures_proto.default_signature.classification_signature.CopyFrom( 194 classification_signature) 195 signature_def = bundle_shim._convert_default_signature_to_signature_def( 196 signatures_proto) 197 198 # Validate classification signature correctly copied over. 199 self.assertEqual(signature_def.method_name, 200 signature_constants.CLASSIFY_METHOD_NAME) 201 self.assertEqual(len(signature_def.inputs), 1) 202 self.assertEqual(len(signature_def.outputs), 2) 203 self.assertProtoEquals( 204 signature_def.inputs[signature_constants.CLASSIFY_INPUTS], 205 meta_graph_pb2.TensorInfo(name=signature_constants.CLASSIFY_INPUTS)) 206 self.assertProtoEquals( 207 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES], 208 meta_graph_pb2.TensorInfo( 209 name=signature_constants.CLASSIFY_OUTPUT_SCORES)) 210 self.assertProtoEquals( 211 signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES], 212 meta_graph_pb2.TensorInfo( 213 name=signature_constants.CLASSIFY_OUTPUT_CLASSES)) 214 215 def testConvertNamedSignatureNonGenericToSignatureDef(self): 216 signatures_proto = manifest_pb2.Signatures() 217 regression_signature = manifest_pb2.RegressionSignature() 218 signatures_proto.named_signatures[ 219 signature_constants.PREDICT_INPUTS].regression_signature.CopyFrom( 220 regression_signature) 221 with self.assertRaises(RuntimeError): 222 _ = bundle_shim._convert_named_signatures_to_signature_def( 223 signatures_proto) 224 signatures_proto = manifest_pb2.Signatures() 225 classification_signature = manifest_pb2.ClassificationSignature() 226 signatures_proto.named_signatures[ 227 signature_constants.PREDICT_INPUTS].classification_signature.CopyFrom( 228 classification_signature) 229 with self.assertRaises(RuntimeError): 230 _ = bundle_shim._convert_named_signatures_to_signature_def( 231 signatures_proto) 232 233 def testConvertNamedSignatureToSignatureDef(self): 234 signatures_proto = manifest_pb2.Signatures() 235 generic_signature = manifest_pb2.GenericSignature() 236 generic_signature.map["input_key"].CopyFrom( 237 manifest_pb2.TensorBinding(tensor_name="input")) 238 signatures_proto.named_signatures[ 239 signature_constants.PREDICT_INPUTS].generic_signature.CopyFrom( 240 generic_signature) 241 242 generic_signature = manifest_pb2.GenericSignature() 243 generic_signature.map["output_key"].CopyFrom( 244 manifest_pb2.TensorBinding(tensor_name="output")) 245 signatures_proto.named_signatures[ 246 signature_constants.PREDICT_OUTPUTS].generic_signature.CopyFrom( 247 generic_signature) 248 signature_def = bundle_shim._convert_named_signatures_to_signature_def( 249 signatures_proto) 250 self.assertEqual(signature_def.method_name, 251 signature_constants.PREDICT_METHOD_NAME) 252 self.assertEqual(len(signature_def.inputs), 1) 253 self.assertEqual(len(signature_def.outputs), 1) 254 self.assertProtoEquals( 255 signature_def.inputs["input_key"], 256 meta_graph_pb2.TensorInfo(name="input")) 257 self.assertProtoEquals( 258 signature_def.outputs["output_key"], 259 meta_graph_pb2.TensorInfo(name="output")) 260 261 def testConvertSignaturesToSignatureDefs(self): 262 base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH) 263 meta_graph_filename = os.path.join(base_path, 264 constants.META_GRAPH_DEF_FILENAME) 265 metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename) 266 default_signature_def, named_signature_def = ( 267 bundle_shim._convert_signatures_to_signature_defs(metagraph_def)) 268 self.assertEqual(default_signature_def.method_name, 269 signature_constants.REGRESS_METHOD_NAME) 270 self.assertEqual(len(default_signature_def.inputs), 1) 271 self.assertEqual(len(default_signature_def.outputs), 1) 272 self.assertProtoEquals( 273 default_signature_def.inputs[signature_constants.REGRESS_INPUTS], 274 meta_graph_pb2.TensorInfo(name="tf_example:0")) 275 self.assertProtoEquals( 276 default_signature_def.outputs[signature_constants.REGRESS_OUTPUTS], 277 meta_graph_pb2.TensorInfo(name="Identity:0")) 278 self.assertEqual(named_signature_def.method_name, 279 signature_constants.PREDICT_METHOD_NAME) 280 self.assertEqual(len(named_signature_def.inputs), 1) 281 self.assertEqual(len(named_signature_def.outputs), 1) 282 self.assertProtoEquals( 283 named_signature_def.inputs["x"], meta_graph_pb2.TensorInfo(name="x:0")) 284 self.assertProtoEquals( 285 named_signature_def.outputs["y"], meta_graph_pb2.TensorInfo(name="y:0")) 286 287 # Now try default signature only 288 collection_def = metagraph_def.collection_def 289 signatures_proto = manifest_pb2.Signatures() 290 signatures = collection_def[constants.SIGNATURES_KEY].any_list.value[0] 291 signatures.Unpack(signatures_proto) 292 named_only_signatures_proto = manifest_pb2.Signatures() 293 named_only_signatures_proto.CopyFrom(signatures_proto) 294 295 default_only_signatures_proto = manifest_pb2.Signatures() 296 default_only_signatures_proto.CopyFrom(signatures_proto) 297 default_only_signatures_proto.named_signatures.clear() 298 default_only_signatures_proto.ClearField("named_signatures") 299 metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[ 300 0].Pack(default_only_signatures_proto) 301 default_signature_def, named_signature_def = ( 302 bundle_shim._convert_signatures_to_signature_defs(metagraph_def)) 303 self.assertEqual(default_signature_def.method_name, 304 signature_constants.REGRESS_METHOD_NAME) 305 self.assertEqual(named_signature_def, None) 306 307 named_only_signatures_proto.ClearField("default_signature") 308 metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[ 309 0].Pack(named_only_signatures_proto) 310 default_signature_def, named_signature_def = ( 311 bundle_shim._convert_signatures_to_signature_defs(metagraph_def)) 312 self.assertEqual(named_signature_def.method_name, 313 signature_constants.PREDICT_METHOD_NAME) 314 self.assertEqual(default_signature_def, None) 315 316 def testLegacyBasic(self): 317 base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH) 318 ops.reset_default_graph() 319 sess, meta_graph_def = ( 320 bundle_shim.load_session_bundle_or_saved_model_bundle_from_path( 321 base_path, 322 tags=[""], 323 target="", 324 config=config_pb2.ConfigProto(device_count={"CPU": 2}))) 325 326 self.assertTrue(sess) 327 asset_path = os.path.join(base_path, constants.ASSETS_DIRECTORY) 328 with sess.as_default(): 329 path1, path2 = sess.run(["filename1:0", "filename2:0"]) 330 self.assertEqual( 331 compat.as_bytes(os.path.join(asset_path, "hello1.txt")), path1) 332 self.assertEqual( 333 compat.as_bytes(os.path.join(asset_path, "hello2.txt")), path2) 334 335 collection_def = meta_graph_def.collection_def 336 337 signatures_any = collection_def[constants.SIGNATURES_KEY].any_list.value 338 self.assertEqual(len(signatures_any), 1) 339 340 def testSavedModelBasic(self): 341 base_path = test.test_src_dir_path(SAVED_MODEL_PATH) 342 ops.reset_default_graph() 343 sess, meta_graph_def = ( 344 bundle_shim.load_session_bundle_or_saved_model_bundle_from_path( 345 base_path, 346 tags=[tag_constants.SERVING], 347 target="", 348 config=config_pb2.ConfigProto(device_count={"CPU": 2}))) 349 350 self.assertTrue(sess) 351 352 # Check basic signature def property. 353 signature_def = meta_graph_def.signature_def 354 self.assertEqual(signature_def["regress_x_to_y"].method_name, 355 signature_constants.REGRESS_METHOD_NAME) 356 with sess.as_default(): 357 output1 = sess.run(["filename_tensor:0"]) 358 self.assertEqual([compat.as_bytes("foo.txt")], output1) 359 360 361 if __name__ == "__main__": 362 test.main() 363