1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for SavedModel.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 23 from tensorflow.core.framework import types_pb2 24 from tensorflow.core.protobuf import config_pb2 25 from tensorflow.core.protobuf import meta_graph_pb2 26 from tensorflow.python.client import session 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import errors 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import test_ops 32 from tensorflow.python.framework import test_util 33 from tensorflow.python.lib.io import file_io 34 from tensorflow.python.ops import control_flow_ops 35 from tensorflow.python.ops import math_ops 36 from tensorflow.python.ops import state_ops 37 from tensorflow.python.ops import variables 38 from tensorflow.python.platform import test 39 from tensorflow.python.saved_model import builder as saved_model_builder 40 from tensorflow.python.saved_model import constants 41 from tensorflow.python.saved_model import loader 42 from tensorflow.python.saved_model import loader_impl 43 from tensorflow.python.saved_model import main_op 44 from tensorflow.python.saved_model import signature_def_utils 45 from tensorflow.python.saved_model import tag_constants 46 from tensorflow.python.training import saver_test_utils 47 from tensorflow.python.util import compat 48 49 SAVED_MODEL_PATH = ("cc/saved_model/testdata/half_plus_two/00000123") 50 51 52 def tearDownModule(): 53 file_io.delete_recursively(test.get_temp_dir()) 54 55 56 @test_util.with_c_api 57 class SavedModelTest(test.TestCase): 58 59 def _get_export_dir(self, label): 60 if ops._USE_C_API: 61 label += "_c_api" 62 return os.path.join(test.get_temp_dir(), label) 63 64 def _init_and_validate_variable(self, sess, variable_name, variable_value): 65 v = variables.Variable(variable_value, name=variable_name) 66 sess.run(variables.global_variables_initializer()) 67 self.assertEqual(variable_value, v.eval()) 68 69 def _build_asset_collection(self, asset_file_name, asset_file_contents, 70 asset_file_tensor_name): 71 asset_filepath = os.path.join( 72 compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name)) 73 file_io.write_string_to_file(asset_filepath, asset_file_contents) 74 asset_file_tensor = constant_op.constant( 75 asset_filepath, name=asset_file_tensor_name) 76 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, asset_file_tensor) 77 asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) 78 return asset_collection 79 80 def _validate_asset_collection(self, export_dir, graph_collection_def, 81 expected_asset_file_name, 82 expected_asset_file_contents, 83 expected_asset_tensor_name): 84 assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value 85 asset = meta_graph_pb2.AssetFileDef() 86 assets_any[0].Unpack(asset) 87 assets_path = os.path.join( 88 compat.as_bytes(export_dir), 89 compat.as_bytes(constants.ASSETS_DIRECTORY), 90 compat.as_bytes(expected_asset_file_name)) 91 actual_asset_contents = file_io.read_file_to_string(assets_path) 92 self.assertEqual(expected_asset_file_contents, 93 compat.as_text(actual_asset_contents)) 94 self.assertEqual(expected_asset_file_name, asset.filename) 95 self.assertEqual(expected_asset_tensor_name, asset.tensor_info.name) 96 97 def _validate_inputs_tensor_info(self, builder, tensor_info): 98 with self.test_session(graph=ops.Graph()) as sess: 99 self._init_and_validate_variable(sess, "v", 42) 100 101 foo_signature = signature_def_utils.build_signature_def({ 102 "foo_inputs": tensor_info 103 }, dict(), "foo") 104 self.assertRaises( 105 AssertionError, 106 builder.add_meta_graph_and_variables, 107 sess, ["foo"], 108 signature_def_map={"foo_key": foo_signature}) 109 110 def _validate_outputs_tensor_info(self, builder, tensor_info): 111 with self.test_session(graph=ops.Graph()) as sess: 112 self._init_and_validate_variable(sess, "v", 42) 113 114 foo_signature = signature_def_utils.build_signature_def( 115 dict(), {"foo_outputs": tensor_info}, "foo") 116 self.assertRaises( 117 AssertionError, 118 builder.add_meta_graph_and_variables, 119 sess, ["foo"], 120 signature_def_map={"foo_key": foo_signature}) 121 122 def testMaybeSavedModelDir(self): 123 base_path = test.test_src_dir_path("/python/saved_model") 124 self.assertFalse(loader.maybe_saved_model_directory(base_path)) 125 base_path = test.test_src_dir_path(SAVED_MODEL_PATH) 126 self.assertTrue(loader.maybe_saved_model_directory(base_path)) 127 base_path = "complete_garbage" 128 self.assertFalse(loader.maybe_saved_model_directory(base_path)) 129 130 def testBadSavedModelFileFormat(self): 131 export_dir = self._get_export_dir("test_bad_saved_model_file_format") 132 # Attempt to load a SavedModel from an export directory that does not exist. 133 with self.test_session(graph=ops.Graph()) as sess: 134 with self.assertRaisesRegexp(IOError, 135 "SavedModel file does not exist at: %s" % 136 export_dir): 137 loader.load(sess, ["foo"], export_dir) 138 139 os.makedirs(export_dir) 140 # Write an invalid binary proto to saved_model.pb. 141 path_to_pb = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB) 142 with open(path_to_pb, "w") as f: 143 f.write("invalid content") 144 with self.test_session(graph=ops.Graph()) as sess: 145 with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" % 146 constants.SAVED_MODEL_FILENAME_PB): 147 loader.load(sess, ["foo"], export_dir) 148 149 # Cleanup the directory and start again. 150 file_io.delete_recursively(export_dir) 151 152 os.makedirs(export_dir) 153 # Write an invalid text proto to saved_model.pbtxt 154 path_to_pbtxt = os.path.join(export_dir, 155 constants.SAVED_MODEL_FILENAME_PBTXT) 156 with open(path_to_pbtxt, "w") as f: 157 f.write("invalid content") 158 with self.test_session(graph=ops.Graph()) as sess: 159 with self.assertRaisesRegexp(IOError, "Cannot parse file.*%s" % 160 constants.SAVED_MODEL_FILENAME_PBTXT): 161 loader.load(sess, ["foo"], export_dir) 162 163 def testVerifySessionGraphUsage(self): 164 export_dir = self._get_export_dir("test_verify_session_graph_usage") 165 builder = saved_model_builder.SavedModelBuilder(export_dir) 166 167 with self.test_session(graph=ops.Graph()) as sess: 168 self._init_and_validate_variable(sess, "v", 42) 169 builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) 170 171 # Save the SavedModel to disk. 172 builder.save() 173 174 # Build a session and supply it to the load operation. 175 sess = session.Session(graph=ops.Graph()) 176 loader.load(sess, [tag_constants.TRAINING], export_dir) 177 178 # Check the variable within the scope of the session and its graph. 179 with sess: 180 self.assertEqual( 181 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 182 183 def testSequence(self): 184 export_dir = self._get_export_dir("test_sequence") 185 builder = saved_model_builder.SavedModelBuilder(export_dir) 186 187 # Expect an assertion error since add_meta_graph_and_variables() should be 188 # invoked before any add_meta_graph() calls. 189 with self.test_session(graph=ops.Graph()) as sess: 190 self.assertRaises(AssertionError, builder.add_meta_graph, ["foo"]) 191 192 # Expect an assertion error for multiple calls of 193 # add_meta_graph_and_variables() since weights should be saved exactly once. 194 with self.test_session(graph=ops.Graph()) as sess: 195 self._init_and_validate_variable(sess, "v", 42) 196 builder.add_meta_graph_and_variables(sess, ["bar"]) 197 self.assertRaises(AssertionError, builder.add_meta_graph_and_variables, 198 sess, ["baz"]) 199 200 def testTags(self): 201 export_dir = self._get_export_dir("test_tags") 202 builder = saved_model_builder.SavedModelBuilder(export_dir) 203 204 # Graph with a single variable. SavedModel invoked to: 205 # - add with weights. 206 # - a single tag (from predefined constants). 207 with self.test_session(graph=ops.Graph()) as sess: 208 self._init_and_validate_variable(sess, "v", 42) 209 builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) 210 211 # Graph that updates the single variable. SavedModel invoked to: 212 # - simply add the model (weights are not updated). 213 # - a single tag (from predefined constants). 214 with self.test_session(graph=ops.Graph()) as sess: 215 self._init_and_validate_variable(sess, "v", 43) 216 builder.add_meta_graph([tag_constants.SERVING]) 217 218 # Graph that updates the single variable. SavedModel invoked to: 219 # - simply add the model (weights are not updated). 220 # - multiple tags (from predefined constants). 221 with self.test_session(graph=ops.Graph()) as sess: 222 self._init_and_validate_variable(sess, "v", 45) 223 builder.add_meta_graph([tag_constants.SERVING, tag_constants.GPU]) 224 225 # Graph that updates the single variable. SavedModel invoked to: 226 # - simply add the model (weights are not updated). 227 # - multiple tags (from predefined constants for serving on TPU). 228 with self.test_session(graph=ops.Graph()) as sess: 229 self._init_and_validate_variable(sess, "v", 45) 230 builder.add_meta_graph([tag_constants.SERVING, tag_constants.TPU]) 231 232 # Graph that updates the single variable. SavedModel is invoked: 233 # - to add the model (weights are not updated). 234 # - multiple custom tags. 235 with self.test_session(graph=ops.Graph()) as sess: 236 self._init_and_validate_variable(sess, "v", 44) 237 builder.add_meta_graph(["foo", "bar"]) 238 239 # Save the SavedModel to disk. 240 builder.save() 241 242 # Restore the graph with a single predefined tag whose variables were saved. 243 with self.test_session(graph=ops.Graph()) as sess: 244 loader.load(sess, [tag_constants.TRAINING], export_dir) 245 self.assertEqual( 246 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 247 248 # Restore the graph with a single predefined tag whose variables were not 249 # saved. 250 with self.test_session(graph=ops.Graph()) as sess: 251 loader.load(sess, [tag_constants.SERVING], export_dir) 252 self.assertEqual( 253 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 254 255 # Restore the graph with multiple predefined tags whose variables were not 256 # saved. 257 with self.test_session(graph=ops.Graph()) as sess: 258 loader.load(sess, [tag_constants.SERVING, tag_constants.GPU], export_dir) 259 self.assertEqual( 260 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 261 262 # Restore the graph with multiple predefined tags (for serving on TPU) 263 # whose variables were not saved. 264 with self.test_session(graph=ops.Graph()) as sess: 265 loader.load(sess, [tag_constants.SERVING, tag_constants.TPU], export_dir) 266 self.assertEqual( 267 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 268 269 # Restore the graph with multiple tags. Provide duplicate tags to test set 270 # semantics. 271 with self.test_session(graph=ops.Graph()) as sess: 272 loader.load(sess, ["foo", "bar", "foo"], export_dir) 273 self.assertEqual( 274 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 275 276 # Try restoring a graph with a non-existent tag. This should yield a runtime 277 # error. 278 with self.test_session(graph=ops.Graph()) as sess: 279 self.assertRaises(RuntimeError, loader.load, sess, ["INVALID"], 280 export_dir) 281 282 # Try restoring a graph where a subset of the tags match. Since tag matching 283 # for meta graph defs follows "all" semantics, this should yield a runtime 284 # error. 285 with self.test_session(graph=ops.Graph()) as sess: 286 self.assertRaises(RuntimeError, loader.load, sess, ["foo", "baz"], 287 export_dir) 288 289 def testVariables(self): 290 export_dir = self._get_export_dir("test_variables") 291 builder = saved_model_builder.SavedModelBuilder(export_dir) 292 293 # Graph with two variables. SavedModel invoked to: 294 # - add with weights. 295 with self.test_session(graph=ops.Graph()) as sess: 296 self._init_and_validate_variable(sess, "v1", 1) 297 self._init_and_validate_variable(sess, "v2", 2) 298 builder.add_meta_graph_and_variables(sess, ["foo"]) 299 300 # Graph with a single variable (subset of the variables from the previous 301 # graph whose weights were saved). SavedModel invoked to: 302 # - simply add the model (weights are not updated). 303 with self.test_session(graph=ops.Graph()) as sess: 304 self._init_and_validate_variable(sess, "v2", 3) 305 builder.add_meta_graph(["bar"]) 306 307 # Graph with a single variable (disjoint set of variables from the previous 308 # graph whose weights were saved). SavedModel invoked to: 309 # - simply add the model (weights are not updated). 310 with self.test_session(graph=ops.Graph()) as sess: 311 self._init_and_validate_variable(sess, "v3", 4) 312 builder.add_meta_graph(["baz"]) 313 314 # Save the SavedModel to disk. 315 builder.save() 316 317 # Restore the graph with tag "foo", whose variables were saved. 318 with self.test_session(graph=ops.Graph()) as sess: 319 loader.load(sess, ["foo"], export_dir) 320 collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 321 self.assertEqual(len(collection_vars), 2) 322 self.assertEqual(1, collection_vars[0].eval()) 323 self.assertEqual(2, collection_vars[1].eval()) 324 325 # Restore the graph with tag "bar", whose variables were not saved. Only the 326 # subset of the variables added to the graph will be restored with the 327 # checkpointed value. 328 with self.test_session(graph=ops.Graph()) as sess: 329 loader.load(sess, ["bar"], export_dir) 330 collection_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 331 self.assertEqual(len(collection_vars), 1) 332 self.assertEqual(2, collection_vars[0].eval()) 333 334 # Try restoring the graph with tag "baz", whose variables were not saved. 335 # Since this graph has a disjoint set of variables from the set that was 336 # saved, this should raise an error. 337 with self.test_session(graph=ops.Graph()) as sess: 338 self.assertRaises(errors.NotFoundError, loader.load, sess, ["baz"], 339 export_dir) 340 341 def testGraphWithoutVariables(self): 342 export_dir = self._get_export_dir("test_graph_has_variables") 343 builder = saved_model_builder.SavedModelBuilder(export_dir) 344 345 # Graph with no variables. 346 with self.test_session(graph=ops.Graph()) as sess: 347 constant_5_name = constant_op.constant(5.0).name 348 builder.add_meta_graph_and_variables(sess, ["foo"]) 349 350 # Second graph with no variables 351 with self.test_session(graph=ops.Graph()) as sess: 352 constant_6_name = constant_op.constant(6.0).name 353 builder.add_meta_graph(["bar"]) 354 355 # Save the SavedModel to disk. 356 builder.save() 357 358 # Restore the graph with tag "foo". 359 with self.test_session(graph=ops.Graph()) as sess: 360 loader.load(sess, ["foo"], export_dir) 361 # Read the constant a from the graph. 362 a = ops.get_default_graph().get_tensor_by_name(constant_5_name) 363 b = constant_op.constant(6.0) 364 c = a * b 365 self.assertEqual(30.0, sess.run(c)) 366 367 # Restore the graph with tag "bar". 368 with self.test_session(graph=ops.Graph()) as sess: 369 loader.load(sess, ["bar"], export_dir) 370 # Read the constant a from the graph. 371 a = ops.get_default_graph().get_tensor_by_name(constant_6_name) 372 b = constant_op.constant(5.0) 373 c = a * b 374 self.assertEqual(30.0, sess.run(c)) 375 376 def testNoOverwrite(self): 377 export_dir = self._get_export_dir("test_no_overwrite") 378 builder = saved_model_builder.SavedModelBuilder(export_dir) 379 380 # Graph with a single variable. SavedModel invoked to: 381 # - add with weights. 382 with self.test_session(graph=ops.Graph()) as sess: 383 self._init_and_validate_variable(sess, "v", 42) 384 builder.add_meta_graph_and_variables(sess, ["foo"]) 385 386 # Save the SavedModel to disk in text format. 387 builder.save(as_text=True) 388 389 # Restore the graph with tag "foo", whose variables were saved. 390 with self.test_session(graph=ops.Graph()) as sess: 391 loader.load(sess, ["foo"], export_dir) 392 self.assertEqual( 393 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 394 395 # An attempt to create another builder with the same export directory should 396 # result in an assertion error. 397 self.assertRaises(AssertionError, saved_model_builder.SavedModelBuilder, 398 export_dir) 399 400 def testSaveAsText(self): 401 export_dir = self._get_export_dir("test_astext") 402 builder = saved_model_builder.SavedModelBuilder(export_dir) 403 404 # Graph with a single variable. SavedModel invoked to: 405 # - add with weights. 406 with self.test_session(graph=ops.Graph()) as sess: 407 self._init_and_validate_variable(sess, "v", 42) 408 builder.add_meta_graph_and_variables(sess, ["foo"]) 409 410 # Graph with the same single variable. SavedModel invoked to: 411 # - simply add the model (weights are not updated). 412 with self.test_session(graph=ops.Graph()) as sess: 413 self._init_and_validate_variable(sess, "v", 43) 414 builder.add_meta_graph(["bar"]) 415 416 # Save the SavedModel to disk in text format. 417 builder.save(as_text=True) 418 419 # Restore the graph with tag "foo", whose variables were saved. 420 with self.test_session(graph=ops.Graph()) as sess: 421 loader.load(sess, ["foo"], export_dir) 422 self.assertEqual( 423 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 424 425 # Restore the graph with tag "bar", whose variables were not saved. 426 with self.test_session(graph=ops.Graph()) as sess: 427 loader.load(sess, ["bar"], export_dir) 428 self.assertEqual( 429 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 430 431 def testCollections(self): 432 export_dir = self._get_export_dir("test_collections") 433 builder = saved_model_builder.SavedModelBuilder(export_dir) 434 435 # Graph with a single variable added to a collection. SavedModel invoked to: 436 # - add with weights. 437 with self.test_session(graph=ops.Graph()) as sess: 438 v = variables.Variable(42, name="v") 439 ops.add_to_collection("foo_vars", v) 440 sess.run(variables.global_variables_initializer()) 441 self.assertEqual(42, v.eval()) 442 builder.add_meta_graph_and_variables(sess, ["foo"]) 443 444 # Graph with the same single variable added to a different collection. 445 # SavedModel invoked to: 446 # - simply add the model (weights are not updated). 447 with self.test_session(graph=ops.Graph()) as sess: 448 v = variables.Variable(43, name="v") 449 ops.add_to_collection("bar_vars", v) 450 sess.run(variables.global_variables_initializer()) 451 self.assertEqual(43, v.eval()) 452 builder.add_meta_graph(["bar"]) 453 454 # Save the SavedModel to disk. 455 builder.save() 456 457 # Restore the graph with tag "foo", whose variables were saved. The 458 # collection 'foo_vars' should contain a single element. The collection 459 # 'bar_vars' should not be found. 460 with self.test_session(graph=ops.Graph()) as sess: 461 loader.load(sess, ["foo"], export_dir) 462 collection_foo_vars = ops.get_collection("foo_vars") 463 self.assertEqual(len(collection_foo_vars), 1) 464 self.assertEqual(42, collection_foo_vars[0].eval()) 465 466 self.assertEqual(len(ops.get_collection("bar_vars")), 0) 467 468 # Restore the graph with tag "bar", whose variables were not saved. The 469 # collection-def exported as part of the meta graph def is updated to 470 # reflect the new collection. The value of the variable in the 471 # collection-def corresponds to the saved value (from the previous graph 472 # with tag "foo"). 473 with self.test_session(graph=ops.Graph()) as sess: 474 loader.load(sess, ["bar"], export_dir) 475 collection_bar_vars = ops.get_collection("bar_vars") 476 self.assertEqual(len(collection_bar_vars), 1) 477 self.assertEqual(42, collection_bar_vars[0].eval()) 478 479 self.assertEqual(len(ops.get_collection("foo_vars")), 0) 480 481 def testSignatureDefs(self): 482 export_dir = self._get_export_dir("test_signature_defs") 483 builder = saved_model_builder.SavedModelBuilder(export_dir) 484 485 # Graph with a single variable and a single entry in the signature def map. 486 # SavedModel is invoked to add with weights. 487 with self.test_session(graph=ops.Graph()) as sess: 488 self._init_and_validate_variable(sess, "v", 42) 489 # Build and populate an empty SignatureDef for testing. 490 foo_signature = signature_def_utils.build_signature_def(dict(), 491 dict(), "foo") 492 builder.add_meta_graph_and_variables( 493 sess, ["foo"], signature_def_map={"foo_key": foo_signature}) 494 495 # Graph with the same single variable and multiple entries in the signature 496 # def map. No weights are saved by SavedModel. 497 with self.test_session(graph=ops.Graph()) as sess: 498 self._init_and_validate_variable(sess, "v", 43) 499 # Build and populate a different SignatureDef for testing. 500 bar_signature = signature_def_utils.build_signature_def(dict(), 501 dict(), "bar") 502 # Also, build a different SignatureDef corresponding to "foo_key" defined 503 # in the previous graph. 504 foo_new_signature = signature_def_utils.build_signature_def(dict(), 505 dict(), 506 "foo_new") 507 builder.add_meta_graph( 508 ["bar"], 509 signature_def_map={ 510 "bar_key": bar_signature, 511 "foo_key": foo_new_signature 512 }) 513 514 # Save the SavedModel to disk. 515 builder.save() 516 517 # Restore the graph with tag "foo". The single entry in the SignatureDef map 518 # corresponding to "foo_key" should exist. 519 with self.test_session(graph=ops.Graph()) as sess: 520 foo_graph = loader.load(sess, ["foo"], export_dir) 521 self.assertEqual( 522 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 523 524 foo_signature = foo_graph.signature_def 525 self.assertEqual(len(foo_signature), 1) 526 self.assertEqual("foo", foo_signature["foo_key"].method_name) 527 528 # Restore the graph with tag "bar". The SignatureDef map should have two 529 # entries. One corresponding to "bar_key" and another corresponding to the 530 # new value of "foo_key". 531 with self.test_session(graph=ops.Graph()) as sess: 532 bar_graph = loader.load(sess, ["bar"], export_dir) 533 self.assertEqual( 534 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 535 536 bar_signature = bar_graph.signature_def 537 self.assertEqual(len(bar_signature), 2) 538 self.assertEqual("bar", bar_signature["bar_key"].method_name) 539 self.assertEqual("foo_new", bar_signature["foo_key"].method_name) 540 541 def testSignatureDefValidation(self): 542 export_dir = self._get_export_dir("test_signature_def_validation") 543 builder = saved_model_builder.SavedModelBuilder(export_dir) 544 545 tensor_without_name = meta_graph_pb2.TensorInfo() 546 tensor_without_name.dtype = types_pb2.DT_FLOAT 547 self._validate_inputs_tensor_info(builder, tensor_without_name) 548 self._validate_outputs_tensor_info(builder, tensor_without_name) 549 550 tensor_without_dtype = meta_graph_pb2.TensorInfo() 551 tensor_without_dtype.name = "x" 552 self._validate_inputs_tensor_info(builder, tensor_without_dtype) 553 self._validate_outputs_tensor_info(builder, tensor_without_dtype) 554 555 tensor_empty = meta_graph_pb2.TensorInfo() 556 self._validate_inputs_tensor_info(builder, tensor_empty) 557 self._validate_outputs_tensor_info(builder, tensor_empty) 558 559 def testAssets(self): 560 export_dir = self._get_export_dir("test_assets") 561 builder = saved_model_builder.SavedModelBuilder(export_dir) 562 563 with self.test_session(graph=ops.Graph()) as sess: 564 self._init_and_validate_variable(sess, "v", 42) 565 566 # Build an asset collection. 567 ignored_filepath = os.path.join( 568 compat.as_bytes(test.get_temp_dir()), compat.as_bytes("ignored.txt")) 569 file_io.write_string_to_file(ignored_filepath, "will be ignored") 570 571 asset_collection = self._build_asset_collection("hello42.txt", 572 "foo bar baz", 573 "asset_file_tensor") 574 575 builder.add_meta_graph_and_variables( 576 sess, ["foo"], assets_collection=asset_collection) 577 578 # Save the SavedModel to disk. 579 builder.save() 580 581 with self.test_session(graph=ops.Graph()) as sess: 582 foo_graph = loader.load(sess, ["foo"], export_dir) 583 self._validate_asset_collection(export_dir, foo_graph.collection_def, 584 "hello42.txt", "foo bar baz", 585 "asset_file_tensor:0") 586 ignored_asset_path = os.path.join( 587 compat.as_bytes(export_dir), 588 compat.as_bytes(constants.ASSETS_DIRECTORY), 589 compat.as_bytes("ignored.txt")) 590 self.assertFalse(file_io.file_exists(ignored_asset_path)) 591 592 def testCustomMainOp(self): 593 export_dir = self._get_export_dir("test_main_op") 594 builder = saved_model_builder.SavedModelBuilder(export_dir) 595 596 with self.test_session(graph=ops.Graph()) as sess: 597 # Add `v1` and `v2` variables to the graph. 598 v1 = variables.Variable(1, name="v1") 599 ops.add_to_collection("v", v1) 600 v2 = variables.Variable(2, name="v2") 601 ops.add_to_collection("v", v2) 602 603 # Initialize another variable `v3` to 42. 604 v3 = variables.Variable(42, name="v3") 605 ops.add_to_collection("v", v3) 606 607 # Set up an assignment op to be run as part of the main_op. 608 with ops.control_dependencies([main_op.main_op()]): 609 add_v1_v2 = math_ops.add(v1._ref(), v2._ref()) 610 custom_main_op = control_flow_ops.group(state_ops.assign(v3, add_v1_v2)) 611 612 sess.run(custom_main_op) 613 builder.add_meta_graph_and_variables( 614 sess, ["foo"], main_op=custom_main_op) 615 616 # Save the SavedModel to disk. 617 builder.save() 618 619 with self.test_session(graph=ops.Graph()) as sess: 620 loader.load(sess, ["foo"], export_dir) 621 self.assertEqual(1, ops.get_collection("v")[0].eval()) 622 self.assertEqual(2, ops.get_collection("v")[1].eval()) 623 # Evaluates to the sum of the first two variables and assigned as part of 624 # the main_op, following a restore. 625 self.assertEqual(3, ops.get_collection("v")[2].eval()) 626 627 def testLegacyInitOp(self): 628 export_dir = self._get_export_dir("test_legacy_init_op") 629 builder = saved_model_builder.SavedModelBuilder(export_dir) 630 631 with self.test_session(graph=ops.Graph()) as sess: 632 # Add `v1` and `v2` variables to the graph. 633 v1 = variables.Variable(1, name="v1") 634 ops.add_to_collection("v", v1) 635 v2 = variables.Variable(2, name="v2") 636 ops.add_to_collection("v", v2) 637 638 # Initialize another variable `v3` to 42. 639 v3 = variables.Variable(42, name="v3", trainable=False, collections=[]) 640 ops.add_to_collection("v", v3) 641 642 # Set up an assignment op to be run as part of the legacy_init_op. 643 assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2)) 644 legacy_init_op = control_flow_ops.group(assign_v3, name="legacy_init_op") 645 646 sess.run(variables.global_variables_initializer()) 647 builder.add_meta_graph_and_variables( 648 sess, ["foo"], legacy_init_op=legacy_init_op) 649 650 # Save the SavedModel to disk. 651 builder.save() 652 653 with self.test_session(graph=ops.Graph()) as sess: 654 loader.load(sess, ["foo"], export_dir) 655 self.assertEqual(1, ops.get_collection("v")[0].eval()) 656 self.assertEqual(2, ops.get_collection("v")[1].eval()) 657 # Evaluates to the sum of the first two variables and assigned as part of 658 # the legacy_init_op, following a restore. 659 self.assertEqual(3, ops.get_collection("v")[2].eval()) 660 661 def testLegacyInitOpWithNonEmptyCollection(self): 662 export_dir = self._get_export_dir( 663 "test_legacy_init_op_with_non_empty_collection") 664 builder = saved_model_builder.SavedModelBuilder(export_dir) 665 666 with self.test_session(graph=ops.Graph()) as sess: 667 # Initialize variable `v1` to 1. 668 v1 = variables.Variable(1, name="v1") 669 ops.add_to_collection("v", v1) 670 671 # Initialize another variable `v2` to 42. 672 v2 = variables.Variable(42, name="v2", trainable=False, collections=[]) 673 ops.add_to_collection("v", v2) 674 675 # Set up an assignment op to be run as part of the legacy_init_op. 676 assign_v2 = state_ops.assign(v2, v1) 677 legacy_init_op = control_flow_ops.group(assign_v2, name="legacy_init_op") 678 679 sess.run(variables.global_variables_initializer()) 680 681 ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, 682 control_flow_ops.no_op()) 683 # AssertionError should be raised since the LEGACY_INIT_OP_KEY collection 684 # is not empty and we don't support multiple init ops. 685 with self.assertRaises(AssertionError): 686 builder.add_meta_graph_and_variables( 687 sess, ["foo"], legacy_init_op=legacy_init_op) 688 689 def testMultipleAssets(self): 690 export_dir = self._get_export_dir("test_multiple_assets") 691 builder = saved_model_builder.SavedModelBuilder(export_dir) 692 693 with self.test_session(graph=ops.Graph()) as sess: 694 self._init_and_validate_variable(sess, "v", 42) 695 696 # Build an asset collection specific to `foo` graph. 697 asset_collection = self._build_asset_collection("foo.txt", "content_foo", 698 "asset_file_tensor") 699 700 # Add the asset collection as part of the graph with tag "foo". 701 builder.add_meta_graph_and_variables( 702 sess, ["foo"], assets_collection=asset_collection) 703 704 with self.test_session(graph=ops.Graph()) as sess: 705 self._init_and_validate_variable(sess, "v", 42) 706 707 # Build an asset collection specific to `bar` graph. 708 asset_collection = self._build_asset_collection("bar.txt", "content_bar", 709 "asset_file_tensor") 710 711 # Add the asset collection as part of the graph with tag "bar". 712 builder.add_meta_graph(["bar"], assets_collection=asset_collection) 713 714 # Save the SavedModel to disk. 715 builder.save() 716 717 # Check assets restored for graph with tag "foo". 718 with self.test_session(graph=ops.Graph()) as sess: 719 foo_graph = loader.load(sess, ["foo"], export_dir) 720 self._validate_asset_collection(export_dir, foo_graph.collection_def, 721 "foo.txt", "content_foo", 722 "asset_file_tensor:0") 723 724 # Check assets restored for graph with tag "bar". 725 with self.test_session(graph=ops.Graph()) as sess: 726 bar_graph = loader.load(sess, ["bar"], export_dir) 727 self._validate_asset_collection(export_dir, bar_graph.collection_def, 728 "bar.txt", "content_bar", 729 "asset_file_tensor:0") 730 731 def testDuplicateAssets(self): 732 export_dir = self._get_export_dir("test_duplicate_assets") 733 builder = saved_model_builder.SavedModelBuilder(export_dir) 734 735 with self.test_session(graph=ops.Graph()) as sess: 736 self._init_and_validate_variable(sess, "v", 42) 737 738 # Build an asset collection with `foo.txt` that has `foo` specific 739 # content. 740 asset_collection = self._build_asset_collection("foo.txt", "content_foo", 741 "asset_file_tensor") 742 743 # Add the asset collection as part of the graph with tag "foo". 744 builder.add_meta_graph_and_variables( 745 sess, ["foo"], assets_collection=asset_collection) 746 747 with self.test_session(graph=ops.Graph()) as sess: 748 self._init_and_validate_variable(sess, "v", 42) 749 750 # Build an asset collection with `foo.txt` that has `bar` specific 751 # content. 752 asset_collection = self._build_asset_collection("foo.txt", "content_bar", 753 "asset_file_tensor") 754 755 # Add the asset collection as part of the graph with tag "bar". 756 builder.add_meta_graph(["bar"], assets_collection=asset_collection) 757 758 # Save the SavedModel to disk. 759 builder.save() 760 761 # Check assets restored for graph with tag "foo". 762 with self.test_session(graph=ops.Graph()) as sess: 763 foo_graph = loader.load(sess, ["foo"], export_dir) 764 self._validate_asset_collection(export_dir, foo_graph.collection_def, 765 "foo.txt", "content_foo", 766 "asset_file_tensor:0") 767 768 # Check assets restored for graph with tag "bar". 769 with self.test_session(graph=ops.Graph()) as sess: 770 bar_graph = loader.load(sess, ["bar"], export_dir) 771 772 # Validate the assets for `bar` graph. `foo.txt` should contain the 773 # original contents corresponding to `foo` graph since an asset with the 774 # same name across multiple graphs is only stored the first time 775 self._validate_asset_collection(export_dir, bar_graph.collection_def, 776 "foo.txt", "content_foo", 777 "asset_file_tensor:0") 778 779 def testOp(self): 780 export_dir = self._get_export_dir("test_op") 781 builder = saved_model_builder.SavedModelBuilder(export_dir) 782 783 with session.Session( 784 graph=ops.Graph(), 785 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 786 with sess.graph.device("/cpu:0"): 787 v1 = variables.Variable(1, name="v1") 788 with sess.graph.device("/cpu:1"): 789 v2 = variables.Variable(2, name="v2") 790 791 # v3 is an unsaved variable derived from v1 and v2. It is used to 792 # exercise the ability to run an init op when restoring a graph. 793 v3 = variables.Variable(1, name="v3", trainable=False, collections=[]) 794 assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2)) 795 init_op = control_flow_ops.group(assign_v3, name="init_op") 796 797 ops.add_to_collection("v", v1) 798 ops.add_to_collection("v", v2) 799 ops.add_to_collection("v", v3) 800 ops.add_to_collection("init_op", init_op) 801 802 sess.run(variables.global_variables_initializer()) 803 self.assertEqual(1, ops.get_collection("v")[0].eval()) 804 self.assertEqual(2, ops.get_collection("v")[1].eval()) 805 806 builder.add_meta_graph_and_variables(sess, ["foo"]) 807 808 # Save the SavedModel to disk. 809 builder.save() 810 811 with session.Session( 812 graph=ops.Graph(), 813 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 814 loader.load(sess, ["foo"], export_dir) 815 816 # Validate variables, run the init op and verify result. 817 self.assertEqual(1, ops.get_collection("v")[0].eval()) 818 self.assertEqual(2, ops.get_collection("v")[1].eval()) 819 ops.get_collection("init_op")[0].run() 820 self.assertEqual(3, ops.get_collection("v")[2].eval()) 821 822 def testCustomSaveable(self): 823 export_dir = self._get_export_dir("custom_saveable") 824 builder = saved_model_builder.SavedModelBuilder(export_dir) 825 826 with session.Session( 827 graph=ops.Graph(), 828 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 829 # CheckpointedOp is a key-value table that can be saved across sessions. 830 # The table register itself in SAVEABLE_OBJECTS collection. 831 v1 = saver_test_utils.CheckpointedOp(name="v1") 832 variables.global_variables_initializer().run() 833 v1.insert("k1", 3.0).run() 834 # Once the table is restored, we can access it through this reference. 835 ops.add_to_collection("table_ref", v1.table_ref) 836 builder.add_meta_graph_and_variables(sess, ["foo"]) 837 838 # Save the SavedModel to disk. 839 builder.save() 840 841 with session.Session( 842 graph=ops.Graph(), 843 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 844 loader.load(sess, ["foo"], export_dir) 845 # Instantiate a wrapper object from the checkpointed reference. 846 v1 = saver_test_utils.CheckpointedOp( 847 name="v1", table_ref=ops.get_collection("table_ref")[0]) 848 self.assertEqual(b"k1", v1.keys().eval()) 849 self.assertEqual(3.0, v1.values().eval()) 850 851 def testClearDevices(self): 852 export_dir = self._get_export_dir("test_clear_devices") 853 builder = saved_model_builder.SavedModelBuilder(export_dir) 854 855 # Specify a device and save a variable. 856 ops.reset_default_graph() 857 with session.Session( 858 target="", 859 config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: 860 with sess.graph.device("/cpu:0"): 861 self._init_and_validate_variable(sess, "v", 42) 862 builder.add_meta_graph_and_variables( 863 sess, [tag_constants.TRAINING], clear_devices=True) 864 865 # Save the SavedModel to disk. 866 builder.save() 867 868 # Restore the graph with a single predefined tag whose variables were saved 869 # without any device information. 870 with self.test_session(graph=ops.Graph()) as sess: 871 loader.load(sess, [tag_constants.TRAINING], export_dir) 872 self.assertEqual( 873 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) 874 875 def testStripDefaultAttrs(self): 876 export_dir = self._get_export_dir("test_strip_default_attrs") 877 builder = saved_model_builder.SavedModelBuilder(export_dir) 878 879 # Add a graph with two float32 variables and a Complex Op composing them 880 # with strip_default_attrs enabled. 881 with session.Session(graph=ops.Graph()) as sess: 882 real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") 883 imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") 884 math_ops.complex(real_num, imag_num, name="complex") 885 sess.run(variables.global_variables_initializer()) 886 builder.add_meta_graph_and_variables( 887 sess, ["foo"], strip_default_attrs=True) 888 889 # Add a graph with the same float32 variables and a Complex Op composing 890 # them with strip_default_attrs disabled. 891 with session.Session(graph=ops.Graph()) as sess: 892 real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") 893 imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") 894 math_ops.complex(real_num, imag_num, name="complex") 895 sess.run(variables.global_variables_initializer()) 896 builder.add_meta_graph(["bar"], strip_default_attrs=False) 897 898 # Save the SavedModel to disk in text format. 899 builder.save(as_text=True) 900 901 # Loading graph "foo" via the loader must restore the defaults for the 902 # "Complex" node based on the "Complex" OpDef in the Op registry. 903 sess = session.Session(graph=ops.Graph()) 904 meta_graph_def = loader.load(sess, ["foo"], export_dir) 905 complex_node = test_util.get_node_def_from_graph("complex", 906 meta_graph_def.graph_def) 907 self.assertIn("T", complex_node.attr) 908 self.assertIn("Tout", complex_node.attr) 909 910 # Load graph "foo" from disk as-is to verify default attrs are stripped. 911 # pylint: disable=protected-access 912 saved_model_pb = loader_impl._parse_saved_model(export_dir) 913 self.assertIsNotNone(saved_model_pb) 914 # pylint: enable=protected-access 915 916 meta_graph_foo_def = None 917 meta_graph_bar_def = None 918 for meta_graph_def in saved_model_pb.meta_graphs: 919 if set(meta_graph_def.meta_info_def.tags) == set(["foo"]): 920 meta_graph_foo_def = meta_graph_def 921 elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]): 922 meta_graph_bar_def = meta_graph_def 923 924 self.assertIsNotNone(meta_graph_foo_def) 925 self.assertIsNotNone(meta_graph_bar_def) 926 927 # "Complex" Op has 2 attributes with defaults: 928 # o "T" : float32. (input type) 929 # o "Tout" : complex64. (output type) 930 931 # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout". 932 # Graph "foo" was saved with strip_default_attrs set to True. 933 node_def = test_util.get_node_def_from_graph("complex", 934 meta_graph_foo_def.graph_def) 935 self.assertNotIn("T", node_def.attr) 936 self.assertNotIn("Tout", node_def.attr) 937 938 # "Complex" Op in graph "bar" must have attributes "T" and "Tout". 939 # Graph "bar" was saved with strip_default_attrs set to False. 940 node_def = test_util.get_node_def_from_graph("complex", 941 meta_graph_bar_def.graph_def) 942 self.assertIn("T", node_def.attr) 943 self.assertIn("Tout", node_def.attr) 944 945 # Tests the behavior of loading SavedModels that having missing attrs or attrs 946 # with incorrect types. 947 def testInconsistentConsumerDefaultAttrs(self): 948 export_dir = self._get_export_dir( 949 "test_strip_default_attrs_no_consumer_defaults") 950 builder = saved_model_builder.SavedModelBuilder(export_dir) 951 952 # Add a graph with a single variable and a test op with a defaultless 953 # float32 attr, "test_attr". 954 with session.Session(graph=ops.Graph()) as sess: 955 variables.Variable(1.0, dtype=dtypes.float64, name="var") 956 test_ops.test_attr(T=dtypes.float32, name="test_attr") 957 sess.run(variables.global_variables_initializer()) 958 builder.add_meta_graph_and_variables(sess, ["foo"]) 959 960 # Save the SavedModel to disk in text format. 961 builder.save(as_text=True) 962 963 # Rewrite the SavedModel to remove the T attr from "test_attr". 964 saved_model_file = os.path.join( 965 export_dir, constants.SAVED_MODEL_FILENAME_PBTXT) 966 with open(saved_model_file) as f: 967 original_saved_model = f.read() 968 969 no_attr_saved_model = original_saved_model.replace(""" 970 attr { 971 key: "T" 972 value { 973 type: DT_FLOAT 974 } 975 }""", "") 976 with open(saved_model_file, "w") as f: 977 f.write(no_attr_saved_model) 978 979 # Loading the SavedModel via the loader must fail because the SavedModel 980 # does not have any attr values for the "TestAttr" node, and there is no 981 # default specified in the TestAttr OpDef. 982 sess = session.Session(graph=ops.Graph()) 983 if ops._USE_C_API: 984 error_message = "NodeDef missing attr 'T' from Op<name=TestAttr" 985 else: 986 error_message = ("Expected one attr with name .*T(out)?.* in name: " 987 "\"test_attr\".*") 988 with self.assertRaisesRegexp(ValueError, error_message): 989 loader.load(sess, ["foo"], export_dir) 990 991 # Rewrite the SavedModel to change the type of the T attr in "test_attr" 992 bad_type_saved_model = original_saved_model.replace(""" 993 attr { 994 key: "T" 995 value { 996 type: DT_FLOAT 997 } 998 }""", """ 999 attr { 1000 key: "T" 1001 value { 1002 type: DT_DOUBLE 1003 } 1004 }""") 1005 with open(saved_model_file, "w") as f: 1006 f.write(bad_type_saved_model) 1007 1008 # Loading the SavedModel via the loader must fail because there is no 1009 # OpKernel registered to handle T = double. 1010 sess = session.Session(graph=ops.Graph()) 1011 with self.assertRaisesRegexp( 1012 errors.InvalidArgumentError, 1013 ".*No OpKernel was registered to support Op \'TestAttr\' with these " 1014 "attrs..*"): 1015 loader.load(sess, ["foo"], export_dir) 1016 1017 1018 if __name__ == "__main__": 1019 test.main() 1020