Home | History | Annotate | Download | only in tools
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.
     16 
     17 This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
     18 variable values stored in a checkpoint file, and output a GraphDef with all of
     19 the variable ops converted into const ops containing the values of the
     20 variables.
     21 
     22 It's useful to do this when we need to load a single file in C++, especially in
     23 environments like mobile or embedded where we may not have access to the
     24 RestoreTensor ops and file loading calls that they rely on.
     25 
     26 An example of command-line usage is:
     27 bazel build tensorflow/python/tools:freeze_graph && \
     28 bazel-bin/tensorflow/python/tools/freeze_graph \
     29 --input_graph=some_graph_def.pb \
     30 --input_checkpoint=model.ckpt-8361242 \
     31 --output_graph=/tmp/frozen_graph.pb --output_node_names=softmax
     32 
     33 You can also look at freeze_graph_test.py for an example of how to use it.
     34 
     35 """
     36 from __future__ import absolute_import
     37 from __future__ import division
     38 from __future__ import print_function
     39 
     40 import argparse
     41 import sys
     42 
     43 from google.protobuf import text_format
     44 
     45 from tensorflow.core.framework import graph_pb2
     46 from tensorflow.core.protobuf import saver_pb2
     47 from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
     48 from tensorflow.python import pywrap_tensorflow
     49 from tensorflow.python.client import session
     50 from tensorflow.python.framework import graph_util
     51 from tensorflow.python.framework import importer
     52 from tensorflow.python.platform import app
     53 from tensorflow.python.platform import gfile
     54 from tensorflow.python.saved_model import loader
     55 from tensorflow.python.saved_model import tag_constants
     56 from tensorflow.python.tools import saved_model_utils
     57 from tensorflow.python.training import saver as saver_lib
     58 
     59 FLAGS = None
     60 
     61 
     62 def freeze_graph_with_def_protos(input_graph_def,
     63                                  input_saver_def,
     64                                  input_checkpoint,
     65                                  output_node_names,
     66                                  restore_op_name,
     67                                  filename_tensor_name,
     68                                  output_graph,
     69                                  clear_devices,
     70                                  initializer_nodes,
     71                                  variable_names_whitelist="",
     72                                  variable_names_blacklist="",
     73                                  input_meta_graph_def=None,
     74                                  input_saved_model_dir=None,
     75                                  saved_model_tags=None,
     76                                  checkpoint_version=saver_pb2.SaverDef.V2):
     77   """Converts all variables in a graph and checkpoint into constants."""
     78   del restore_op_name, filename_tensor_name  # Unused by updated loading code.
     79 
     80   # 'input_checkpoint' may be a prefix if we're using Saver V2 format
     81   if (not input_saved_model_dir and
     82       not saver_lib.checkpoint_exists(input_checkpoint)):
     83     print("Input checkpoint '" + input_checkpoint + "' doesn't exist!")
     84     return -1
     85 
     86   if not output_node_names:
     87     print("You need to supply the name of a node to --output_node_names.")
     88     return -1
     89 
     90   # Remove all the explicit device specifications for this node. This helps to
     91   # make the graph more portable.
     92   if clear_devices:
     93     if input_meta_graph_def:
     94       for node in input_meta_graph_def.graph_def.node:
     95         node.device = ""
     96     elif input_graph_def:
     97       for node in input_graph_def.node:
     98         node.device = ""
     99 
    100   if input_graph_def:
    101     _ = importer.import_graph_def(input_graph_def, name="")
    102   with session.Session() as sess:
    103     if input_saver_def:
    104       saver = saver_lib.Saver(
    105           saver_def=input_saver_def, write_version=checkpoint_version)
    106       saver.restore(sess, input_checkpoint)
    107     elif input_meta_graph_def:
    108       restorer = saver_lib.import_meta_graph(
    109           input_meta_graph_def, clear_devices=True)
    110       restorer.restore(sess, input_checkpoint)
    111       if initializer_nodes:
    112         sess.run(initializer_nodes.replace(' ', '').split(","))
    113     elif input_saved_model_dir:
    114       if saved_model_tags is None:
    115         saved_model_tags = []
    116       loader.load(sess, saved_model_tags, input_saved_model_dir)
    117     else:
    118       var_list = {}
    119       reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
    120       var_to_shape_map = reader.get_variable_to_shape_map()
    121       for key in var_to_shape_map:
    122         try:
    123           tensor = sess.graph.get_tensor_by_name(key + ":0")
    124         except KeyError:
    125           # This tensor doesn't exist in the graph (for example it's
    126           # 'global_step' or a similar housekeeping element) so skip it.
    127           continue
    128         var_list[key] = tensor
    129       saver = saver_lib.Saver(
    130           var_list=var_list, write_version=checkpoint_version)
    131       saver.restore(sess, input_checkpoint)
    132       if initializer_nodes:
    133         sess.run(initializer_nodes.replace(' ', '').split(","))
    134 
    135     variable_names_whitelist = (
    136         variable_names_whitelist.replace(' ', '').split(",")
    137         if variable_names_whitelist else None)
    138     variable_names_blacklist = (
    139         variable_names_blacklist.replace(' ', '').split(",")
    140         if variable_names_blacklist else None)
    141 
    142     if input_meta_graph_def:
    143       output_graph_def = graph_util.convert_variables_to_constants(
    144           sess,
    145           input_meta_graph_def.graph_def,
    146           output_node_names.replace(' ', '').split(","),
    147           variable_names_whitelist=variable_names_whitelist,
    148           variable_names_blacklist=variable_names_blacklist)
    149     else:
    150       output_graph_def = graph_util.convert_variables_to_constants(
    151           sess,
    152           input_graph_def,
    153           output_node_names.replace(' ', '').split(","),
    154           variable_names_whitelist=variable_names_whitelist,
    155           variable_names_blacklist=variable_names_blacklist)
    156 
    157   # Write GraphDef to file if output path has been given.
    158   if output_graph:
    159     with gfile.GFile(output_graph, "wb") as f:
    160       f.write(output_graph_def.SerializeToString())
    161 
    162   return output_graph_def
    163 
    164 
    165 def _parse_input_graph_proto(input_graph, input_binary):
    166   """Parser input tensorflow graph into GraphDef proto."""
    167   if not gfile.Exists(input_graph):
    168     print("Input graph file '" + input_graph + "' does not exist!")
    169     return -1
    170   input_graph_def = graph_pb2.GraphDef()
    171   mode = "rb" if input_binary else "r"
    172   with gfile.FastGFile(input_graph, mode) as f:
    173     if input_binary:
    174       input_graph_def.ParseFromString(f.read())
    175     else:
    176       text_format.Merge(f.read(), input_graph_def)
    177   return input_graph_def
    178 
    179 
    180 def _parse_input_meta_graph_proto(input_graph, input_binary):
    181   """Parser input tensorflow graph into MetaGraphDef proto."""
    182   if not gfile.Exists(input_graph):
    183     print("Input meta graph file '" + input_graph + "' does not exist!")
    184     return -1
    185   input_meta_graph_def = MetaGraphDef()
    186   mode = "rb" if input_binary else "r"
    187   with gfile.FastGFile(input_graph, mode) as f:
    188     if input_binary:
    189       input_meta_graph_def.ParseFromString(f.read())
    190     else:
    191       text_format.Merge(f.read(), input_meta_graph_def)
    192   print("Loaded meta graph file '" + input_graph)
    193   return input_meta_graph_def
    194 
    195 
    196 def _parse_input_saver_proto(input_saver, input_binary):
    197   """Parser input tensorflow Saver into SaverDef proto."""
    198   if not gfile.Exists(input_saver):
    199     print("Input saver file '" + input_saver + "' does not exist!")
    200     return -1
    201   mode = "rb" if input_binary else "r"
    202   with gfile.FastGFile(input_saver, mode) as f:
    203     saver_def = saver_pb2.SaverDef()
    204     if input_binary:
    205       saver_def.ParseFromString(f.read())
    206     else:
    207       text_format.Merge(f.read(), saver_def)
    208   return saver_def
    209 
    210 
    211 def freeze_graph(input_graph,
    212                  input_saver,
    213                  input_binary,
    214                  input_checkpoint,
    215                  output_node_names,
    216                  restore_op_name,
    217                  filename_tensor_name,
    218                  output_graph,
    219                  clear_devices,
    220                  initializer_nodes,
    221                  variable_names_whitelist="",
    222                  variable_names_blacklist="",
    223                  input_meta_graph=None,
    224                  input_saved_model_dir=None,
    225                  saved_model_tags=tag_constants.SERVING,
    226                  checkpoint_version=saver_pb2.SaverDef.V2):
    227   """Converts all variables in a graph and checkpoint into constants."""
    228   input_graph_def = None
    229   if input_saved_model_dir:
    230     input_graph_def = saved_model_utils.get_meta_graph_def(
    231         input_saved_model_dir, saved_model_tags).graph_def
    232   elif input_graph:
    233     input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
    234   input_meta_graph_def = None
    235   if input_meta_graph:
    236     input_meta_graph_def = _parse_input_meta_graph_proto(
    237         input_meta_graph, input_binary)
    238   input_saver_def = None
    239   if input_saver:
    240     input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
    241   freeze_graph_with_def_protos(
    242       input_graph_def,
    243       input_saver_def,
    244       input_checkpoint,
    245       output_node_names,
    246       restore_op_name,
    247       filename_tensor_name,
    248       output_graph,
    249       clear_devices,
    250       initializer_nodes,
    251       variable_names_whitelist,
    252       variable_names_blacklist,
    253       input_meta_graph_def,
    254       input_saved_model_dir,
    255       saved_model_tags.replace(' ', '').split(","),
    256       checkpoint_version=checkpoint_version)
    257 
    258 
    259 def main(unused_args):
    260   if FLAGS.checkpoint_version == 1:
    261     checkpoint_version = saver_pb2.SaverDef.V1
    262   elif FLAGS.checkpoint_version == 2:
    263     checkpoint_version = saver_pb2.SaverDef.V2
    264   else:
    265     print("Invalid checkpoint version (must be '1' or '2'): %d" %
    266           FLAGS.checkpoint_version)
    267     return -1
    268   freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
    269                FLAGS.input_checkpoint, FLAGS.output_node_names,
    270                FLAGS.restore_op_name, FLAGS.filename_tensor_name,
    271                FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
    272                FLAGS.variable_names_whitelist, FLAGS.variable_names_blacklist,
    273                FLAGS.input_meta_graph, FLAGS.input_saved_model_dir,
    274                FLAGS.saved_model_tags, checkpoint_version)
    275 
    276 
    277 if __name__ == "__main__":
    278   parser = argparse.ArgumentParser()
    279   parser.register("type", "bool", lambda v: v.lower() == "true")
    280   parser.add_argument(
    281       "--input_graph",
    282       type=str,
    283       default="",
    284       help="TensorFlow \'GraphDef\' file to load.")
    285   parser.add_argument(
    286       "--input_saver",
    287       type=str,
    288       default="",
    289       help="TensorFlow saver file to load.")
    290   parser.add_argument(
    291       "--input_checkpoint",
    292       type=str,
    293       default="",
    294       help="TensorFlow variables file to load.")
    295   parser.add_argument(
    296       "--checkpoint_version",
    297       type=int,
    298       default=2,
    299       help="Tensorflow variable file format")
    300   parser.add_argument(
    301       "--output_graph",
    302       type=str,
    303       default="",
    304       help="Output \'GraphDef\' file name.")
    305   parser.add_argument(
    306       "--input_binary",
    307       nargs="?",
    308       const=True,
    309       type="bool",
    310       default=False,
    311       help="Whether the input files are in binary format.")
    312   parser.add_argument(
    313       "--output_node_names",
    314       type=str,
    315       default="",
    316       help="The name of the output nodes, comma separated.")
    317   parser.add_argument(
    318       "--restore_op_name",
    319       type=str,
    320       default="save/restore_all",
    321       help="""\
    322       The name of the master restore operator. Deprecated, unused by updated \
    323       loading code.
    324       """)
    325   parser.add_argument(
    326       "--filename_tensor_name",
    327       type=str,
    328       default="save/Const:0",
    329       help="""\
    330       The name of the tensor holding the save path. Deprecated, unused by \
    331       updated loading code.
    332       """)
    333   parser.add_argument(
    334       "--clear_devices",
    335       nargs="?",
    336       const=True,
    337       type="bool",
    338       default=True,
    339       help="Whether to remove device specifications.")
    340   parser.add_argument(
    341       "--initializer_nodes",
    342       type=str,
    343       default="",
    344       help="Comma separated list of initializer nodes to run before freezing.")
    345   parser.add_argument(
    346       "--variable_names_whitelist",
    347       type=str,
    348       default="",
    349       help="""\
    350       Comma separated list of variables to convert to constants. If specified, \
    351       only those variables will be converted to constants.\
    352       """)
    353   parser.add_argument(
    354       "--variable_names_blacklist",
    355       type=str,
    356       default="",
    357       help="""\
    358       Comma separated list of variables to skip converting to constants.\
    359       """)
    360   parser.add_argument(
    361       "--input_meta_graph",
    362       type=str,
    363       default="",
    364       help="TensorFlow \'MetaGraphDef\' file to load.")
    365   parser.add_argument(
    366       "--input_saved_model_dir",
    367       type=str,
    368       default="",
    369       help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
    370   parser.add_argument(
    371       "--saved_model_tags",
    372       type=str,
    373       default="serve",
    374       help="""\
    375       Group of tag(s) of the MetaGraphDef to load, in string format,\
    376       separated by \',\'. For tag-set contains multiple tags, all tags \
    377       must be passed in.\
    378       """)
    379   FLAGS, unparsed = parser.parse_known_args()
    380   app.run(main=main, argv=[sys.argv[0]] + unparsed)
    381