Home | History | Annotate | Download | only in quantization
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Converts a GraphDef file into a DOT format suitable for visualization.
     16 
     17 This script takes a GraphDef representing a network, and produces a DOT file
     18 that can then be visualized by GraphViz tools like dot and xdot.
     19 
     20 """
     21 from __future__ import absolute_import
     22 from __future__ import division
     23 from __future__ import print_function
     24 
     25 import re
     26 
     27 from google.protobuf import text_format
     28 
     29 from tensorflow.core.framework import graph_pb2
     30 from tensorflow.python.platform import app
     31 from tensorflow.python.platform import flags
     32 from tensorflow.python.platform import gfile
     33 
     34 FLAGS = flags.FLAGS
     35 
     36 flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""")
     37 flags.DEFINE_bool("input_binary", True,
     38                   """Whether the input files are in binary format.""")
     39 flags.DEFINE_string("dot_output", "", """Where to write the DOT output.""")
     40 
     41 
     42 def main(unused_args):
     43   if not gfile.Exists(FLAGS.graph):
     44     print("Input graph file '" + FLAGS.graph + "' does not exist!")
     45     return -1
     46 
     47   graph = graph_pb2.GraphDef()
     48   with open(FLAGS.graph, "r") as f:
     49     if FLAGS.input_binary:
     50       graph.ParseFromString(f.read())
     51     else:
     52       text_format.Merge(f.read(), graph)
     53 
     54   with open(FLAGS.dot_output, "wb") as f:
     55     print("digraph graphname {", file=f)
     56     for node in graph.node:
     57       output_name = node.name
     58       print("  \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f)
     59       for input_full_name in node.input:
     60         parts = input_full_name.split(":")
     61         input_name = re.sub(r"^\^", "", parts[0])
     62         print("  \"" + input_name + "\" -> \"" + output_name + "\";", file=f)
     63     print("}", file=f)
     64   print("Created DOT file '" + FLAGS.dot_output + "'.")
     65 
     66 
     67 if __name__ == "__main__":
     68   app.run()
     69