1 # pylint: disable=g-bad-file-header 2 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # ============================================================================== 16 r"""Removes parts of a graph that are only needed for training. 17 18 There are several common transformations that can be applied to GraphDefs 19 created to train a model, that help reduce the amount of computation needed when 20 the network is used only for inference. These include: 21 22 - Removing training-only operations like checkpoint saving. 23 24 - Stripping out parts of the graph that are never reached. 25 26 - Removing debug operations like CheckNumerics. 27 28 - Folding batch normalization ops into the pre-calculated weights. 29 30 - Fusing common operations into unified versions. 31 32 This script takes either a frozen binary GraphDef file (where the weight 33 variables have been converted into constants by the freeze_graph script), or a 34 text GraphDef proto file (the weight variables are stored in a separate 35 checkpoint file), and outputs a new GraphDef with the optimizations applied. 36 37 If the input graph is a text graph file, make sure to include the node that 38 restores the variable weights in output_names. That node is usually named 39 "restore_all". 40 41 An example of command-line usage is: 42 43 bazel build tensorflow/python/tools:optimize_for_inference && \ 44 bazel-bin/tensorflow/python/tools/optimize_for_inference \ 45 --input=frozen_inception_graph.pb \ 46 --output=optimized_inception_graph.pb \ 47 --frozen_graph=True \ 48 --input_names=Mul \ 49 --output_names=softmax 50 51 52 """ 53 54 from __future__ import absolute_import 55 from __future__ import division 56 from __future__ import print_function 57 58 import argparse 59 import os 60 import sys 61 62 from google.protobuf import text_format 63 64 from tensorflow.core.framework import graph_pb2 65 from tensorflow.python.framework import dtypes 66 from tensorflow.python.framework import graph_io 67 from tensorflow.python.platform import app 68 from tensorflow.python.platform import gfile 69 from tensorflow.python.tools import optimize_for_inference_lib 70 71 FLAGS = None 72 73 74 def main(unused_args): 75 if not gfile.Exists(FLAGS.input): 76 print("Input graph file '" + FLAGS.input + "' does not exist!") 77 return -1 78 79 input_graph_def = graph_pb2.GraphDef() 80 with gfile.Open(FLAGS.input, "rb") as f: 81 data = f.read() 82 if FLAGS.frozen_graph: 83 input_graph_def.ParseFromString(data) 84 else: 85 text_format.Merge(data.decode("utf-8"), input_graph_def) 86 87 output_graph_def = optimize_for_inference_lib.optimize_for_inference( 88 input_graph_def, 89 FLAGS.input_names.split(","), 90 FLAGS.output_names.split(","), FLAGS.placeholder_type_enum) 91 92 if FLAGS.frozen_graph: 93 f = gfile.FastGFile(FLAGS.output, "w") 94 f.write(output_graph_def.SerializeToString()) 95 else: 96 graph_io.write_graph(output_graph_def, 97 os.path.dirname(FLAGS.output), 98 os.path.basename(FLAGS.output)) 99 return 0 100 101 102 def parse_args(): 103 """Parses command line arguments.""" 104 parser = argparse.ArgumentParser() 105 parser.register("type", "bool", lambda v: v.lower() == "true") 106 parser.add_argument( 107 "--input", 108 type=str, 109 default="", 110 help="TensorFlow \'GraphDef\' file to load.") 111 parser.add_argument( 112 "--output", 113 type=str, 114 default="", 115 help="File to save the output graph to.") 116 parser.add_argument( 117 "--input_names", 118 type=str, 119 default="", 120 help="Input node names, comma separated.") 121 parser.add_argument( 122 "--output_names", 123 type=str, 124 default="", 125 help="Output node names, comma separated.") 126 parser.add_argument( 127 "--frozen_graph", 128 nargs="?", 129 const=True, 130 type="bool", 131 default=True, 132 help="""\ 133 If true, the input graph is a binary frozen GraphDef 134 file; if false, it is a text GraphDef proto file.\ 135 """) 136 parser.add_argument( 137 "--placeholder_type_enum", 138 type=int, 139 default=dtypes.float32.as_datatype_enum, 140 help="The AttrValue enum to use for placeholders.") 141 return parser.parse_known_args() 142 143 144 if __name__ == "__main__": 145 FLAGS, unparsed = parse_args() 146 app.run(main=main, argv=[sys.argv[0]] + unparsed) 147