1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Print an AST tree in a form more readable than ast.dump.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import gast 22 import termcolor 23 24 25 class PrettyPrinter(gast.NodeVisitor): 26 """Print AST nodes.""" 27 28 def __init__(self, color): 29 self.indent_lvl = 0 30 self.result = '' 31 self.color = color 32 33 def _color(self, string, color, attrs=None): 34 if self.color: 35 return termcolor.colored(string, color, attrs=attrs) 36 return string 37 38 def _type(self, node): 39 return self._color(node.__class__.__name__, None, ['bold']) 40 41 def _field(self, name): 42 return self._color(name, 'blue') 43 44 def _value(self, name): 45 return self._color(name, 'magenta') 46 47 def _warning(self, name): 48 return self._color(name, 'red') 49 50 def _indent(self): 51 return self._color('| ' * self.indent_lvl, None, ['dark']) 52 53 def _print(self, s): 54 self.result += s 55 self.result += '\n' 56 57 def generic_visit(self, node, name=None): 58 if node._fields: 59 cont = ':' 60 else: 61 cont = '()' 62 63 if name: 64 self._print('%s%s=%s%s' % (self._indent(), self._field(name), 65 self._type(node), cont)) 66 else: 67 self._print('%s%s%s' % (self._indent(), self._type(node), cont)) 68 69 self.indent_lvl += 1 70 for f in node._fields: 71 if not hasattr(node, f): 72 self._print('%s%s' % (self._indent(), self._warning('%s=<unset>' % f))) 73 continue 74 v = getattr(node, f) 75 if isinstance(v, list): 76 if v: 77 self._print('%s%s=[' % (self._indent(), self._field(f))) 78 self.indent_lvl += 1 79 for n in v: 80 self.generic_visit(n) 81 self.indent_lvl -= 1 82 self._print('%s]' % (self._indent())) 83 else: 84 self._print('%s%s=[]' % (self._indent(), self._field(f))) 85 elif isinstance(v, tuple): 86 if v: 87 self._print('%s%s=(' % (self._indent(), self._field(f))) 88 self.indent_lvl += 1 89 for n in v: 90 self.generic_visit(n) 91 self.indent_lvl -= 1 92 self._print('%s)' % (self._indent())) 93 else: 94 self._print('%s%s=()' % (self._indent(), self._field(f))) 95 elif isinstance(v, gast.AST): 96 self.generic_visit(v, f) 97 elif isinstance(v, str): 98 self._print('%s%s=%s' % (self._indent(), self._field(f), 99 self._value('"%s"' % v))) 100 else: 101 self._print('%s%s=%s' % (self._indent(), self._field(f), 102 self._value(v))) 103 self.indent_lvl -= 1 104 105 106 def fmt(node, color=True): 107 printer = PrettyPrinter(color) 108 if isinstance(node, (list, tuple)): 109 for n in node: 110 printer.visit(n) 111 else: 112 printer.visit(node) 113 return printer.result 114