Home | History | Annotate | Download | only in tools
      1 # pylint: disable=g-bad-file-header
      2 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      3 #
      4 # Licensed under the Apache License, Version 2.0 (the "License");
      5 # you may not use this file except in compliance with the License.
      6 # You may obtain a copy of the License at
      7 #
      8 #     http://www.apache.org/licenses/LICENSE-2.0
      9 #
     10 # Unless required by applicable law or agreed to in writing, software
     11 # distributed under the License is distributed on an "AS IS" BASIS,
     12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13 # See the License for the specific language governing permissions and
     14 # limitations under the License.
     15 # ==============================================================================
     16 r"""Removes parts of a graph that are only needed for training.
     17 
     18 There are several common transformations that can be applied to GraphDefs
     19 created to train a model, that help reduce the amount of computation needed when
     20 the network is used only for inference. These include:
     21 
     22  - Removing training-only operations like checkpoint saving.
     23 
     24  - Stripping out parts of the graph that are never reached.
     25 
     26  - Removing debug operations like CheckNumerics.
     27 
     28  - Folding batch normalization ops into the pre-calculated weights.
     29 
     30  - Fusing common operations into unified versions.
     31 
     32 This script takes either a frozen binary GraphDef file (where the weight
     33 variables have been converted into constants by the freeze_graph script), or a
     34 text GraphDef proto file (the weight variables are stored in a separate
     35 checkpoint file), and outputs a new GraphDef with the optimizations applied.
     36 
     37 If the input graph is a text graph file, make sure to include the node that
     38 restores the variable weights in output_names. That node is usually named
     39 "restore_all".
     40 
     41 An example of command-line usage is:
     42 
     43 bazel build tensorflow/python/tools:optimize_for_inference && \
     44 bazel-bin/tensorflow/python/tools/optimize_for_inference \
     45 --input=frozen_inception_graph.pb \
     46 --output=optimized_inception_graph.pb \
     47 --frozen_graph=True \
     48 --input_names=Mul \
     49 --output_names=softmax
     50 
     51 
     52 """
     53 
     54 from __future__ import absolute_import
     55 from __future__ import division
     56 from __future__ import print_function
     57 
     58 import argparse
     59 import os
     60 import sys
     61 
     62 from google.protobuf import text_format
     63 
     64 from tensorflow.core.framework import graph_pb2
     65 from tensorflow.python.framework import dtypes
     66 from tensorflow.python.framework import graph_io
     67 from tensorflow.python.platform import app
     68 from tensorflow.python.platform import gfile
     69 from tensorflow.python.tools import optimize_for_inference_lib
     70 
     71 FLAGS = None
     72 
     73 
     74 def main(unused_args):
     75   if not gfile.Exists(FLAGS.input):
     76     print("Input graph file '" + FLAGS.input + "' does not exist!")
     77     return -1
     78 
     79   input_graph_def = graph_pb2.GraphDef()
     80   with gfile.Open(FLAGS.input, "rb") as f:
     81     data = f.read()
     82     if FLAGS.frozen_graph:
     83       input_graph_def.ParseFromString(data)
     84     else:
     85       text_format.Merge(data.decode("utf-8"), input_graph_def)
     86 
     87   output_graph_def = optimize_for_inference_lib.optimize_for_inference(
     88       input_graph_def,
     89       FLAGS.input_names.split(","),
     90       FLAGS.output_names.split(","), FLAGS.placeholder_type_enum)
     91 
     92   if FLAGS.frozen_graph:
     93     f = gfile.FastGFile(FLAGS.output, "w")
     94     f.write(output_graph_def.SerializeToString())
     95   else:
     96     graph_io.write_graph(output_graph_def,
     97                          os.path.dirname(FLAGS.output),
     98                          os.path.basename(FLAGS.output))
     99   return 0
    100 
    101 
    102 def parse_args():
    103   """Parses command line arguments."""
    104   parser = argparse.ArgumentParser()
    105   parser.register("type", "bool", lambda v: v.lower() == "true")
    106   parser.add_argument(
    107       "--input",
    108       type=str,
    109       default="",
    110       help="TensorFlow \'GraphDef\' file to load.")
    111   parser.add_argument(
    112       "--output",
    113       type=str,
    114       default="",
    115       help="File to save the output graph to.")
    116   parser.add_argument(
    117       "--input_names",
    118       type=str,
    119       default="",
    120       help="Input node names, comma separated.")
    121   parser.add_argument(
    122       "--output_names",
    123       type=str,
    124       default="",
    125       help="Output node names, comma separated.")
    126   parser.add_argument(
    127       "--frozen_graph",
    128       nargs="?",
    129       const=True,
    130       type="bool",
    131       default=True,
    132       help="""\
    133       If true, the input graph is a binary frozen GraphDef
    134       file; if false, it is a text GraphDef proto file.\
    135       """)
    136   parser.add_argument(
    137       "--placeholder_type_enum",
    138       type=int,
    139       default=dtypes.float32.as_datatype_enum,
    140       help="The AttrValue enum to use for placeholders.")
    141   return parser.parse_known_args()
    142 
    143 
    144 if __name__ == "__main__":
    145   FLAGS, unparsed = parse_args()
    146   app.run(main=main, argv=[sys.argv[0]] + unparsed)
    147