1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests the graph freezing tool.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 23 from tensorflow.core.example import example_pb2 24 from tensorflow.core.framework import graph_pb2 25 from tensorflow.core.protobuf import saver_pb2 26 from tensorflow.python.client import session 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import graph_io 29 from tensorflow.python.framework import importer 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import test_util 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import math_ops 34 from tensorflow.python.ops import parsing_ops 35 from tensorflow.python.ops import variables 36 from tensorflow.python.platform import test 37 from tensorflow.python.saved_model import builder as saved_model_builder 38 from tensorflow.python.saved_model import signature_constants 39 from tensorflow.python.saved_model import signature_def_utils 40 from tensorflow.python.saved_model import tag_constants 41 from tensorflow.python.tools import freeze_graph 42 from tensorflow.python.training import saver as saver_lib 43 44 45 class FreezeGraphTest(test_util.TensorFlowTestCase): 46 47 def _testFreezeGraph(self, saver_write_version): 48 49 checkpoint_prefix = os.path.join(self.get_temp_dir(), "saved_checkpoint") 50 checkpoint_meta_graph_file = os.path.join(self.get_temp_dir(), 51 "saved_checkpoint.meta") 52 checkpoint_state_name = "checkpoint_state" 53 input_graph_name = "input_graph.pb" 54 output_graph_name = "output_graph.pb" 55 56 # We'll create an input graph that has a single variable containing 1.0, 57 # and that then multiplies it by 2. 58 with ops.Graph().as_default(): 59 variable_node = variables.Variable(1.0, name="variable_node") 60 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 61 sess = session.Session() 62 init = variables.global_variables_initializer() 63 sess.run(init) 64 output = sess.run(output_node) 65 self.assertNear(2.0, output, 0.00001) 66 saver = saver_lib.Saver(write_version=saver_write_version) 67 checkpoint_path = saver.save( 68 sess, 69 checkpoint_prefix, 70 global_step=0, 71 latest_filename=checkpoint_state_name) 72 graph_io.write_graph(sess.graph, self.get_temp_dir(), input_graph_name) 73 74 # We save out the graph to disk, and then call the const conversion 75 # routine. 76 input_graph_path = os.path.join(self.get_temp_dir(), input_graph_name) 77 input_saver_def_path = "" 78 input_binary = False 79 output_node_names = "output_node" 80 restore_op_name = "save/restore_all" 81 filename_tensor_name = "save/Const:0" 82 output_graph_path = os.path.join(self.get_temp_dir(), output_graph_name) 83 clear_devices = False 84 input_meta_graph = checkpoint_meta_graph_file 85 86 freeze_graph.freeze_graph( 87 input_graph_path, 88 input_saver_def_path, 89 input_binary, 90 checkpoint_path, 91 output_node_names, 92 restore_op_name, 93 filename_tensor_name, 94 output_graph_path, 95 clear_devices, 96 "", 97 "", 98 input_meta_graph, 99 checkpoint_version=saver_write_version) 100 101 # Now we make sure the variable is now a constant, and that the graph still 102 # produces the expected result. 103 with ops.Graph().as_default(): 104 output_graph_def = graph_pb2.GraphDef() 105 with open(output_graph_path, "rb") as f: 106 output_graph_def.ParseFromString(f.read()) 107 _ = importer.import_graph_def(output_graph_def, name="") 108 109 self.assertEqual(4, len(output_graph_def.node)) 110 for node in output_graph_def.node: 111 self.assertNotEqual("VariableV2", node.op) 112 self.assertNotEqual("Variable", node.op) 113 114 with session.Session() as sess: 115 output_node = sess.graph.get_tensor_by_name("output_node:0") 116 output = sess.run(output_node) 117 self.assertNear(2.0, output, 0.00001) 118 119 def _createTFExampleString(self, feature_name, feature_value): 120 """Create a serialized tensorflow example.""" 121 example = example_pb2.Example() 122 example.features.feature[feature_name].float_list.value.extend([ 123 feature_value]) 124 return example.SerializeToString() 125 126 def _writeDummySavedModel(self, path, feature_name): 127 """Writes a classifier with two input features to the given path.""" 128 with ops.Graph().as_default(): 129 examples = array_ops.placeholder(dtypes.string, name="input_node") 130 feature_configs = { 131 feature_name: parsing_ops.FixedLenFeature(shape=[], 132 dtype=dtypes.float32), 133 } 134 features = parsing_ops.parse_example(examples, feature_configs) 135 feature = features[feature_name] 136 137 variable_node = variables.Variable(1.0, name="variable_node") 138 scores = math_ops.multiply(variable_node, feature, name="output_node") 139 class_feature = array_ops.fill(array_ops.shape(feature), 140 "class_%s" % feature_name) 141 classes = array_ops.transpose(class_feature) 142 143 with session.Session() as sess: 144 sess.run(variables.global_variables_initializer()) 145 signature = ( 146 signature_def_utils.classification_signature_def( 147 examples=examples, 148 classes=classes, 149 scores=scores,)) 150 builder = saved_model_builder.SavedModelBuilder(path) 151 builder.add_meta_graph_and_variables( 152 sess, 153 [tag_constants.SERVING], 154 signature_def_map={ 155 signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 156 signature, 157 },) 158 builder.save(as_text=True) 159 160 def testFreezeGraphV1(self): 161 self._testFreezeGraph(saver_pb2.SaverDef.V1) 162 163 def testFreezeGraphV2(self): 164 self._testFreezeGraph(saver_pb2.SaverDef.V2) 165 166 def testFreezeMetaGraph(self): 167 tmp_dir = self.get_temp_dir() 168 checkpoint_prefix = os.path.join(tmp_dir, "meta_graph_checkpoint") 169 checkpoint_state_name = "checkpoint_state" 170 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 171 172 with ops.Graph().as_default(): 173 variable_node = variables.Variable(1.0, name="variable_node") 174 output_node = math_ops.multiply(variable_node, 2.0, name="output_node") 175 sess = session.Session() 176 init = variables.global_variables_initializer() 177 sess.run(init) 178 output = sess.run(output_node) 179 self.assertNear(2.0, output, 0.00001) 180 saver = saver_lib.Saver() 181 checkpoint_path = saver.save( 182 sess, 183 checkpoint_prefix, 184 global_step=0, 185 latest_filename=checkpoint_state_name) 186 187 input_saver_def_path = "" 188 input_binary = True 189 output_node_names = "output_node" 190 restore_op_name = "save/restore_all" 191 filename_tensor_name = "save/Const:0" 192 clear_devices = False 193 input_meta_graph = checkpoint_path + ".meta" 194 195 freeze_graph.freeze_graph( 196 "", input_saver_def_path, input_binary, checkpoint_path, 197 output_node_names, restore_op_name, filename_tensor_name, 198 output_graph_filename, clear_devices, "", "", "", input_meta_graph) 199 200 # Now we make sure the variable is now a constant, and that the graph still 201 # produces the expected result. 202 with ops.Graph().as_default(): 203 output_graph_def = graph_pb2.GraphDef() 204 with open(output_graph_filename, "rb") as f: 205 output_graph_def.ParseFromString(f.read()) 206 _ = importer.import_graph_def(output_graph_def, name="") 207 208 self.assertEqual(4, len(output_graph_def.node)) 209 for node in output_graph_def.node: 210 self.assertNotEqual("VariableV2", node.op) 211 self.assertNotEqual("Variable", node.op) 212 213 with session.Session() as sess: 214 output_node = sess.graph.get_tensor_by_name("output_node:0") 215 output = sess.run(output_node) 216 self.assertNear(2.0, output, 0.00001) 217 218 def testFreezeSavedModel(self): 219 tmp_dir = self.get_temp_dir() 220 saved_model_dir = os.path.join(tmp_dir, "saved_model_dir") 221 feature_name = "feature" 222 self._writeDummySavedModel(saved_model_dir, feature_name) 223 output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") 224 225 input_saved_model_dir = saved_model_dir 226 output_node_names = "output_node" 227 input_binary = False 228 input_saver_def_path = False 229 restore_op_name = None 230 filename_tensor_name = None 231 clear_devices = False 232 input_meta_graph = False 233 checkpoint_path = None 234 input_graph_filename = None 235 saved_model_tags = tag_constants.SERVING 236 237 freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, 238 input_binary, checkpoint_path, output_node_names, 239 restore_op_name, filename_tensor_name, 240 output_graph_filename, clear_devices, "", "", "", 241 input_meta_graph, input_saved_model_dir, 242 saved_model_tags) 243 244 # Now we make sure the variable is now a constant, and that the graph still 245 # produces the expected result. 246 with ops.Graph().as_default(): 247 output_graph_def = graph_pb2.GraphDef() 248 with open(output_graph_filename, "rb") as f: 249 output_graph_def.ParseFromString(f.read()) 250 _ = importer.import_graph_def(output_graph_def, name="") 251 252 self.assertEqual(8, len(output_graph_def.node)) 253 for node in output_graph_def.node: 254 self.assertNotEqual("VariableV2", node.op) 255 self.assertNotEqual("Variable", node.op) 256 257 feature_value = 2.0 258 example = self._createTFExampleString(feature_name, feature_value) 259 with session.Session() as sess: 260 input_node = sess.graph.get_tensor_by_name("input_node:0") 261 output_node = sess.graph.get_tensor_by_name("output_node:0") 262 output = sess.run(output_node, feed_dict={input_node: [example]}) 263 self.assertNear(feature_value, output, 0.00001) 264 265 266 if __name__ == "__main__": 267 test.main() 268