Home | History | Annotate | Download | only in tools
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests the graph freezing tool."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 
     23 from tensorflow.core.example import example_pb2
     24 from tensorflow.core.framework import graph_pb2
     25 from tensorflow.core.protobuf import saver_pb2
     26 from tensorflow.python.client import session
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import graph_io
     29 from tensorflow.python.framework import importer
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import test_util
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import math_ops
     34 from tensorflow.python.ops import parsing_ops
     35 from tensorflow.python.ops import variables
     36 from tensorflow.python.platform import test
     37 from tensorflow.python.saved_model import builder as saved_model_builder
     38 from tensorflow.python.saved_model import signature_constants
     39 from tensorflow.python.saved_model import signature_def_utils
     40 from tensorflow.python.saved_model import tag_constants
     41 from tensorflow.python.tools import freeze_graph
     42 from tensorflow.python.training import saver as saver_lib
     43 
     44 
     45 class FreezeGraphTest(test_util.TensorFlowTestCase):
     46 
     47   def _testFreezeGraph(self, saver_write_version):
     48 
     49     checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint")
     50     checkpoint_meta_graph_file = os.path.join(self.get_temp_dir(),
     51                                               "saved_checkpoint.meta")
     52     checkpoint_state_name = "checkpoint_state"
     53     input_graph_name = "input_graph.pb"
     54     output_graph_name = "output_graph.pb"
     55 
     56     # We'll create an input graph that has a single variable containing 1.0,
     57     # and that then multiplies it by 2.
     58     with ops.Graph().as_default():
     59       variable_node = variables.Variable(1.0, name="variable_node")
     60       output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
     61       sess = session.Session()
     62       init = variables.global_variables_initializer()
     63       sess.run(init)
     64       output = sess.run(output_node)
     65       self.assertNear(2.0, output, 0.00001)
     66       saver = saver_lib.Saver(write_version=saver_write_version)
     67       checkpoint_path = saver.save(
     68           sess,
     69           checkpoint_prefix,
     70           global_step=0,
     71           latest_filename=checkpoint_state_name)
     72       graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name)
     73 
     74     # We save out the graph to disk, and then call the const conversion
     75     # routine.
     76     input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name)
     77     input_saver_def_path = ""
     78     input_binary = False
     79     output_node_names = "output_node"
     80     restore_op_name = "save/restore_all"
     81     filename_tensor_name = "save/Const:0"
     82     output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name)
     83     clear_devices = False
     84     input_meta_graph = checkpoint_meta_graph_file
     85 
     86     freeze_graph.freeze_graph(
     87         input_graph_path,
     88         input_saver_def_path,
     89         input_binary,
     90         checkpoint_path,
     91         output_node_names,
     92         restore_op_name,
     93         filename_tensor_name,
     94         output_graph_path,
     95         clear_devices,
     96         "",
     97         "",
     98         input_meta_graph,
     99         checkpoint_version=saver_write_version)
    100 
    101     # Now we make sure the variable is now a constant, and that the graph still
    102     # produces the expected result.
    103     with ops.Graph().as_default():
    104       output_graph_def = graph_pb2.GraphDef()
    105       with open(output_graph_path, "rb") as f:
    106         output_graph_def.ParseFromString(f.read())
    107         _ = importer.import_graph_def(output_graph_def, name="")
    108 
    109       self.assertEqual(4, len(output_graph_def.node))
    110       for node in output_graph_def.node:
    111         self.assertNotEqual("VariableV2", node.op)
    112         self.assertNotEqual("Variable", node.op)
    113 
    114       with session.Session() as sess:
    115         output_node = sess.graph.get_tensor_by_name("output_node:0")
    116         output = sess.run(output_node)
    117         self.assertNear(2.0, output, 0.00001)
    118 
    119   def _createTFExampleString(self, feature_name, feature_value):
    120     """Create a serialized tensorflow example."""
    121     example = example_pb2.Example()
    122     example.features.feature[feature_name].float_list.value.extend([
    123         feature_value])
    124     return example.SerializeToString()
    125 
    126   def _writeDummySavedModel(self, path, feature_name):
    127     """Writes a classifier with two input features to the given path."""
    128     with ops.Graph().as_default():
    129       examples = array_ops.placeholder(dtypes.string, name="input_node")
    130       feature_configs = {
    131           feature_name: parsing_ops.FixedLenFeature(shape=[],
    132                                                     dtype=dtypes.float32),
    133       }
    134       features = parsing_ops.parse_example(examples, feature_configs)
    135       feature = features[feature_name]
    136 
    137       variable_node = variables.Variable(1.0, name="variable_node")
    138       scores = math_ops.multiply(variable_node, feature, name="output_node")
    139       class_feature = array_ops.fill(array_ops.shape(feature),
    140                                      "class_%s" % feature_name)
    141       classes = array_ops.transpose(class_feature)
    142 
    143       with session.Session() as sess:
    144         sess.run(variables.global_variables_initializer())
    145         signature = (
    146             signature_def_utils.classification_signature_def(
    147                 examples=examples,
    148                 classes=classes,
    149                 scores=scores,))
    150         builder = saved_model_builder.SavedModelBuilder(path)
    151         builder.add_meta_graph_and_variables(
    152             sess,
    153             [tag_constants.SERVING],
    154             signature_def_map={
    155                 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
    156                     signature,
    157             },)
    158         builder.save(as_text=True)
    159 
    160   def testFreezeGraphV1(self):
    161     self._testFreezeGraph(saver_pb2.SaverDef.V1)
    162 
    163   def testFreezeGraphV2(self):
    164     self._testFreezeGraph(saver_pb2.SaverDef.V2)
    165 
    166   def testFreezeMetaGraph(self):
    167     tmp_dir = self.get_temp_dir()
    168     checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint")
    169     checkpoint_state_name = "checkpoint_state"
    170     output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
    171 
    172     with ops.Graph().as_default():
    173       variable_node = variables.Variable(1.0, name="variable_node")
    174       output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
    175       sess = session.Session()
    176       init = variables.global_variables_initializer()
    177       sess.run(init)
    178       output = sess.run(output_node)
    179       self.assertNear(2.0, output, 0.00001)
    180       saver = saver_lib.Saver()
    181       checkpoint_path = saver.save(
    182           sess,
    183           checkpoint_prefix,
    184           global_step=0,
    185           latest_filename=checkpoint_state_name)
    186 
    187     input_saver_def_path = ""
    188     input_binary = True
    189     output_node_names = "output_node"
    190     restore_op_name = "save/restore_all"
    191     filename_tensor_name = "save/Const:0"
    192     clear_devices = False
    193     input_meta_graph = checkpoint_path + ".meta"
    194 
    195     freeze_graph.freeze_graph(
    196         "", input_saver_def_path, input_binary, checkpoint_path,
    197         output_node_names, restore_op_name, filename_tensor_name,
    198         output_graph_filename, clear_devices, "", "", "", input_meta_graph)
    199 
    200     # Now we make sure the variable is now a constant, and that the graph still
    201     # produces the expected result.
    202     with ops.Graph().as_default():
    203       output_graph_def = graph_pb2.GraphDef()
    204       with open(output_graph_filename, "rb") as f:
    205         output_graph_def.ParseFromString(f.read())
    206         _ = importer.import_graph_def(output_graph_def, name="")
    207 
    208       self.assertEqual(4, len(output_graph_def.node))
    209       for node in output_graph_def.node:
    210         self.assertNotEqual("VariableV2", node.op)
    211         self.assertNotEqual("Variable", node.op)
    212 
    213       with session.Session() as sess:
    214         output_node = sess.graph.get_tensor_by_name("output_node:0")
    215         output = sess.run(output_node)
    216         self.assertNear(2.0, output, 0.00001)
    217 
    218   def testFreezeSavedModel(self):
    219     tmp_dir = self.get_temp_dir()
    220     saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
    221     feature_name = "feature"
    222     self._writeDummySavedModel(saved_model_dir, feature_name)
    223     output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
    224 
    225     input_saved_model_dir = saved_model_dir
    226     output_node_names = "output_node"
    227     input_binary = False
    228     input_saver_def_path = False
    229     restore_op_name = None
    230     filename_tensor_name = None
    231     clear_devices = False
    232     input_meta_graph = False
    233     checkpoint_path = None
    234     input_graph_filename = None
    235     saved_model_tags = tag_constants.SERVING
    236 
    237     freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
    238                               input_binary, checkpoint_path, output_node_names,
    239                               restore_op_name, filename_tensor_name,
    240                               output_graph_filename, clear_devices, "", "", "",
    241                               input_meta_graph, input_saved_model_dir,
    242                               saved_model_tags)
    243 
    244     # Now we make sure the variable is now a constant, and that the graph still
    245     # produces the expected result.
    246     with ops.Graph().as_default():
    247       output_graph_def = graph_pb2.GraphDef()
    248       with open(output_graph_filename, "rb") as f:
    249         output_graph_def.ParseFromString(f.read())
    250         _ = importer.import_graph_def(output_graph_def, name="")
    251 
    252       self.assertEqual(8, len(output_graph_def.node))
    253       for node in output_graph_def.node:
    254         self.assertNotEqual("VariableV2", node.op)
    255         self.assertNotEqual("Variable", node.op)
    256 
    257       feature_value = 2.0
    258       example = self._createTFExampleString(feature_name, feature_value)
    259       with session.Session() as sess:
    260         input_node = sess.graph.get_tensor_by_name("input_node:0")
    261         output_node = sess.graph.get_tensor_by_name("output_node:0")
    262         output = sess.run(output_node, feed_dict={input_node: [example]})
    263         self.assertNear(feature_value, output, 0.00001)
    264 
    265 
    266 if __name__ == "__main__":
    267   test.main()
    268