1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================= 15 """Tests for tensorflow.python.framework.meta_graph.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import math 22 import os.path 23 import random 24 import shutil 25 26 from tensorflow.core.framework import graph_pb2 27 from tensorflow.core.protobuf import meta_graph_pb2 28 from tensorflow.python.client import session 29 from tensorflow.python.framework import constant_op 30 from tensorflow.python.framework import dtypes 31 from tensorflow.python.framework import function 32 from tensorflow.python.framework import meta_graph 33 from tensorflow.python.framework import ops 34 from tensorflow.python.framework import test_util 35 from tensorflow.python.ops import array_ops 36 from tensorflow.python.ops import control_flow_ops 37 from tensorflow.python.ops import data_flow_ops 38 from tensorflow.python.ops import gradients_impl 39 from tensorflow.python.ops import math_ops 40 from tensorflow.python.ops import metrics 41 from tensorflow.python.ops import nn_ops 42 from tensorflow.python.ops import partitioned_variables 43 from tensorflow.python.ops import random_ops 44 from tensorflow.python.ops import resource_variable_ops 45 from tensorflow.python.ops import variable_scope 46 from tensorflow.python.ops import variables 47 from tensorflow.python.platform import gfile 48 from tensorflow.python.platform import test 49 from tensorflow.python.training import queue_runner_impl 50 51 52 # pylint: disable=invalid-name 53 def _TestDir(test_name): 54 test_dir = os.path.join(test.get_temp_dir(), test_name) 55 if os.path.exists(test_dir): 56 shutil.rmtree(test_dir) 57 gfile.MakeDirs(test_dir) 58 return test_dir 59 60 61 # pylint: enable=invalid-name 62 63 64 @test_util.with_c_api 65 class SimpleMetaGraphTest(test.TestCase): 66 67 def testNoVariables(self): 68 test_dir = _TestDir("no_variables") 69 filename = os.path.join(test_dir, "metafile") 70 71 input_feed_value = -10 # Arbitrary input value for feed_dict. 72 73 orig_graph = ops.Graph() 74 with self.test_session(graph=orig_graph) as sess: 75 # Create a minimal graph with zero variables. 76 input_tensor = array_ops.placeholder( 77 dtypes.float32, shape=[], name="input") 78 offset = constant_op.constant(42, dtype=dtypes.float32, name="offset") 79 output_tensor = math_ops.add(input_tensor, offset, name="add_offset") 80 81 # Add input and output tensors to graph collections. 82 ops.add_to_collection("input_tensor", input_tensor) 83 ops.add_to_collection("output_tensor", output_tensor) 84 85 output_value = sess.run(output_tensor, {input_tensor: input_feed_value}) 86 self.assertEqual(output_value, 32) 87 88 # Generates MetaGraphDef. 89 meta_graph_def, var_list = meta_graph.export_scoped_meta_graph( 90 filename=filename, 91 graph_def=ops.get_default_graph().as_graph_def(add_shapes=True), 92 collection_list=["input_tensor", "output_tensor"], 93 saver_def=None) 94 self.assertTrue(meta_graph_def.HasField("meta_info_def")) 95 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "") 96 self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version, 97 "") 98 self.assertEqual({}, var_list) 99 100 # Create a clean graph and import the MetaGraphDef nodes. 101 new_graph = ops.Graph() 102 with self.test_session(graph=new_graph) as sess: 103 # Import the previously export meta graph. 104 meta_graph.import_scoped_meta_graph(filename) 105 106 # Re-exports the current graph state for comparison to the original. 107 new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename + 108 "_new") 109 test_util.assert_meta_graph_protos_equal(self, meta_graph_def, 110 new_meta_graph_def) 111 112 # Ensures that we can still get a reference to our graph collections. 113 new_input_tensor = ops.get_collection("input_tensor")[0] 114 new_output_tensor = ops.get_collection("output_tensor")[0] 115 # Verifies that the new graph computes the same result as the original. 116 new_output_value = sess.run(new_output_tensor, 117 {new_input_tensor: input_feed_value}) 118 self.assertEqual(new_output_value, output_value) 119 120 def testStrippedOpListNestedFunctions(self): 121 with self.test_session(): 122 # Square two levels deep 123 @function.Defun(dtypes.int32) 124 def f0(x): 125 return math_ops.square(x) 126 127 @function.Defun(dtypes.int32) 128 def f1(x): 129 return f0(x) 130 131 # At this point we've defined two functions but haven't called them, so 132 # there should be no used ops. 133 op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph() 134 .as_graph_def()) 135 self.assertEqual(len(op_list.op), 0) 136 137 # If we call the function on a constant, there should be two ops 138 _ = f1(constant_op.constant(7)) 139 op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph() 140 .as_graph_def()) 141 self.assertEqual(["Const", "Square"], [op.name for op in op_list.op]) 142 143 def testStrippedOpListRecursiveFunctions(self): 144 # The function module doesn't support recursive functions, so we build a 145 # recursive function situation by ourselves: A calls B calls A and Const. 146 graph = graph_pb2.GraphDef() 147 a = graph.library.function.add() 148 b = graph.library.function.add() 149 a.signature.name = "A" 150 b.signature.name = "B" 151 a.node_def.add().op = "B" 152 b.node_def.add().op = "Const" 153 b.node_def.add().op = "A" 154 155 # Use A in the graph 156 graph.node.add().op = "A" 157 158 # The stripped op list should contain just Const. 159 op_list = meta_graph.stripped_op_list_for_graph(graph) 160 self.assertEqual(["Const"], [op.name for op in op_list.op]) 161 162 def testDefaultAttrStripping(self): 163 """Verifies that default attributes are stripped from a graph def.""" 164 165 # Complex Op has 2 attributes with defaults: 166 # o "T" : float32. 167 # o "Tout" : complex64. 168 169 # When inputs to the Complex Op are float32 instances, "T" maps to float32 170 # and "Tout" maps to complex64. Since these attr values map to their 171 # defaults, they must be stripped unless stripping of default attrs is 172 # disabled. 173 with self.test_session(): 174 real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real") 175 imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag") 176 math_ops.complex(real_num, imag_num, name="complex") 177 178 # strip_default_attrs is enabled. 179 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 180 graph_def=ops.get_default_graph().as_graph_def(), 181 strip_default_attrs=True) 182 node_def = test_util.get_node_def_from_graph("complex", 183 meta_graph_def.graph_def) 184 self.assertNotIn("T", node_def.attr) 185 self.assertNotIn("Tout", node_def.attr) 186 self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) 187 188 # strip_default_attrs is disabled. 189 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 190 graph_def=ops.get_default_graph().as_graph_def(), 191 strip_default_attrs=False) 192 node_def = test_util.get_node_def_from_graph("complex", 193 meta_graph_def.graph_def) 194 self.assertIn("T", node_def.attr) 195 self.assertIn("Tout", node_def.attr) 196 self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs) 197 198 # When inputs to the Complex Op are float64 instances, "T" maps to float64 199 # and "Tout" maps to complex128. Since these attr values don't map to their 200 # defaults, they must not be stripped. 201 with self.test_session(graph=ops.Graph()): 202 real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real") 203 imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag") 204 math_ops.complex(real_num, imag_num, name="complex") 205 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 206 graph_def=ops.get_default_graph().as_graph_def(), 207 strip_default_attrs=True) 208 node_def = test_util.get_node_def_from_graph("complex", 209 meta_graph_def.graph_def) 210 self.assertEqual(node_def.attr["T"].type, dtypes.float64) 211 self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128) 212 self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) 213 214 def testDefaultAttrStrippingNestedFunctions(self): 215 """Verifies that default attributes are stripped from function node defs.""" 216 with self.test_session(): 217 @function.Defun(dtypes.float32, dtypes.float32) 218 def f0(i, j): 219 return math_ops.complex(i, j, name="double_nested_complex") 220 221 @function.Defun(dtypes.float32, dtypes.float32) 222 def f1(i, j): 223 return f0(i, j) 224 225 _ = f1(constant_op.constant(1.0), constant_op.constant(2.0)) 226 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 227 graph_def=ops.get_default_graph().as_graph_def(), 228 strip_default_attrs=True) 229 230 double_nested_complex_node_def = None 231 for function_def in meta_graph_def.graph_def.library.function: 232 for node_def in function_def.node_def: 233 if node_def.name.startswith("double_nested_complex"): 234 double_nested_complex_node_def = node_def 235 break 236 if double_nested_complex_node_def: 237 break 238 239 self.assertIsNotNone(double_nested_complex_node_def) 240 self.assertNotIn("T", double_nested_complex_node_def.attr) 241 self.assertNotIn("Tout", double_nested_complex_node_def.attr) 242 self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) 243 244 def testDefaultAttrStrippingUnregisteredOps(self): 245 """Verifies that nodes with un-registered ops are not stripped.""" 246 graph_def = graph_pb2.GraphDef() 247 node = graph_def.node.add() 248 node.name = "node_with_unreg_op" 249 node.op = "unreg_op" 250 node.attr["attr_1"].i = 1 251 252 meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef() 253 meta_info_def.stripped_op_list.op.add() 254 255 with self.test_session(): 256 meta_graph_def = meta_graph.create_meta_graph_def( 257 meta_info_def=meta_info_def, graph_def=graph_def, 258 strip_default_attrs=True) 259 node_def = test_util.get_node_def_from_graph("node_with_unreg_op", 260 meta_graph_def.graph_def) 261 self.assertEqual(node_def.attr["attr_1"].i, 1) 262 self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) 263 264 def testVariableObjectsAreSharedAmongCollections(self): 265 with ops.Graph().as_default() as graph1: 266 v = variables.Variable(3.0) 267 # A single instance of Variable is shared among the collections: 268 global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 269 trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 270 self.assertEqual(len(global_vars), 1) 271 self.assertEqual(len(trainable_vars), 1) 272 self.assertIs(global_vars[0], trainable_vars[0]) 273 self.assertIs(v, global_vars[0]) 274 275 orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1) 276 del graph1 # To avoid accidental references in code involving graph2. 277 278 with ops.Graph().as_default() as graph2: 279 meta_graph.import_scoped_meta_graph(orig_meta_graph) 280 global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) 281 trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) 282 self.assertEqual(len(global_vars), 1) 283 self.assertEqual(len(trainable_vars), 1) 284 # A single instance of Variable is shared among the collections: 285 self.assertIs(global_vars[0], trainable_vars[0]) 286 287 288 @test_util.with_c_api 289 class ScopedMetaGraphTest(test.TestCase): 290 291 def _testScopedExport(self, test_dir, exported_filenames): 292 graph = ops.Graph() 293 with graph.as_default(): 294 # Creates an inference graph. 295 # Hidden 1 296 colocate_constraint = constant_op.constant(1.2, name="constraint") 297 images = constant_op.constant( 298 1.2, dtypes.float32, shape=[100, 28], name="images") 299 with ops.name_scope("hidden1"): 300 with graph.colocate_with(colocate_constraint.op): 301 weights1 = variables.Variable( 302 random_ops.truncated_normal( 303 [28, 128], stddev=1.0 / math.sqrt(float(28))), 304 name="weights") 305 # The use of control_flow_ops.cond here is purely for adding test 306 # coverage the save and restore of control flow context (which doesn't 307 # make any sense here from a machine learning perspective). The typical 308 # biases is a simple Variable without the conditions. 309 biases1 = variables.Variable( 310 control_flow_ops.cond( 311 math_ops.less(random.random(), 0.5), 312 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])), 313 name="biases") 314 hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1) 315 316 # Hidden 2 317 with ops.name_scope("hidden2"): 318 weights2 = variables.Variable( 319 random_ops.truncated_normal( 320 [128, 32], stddev=1.0 / math.sqrt(float(128))), 321 name="weights") 322 323 # The use of control_flow_ops.while_loop here is purely for adding test 324 # coverage the save and restore of control flow context (which doesn't 325 # make any sense here from a machine learning perspective). The typical 326 # biases is a simple Variable without the conditions. 327 def loop_cond(it, _): 328 return it < 2 329 330 def loop_body(it, biases2): 331 biases2 += constant_op.constant(0.1, shape=[32]) 332 return it + 1, biases2 333 334 _, biases2 = control_flow_ops.while_loop( 335 loop_cond, 336 loop_body, [ 337 constant_op.constant(0), variables.Variable( 338 array_ops.zeros([32]), name="biases") 339 ]) 340 hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2) 341 # Linear 342 with ops.name_scope("softmax_linear"): 343 weights3 = variables.Variable( 344 random_ops.truncated_normal( 345 [32, 10], stddev=1.0 / math.sqrt(float(32))), 346 name="weights") 347 biases3 = variables.Variable(array_ops.zeros([10]), name="biases") 348 logits = math_ops.matmul(hidden2, weights3) + biases3 349 ops.add_to_collection("logits", logits) 350 351 # Exports each sub-graph. 352 # Exports the first one with unbound_inputs_col_name set to default. 353 orig_meta_graph1, var_list = meta_graph.export_scoped_meta_graph( 354 filename=os.path.join(test_dir, exported_filenames[0]), 355 graph=ops.get_default_graph(), 356 export_scope="hidden1") 357 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 358 var_names = [v.name for _, v in var_list.items()] 359 self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"], 360 sorted(var_names)) 361 362 # Exports the rest with no unbound_inputs_col_name. 363 orig_meta_graph2, _ = meta_graph.export_scoped_meta_graph( 364 filename=os.path.join(test_dir, exported_filenames[1]), 365 graph=ops.get_default_graph(), 366 export_scope="hidden2", 367 unbound_inputs_col_name=None) 368 orig_meta_graph3, _ = meta_graph.export_scoped_meta_graph( 369 filename=os.path.join(test_dir, exported_filenames[2]), 370 graph=ops.get_default_graph(), 371 export_scope="softmax_linear", 372 unbound_inputs_col_name=None) 373 374 return [orig_meta_graph1, orig_meta_graph2, orig_meta_graph3] 375 376 def _testScopedImport(self, test_dir, exported_filenames): 377 graph = ops.Graph() 378 # Create all the missing inputs. 379 with graph.as_default(): 380 new_image = constant_op.constant( 381 1.2, dtypes.float32, shape=[100, 28], name="images") 382 383 with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"): 384 meta_graph.import_scoped_meta_graph( 385 os.path.join(test_dir, exported_filenames[0]), 386 graph=graph, 387 import_scope="new_hidden1") 388 389 with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"): 390 meta_graph.import_scoped_meta_graph( 391 os.path.join(test_dir, exported_filenames[0]), 392 graph=graph, 393 input_map={"image:0": new_image}, 394 import_scope="new_hidden1") 395 396 # Verifies we can import the original "hidden1" into "new_hidden1". 397 var_list = meta_graph.import_scoped_meta_graph( 398 os.path.join(test_dir, exported_filenames[0]), 399 graph=graph, 400 input_map={"$unbound_inputs_images": new_image}, 401 import_scope="new_hidden1") 402 403 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 404 new_var_names = [v.name for _, v in var_list.items()] 405 self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"], 406 sorted(new_var_names)) 407 408 # Verifies we can import the original "hidden2" into "new_hidden2". 409 hidden1 = array_ops.identity( 410 graph.as_graph_element("new_hidden1/Relu:0"), name="hidden1/Relu") 411 var_list = meta_graph.import_scoped_meta_graph( 412 os.path.join(test_dir, exported_filenames[1]), 413 graph=graph, 414 input_map={"$unbound_inputs_hidden1/Relu": hidden1}, 415 import_scope="new_hidden2", 416 unbound_inputs_col_name=None) 417 418 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 419 new_var_names = [v.name for _, v in var_list.items()] 420 self.assertEqual(["new_hidden2/biases:0", "new_hidden2/weights:0"], 421 sorted(new_var_names)) 422 423 # Verifies we can import the original "softmax_linear" into 424 # "new_softmax_linear". 425 hidden2 = array_ops.identity( 426 graph.as_graph_element("new_hidden2/Relu:0"), name="hidden2/Relu") 427 var_list = meta_graph.import_scoped_meta_graph( 428 os.path.join(test_dir, exported_filenames[2]), 429 graph=graph, 430 input_map={"$unbound_inputs_hidden2/Relu": hidden2}, 431 import_scope="new_softmax_linear", 432 unbound_inputs_col_name=None) 433 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 434 new_var_names = [v.name for _, v in var_list.items()] 435 self.assertEqual( 436 ["new_softmax_linear/biases:0", "new_softmax_linear/weights:0"], 437 sorted(new_var_names)) 438 439 # Exports the scoped meta graphs again. 440 new_meta_graph1, var_list = meta_graph.export_scoped_meta_graph( 441 graph=graph, export_scope="new_hidden1") 442 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 443 444 new_meta_graph2, var_list = meta_graph.export_scoped_meta_graph( 445 graph=graph, export_scope="new_hidden2", unbound_inputs_col_name=None) 446 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 447 448 new_meta_graph3, var_list = meta_graph.export_scoped_meta_graph( 449 graph=graph, 450 export_scope="new_softmax_linear", 451 unbound_inputs_col_name=None) 452 self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys())) 453 454 return [new_meta_graph1, new_meta_graph2, new_meta_graph3] 455 456 # Verifies that we can export the subgraph under each layer and import 457 # them into new layers in a new graph. 458 def testScopedExportAndImport(self): 459 test_dir = _TestDir("scoped_export_import") 460 filenames = [ 461 "exported_hidden1.pbtxt", "exported_hidden2.pbtxt", 462 "exported_softmax_linear.pbtxt" 463 ] 464 orig_meta_graphs = self._testScopedExport(test_dir, filenames) 465 new_meta_graphs = self._testScopedImport(test_dir, filenames) 466 for a, b in zip(orig_meta_graphs, new_meta_graphs): 467 # The unbound input strings are slightly different with the C API enabled 468 # ("images" vs "images:0") due to the original import_graph_def code 469 # vs. ImportGraphDef in C++. 470 # TODO(skyewm): update the pbtxts once _USE_C_API is removed. 471 del a.collection_def["unbound_inputs"] 472 del b.collection_def["unbound_inputs"] 473 test_util.assert_meta_graph_protos_equal(self, a, b) 474 475 def testWhileLoopGradients(self): 476 # Create a simple while loop. 477 with ops.Graph().as_default(): 478 with ops.name_scope("export"): 479 var = variables.Variable(0) 480 var_name = var.name 481 _, output = control_flow_ops.while_loop(lambda i, x: i < 5, 482 lambda i, x: (i + 1, x + i), 483 [0, var]) 484 output_name = output.name 485 486 # Generate a MetaGraphDef containing the while loop with an export scope. 487 meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 488 export_scope="export") 489 490 # Build and run the gradients of the while loop. We use this below to 491 # verify that the gradients are correct with the imported MetaGraphDef. 492 init_op = variables.global_variables_initializer() 493 grad = gradients_impl.gradients([output], [var]) 494 with session.Session() as sess: 495 sess.run(init_op) 496 expected_grad_value = sess.run(grad) 497 498 # Restore the MetaGraphDef into a new Graph with an import scope. 499 with ops.Graph().as_default(): 500 meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import") 501 502 # Re-export and make sure we get the same MetaGraphDef. 503 new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph( 504 export_scope="import") 505 test_util.assert_meta_graph_protos_equal( 506 self, meta_graph_def, new_meta_graph_def) 507 508 # Make sure we can still build gradients and get the same result. 509 510 def new_name(tensor_name): 511 base_tensor_name = tensor_name.replace("export/", "") 512 return "import/" + base_tensor_name 513 514 var = ops.get_default_graph().get_tensor_by_name(new_name(var_name)) 515 output = ops.get_default_graph().get_tensor_by_name(new_name(output_name)) 516 grad = gradients_impl.gradients([output], [var]) 517 518 init_op = variables.global_variables_initializer() 519 520 with session.Session() as sess: 521 sess.run(init_op) 522 actual_grad_value = sess.run(grad) 523 self.assertEqual(expected_grad_value, actual_grad_value) 524 525 def testScopedImportUnderNameScope(self): 526 graph = ops.Graph() 527 with graph.as_default(): 528 variables.Variable(initial_value=1.0, trainable=True, name="myvar") 529 meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph) 530 531 graph = ops.Graph() 532 with graph.as_default(): 533 with ops.name_scope("foo"): 534 imported_variables = meta_graph.import_scoped_meta_graph( 535 meta_graph_def, import_scope="bar") 536 self.assertEqual(len(imported_variables), 1) 537 self.assertEqual(list(imported_variables.values())[0].name, 538 "foo/bar/myvar:0") 539 540 def testImportsUsingSameScopeName(self): 541 with ops.Graph().as_default(): 542 variables.Variable(0, name="v") 543 meta_graph_def, _ = meta_graph.export_scoped_meta_graph() 544 with ops.Graph().as_default(): 545 for suffix in ["", "_1"]: 546 imported_variables = meta_graph.import_scoped_meta_graph( 547 meta_graph_def, import_scope="s") 548 self.assertEqual(len(imported_variables), 1) 549 self.assertEqual(list(imported_variables.keys())[0], "v:0") 550 self.assertEqual(list(imported_variables.values())[0].name, 551 "s" + suffix + "/v:0") 552 553 def testScopedImportWithSelectedCollections(self): 554 meta_graph_filename = os.path.join( 555 _TestDir("selected_collections_import"), "meta_graph.pb") 556 557 graph = ops.Graph() 558 # Add a variable to populate two collections. The functionality tested is 559 # not specific to variables, but using variables in the test is convenient. 560 with graph.as_default(): 561 variables.Variable(initial_value=1.0, trainable=True) 562 self.assertTrue( 563 all([ 564 graph.get_collection(key) 565 for key in 566 [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES] 567 ])) 568 meta_graph.export_scoped_meta_graph( 569 filename=meta_graph_filename, graph=graph) 570 571 def _test_import(include_collection_keys, omit_collection_keys): 572 assert set(include_collection_keys).isdisjoint(omit_collection_keys) 573 newgraph = ops.Graph() 574 import_scope = "some_scope_name" 575 576 def _restore_collections_predicate(collection_key): 577 return (collection_key in include_collection_keys and 578 collection_key not in omit_collection_keys) 579 580 meta_graph.import_scoped_meta_graph( 581 meta_graph_filename, 582 graph=newgraph, 583 import_scope=import_scope, 584 restore_collections_predicate=_restore_collections_predicate) 585 collection_values = [ 586 newgraph.get_collection(name=key, scope=import_scope) 587 for key in include_collection_keys 588 ] 589 self.assertTrue(all(collection_values)) 590 collection_values = [ 591 newgraph.get_collection(name=key, scope=import_scope) 592 for key in omit_collection_keys 593 ] 594 self.assertFalse(any(collection_values)) 595 596 _test_import( 597 include_collection_keys=[ 598 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES 599 ], 600 omit_collection_keys=[]) 601 _test_import( 602 include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES], 603 omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES]) 604 _test_import( 605 include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES], 606 omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES]) 607 _test_import( 608 include_collection_keys=[], 609 omit_collection_keys=[ 610 ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES 611 ]) 612 613 def _testScopedExportWithQueue(self, test_dir, exported_filename): 614 graph = ops.Graph() 615 with graph.as_default(): 616 with ops.name_scope("queue1"): 617 input_queue = data_flow_ops.FIFOQueue(10, dtypes.float32) 618 enqueue = input_queue.enqueue((9876), name="enqueue") 619 close = input_queue.close(name="close") 620 qr = queue_runner_impl.QueueRunner(input_queue, [enqueue], close) 621 queue_runner_impl.add_queue_runner(qr) 622 input_queue.dequeue(name="dequeue") 623 624 orig_meta_graph, _ = meta_graph.export_scoped_meta_graph( 625 filename=os.path.join(test_dir, exported_filename), 626 graph=ops.get_default_graph(), 627 export_scope="queue1") 628 629 return orig_meta_graph 630 631 def _testScopedImportWithQueue(self, test_dir, exported_filename, 632 new_exported_filename): 633 graph = ops.Graph() 634 meta_graph.import_scoped_meta_graph( 635 os.path.join(test_dir, exported_filename), 636 graph=graph, 637 import_scope="new_queue1") 638 graph.as_graph_element("new_queue1/dequeue:0") 639 graph.as_graph_element("new_queue1/close") 640 with graph.as_default(): 641 new_meta_graph, _ = meta_graph.export_scoped_meta_graph( 642 filename=os.path.join(test_dir, new_exported_filename), 643 graph=graph, 644 export_scope="new_queue1") 645 646 return new_meta_graph 647 648 # Verifies that we can export the subgraph containing a FIFOQueue under 649 # "queue1" and import it into "new_queue1" in a new graph. 650 def testScopedWithQueue(self): 651 test_dir = _TestDir("scoped_with_queue") 652 orig_meta_graph = self._testScopedExportWithQueue(test_dir, 653 "exported_queue1.pbtxt") 654 new_meta_graph = self._testScopedImportWithQueue( 655 test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt") 656 test_util.assert_meta_graph_protos_equal(self, orig_meta_graph, 657 new_meta_graph) 658 659 # Verifies that we can export a subgraph in a nested name scope containing a 660 # "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new 661 # graph. 662 def doTestExportNestedNames(self, use_resource=False): 663 graph1 = ops.Graph() 664 with graph1.as_default(): 665 with ops.name_scope("hidden1/hidden2/hidden3"): 666 images = constant_op.constant( 667 1.0, dtypes.float32, shape=[3, 2], name="images") 668 if use_resource: 669 weights1 = variables.Variable( 670 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 671 biases1 = resource_variable_ops.ResourceVariable( 672 [0.1] * 3, name="biases") 673 else: 674 biases1 = variables.Variable([0.1] * 3, name="biases") 675 weights1 = variables.Variable( 676 [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights") 677 nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu") 678 679 orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph( 680 export_scope="hidden1/hidden2", graph=graph1) 681 var_names = [v.name for _, v in var_list.items()] 682 self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"], 683 sorted(var_list.keys())) 684 self.assertEqual([ 685 "hidden1/hidden2/hidden3/biases:0", "hidden1/hidden2/hidden3/weights:0" 686 ], sorted(var_names)) 687 for node in orig_meta_graph.graph_def.node: 688 self.assertTrue(node.name.startswith("hidden3")) 689 690 graph2 = ops.Graph() 691 new_var_list = meta_graph.import_scoped_meta_graph( 692 orig_meta_graph, import_scope="new_hidden1/new_hidden2", graph=graph2) 693 self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"], 694 sorted(new_var_list.keys())) 695 new_var_names = [v.name for _, v in new_var_list.items()] 696 self.assertEqual([ 697 "new_hidden1/new_hidden2/hidden3/biases:0", 698 "new_hidden1/new_hidden2/hidden3/weights:0" 699 ], sorted(new_var_names)) 700 701 nodes = [ 702 "new_hidden1/new_hidden2/hidden3/biases/Assign", 703 "new_hidden1/new_hidden2/hidden3/weights/Assign" 704 ] 705 expected = [ 706 b"loc:@new_hidden1/new_hidden2/hidden3/biases", 707 b"loc:@new_hidden1/new_hidden2/hidden3/weights" 708 ] 709 for n, e in zip(nodes, expected): 710 self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class")) 711 712 def testExportNestedNames(self): 713 self.doTestExportNestedNames(use_resource=False) 714 715 def testExportNestedNamesResource(self): 716 self.doTestExportNestedNames(use_resource=True) 717 718 def testPotentialCycle(self): 719 graph1 = ops.Graph() 720 with graph1.as_default(): 721 a = constant_op.constant(1.0, shape=[2, 2]) 722 b = constant_op.constant(2.0, shape=[2, 2]) 723 matmul = math_ops.matmul(a, b) 724 with ops.name_scope("hidden1"): 725 c = nn_ops.relu(matmul) 726 d = constant_op.constant(3.0, shape=[2, 2]) 727 matmul = math_ops.matmul(c, d) 728 729 orig_meta_graph, _ = meta_graph.export_scoped_meta_graph( 730 export_scope="hidden1", graph=graph1) 731 732 graph2 = ops.Graph() 733 with graph2.as_default(): 734 with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"): 735 meta_graph.import_scoped_meta_graph( 736 orig_meta_graph, import_scope="new_hidden1") 737 738 meta_graph.import_scoped_meta_graph( 739 orig_meta_graph, 740 import_scope="new_hidden1", 741 input_map={ 742 "$unbound_inputs_MatMul": constant_op.constant( 743 4.0, shape=[2, 2]) 744 }) 745 746 def testClearDevices(self): 747 graph1 = ops.Graph() 748 with graph1.as_default(): 749 with ops.device("/device:CPU:0"): 750 a = variables.Variable( 751 constant_op.constant( 752 1.0, shape=[2, 2]), name="a") 753 with ops.device("/job:ps/replica:0/task:0/device:GPU:0"): 754 b = variables.Variable( 755 constant_op.constant( 756 2.0, shape=[2, 2]), name="b") 757 with ops.device("/job:localhost/replica:0/task:0/cpu:0"): 758 math_ops.matmul(a, b, name="matmul") 759 760 self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device)) 761 self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0", 762 str(graph1.as_graph_element("b").device)) 763 self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0", 764 str(graph1.as_graph_element("matmul").device)) 765 766 # Verifies that devices are cleared on export. 767 orig_meta_graph, _ = meta_graph.export_scoped_meta_graph( 768 graph=graph1, clear_devices=True) 769 770 graph2 = ops.Graph() 771 with graph2.as_default(): 772 meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False) 773 774 self.assertEqual("", str(graph2.as_graph_element("a").device)) 775 self.assertEqual("", str(graph2.as_graph_element("b").device)) 776 self.assertEqual("", str(graph2.as_graph_element("matmul").device)) 777 778 # Verifies that devices are cleared on export when passing in graph_def. 779 orig_meta_graph, _ = meta_graph.export_scoped_meta_graph( 780 graph_def=graph1.as_graph_def(), clear_devices=True) 781 782 graph2 = ops.Graph() 783 with graph2.as_default(): 784 meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False) 785 786 self.assertEqual("", str(graph2.as_graph_element("a").device)) 787 self.assertEqual("", str(graph2.as_graph_element("b").device)) 788 self.assertEqual("", str(graph2.as_graph_element("matmul").device)) 789 790 # Verifies that devices are cleared on import. 791 orig_meta_graph, _ = meta_graph.export_scoped_meta_graph( 792 graph=graph1, clear_devices=False) 793 794 graph2 = ops.Graph() 795 with graph2.as_default(): 796 meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True) 797 798 self.assertEqual("", str(graph2.as_graph_element("a").device)) 799 self.assertEqual("", str(graph2.as_graph_element("b").device)) 800 self.assertEqual("", str(graph2.as_graph_element("matmul").device)) 801 802 803 @test_util.with_c_api 804 class MetaGraphWithVariableScopeTest(test.TestCase): 805 806 def testMetricsCollection(self): 807 808 def _enqueue_vector(sess, queue, values, shape=None): 809 if not shape: 810 shape = (1, len(values)) 811 dtype = queue.dtypes[0] 812 sess.run( 813 queue.enqueue(constant_op.constant( 814 values, dtype=dtype, shape=shape))) 815 816 meta_graph_filename = os.path.join( 817 _TestDir("metrics_export"), "meta_graph.pb") 818 819 graph = ops.Graph() 820 with self.test_session(graph=graph) as sess: 821 values_queue = data_flow_ops.FIFOQueue( 822 4, dtypes.float32, shapes=(1, 2)) 823 _enqueue_vector(sess, values_queue, [0, 1]) 824 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 825 _enqueue_vector(sess, values_queue, [6.5, 0]) 826 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 827 values = values_queue.dequeue() 828 829 _, update_op = metrics.mean(values) 830 831 initializer = variables.local_variables_initializer() 832 sess.run(initializer) 833 sess.run(update_op) 834 835 meta_graph.export_scoped_meta_graph( 836 filename=meta_graph_filename, graph=graph) 837 838 # Verifies that importing a meta_graph with LOCAL_VARIABLES collection 839 # works correctly. 840 graph = ops.Graph() 841 with self.test_session(graph=graph) as sess: 842 meta_graph.import_scoped_meta_graph(meta_graph_filename) 843 initializer = variables.local_variables_initializer() 844 sess.run(initializer) 845 846 # Verifies that importing an old meta_graph where "local_variables" 847 # collection is of node_list type works, but cannot build initializer 848 # with the collection. 849 graph = ops.Graph() 850 with self.test_session(graph=graph) as sess: 851 meta_graph.import_scoped_meta_graph( 852 test.test_src_dir_path( 853 "python/framework/testdata/metrics_export_meta_graph.pb")) 854 self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)), 855 2) 856 with self.assertRaisesRegexp( 857 AttributeError, "'Tensor' object has no attribute 'initializer'"): 858 initializer = variables.local_variables_initializer() 859 860 861 @test_util.with_c_api 862 class ExportImportAcrossScopesTest(test.TestCase): 863 864 def testPartionedVariables(self): 865 866 def make_graph_with_partitioned_variables(use_resource): 867 variable_scope.get_variable( 868 name="weights", 869 partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0), 870 initializer=random_ops.truncated_normal([100, 10]), 871 use_resource=use_resource) 872 # The next variable illustrates the necessity of restoring collections 873 # in a deterministic fashion when using ResourceVariables. 874 variable_scope.get_variable( 875 name="another", 876 shape=[], 877 collections=["a", "b", "z", "f", "e", "d", "g"], 878 use_resource=use_resource) 879 880 self._testExportImportAcrossScopes( 881 make_graph_with_partitioned_variables, use_resource=False) 882 self._testExportImportAcrossScopes( 883 make_graph_with_partitioned_variables, use_resource=True) 884 885 def _testExportImportAcrossScopes(self, graph_fn, use_resource): 886 """Tests export and importing a graph across scopes. 887 888 Args: 889 graph_fn: A closure that creates a graph on the current scope. 890 use_resource: A bool indicating whether or not to use ResourceVariables. 891 """ 892 with ops.Graph().as_default() as original_graph: 893 with variable_scope.variable_scope("dropA/dropB/keepA"): 894 graph_fn(use_resource=use_resource) 895 exported_meta_graph_def = meta_graph.export_scoped_meta_graph( 896 graph=original_graph, 897 export_scope="dropA/dropB")[0] 898 899 with ops.Graph().as_default() as imported_graph: 900 meta_graph.import_scoped_meta_graph( 901 exported_meta_graph_def, 902 import_scope="importA") 903 904 with ops.Graph().as_default() as expected_graph: 905 with variable_scope.variable_scope("importA/keepA"): 906 graph_fn(use_resource=use_resource) 907 908 if use_resource: 909 # Bringing in collections that contain ResourceVariables will adds ops 910 # to the graph the first time a variable is encountered, so mimic the 911 # same behavior. 912 seen_variables = set() 913 for collection_key in sorted([ 914 ops.GraphKeys.GLOBAL_VARIABLES, 915 ops.GraphKeys.TRAINABLE_VARIABLES, 916 ]): 917 for var in expected_graph.get_collection(collection_key): 918 if var not in seen_variables: 919 var._read_variable_op() 920 seen_variables.add(var) 921 922 result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0] 923 expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0] 924 925 if use_resource: 926 # Clear all shared_name attributes before comparing, since they are 927 # orthogonal to scopes and are not updated on export/import. 928 for meta_graph_def in [result, expected]: 929 for node in meta_graph_def.graph_def.node: 930 shared_name_attr = "shared_name" 931 shared_name_value = node.attr.get(shared_name_attr, None) 932 if shared_name_value and shared_name_value.HasField("s"): 933 if shared_name_value.s: 934 node.attr[shared_name_attr].s = b"" 935 936 test_util.assert_meta_graph_protos_equal(self, expected, result) 937 938 939 if __name__ == "__main__": 940 test.main() 941