Home | History | Annotate | Download | only in pyct
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Print an AST tree in a form more readable than ast.dump."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import gast
     22 import termcolor
     23 
     24 
     25 class PrettyPrinter(gast.NodeVisitor):
     26   """Print AST nodes."""
     27 
     28   def __init__(self, color):
     29     self.indent_lvl = 0
     30     self.result = ''
     31     self.color = color
     32 
     33   def _color(self, string, color, attrs=None):
     34     if self.color:
     35       return termcolor.colored(string, color, attrs=attrs)
     36     return string
     37 
     38   def _type(self, node):
     39     return self._color(node.__class__.__name__, None, ['bold'])
     40 
     41   def _field(self, name):
     42     return self._color(name, 'blue')
     43 
     44   def _value(self, name):
     45     return self._color(name, 'magenta')
     46 
     47   def _warning(self, name):
     48     return self._color(name, 'red')
     49 
     50   def _indent(self):
     51     return self._color('| ' * self.indent_lvl, None, ['dark'])
     52 
     53   def _print(self, s):
     54     self.result += s
     55     self.result += '\n'
     56 
     57   def generic_visit(self, node, name=None):
     58     if node._fields:
     59       cont = ':'
     60     else:
     61       cont = '()'
     62 
     63     if name:
     64       self._print('%s%s=%s%s' % (self._indent(), self._field(name),
     65                                  self._type(node), cont))
     66     else:
     67       self._print('%s%s%s' % (self._indent(), self._type(node), cont))
     68 
     69     self.indent_lvl += 1
     70     for f in node._fields:
     71       if not hasattr(node, f):
     72         self._print('%s%s' % (self._indent(), self._warning('%s=<unset>' % f)))
     73         continue
     74       v = getattr(node, f)
     75       if isinstance(v, list):
     76         if v:
     77           self._print('%s%s=[' % (self._indent(), self._field(f)))
     78           self.indent_lvl += 1
     79           for n in v:
     80             self.generic_visit(n)
     81           self.indent_lvl -= 1
     82           self._print('%s]' % (self._indent()))
     83         else:
     84           self._print('%s%s=[]' % (self._indent(), self._field(f)))
     85       elif isinstance(v, tuple):
     86         if v:
     87           self._print('%s%s=(' % (self._indent(), self._field(f)))
     88           self.indent_lvl += 1
     89           for n in v:
     90             self.generic_visit(n)
     91           self.indent_lvl -= 1
     92           self._print('%s)' % (self._indent()))
     93         else:
     94           self._print('%s%s=()' % (self._indent(), self._field(f)))
     95       elif isinstance(v, gast.AST):
     96         self.generic_visit(v, f)
     97       elif isinstance(v, str):
     98         self._print('%s%s=%s' % (self._indent(), self._field(f),
     99                                  self._value('"%s"' % v)))
    100       else:
    101         self._print('%s%s=%s' % (self._indent(), self._field(f),
    102                                  self._value(v)))
    103     self.indent_lvl -= 1
    104 
    105 
    106 def fmt(node, color=True):
    107   printer = PrettyPrinter(color)
    108   if isinstance(node, (list, tuple)):
    109     for n in node:
    110       printer.visit(n)
    111   else:
    112     printer.visit(node)
    113   return printer.result
    114