Home | History | Annotate | Download | only in framework
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 """Tests for tensorflow.python.framework.meta_graph.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import math
     22 import os.path
     23 import random
     24 import shutil
     25 
     26 from tensorflow.core.framework import graph_pb2
     27 from tensorflow.core.protobuf import meta_graph_pb2
     28 from tensorflow.python.client import session
     29 from tensorflow.python.framework import constant_op
     30 from tensorflow.python.framework import dtypes
     31 from tensorflow.python.framework import function
     32 from tensorflow.python.framework import meta_graph
     33 from tensorflow.python.framework import ops
     34 from tensorflow.python.framework import test_util
     35 from tensorflow.python.ops import array_ops
     36 from tensorflow.python.ops import control_flow_ops
     37 from tensorflow.python.ops import data_flow_ops
     38 from tensorflow.python.ops import gradients_impl
     39 from tensorflow.python.ops import math_ops
     40 from tensorflow.python.ops import metrics
     41 from tensorflow.python.ops import nn_ops
     42 from tensorflow.python.ops import partitioned_variables
     43 from tensorflow.python.ops import random_ops
     44 from tensorflow.python.ops import resource_variable_ops
     45 from tensorflow.python.ops import variable_scope
     46 from tensorflow.python.ops import variables
     47 from tensorflow.python.platform import gfile
     48 from tensorflow.python.platform import test
     49 from tensorflow.python.training import queue_runner_impl
     50 
     51 
     52 # pylint: disable=invalid-name
     53 def _TestDir(test_name):
     54   test_dir = os.path.join(test.get_temp_dir(), test_name)
     55   if os.path.exists(test_dir):
     56     shutil.rmtree(test_dir)
     57   gfile.MakeDirs(test_dir)
     58   return test_dir
     59 
     60 
     61 # pylint: enable=invalid-name
     62 
     63 
     64 @test_util.with_c_api
     65 class SimpleMetaGraphTest(test.TestCase):
     66 
     67   def testNoVariables(self):
     68     test_dir = _TestDir("no_variables")
     69     filename = os.path.join(test_dir, "metafile")
     70 
     71     input_feed_value = -10  # Arbitrary input value for feed_dict.
     72 
     73     orig_graph = ops.Graph()
     74     with self.test_session(graph=orig_graph) as sess:
     75       # Create a minimal graph with zero variables.
     76       input_tensor = array_ops.placeholder(
     77           dtypes.float32, shape=[], name="input")
     78       offset = constant_op.constant(42, dtype=dtypes.float32, name="offset")
     79       output_tensor = math_ops.add(input_tensor, offset, name="add_offset")
     80 
     81       # Add input and output tensors to graph collections.
     82       ops.add_to_collection("input_tensor", input_tensor)
     83       ops.add_to_collection("output_tensor", output_tensor)
     84 
     85       output_value = sess.run(output_tensor, {input_tensor: input_feed_value})
     86       self.assertEqual(output_value, 32)
     87 
     88       # Generates MetaGraphDef.
     89       meta_graph_def, var_list = meta_graph.export_scoped_meta_graph(
     90           filename=filename,
     91           graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
     92           collection_list=["input_tensor", "output_tensor"],
     93           saver_def=None)
     94       self.assertTrue(meta_graph_def.HasField("meta_info_def"))
     95       self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_version, "")
     96       self.assertNotEqual(meta_graph_def.meta_info_def.tensorflow_git_version,
     97                           "")
     98       self.assertEqual({}, var_list)
     99 
    100     # Create a clean graph and import the MetaGraphDef nodes.
    101     new_graph = ops.Graph()
    102     with self.test_session(graph=new_graph) as sess:
    103       # Import the previously export meta graph.
    104       meta_graph.import_scoped_meta_graph(filename)
    105 
    106       # Re-exports the current graph state for comparison to the original.
    107       new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(filename +
    108                                                                   "_new")
    109       test_util.assert_meta_graph_protos_equal(self, meta_graph_def,
    110                                                new_meta_graph_def)
    111 
    112       # Ensures that we can still get a reference to our graph collections.
    113       new_input_tensor = ops.get_collection("input_tensor")[0]
    114       new_output_tensor = ops.get_collection("output_tensor")[0]
    115       # Verifies that the new graph computes the same result as the original.
    116       new_output_value = sess.run(new_output_tensor,
    117                                   {new_input_tensor: input_feed_value})
    118       self.assertEqual(new_output_value, output_value)
    119 
    120   def testStrippedOpListNestedFunctions(self):
    121     with self.test_session():
    122       # Square two levels deep
    123       @function.Defun(dtypes.int32)
    124       def f0(x):
    125         return math_ops.square(x)
    126 
    127       @function.Defun(dtypes.int32)
    128       def f1(x):
    129         return f0(x)
    130 
    131       # At this point we've defined two functions but haven't called them, so
    132       # there should be no used ops.
    133       op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph()
    134                                                       .as_graph_def())
    135       self.assertEqual(len(op_list.op), 0)
    136 
    137       # If we call the function on a constant, there should be two ops
    138       _ = f1(constant_op.constant(7))
    139       op_list = meta_graph.stripped_op_list_for_graph(ops.get_default_graph()
    140                                                       .as_graph_def())
    141       self.assertEqual(["Const", "Square"], [op.name for op in op_list.op])
    142 
    143   def testStrippedOpListRecursiveFunctions(self):
    144     # The function module doesn't support recursive functions, so we build a
    145     # recursive function situation by ourselves: A calls B calls A and Const.
    146     graph = graph_pb2.GraphDef()
    147     a = graph.library.function.add()
    148     b = graph.library.function.add()
    149     a.signature.name = "A"
    150     b.signature.name = "B"
    151     a.node_def.add().op = "B"
    152     b.node_def.add().op = "Const"
    153     b.node_def.add().op = "A"
    154 
    155     # Use A in the graph
    156     graph.node.add().op = "A"
    157 
    158     # The stripped op list should contain just Const.
    159     op_list = meta_graph.stripped_op_list_for_graph(graph)
    160     self.assertEqual(["Const"], [op.name for op in op_list.op])
    161 
    162   def testDefaultAttrStripping(self):
    163     """Verifies that default attributes are stripped from a graph def."""
    164 
    165     # Complex Op has 2 attributes with defaults:
    166     #   o "T"    : float32.
    167     #   o "Tout" : complex64.
    168 
    169     # When inputs to the Complex Op are float32 instances, "T" maps to float32
    170     # and "Tout" maps to complex64. Since these attr values map to their
    171     # defaults, they must be stripped unless stripping of default attrs is
    172     # disabled.
    173     with self.test_session():
    174       real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
    175       imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
    176       math_ops.complex(real_num, imag_num, name="complex")
    177 
    178       # strip_default_attrs is enabled.
    179       meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
    180           graph_def=ops.get_default_graph().as_graph_def(),
    181           strip_default_attrs=True)
    182       node_def = test_util.get_node_def_from_graph("complex",
    183                                                    meta_graph_def.graph_def)
    184       self.assertNotIn("T", node_def.attr)
    185       self.assertNotIn("Tout", node_def.attr)
    186       self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
    187 
    188       # strip_default_attrs is disabled.
    189       meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
    190           graph_def=ops.get_default_graph().as_graph_def(),
    191           strip_default_attrs=False)
    192       node_def = test_util.get_node_def_from_graph("complex",
    193                                                    meta_graph_def.graph_def)
    194       self.assertIn("T", node_def.attr)
    195       self.assertIn("Tout", node_def.attr)
    196       self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)
    197 
    198     # When inputs to the Complex Op are float64 instances, "T" maps to float64
    199     # and "Tout" maps to complex128. Since these attr values don't map to their
    200     # defaults, they must not be stripped.
    201     with self.test_session(graph=ops.Graph()):
    202       real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
    203       imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
    204       math_ops.complex(real_num, imag_num, name="complex")
    205       meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
    206           graph_def=ops.get_default_graph().as_graph_def(),
    207           strip_default_attrs=True)
    208       node_def = test_util.get_node_def_from_graph("complex",
    209                                                    meta_graph_def.graph_def)
    210       self.assertEqual(node_def.attr["T"].type, dtypes.float64)
    211       self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
    212       self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
    213 
    214   def testDefaultAttrStrippingNestedFunctions(self):
    215     """Verifies that default attributes are stripped from function node defs."""
    216     with self.test_session():
    217       @function.Defun(dtypes.float32, dtypes.float32)
    218       def f0(i, j):
    219         return math_ops.complex(i, j, name="double_nested_complex")
    220 
    221       @function.Defun(dtypes.float32, dtypes.float32)
    222       def f1(i, j):
    223         return f0(i, j)
    224 
    225       _ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
    226       meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
    227           graph_def=ops.get_default_graph().as_graph_def(),
    228           strip_default_attrs=True)
    229 
    230       double_nested_complex_node_def = None
    231       for function_def in meta_graph_def.graph_def.library.function:
    232         for node_def in function_def.node_def:
    233           if node_def.name.startswith("double_nested_complex"):
    234             double_nested_complex_node_def = node_def
    235             break
    236         if double_nested_complex_node_def:
    237           break
    238 
    239       self.assertIsNotNone(double_nested_complex_node_def)
    240       self.assertNotIn("T", double_nested_complex_node_def.attr)
    241       self.assertNotIn("Tout", double_nested_complex_node_def.attr)
    242       self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
    243 
    244   def testDefaultAttrStrippingUnregisteredOps(self):
    245     """Verifies that nodes with un-registered ops are not stripped."""
    246     graph_def = graph_pb2.GraphDef()
    247     node = graph_def.node.add()
    248     node.name = "node_with_unreg_op"
    249     node.op = "unreg_op"
    250     node.attr["attr_1"].i = 1
    251 
    252     meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
    253     meta_info_def.stripped_op_list.op.add()
    254 
    255     with self.test_session():
    256       meta_graph_def = meta_graph.create_meta_graph_def(
    257           meta_info_def=meta_info_def, graph_def=graph_def,
    258           strip_default_attrs=True)
    259       node_def = test_util.get_node_def_from_graph("node_with_unreg_op",
    260                                                    meta_graph_def.graph_def)
    261       self.assertEqual(node_def.attr["attr_1"].i, 1)
    262       self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
    263 
    264   def testVariableObjectsAreSharedAmongCollections(self):
    265     with ops.Graph().as_default() as graph1:
    266       v = variables.Variable(3.0)
    267       # A single instance of Variable is shared among the collections:
    268       global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    269       trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
    270       self.assertEqual(len(global_vars), 1)
    271       self.assertEqual(len(trainable_vars), 1)
    272       self.assertIs(global_vars[0], trainable_vars[0])
    273       self.assertIs(v, global_vars[0])
    274 
    275     orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1)
    276     del graph1  # To avoid accidental references in code involving graph2.
    277 
    278     with ops.Graph().as_default() as graph2:
    279       meta_graph.import_scoped_meta_graph(orig_meta_graph)
    280       global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
    281       trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
    282       self.assertEqual(len(global_vars), 1)
    283       self.assertEqual(len(trainable_vars), 1)
    284       # A single instance of Variable is shared among the collections:
    285       self.assertIs(global_vars[0], trainable_vars[0])
    286 
    287 
    288 @test_util.with_c_api
    289 class ScopedMetaGraphTest(test.TestCase):
    290 
    291   def _testScopedExport(self, test_dir, exported_filenames):
    292     graph = ops.Graph()
    293     with graph.as_default():
    294       # Creates an inference graph.
    295       # Hidden 1
    296       colocate_constraint = constant_op.constant(1.2, name="constraint")
    297       images = constant_op.constant(
    298           1.2, dtypes.float32, shape=[100, 28], name="images")
    299       with ops.name_scope("hidden1"):
    300         with graph.colocate_with(colocate_constraint.op):
    301           weights1 = variables.Variable(
    302               random_ops.truncated_normal(
    303                   [28, 128], stddev=1.0 / math.sqrt(float(28))),
    304               name="weights")
    305         # The use of control_flow_ops.cond here is purely for adding test
    306         # coverage the save and restore of control flow context (which doesn't
    307         # make any sense here from a machine learning perspective).  The typical
    308         # biases is a simple Variable without the conditions.
    309         biases1 = variables.Variable(
    310             control_flow_ops.cond(
    311                 math_ops.less(random.random(), 0.5),
    312                 lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
    313             name="biases")
    314         hidden1 = nn_ops.relu(math_ops.matmul(images, weights1) + biases1)
    315 
    316       # Hidden 2
    317       with ops.name_scope("hidden2"):
    318         weights2 = variables.Variable(
    319             random_ops.truncated_normal(
    320                 [128, 32], stddev=1.0 / math.sqrt(float(128))),
    321             name="weights")
    322 
    323         # The use of control_flow_ops.while_loop here is purely for adding test
    324         # coverage the save and restore of control flow context (which doesn't
    325         # make any sense here from a machine learning perspective).  The typical
    326         # biases is a simple Variable without the conditions.
    327         def loop_cond(it, _):
    328           return it < 2
    329 
    330         def loop_body(it, biases2):
    331           biases2 += constant_op.constant(0.1, shape=[32])
    332           return it + 1, biases2
    333 
    334         _, biases2 = control_flow_ops.while_loop(
    335             loop_cond,
    336             loop_body, [
    337                 constant_op.constant(0), variables.Variable(
    338                     array_ops.zeros([32]), name="biases")
    339             ])
    340         hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
    341       # Linear
    342       with ops.name_scope("softmax_linear"):
    343         weights3 = variables.Variable(
    344             random_ops.truncated_normal(
    345                 [32, 10], stddev=1.0 / math.sqrt(float(32))),
    346             name="weights")
    347         biases3 = variables.Variable(array_ops.zeros([10]), name="biases")
    348         logits = math_ops.matmul(hidden2, weights3) + biases3
    349         ops.add_to_collection("logits", logits)
    350 
    351       # Exports each sub-graph.
    352       # Exports the first one with unbound_inputs_col_name set to default.
    353       orig_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
    354           filename=os.path.join(test_dir, exported_filenames[0]),
    355           graph=ops.get_default_graph(),
    356           export_scope="hidden1")
    357       self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    358       var_names = [v.name for _, v in var_list.items()]
    359       self.assertEqual(["hidden1/biases:0", "hidden1/weights:0"],
    360                        sorted(var_names))
    361 
    362       # Exports the rest with no unbound_inputs_col_name.
    363       orig_meta_graph2, _ = meta_graph.export_scoped_meta_graph(
    364           filename=os.path.join(test_dir, exported_filenames[1]),
    365           graph=ops.get_default_graph(),
    366           export_scope="hidden2",
    367           unbound_inputs_col_name=None)
    368       orig_meta_graph3, _ = meta_graph.export_scoped_meta_graph(
    369           filename=os.path.join(test_dir, exported_filenames[2]),
    370           graph=ops.get_default_graph(),
    371           export_scope="softmax_linear",
    372           unbound_inputs_col_name=None)
    373 
    374     return [orig_meta_graph1, orig_meta_graph2, orig_meta_graph3]
    375 
    376   def _testScopedImport(self, test_dir, exported_filenames):
    377     graph = ops.Graph()
    378     # Create all the missing inputs.
    379     with graph.as_default():
    380       new_image = constant_op.constant(
    381           1.2, dtypes.float32, shape=[100, 28], name="images")
    382 
    383     with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
    384       meta_graph.import_scoped_meta_graph(
    385           os.path.join(test_dir, exported_filenames[0]),
    386           graph=graph,
    387           import_scope="new_hidden1")
    388 
    389     with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
    390       meta_graph.import_scoped_meta_graph(
    391           os.path.join(test_dir, exported_filenames[0]),
    392           graph=graph,
    393           input_map={"image:0": new_image},
    394           import_scope="new_hidden1")
    395 
    396     # Verifies we can import the original "hidden1" into "new_hidden1".
    397     var_list = meta_graph.import_scoped_meta_graph(
    398         os.path.join(test_dir, exported_filenames[0]),
    399         graph=graph,
    400         input_map={"$unbound_inputs_images": new_image},
    401         import_scope="new_hidden1")
    402 
    403     self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    404     new_var_names = [v.name for _, v in var_list.items()]
    405     self.assertEqual(["new_hidden1/biases:0", "new_hidden1/weights:0"],
    406                      sorted(new_var_names))
    407 
    408     # Verifies we can import the original "hidden2" into "new_hidden2".
    409     hidden1 = array_ops.identity(
    410         graph.as_graph_element("new_hidden1/Relu:0"), name="hidden1/Relu")
    411     var_list = meta_graph.import_scoped_meta_graph(
    412         os.path.join(test_dir, exported_filenames[1]),
    413         graph=graph,
    414         input_map={"$unbound_inputs_hidden1/Relu": hidden1},
    415         import_scope="new_hidden2",
    416         unbound_inputs_col_name=None)
    417 
    418     self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    419     new_var_names = [v.name for _, v in var_list.items()]
    420     self.assertEqual(["new_hidden2/biases:0", "new_hidden2/weights:0"],
    421                      sorted(new_var_names))
    422 
    423     # Verifies we can import the original "softmax_linear" into
    424     # "new_softmax_linear".
    425     hidden2 = array_ops.identity(
    426         graph.as_graph_element("new_hidden2/Relu:0"), name="hidden2/Relu")
    427     var_list = meta_graph.import_scoped_meta_graph(
    428         os.path.join(test_dir, exported_filenames[2]),
    429         graph=graph,
    430         input_map={"$unbound_inputs_hidden2/Relu": hidden2},
    431         import_scope="new_softmax_linear",
    432         unbound_inputs_col_name=None)
    433     self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    434     new_var_names = [v.name for _, v in var_list.items()]
    435     self.assertEqual(
    436         ["new_softmax_linear/biases:0", "new_softmax_linear/weights:0"],
    437         sorted(new_var_names))
    438 
    439     # Exports the scoped meta graphs again.
    440     new_meta_graph1, var_list = meta_graph.export_scoped_meta_graph(
    441         graph=graph, export_scope="new_hidden1")
    442     self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    443 
    444     new_meta_graph2, var_list = meta_graph.export_scoped_meta_graph(
    445         graph=graph, export_scope="new_hidden2", unbound_inputs_col_name=None)
    446     self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    447 
    448     new_meta_graph3, var_list = meta_graph.export_scoped_meta_graph(
    449         graph=graph,
    450         export_scope="new_softmax_linear",
    451         unbound_inputs_col_name=None)
    452     self.assertEqual(["biases:0", "weights:0"], sorted(var_list.keys()))
    453 
    454     return [new_meta_graph1, new_meta_graph2, new_meta_graph3]
    455 
    456   # Verifies that we can export the subgraph under each layer and import
    457   # them into new layers in a new graph.
    458   def testScopedExportAndImport(self):
    459     test_dir = _TestDir("scoped_export_import")
    460     filenames = [
    461         "exported_hidden1.pbtxt", "exported_hidden2.pbtxt",
    462         "exported_softmax_linear.pbtxt"
    463     ]
    464     orig_meta_graphs = self._testScopedExport(test_dir, filenames)
    465     new_meta_graphs = self._testScopedImport(test_dir, filenames)
    466     for a, b in zip(orig_meta_graphs, new_meta_graphs):
    467       # The unbound input strings are slightly different with the C API enabled
    468       # ("images" vs "images:0") due to the original import_graph_def code
    469       # vs. ImportGraphDef in C++.
    470       # TODO(skyewm): update the pbtxts once _USE_C_API is removed.
    471       del a.collection_def["unbound_inputs"]
    472       del b.collection_def["unbound_inputs"]
    473       test_util.assert_meta_graph_protos_equal(self, a, b)
    474 
    475   def testWhileLoopGradients(self):
    476     # Create a simple while loop.
    477     with ops.Graph().as_default():
    478       with ops.name_scope("export"):
    479         var = variables.Variable(0)
    480         var_name = var.name
    481         _, output = control_flow_ops.while_loop(lambda i, x: i < 5,
    482                                                 lambda i, x: (i + 1, x + i),
    483                                                 [0, var])
    484         output_name = output.name
    485 
    486       # Generate a MetaGraphDef containing the while loop with an export scope.
    487       meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
    488           export_scope="export")
    489 
    490       # Build and run the gradients of the while loop. We use this below to
    491       # verify that the gradients are correct with the imported MetaGraphDef.
    492       init_op = variables.global_variables_initializer()
    493       grad = gradients_impl.gradients([output], [var])
    494       with session.Session() as sess:
    495         sess.run(init_op)
    496         expected_grad_value = sess.run(grad)
    497 
    498     # Restore the MetaGraphDef into a new Graph with an import scope.
    499     with ops.Graph().as_default():
    500       meta_graph.import_scoped_meta_graph(meta_graph_def, import_scope="import")
    501 
    502       # Re-export and make sure we get the same MetaGraphDef.
    503       new_meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
    504           export_scope="import")
    505       test_util.assert_meta_graph_protos_equal(
    506           self, meta_graph_def, new_meta_graph_def)
    507 
    508       # Make sure we can still build gradients and get the same result.
    509 
    510       def new_name(tensor_name):
    511         base_tensor_name = tensor_name.replace("export/", "")
    512         return "import/" + base_tensor_name
    513 
    514       var = ops.get_default_graph().get_tensor_by_name(new_name(var_name))
    515       output = ops.get_default_graph().get_tensor_by_name(new_name(output_name))
    516       grad = gradients_impl.gradients([output], [var])
    517 
    518       init_op = variables.global_variables_initializer()
    519 
    520       with session.Session() as sess:
    521         sess.run(init_op)
    522         actual_grad_value = sess.run(grad)
    523         self.assertEqual(expected_grad_value, actual_grad_value)
    524 
    525   def testScopedImportUnderNameScope(self):
    526     graph = ops.Graph()
    527     with graph.as_default():
    528       variables.Variable(initial_value=1.0, trainable=True, name="myvar")
    529     meta_graph_def, _ = meta_graph.export_scoped_meta_graph(graph=graph)
    530 
    531     graph = ops.Graph()
    532     with graph.as_default():
    533       with ops.name_scope("foo"):
    534         imported_variables = meta_graph.import_scoped_meta_graph(
    535             meta_graph_def, import_scope="bar")
    536         self.assertEqual(len(imported_variables), 1)
    537         self.assertEqual(list(imported_variables.values())[0].name,
    538                          "foo/bar/myvar:0")
    539 
    540   def testImportsUsingSameScopeName(self):
    541     with ops.Graph().as_default():
    542       variables.Variable(0, name="v")
    543       meta_graph_def, _ = meta_graph.export_scoped_meta_graph()
    544     with ops.Graph().as_default():
    545       for suffix in ["", "_1"]:
    546         imported_variables = meta_graph.import_scoped_meta_graph(
    547             meta_graph_def, import_scope="s")
    548         self.assertEqual(len(imported_variables), 1)
    549         self.assertEqual(list(imported_variables.keys())[0], "v:0")
    550         self.assertEqual(list(imported_variables.values())[0].name,
    551                          "s" + suffix + "/v:0")
    552 
    553   def testScopedImportWithSelectedCollections(self):
    554     meta_graph_filename = os.path.join(
    555         _TestDir("selected_collections_import"), "meta_graph.pb")
    556 
    557     graph = ops.Graph()
    558     # Add a variable to populate two collections. The functionality tested is
    559     # not specific to variables, but using variables in the test is convenient.
    560     with graph.as_default():
    561       variables.Variable(initial_value=1.0, trainable=True)
    562     self.assertTrue(
    563         all([
    564             graph.get_collection(key)
    565             for key in
    566             [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES]
    567         ]))
    568     meta_graph.export_scoped_meta_graph(
    569         filename=meta_graph_filename, graph=graph)
    570 
    571     def _test_import(include_collection_keys, omit_collection_keys):
    572       assert set(include_collection_keys).isdisjoint(omit_collection_keys)
    573       newgraph = ops.Graph()
    574       import_scope = "some_scope_name"
    575 
    576       def _restore_collections_predicate(collection_key):
    577         return (collection_key in include_collection_keys and
    578                 collection_key not in omit_collection_keys)
    579 
    580       meta_graph.import_scoped_meta_graph(
    581           meta_graph_filename,
    582           graph=newgraph,
    583           import_scope=import_scope,
    584           restore_collections_predicate=_restore_collections_predicate)
    585       collection_values = [
    586           newgraph.get_collection(name=key, scope=import_scope)
    587           for key in include_collection_keys
    588       ]
    589       self.assertTrue(all(collection_values))
    590       collection_values = [
    591           newgraph.get_collection(name=key, scope=import_scope)
    592           for key in omit_collection_keys
    593       ]
    594       self.assertFalse(any(collection_values))
    595 
    596     _test_import(
    597         include_collection_keys=[
    598             ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
    599         ],
    600         omit_collection_keys=[])
    601     _test_import(
    602         include_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES],
    603         omit_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES])
    604     _test_import(
    605         include_collection_keys=[ops.GraphKeys.TRAINABLE_VARIABLES],
    606         omit_collection_keys=[ops.GraphKeys.GLOBAL_VARIABLES])
    607     _test_import(
    608         include_collection_keys=[],
    609         omit_collection_keys=[
    610             ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES
    611         ])
    612 
    613   def _testScopedExportWithQueue(self, test_dir, exported_filename):
    614     graph = ops.Graph()
    615     with graph.as_default():
    616       with ops.name_scope("queue1"):
    617         input_queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
    618         enqueue = input_queue.enqueue((9876), name="enqueue")
    619         close = input_queue.close(name="close")
    620         qr = queue_runner_impl.QueueRunner(input_queue, [enqueue], close)
    621         queue_runner_impl.add_queue_runner(qr)
    622         input_queue.dequeue(name="dequeue")
    623 
    624       orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
    625           filename=os.path.join(test_dir, exported_filename),
    626           graph=ops.get_default_graph(),
    627           export_scope="queue1")
    628 
    629     return orig_meta_graph
    630 
    631   def _testScopedImportWithQueue(self, test_dir, exported_filename,
    632                                  new_exported_filename):
    633     graph = ops.Graph()
    634     meta_graph.import_scoped_meta_graph(
    635         os.path.join(test_dir, exported_filename),
    636         graph=graph,
    637         import_scope="new_queue1")
    638     graph.as_graph_element("new_queue1/dequeue:0")
    639     graph.as_graph_element("new_queue1/close")
    640     with graph.as_default():
    641       new_meta_graph, _ = meta_graph.export_scoped_meta_graph(
    642           filename=os.path.join(test_dir, new_exported_filename),
    643           graph=graph,
    644           export_scope="new_queue1")
    645 
    646     return new_meta_graph
    647 
    648   # Verifies that we can export the subgraph containing a FIFOQueue under
    649   # "queue1" and import it into "new_queue1" in a new graph.
    650   def testScopedWithQueue(self):
    651     test_dir = _TestDir("scoped_with_queue")
    652     orig_meta_graph = self._testScopedExportWithQueue(test_dir,
    653                                                       "exported_queue1.pbtxt")
    654     new_meta_graph = self._testScopedImportWithQueue(
    655         test_dir, "exported_queue1.pbtxt", "exported_new_queue1.pbtxt")
    656     test_util.assert_meta_graph_protos_equal(self, orig_meta_graph,
    657                                              new_meta_graph)
    658 
    659   # Verifies that we can export a subgraph in a nested name scope containing a
    660   # "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
    661   # graph.
    662   def doTestExportNestedNames(self, use_resource=False):
    663     graph1 = ops.Graph()
    664     with graph1.as_default():
    665       with ops.name_scope("hidden1/hidden2/hidden3"):
    666         images = constant_op.constant(
    667             1.0, dtypes.float32, shape=[3, 2], name="images")
    668         if use_resource:
    669           weights1 = variables.Variable(
    670               [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
    671           biases1 = resource_variable_ops.ResourceVariable(
    672               [0.1] * 3, name="biases")
    673         else:
    674           biases1 = variables.Variable([0.1] * 3, name="biases")
    675           weights1 = variables.Variable(
    676               [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
    677         nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
    678 
    679     orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
    680         export_scope="hidden1/hidden2", graph=graph1)
    681     var_names = [v.name for _, v in var_list.items()]
    682     self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
    683                      sorted(var_list.keys()))
    684     self.assertEqual([
    685         "hidden1/hidden2/hidden3/biases:0", "hidden1/hidden2/hidden3/weights:0"
    686     ], sorted(var_names))
    687     for node in orig_meta_graph.graph_def.node:
    688       self.assertTrue(node.name.startswith("hidden3"))
    689 
    690     graph2 = ops.Graph()
    691     new_var_list = meta_graph.import_scoped_meta_graph(
    692         orig_meta_graph, import_scope="new_hidden1/new_hidden2", graph=graph2)
    693     self.assertEqual(["hidden3/biases:0", "hidden3/weights:0"],
    694                      sorted(new_var_list.keys()))
    695     new_var_names = [v.name for _, v in new_var_list.items()]
    696     self.assertEqual([
    697         "new_hidden1/new_hidden2/hidden3/biases:0",
    698         "new_hidden1/new_hidden2/hidden3/weights:0"
    699     ], sorted(new_var_names))
    700 
    701     nodes = [
    702         "new_hidden1/new_hidden2/hidden3/biases/Assign",
    703         "new_hidden1/new_hidden2/hidden3/weights/Assign"
    704     ]
    705     expected = [
    706         b"loc:@new_hidden1/new_hidden2/hidden3/biases",
    707         b"loc:@new_hidden1/new_hidden2/hidden3/weights"
    708     ]
    709     for n, e in zip(nodes, expected):
    710       self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
    711 
    712   def testExportNestedNames(self):
    713     self.doTestExportNestedNames(use_resource=False)
    714 
    715   def testExportNestedNamesResource(self):
    716     self.doTestExportNestedNames(use_resource=True)
    717 
    718   def testPotentialCycle(self):
    719     graph1 = ops.Graph()
    720     with graph1.as_default():
    721       a = constant_op.constant(1.0, shape=[2, 2])
    722       b = constant_op.constant(2.0, shape=[2, 2])
    723       matmul = math_ops.matmul(a, b)
    724       with ops.name_scope("hidden1"):
    725         c = nn_ops.relu(matmul)
    726         d = constant_op.constant(3.0, shape=[2, 2])
    727         matmul = math_ops.matmul(c, d)
    728 
    729     orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
    730         export_scope="hidden1", graph=graph1)
    731 
    732     graph2 = ops.Graph()
    733     with graph2.as_default():
    734       with self.assertRaisesRegexp(ValueError, "Graph contains unbound inputs"):
    735         meta_graph.import_scoped_meta_graph(
    736             orig_meta_graph, import_scope="new_hidden1")
    737 
    738       meta_graph.import_scoped_meta_graph(
    739           orig_meta_graph,
    740           import_scope="new_hidden1",
    741           input_map={
    742               "$unbound_inputs_MatMul": constant_op.constant(
    743                   4.0, shape=[2, 2])
    744           })
    745 
    746   def testClearDevices(self):
    747     graph1 = ops.Graph()
    748     with graph1.as_default():
    749       with ops.device("/device:CPU:0"):
    750         a = variables.Variable(
    751             constant_op.constant(
    752                 1.0, shape=[2, 2]), name="a")
    753       with ops.device("/job:ps/replica:0/task:0/device:GPU:0"):
    754         b = variables.Variable(
    755             constant_op.constant(
    756                 2.0, shape=[2, 2]), name="b")
    757       with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
    758         math_ops.matmul(a, b, name="matmul")
    759 
    760     self.assertEqual("/device:CPU:0", str(graph1.as_graph_element("a").device))
    761     self.assertEqual("/job:ps/replica:0/task:0/device:GPU:0",
    762                      str(graph1.as_graph_element("b").device))
    763     self.assertEqual("/job:localhost/replica:0/task:0/device:CPU:0",
    764                      str(graph1.as_graph_element("matmul").device))
    765 
    766     # Verifies that devices are cleared on export.
    767     orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
    768         graph=graph1, clear_devices=True)
    769 
    770     graph2 = ops.Graph()
    771     with graph2.as_default():
    772       meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)
    773 
    774     self.assertEqual("", str(graph2.as_graph_element("a").device))
    775     self.assertEqual("", str(graph2.as_graph_element("b").device))
    776     self.assertEqual("", str(graph2.as_graph_element("matmul").device))
    777 
    778     # Verifies that devices are cleared on export when passing in graph_def.
    779     orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
    780         graph_def=graph1.as_graph_def(), clear_devices=True)
    781 
    782     graph2 = ops.Graph()
    783     with graph2.as_default():
    784       meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=False)
    785 
    786     self.assertEqual("", str(graph2.as_graph_element("a").device))
    787     self.assertEqual("", str(graph2.as_graph_element("b").device))
    788     self.assertEqual("", str(graph2.as_graph_element("matmul").device))
    789 
    790     # Verifies that devices are cleared on import.
    791     orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(
    792         graph=graph1, clear_devices=False)
    793 
    794     graph2 = ops.Graph()
    795     with graph2.as_default():
    796       meta_graph.import_scoped_meta_graph(orig_meta_graph, clear_devices=True)
    797 
    798     self.assertEqual("", str(graph2.as_graph_element("a").device))
    799     self.assertEqual("", str(graph2.as_graph_element("b").device))
    800     self.assertEqual("", str(graph2.as_graph_element("matmul").device))
    801 
    802 
    803 @test_util.with_c_api
    804 class MetaGraphWithVariableScopeTest(test.TestCase):
    805 
    806   def testMetricsCollection(self):
    807 
    808     def _enqueue_vector(sess, queue, values, shape=None):
    809       if not shape:
    810         shape = (1, len(values))
    811       dtype = queue.dtypes[0]
    812       sess.run(
    813           queue.enqueue(constant_op.constant(
    814               values, dtype=dtype, shape=shape)))
    815 
    816     meta_graph_filename = os.path.join(
    817         _TestDir("metrics_export"), "meta_graph.pb")
    818 
    819     graph = ops.Graph()
    820     with self.test_session(graph=graph) as sess:
    821       values_queue = data_flow_ops.FIFOQueue(
    822           4, dtypes.float32, shapes=(1, 2))
    823       _enqueue_vector(sess, values_queue, [0, 1])
    824       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    825       _enqueue_vector(sess, values_queue, [6.5, 0])
    826       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    827       values = values_queue.dequeue()
    828 
    829       _, update_op = metrics.mean(values)
    830 
    831       initializer = variables.local_variables_initializer()
    832       sess.run(initializer)
    833       sess.run(update_op)
    834 
    835     meta_graph.export_scoped_meta_graph(
    836         filename=meta_graph_filename, graph=graph)
    837 
    838     # Verifies that importing a meta_graph with LOCAL_VARIABLES collection
    839     # works correctly.
    840     graph = ops.Graph()
    841     with self.test_session(graph=graph) as sess:
    842       meta_graph.import_scoped_meta_graph(meta_graph_filename)
    843       initializer = variables.local_variables_initializer()
    844       sess.run(initializer)
    845 
    846     # Verifies that importing an old meta_graph where "local_variables"
    847     # collection is of node_list type works, but cannot build initializer
    848     # with the collection.
    849     graph = ops.Graph()
    850     with self.test_session(graph=graph) as sess:
    851       meta_graph.import_scoped_meta_graph(
    852           test.test_src_dir_path(
    853               "python/framework/testdata/metrics_export_meta_graph.pb"))
    854       self.assertEqual(len(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES)),
    855                        2)
    856       with self.assertRaisesRegexp(
    857           AttributeError, "'Tensor' object has no attribute 'initializer'"):
    858         initializer = variables.local_variables_initializer()
    859 
    860 
    861 @test_util.with_c_api
    862 class ExportImportAcrossScopesTest(test.TestCase):
    863 
    864   def testPartionedVariables(self):
    865 
    866     def make_graph_with_partitioned_variables(use_resource):
    867       variable_scope.get_variable(
    868           name="weights",
    869           partitioner=partitioned_variables.fixed_size_partitioner(3, axis=0),
    870           initializer=random_ops.truncated_normal([100, 10]),
    871           use_resource=use_resource)
    872       # The next variable illustrates the necessity of restoring collections
    873       # in a deterministic fashion when using ResourceVariables.
    874       variable_scope.get_variable(
    875           name="another",
    876           shape=[],
    877           collections=["a", "b", "z", "f", "e", "d", "g"],
    878           use_resource=use_resource)
    879 
    880     self._testExportImportAcrossScopes(
    881         make_graph_with_partitioned_variables, use_resource=False)
    882     self._testExportImportAcrossScopes(
    883         make_graph_with_partitioned_variables, use_resource=True)
    884 
    885   def _testExportImportAcrossScopes(self, graph_fn, use_resource):
    886     """Tests export and importing a graph across scopes.
    887 
    888     Args:
    889       graph_fn: A closure that creates a graph on the current scope.
    890       use_resource: A bool indicating whether or not to use ResourceVariables.
    891     """
    892     with ops.Graph().as_default() as original_graph:
    893       with variable_scope.variable_scope("dropA/dropB/keepA"):
    894         graph_fn(use_resource=use_resource)
    895     exported_meta_graph_def = meta_graph.export_scoped_meta_graph(
    896         graph=original_graph,
    897         export_scope="dropA/dropB")[0]
    898 
    899     with ops.Graph().as_default() as imported_graph:
    900       meta_graph.import_scoped_meta_graph(
    901           exported_meta_graph_def,
    902           import_scope="importA")
    903 
    904     with ops.Graph().as_default() as expected_graph:
    905       with variable_scope.variable_scope("importA/keepA"):
    906         graph_fn(use_resource=use_resource)
    907 
    908       if use_resource:
    909         # Bringing in collections that contain ResourceVariables will adds ops
    910         # to the graph the first time a variable is encountered, so mimic the
    911         # same behavior.
    912         seen_variables = set()
    913         for collection_key in sorted([
    914             ops.GraphKeys.GLOBAL_VARIABLES,
    915             ops.GraphKeys.TRAINABLE_VARIABLES,
    916         ]):
    917           for var in expected_graph.get_collection(collection_key):
    918             if var not in seen_variables:
    919               var._read_variable_op()
    920               seen_variables.add(var)
    921 
    922     result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0]
    923     expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0]
    924 
    925     if use_resource:
    926       # Clear all shared_name attributes before comparing, since they are
    927       # orthogonal to scopes and are not updated on export/import.
    928       for meta_graph_def in [result, expected]:
    929         for node in meta_graph_def.graph_def.node:
    930           shared_name_attr = "shared_name"
    931           shared_name_value = node.attr.get(shared_name_attr, None)
    932           if shared_name_value and shared_name_value.HasField("s"):
    933             if shared_name_value.s:
    934               node.attr[shared_name_attr].s = b""
    935 
    936     test_util.assert_meta_graph_protos_equal(self, expected, result)
    937 
    938 
    939 if __name__ == "__main__":
    940   test.main()
    941