Home | History | Annotate | Download | only in Lib
      1 """
      2     ast
      3     ~~~
      4 
      5     The `ast` module helps Python applications to process trees of the Python
      6     abstract syntax grammar.  The abstract syntax itself might change with
      7     each Python release; this module helps to find out programmatically what
      8     the current grammar looks like and allows modifications of it.
      9 
     10     An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
     11     a flag to the `compile()` builtin function or by using the `parse()`
     12     function from this module.  The result will be a tree of objects whose
     13     classes all inherit from `ast.AST`.
     14 
     15     A modified abstract syntax tree can be compiled into a Python code object
     16     using the built-in `compile()` function.
     17 
     18     Additionally various helper functions are provided that make working with
     19     the trees simpler.  The main intention of the helper functions and this
     20     module in general is to provide an easy to use interface for libraries
     21     that work tightly with the python syntax (template engines for example).
     22 
     23 
     24     :copyright: Copyright 2008 by Armin Ronacher.
     25     :license: Python License.
     26 """
     27 from _ast import *
     28 
     29 
     30 def parse(source, filename='<unknown>', mode='exec'):
     31     """
     32     Parse the source into an AST node.
     33     Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
     34     """
     35     return compile(source, filename, mode, PyCF_ONLY_AST)
     36 
     37 
     38 _NUM_TYPES = (int, float, complex)
     39 
     40 def literal_eval(node_or_string):
     41     """
     42     Safely evaluate an expression node or a string containing a Python
     43     expression.  The string or node provided may only consist of the following
     44     Python literal structures: strings, bytes, numbers, tuples, lists, dicts,
     45     sets, booleans, and None.
     46     """
     47     if isinstance(node_or_string, str):
     48         node_or_string = parse(node_or_string, mode='eval')
     49     if isinstance(node_or_string, Expression):
     50         node_or_string = node_or_string.body
     51     def _convert(node):
     52         if isinstance(node, Constant):
     53             return node.value
     54         elif isinstance(node, (Str, Bytes)):
     55             return node.s
     56         elif isinstance(node, Num):
     57             return node.n
     58         elif isinstance(node, Tuple):
     59             return tuple(map(_convert, node.elts))
     60         elif isinstance(node, List):
     61             return list(map(_convert, node.elts))
     62         elif isinstance(node, Set):
     63             return set(map(_convert, node.elts))
     64         elif isinstance(node, Dict):
     65             return dict((_convert(k), _convert(v)) for k, v
     66                         in zip(node.keys, node.values))
     67         elif isinstance(node, NameConstant):
     68             return node.value
     69         elif isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
     70             operand = _convert(node.operand)
     71             if isinstance(operand, _NUM_TYPES):
     72                 if isinstance(node.op, UAdd):
     73                     return + operand
     74                 else:
     75                     return - operand
     76         elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
     77             left = _convert(node.left)
     78             right = _convert(node.right)
     79             if isinstance(left, _NUM_TYPES) and isinstance(right, _NUM_TYPES):
     80                 if isinstance(node.op, Add):
     81                     return left + right
     82                 else:
     83                     return left - right
     84         raise ValueError('malformed node or string: ' + repr(node))
     85     return _convert(node_or_string)
     86 
     87 
     88 def dump(node, annotate_fields=True, include_attributes=False):
     89     """
     90     Return a formatted dump of the tree in *node*.  This is mainly useful for
     91     debugging purposes.  The returned string will show the names and the values
     92     for fields.  This makes the code impossible to evaluate, so if evaluation is
     93     wanted *annotate_fields* must be set to False.  Attributes such as line
     94     numbers and column offsets are not dumped by default.  If this is wanted,
     95     *include_attributes* can be set to True.
     96     """
     97     def _format(node):
     98         if isinstance(node, AST):
     99             fields = [(a, _format(b)) for a, b in iter_fields(node)]
    100             rv = '%s(%s' % (node.__class__.__name__, ', '.join(
    101                 ('%s=%s' % field for field in fields)
    102                 if annotate_fields else
    103                 (b for a, b in fields)
    104             ))
    105             if include_attributes and node._attributes:
    106                 rv += fields and ', ' or ' '
    107                 rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
    108                                 for a in node._attributes)
    109             return rv + ')'
    110         elif isinstance(node, list):
    111             return '[%s]' % ', '.join(_format(x) for x in node)
    112         return repr(node)
    113     if not isinstance(node, AST):
    114         raise TypeError('expected AST, got %r' % node.__class__.__name__)
    115     return _format(node)
    116 
    117 
    118 def copy_location(new_node, old_node):
    119     """
    120     Copy source location (`lineno` and `col_offset` attributes) from
    121     *old_node* to *new_node* if possible, and return *new_node*.
    122     """
    123     for attr in 'lineno', 'col_offset':
    124         if attr in old_node._attributes and attr in new_node._attributes \
    125            and hasattr(old_node, attr):
    126             setattr(new_node, attr, getattr(old_node, attr))
    127     return new_node
    128 
    129 
    130 def fix_missing_locations(node):
    131     """
    132     When you compile a node tree with compile(), the compiler expects lineno and
    133     col_offset attributes for every node that supports them.  This is rather
    134     tedious to fill in for generated nodes, so this helper adds these attributes
    135     recursively where not already set, by setting them to the values of the
    136     parent node.  It works recursively starting at *node*.
    137     """
    138     def _fix(node, lineno, col_offset):
    139         if 'lineno' in node._attributes:
    140             if not hasattr(node, 'lineno'):
    141                 node.lineno = lineno
    142             else:
    143                 lineno = node.lineno
    144         if 'col_offset' in node._attributes:
    145             if not hasattr(node, 'col_offset'):
    146                 node.col_offset = col_offset
    147             else:
    148                 col_offset = node.col_offset
    149         for child in iter_child_nodes(node):
    150             _fix(child, lineno, col_offset)
    151     _fix(node, 1, 0)
    152     return node
    153 
    154 
    155 def increment_lineno(node, n=1):
    156     """
    157     Increment the line number of each node in the tree starting at *node* by *n*.
    158     This is useful to "move code" to a different location in a file.
    159     """
    160     for child in walk(node):
    161         if 'lineno' in child._attributes:
    162             child.lineno = getattr(child, 'lineno', 0) + n
    163     return node
    164 
    165 
    166 def iter_fields(node):
    167     """
    168     Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
    169     that is present on *node*.
    170     """
    171     for field in node._fields:
    172         try:
    173             yield field, getattr(node, field)
    174         except AttributeError:
    175             pass
    176 
    177 
    178 def iter_child_nodes(node):
    179     """
    180     Yield all direct child nodes of *node*, that is, all fields that are nodes
    181     and all items of fields that are lists of nodes.
    182     """
    183     for name, field in iter_fields(node):
    184         if isinstance(field, AST):
    185             yield field
    186         elif isinstance(field, list):
    187             for item in field:
    188                 if isinstance(item, AST):
    189                     yield item
    190 
    191 
    192 def get_docstring(node, clean=True):
    193     """
    194     Return the docstring for the given node or None if no docstring can
    195     be found.  If the node provided does not have docstrings a TypeError
    196     will be raised.
    197     """
    198     if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)):
    199         raise TypeError("%r can't have docstrings" % node.__class__.__name__)
    200     if not(node.body and isinstance(node.body[0], Expr)):
    201         return
    202     node = node.body[0].value
    203     if isinstance(node, Str):
    204         text = node.s
    205     elif isinstance(node, Constant) and isinstance(node.value, str):
    206         text = node.value
    207     else:
    208         return
    209     if clean:
    210         import inspect
    211         text = inspect.cleandoc(text)
    212     return text
    213 
    214 
    215 def walk(node):
    216     """
    217     Recursively yield all descendant nodes in the tree starting at *node*
    218     (including *node* itself), in no specified order.  This is useful if you
    219     only want to modify nodes in place and don't care about the context.
    220     """
    221     from collections import deque
    222     todo = deque([node])
    223     while todo:
    224         node = todo.popleft()
    225         todo.extend(iter_child_nodes(node))
    226         yield node
    227 
    228 
    229 class NodeVisitor(object):
    230     """
    231     A node visitor base class that walks the abstract syntax tree and calls a
    232     visitor function for every node found.  This function may return a value
    233     which is forwarded by the `visit` method.
    234 
    235     This class is meant to be subclassed, with the subclass adding visitor
    236     methods.
    237 
    238     Per default the visitor functions for the nodes are ``'visit_'`` +
    239     class name of the node.  So a `TryFinally` node visit function would
    240     be `visit_TryFinally`.  This behavior can be changed by overriding
    241     the `visit` method.  If no visitor function exists for a node
    242     (return value `None`) the `generic_visit` visitor is used instead.
    243 
    244     Don't use the `NodeVisitor` if you want to apply changes to nodes during
    245     traversing.  For this a special visitor exists (`NodeTransformer`) that
    246     allows modifications.
    247     """
    248 
    249     def visit(self, node):
    250         """Visit a node."""
    251         method = 'visit_' + node.__class__.__name__
    252         visitor = getattr(self, method, self.generic_visit)
    253         return visitor(node)
    254 
    255     def generic_visit(self, node):
    256         """Called if no explicit visitor function exists for a node."""
    257         for field, value in iter_fields(node):
    258             if isinstance(value, list):
    259                 for item in value:
    260                     if isinstance(item, AST):
    261                         self.visit(item)
    262             elif isinstance(value, AST):
    263                 self.visit(value)
    264 
    265 
    266 class NodeTransformer(NodeVisitor):
    267     """
    268     A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
    269     allows modification of nodes.
    270 
    271     The `NodeTransformer` will walk the AST and use the return value of the
    272     visitor methods to replace or remove the old node.  If the return value of
    273     the visitor method is ``None``, the node will be removed from its location,
    274     otherwise it is replaced with the return value.  The return value may be the
    275     original node in which case no replacement takes place.
    276 
    277     Here is an example transformer that rewrites all occurrences of name lookups
    278     (``foo``) to ``data['foo']``::
    279 
    280        class RewriteName(NodeTransformer):
    281 
    282            def visit_Name(self, node):
    283                return copy_location(Subscript(
    284                    value=Name(id='data', ctx=Load()),
    285                    slice=Index(value=Str(s=node.id)),
    286                    ctx=node.ctx
    287                ), node)
    288 
    289     Keep in mind that if the node you're operating on has child nodes you must
    290     either transform the child nodes yourself or call the :meth:`generic_visit`
    291     method for the node first.
    292 
    293     For nodes that were part of a collection of statements (that applies to all
    294     statement nodes), the visitor may also return a list of nodes rather than
    295     just a single node.
    296 
    297     Usually you use the transformer like this::
    298 
    299        node = YourTransformer().visit(node)
    300     """
    301 
    302     def generic_visit(self, node):
    303         for field, old_value in iter_fields(node):
    304             if isinstance(old_value, list):
    305                 new_values = []
    306                 for value in old_value:
    307                     if isinstance(value, AST):
    308                         value = self.visit(value)
    309                         if value is None:
    310                             continue
    311                         elif not isinstance(value, AST):
    312                             new_values.extend(value)
    313                             continue
    314                     new_values.append(value)
    315                 old_value[:] = new_values
    316             elif isinstance(old_value, AST):
    317                 new_node = self.visit(old_value)
    318                 if new_node is None:
    319                     delattr(node, field)
    320                 else:
    321                     setattr(node, field, new_node)
    322         return node
    323