1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Converts a GraphDef file into a DOT format suitable for visualization. 16 17 This script takes a GraphDef representing a network, and produces a DOT file 18 that can then be visualized by GraphViz tools like dot and xdot. 19 20 """ 21 from __future__ import absolute_import 22 from __future__ import division 23 from __future__ import print_function 24 25 import re 26 27 from google.protobuf import text_format 28 29 from tensorflow.core.framework import graph_pb2 30 from tensorflow.python.platform import app 31 from tensorflow.python.platform import flags 32 from tensorflow.python.platform import gfile 33 34 FLAGS = flags.FLAGS 35 36 flags.DEFINE_string("graph", "", """TensorFlow 'GraphDef' file to load.""") 37 flags.DEFINE_bool("input_binary", True, 38 """Whether the input files are in binary format.""") 39 flags.DEFINE_string("dot_output", "", """Where to write the DOT output.""") 40 41 42 def main(unused_args): 43 if not gfile.Exists(FLAGS.graph): 44 print("Input graph file '" + FLAGS.graph + "' does not exist!") 45 return -1 46 47 graph = graph_pb2.GraphDef() 48 with open(FLAGS.graph, "r") as f: 49 if FLAGS.input_binary: 50 graph.ParseFromString(f.read()) 51 else: 52 text_format.Merge(f.read(), graph) 53 54 with open(FLAGS.dot_output, "wb") as f: 55 print("digraph graphname {", file=f) 56 for node in graph.node: 57 output_name = node.name 58 print(" \"" + output_name + "\" [label=\"" + node.op + "\"];", file=f) 59 for input_full_name in node.input: 60 parts = input_full_name.split(":") 61 input_name = re.sub(r"^\^", "", parts[0]) 62 print(" \"" + input_name + "\" -> \"" + output_name + "\";", file=f) 63 print("}", file=f) 64 print("Created DOT file '" + FLAGS.dot_output + "'.") 65 66 67 if __name__ == "__main__": 68 app.run() 69