1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 r"""Converts checkpoint variables into Const ops in a standalone GraphDef file. 16 17 This script is designed to take a GraphDef proto, a SaverDef proto, and a set of 18 variable values stored in a checkpoint file, and output a GraphDef with all of 19 the variable ops converted into const ops containing the values of the 20 variables. 21 22 It's useful to do this when we need to load a single file in C++, especially in 23 environments like mobile or embedded where we may not have access to the 24 RestoreTensor ops and file loading calls that they rely on. 25 26 An example of command-line usage is: 27 bazel build tensorflow/python/tools:freeze_graph && \ 28 bazel-bin/tensorflow/python/tools/freeze_graph \ 29 --input_graph=some_graph_def.pb \ 30 --input_checkpoint=model.ckpt-8361242 \ 31 --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax 32 33 You can also look at freeze_graph_test.py for an example of how to use it. 34 35 """ 36 from __future__ import absolute_import 37 from __future__ import division 38 from __future__ import print_function 39 40 import argparse 41 import sys 42 43 from google.protobuf import text_format 44 45 from tensorflow.core.framework import graph_pb2 46 from tensorflow.core.protobuf import saver_pb2 47 from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef 48 from tensorflow.python import pywrap_tensorflow 49 from tensorflow.python.client import session 50 from tensorflow.python.framework import graph_util 51 from tensorflow.python.framework import importer 52 from tensorflow.python.platform import app 53 from tensorflow.python.platform import gfile 54 from tensorflow.python.saved_model import loader 55 from tensorflow.python.saved_model import tag_constants 56 from tensorflow.python.tools import saved_model_utils 57 from tensorflow.python.training import saver as saver_lib 58 59 FLAGS = None 60 61 62 def freeze_graph_with_def_protos(input_graph_def, 63 input_saver_def, 64 input_checkpoint, 65 output_node_names, 66 restore_op_name, 67 filename_tensor_name, 68 output_graph, 69 clear_devices, 70 initializer_nodes, 71 variable_names_whitelist="", 72 variable_names_blacklist="", 73 input_meta_graph_def=None, 74 input_saved_model_dir=None, 75 saved_model_tags=None, 76 checkpoint_version=saver_pb2.SaverDef.V2): 77 """Converts all variables in a graph and checkpoint into constants.""" 78 del restore_op_name, filename_tensor_name # Unused by updated loading code. 79 80 # 'input_checkpoint' may be a prefix if we're using Saver V2 format 81 if (not input_saved_model_dir and 82 not saver_lib.checkpoint_exists(input_checkpoint)): 83 print("Input checkpoint '" + input_checkpoint + "' doesn't exist!") 84 return -1 85 86 if not output_node_names: 87 print("You need to supply the name of a node to --output_node_names.") 88 return -1 89 90 # Remove all the explicit device specifications for this node. This helps to 91 # make the graph more portable. 92 if clear_devices: 93 if input_meta_graph_def: 94 for node in input_meta_graph_def.graph_def.node: 95 node.device = "" 96 elif input_graph_def: 97 for node in input_graph_def.node: 98 node.device = "" 99 100 if input_graph_def: 101 _ = importer.import_graph_def(input_graph_def, name="") 102 with session.Session() as sess: 103 if input_saver_def: 104 saver = saver_lib.Saver( 105 saver_def=input_saver_def, write_version=checkpoint_version) 106 saver.restore(sess, input_checkpoint) 107 elif input_meta_graph_def: 108 restorer = saver_lib.import_meta_graph( 109 input_meta_graph_def, clear_devices=True) 110 restorer.restore(sess, input_checkpoint) 111 if initializer_nodes: 112 sess.run(initializer_nodes.replace(' ', '').split(",")) 113 elif input_saved_model_dir: 114 if saved_model_tags is None: 115 saved_model_tags = [] 116 loader.load(sess, saved_model_tags, input_saved_model_dir) 117 else: 118 var_list = {} 119 reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint) 120 var_to_shape_map = reader.get_variable_to_shape_map() 121 for key in var_to_shape_map: 122 try: 123 tensor = sess.graph.get_tensor_by_name(key + ":0") 124 except KeyError: 125 # This tensor doesn't exist in the graph (for example it's 126 # 'global_step' or a similar housekeeping element) so skip it. 127 continue 128 var_list[key] = tensor 129 saver = saver_lib.Saver( 130 var_list=var_list, write_version=checkpoint_version) 131 saver.restore(sess, input_checkpoint) 132 if initializer_nodes: 133 sess.run(initializer_nodes.replace(' ', '').split(",")) 134 135 variable_names_whitelist = ( 136 variable_names_whitelist.replace(' ', '').split(",") 137 if variable_names_whitelist else None) 138 variable_names_blacklist = ( 139 variable_names_blacklist.replace(' ', '').split(",") 140 if variable_names_blacklist else None) 141 142 if input_meta_graph_def: 143 output_graph_def = graph_util.convert_variables_to_constants( 144 sess, 145 input_meta_graph_def.graph_def, 146 output_node_names.replace(' ', '').split(","), 147 variable_names_whitelist=variable_names_whitelist, 148 variable_names_blacklist=variable_names_blacklist) 149 else: 150 output_graph_def = graph_util.convert_variables_to_constants( 151 sess, 152 input_graph_def, 153 output_node_names.replace(' ', '').split(","), 154 variable_names_whitelist=variable_names_whitelist, 155 variable_names_blacklist=variable_names_blacklist) 156 157 # Write GraphDef to file if output path has been given. 158 if output_graph: 159 with gfile.GFile(output_graph, "wb") as f: 160 f.write(output_graph_def.SerializeToString()) 161 162 return output_graph_def 163 164 165 def _parse_input_graph_proto(input_graph, input_binary): 166 """Parser input tensorflow graph into GraphDef proto.""" 167 if not gfile.Exists(input_graph): 168 print("Input graph file '" + input_graph + "' does not exist!") 169 return -1 170 input_graph_def = graph_pb2.GraphDef() 171 mode = "rb" if input_binary else "r" 172 with gfile.FastGFile(input_graph, mode) as f: 173 if input_binary: 174 input_graph_def.ParseFromString(f.read()) 175 else: 176 text_format.Merge(f.read(), input_graph_def) 177 return input_graph_def 178 179 180 def _parse_input_meta_graph_proto(input_graph, input_binary): 181 """Parser input tensorflow graph into MetaGraphDef proto.""" 182 if not gfile.Exists(input_graph): 183 print("Input meta graph file '" + input_graph + "' does not exist!") 184 return -1 185 input_meta_graph_def = MetaGraphDef() 186 mode = "rb" if input_binary else "r" 187 with gfile.FastGFile(input_graph, mode) as f: 188 if input_binary: 189 input_meta_graph_def.ParseFromString(f.read()) 190 else: 191 text_format.Merge(f.read(), input_meta_graph_def) 192 print("Loaded meta graph file '" + input_graph) 193 return input_meta_graph_def 194 195 196 def _parse_input_saver_proto(input_saver, input_binary): 197 """Parser input tensorflow Saver into SaverDef proto.""" 198 if not gfile.Exists(input_saver): 199 print("Input saver file '" + input_saver + "' does not exist!") 200 return -1 201 mode = "rb" if input_binary else "r" 202 with gfile.FastGFile(input_saver, mode) as f: 203 saver_def = saver_pb2.SaverDef() 204 if input_binary: 205 saver_def.ParseFromString(f.read()) 206 else: 207 text_format.Merge(f.read(), saver_def) 208 return saver_def 209 210 211 def freeze_graph(input_graph, 212 input_saver, 213 input_binary, 214 input_checkpoint, 215 output_node_names, 216 restore_op_name, 217 filename_tensor_name, 218 output_graph, 219 clear_devices, 220 initializer_nodes, 221 variable_names_whitelist="", 222 variable_names_blacklist="", 223 input_meta_graph=None, 224 input_saved_model_dir=None, 225 saved_model_tags=tag_constants.SERVING, 226 checkpoint_version=saver_pb2.SaverDef.V2): 227 """Converts all variables in a graph and checkpoint into constants.""" 228 input_graph_def = None 229 if input_saved_model_dir: 230 input_graph_def = saved_model_utils.get_meta_graph_def( 231 input_saved_model_dir, saved_model_tags).graph_def 232 elif input_graph: 233 input_graph_def = _parse_input_graph_proto(input_graph, input_binary) 234 input_meta_graph_def = None 235 if input_meta_graph: 236 input_meta_graph_def = _parse_input_meta_graph_proto( 237 input_meta_graph, input_binary) 238 input_saver_def = None 239 if input_saver: 240 input_saver_def = _parse_input_saver_proto(input_saver, input_binary) 241 freeze_graph_with_def_protos( 242 input_graph_def, 243 input_saver_def, 244 input_checkpoint, 245 output_node_names, 246 restore_op_name, 247 filename_tensor_name, 248 output_graph, 249 clear_devices, 250 initializer_nodes, 251 variable_names_whitelist, 252 variable_names_blacklist, 253 input_meta_graph_def, 254 input_saved_model_dir, 255 saved_model_tags.replace(' ', '').split(","), 256 checkpoint_version=checkpoint_version) 257 258 259 def main(unused_args): 260 if FLAGS.checkpoint_version == 1: 261 checkpoint_version = saver_pb2.SaverDef.V1 262 elif FLAGS.checkpoint_version == 2: 263 checkpoint_version = saver_pb2.SaverDef.V2 264 else: 265 print("Invalid checkpoint version (must be '1' or '2'): %d" % 266 FLAGS.checkpoint_version) 267 return -1 268 freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, 269 FLAGS.input_checkpoint, FLAGS.output_node_names, 270 FLAGS.restore_op_name, FLAGS.filename_tensor_name, 271 FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes, 272 FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist, 273 FLAGS.input_meta_graph, FLAGS.input_saved_model_dir, 274 FLAGS.saved_model_tags, checkpoint_version) 275 276 277 if __name__ == "__main__": 278 parser = argparse.ArgumentParser() 279 parser.register("type", "bool", lambda v: v.lower() == "true") 280 parser.add_argument( 281 "--input_graph", 282 type=str, 283 default="", 284 help="TensorFlow \'GraphDef\' file to load.") 285 parser.add_argument( 286 "--input_saver", 287 type=str, 288 default="", 289 help="TensorFlow saver file to load.") 290 parser.add_argument( 291 "--input_checkpoint", 292 type=str, 293 default="", 294 help="TensorFlow variables file to load.") 295 parser.add_argument( 296 "--checkpoint_version", 297 type=int, 298 default=2, 299 help="Tensorflow variable file format") 300 parser.add_argument( 301 "--output_graph", 302 type=str, 303 default="", 304 help="Output \'GraphDef\' file name.") 305 parser.add_argument( 306 "--input_binary", 307 nargs="?", 308 const=True, 309 type="bool", 310 default=False, 311 help="Whether the input files are in binary format.") 312 parser.add_argument( 313 "--output_node_names", 314 type=str, 315 default="", 316 help="The name of the output nodes, comma separated.") 317 parser.add_argument( 318 "--restore_op_name", 319 type=str, 320 default="save/restore_all", 321 help="""\ 322 The name of the master restore operator. Deprecated, unused by updated \ 323 loading code. 324 """) 325 parser.add_argument( 326 "--filename_tensor_name", 327 type=str, 328 default="save/Const:0", 329 help="""\ 330 The name of the tensor holding the save path. Deprecated, unused by \ 331 updated loading code. 332 """) 333 parser.add_argument( 334 "--clear_devices", 335 nargs="?", 336 const=True, 337 type="bool", 338 default=True, 339 help="Whether to remove device specifications.") 340 parser.add_argument( 341 "--initializer_nodes", 342 type=str, 343 default="", 344 help="Comma separated list of initializer nodes to run before freezing.") 345 parser.add_argument( 346 "--variable_names_whitelist", 347 type=str, 348 default="", 349 help="""\ 350 Comma separated list of variables to convert to constants. If specified, \ 351 only those variables will be converted to constants.\ 352 """) 353 parser.add_argument( 354 "--variable_names_blacklist", 355 type=str, 356 default="", 357 help="""\ 358 Comma separated list of variables to skip converting to constants.\ 359 """) 360 parser.add_argument( 361 "--input_meta_graph", 362 type=str, 363 default="", 364 help="TensorFlow \'MetaGraphDef\' file to load.") 365 parser.add_argument( 366 "--input_saved_model_dir", 367 type=str, 368 default="", 369 help="Path to the dir with TensorFlow \'SavedModel\' file and variables.") 370 parser.add_argument( 371 "--saved_model_tags", 372 type=str, 373 default="serve", 374 help="""\ 375 Group of tag(s) of the MetaGraphDef to load, in string format,\ 376 separated by \',\'. For tag-set contains multiple tags, all tags \ 377 must be passed in.\ 378 """) 379 FLAGS, unparsed = parser.parse_known_args() 380 app.run(main=main, argv=[sys.argv[0]] + unparsed) 381