Home | History | Annotate | Download | only in python2.7
      1 # -*- coding: utf-8 -*-
      2 """
      3     ast
      4     ~~~
      5 
      6     The `ast` module helps Python applications to process trees of the Python
      7     abstract syntax grammar.  The abstract syntax itself might change with
      8     each Python release; this module helps to find out programmatically what
      9     the current grammar looks like and allows modifications of it.
     10 
     11     An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as
     12     a flag to the `compile()` builtin function or by using the `parse()`
     13     function from this module.  The result will be a tree of objects whose
     14     classes all inherit from `ast.AST`.
     15 
     16     A modified abstract syntax tree can be compiled into a Python code object
     17     using the built-in `compile()` function.
     18 
     19     Additionally various helper functions are provided that make working with
     20     the trees simpler.  The main intention of the helper functions and this
     21     module in general is to provide an easy to use interface for libraries
     22     that work tightly with the python syntax (template engines for example).
     23 
     24 
     25     :copyright: Copyright 2008 by Armin Ronacher.
     26     :license: Python License.
     27 """
     28 from _ast import *
     29 from _ast import __version__
     30 
     31 
     32 def parse(source, filename='<unknown>', mode='exec'):
     33     """
     34     Parse the source into an AST node.
     35     Equivalent to compile(source, filename, mode, PyCF_ONLY_AST).
     36     """
     37     return compile(source, filename, mode, PyCF_ONLY_AST)
     38 
     39 
     40 def literal_eval(node_or_string):
     41     """
     42     Safely evaluate an expression node or a string containing a Python
     43     expression.  The string or node provided may only consist of the following
     44     Python literal structures: strings, numbers, tuples, lists, dicts, booleans,
     45     and None.
     46     """
     47     _safe_names = {'None': None, 'True': True, 'False': False}
     48     if isinstance(node_or_string, basestring):
     49         node_or_string = parse(node_or_string, mode='eval')
     50     if isinstance(node_or_string, Expression):
     51         node_or_string = node_or_string.body
     52     def _convert(node):
     53         if isinstance(node, Str):
     54             return node.s
     55         elif isinstance(node, Num):
     56             return node.n
     57         elif isinstance(node, Tuple):
     58             return tuple(map(_convert, node.elts))
     59         elif isinstance(node, List):
     60             return list(map(_convert, node.elts))
     61         elif isinstance(node, Dict):
     62             return dict((_convert(k), _convert(v)) for k, v
     63                         in zip(node.keys, node.values))
     64         elif isinstance(node, Name):
     65             if node.id in _safe_names:
     66                 return _safe_names[node.id]
     67         elif isinstance(node, BinOp) and \
     68              isinstance(node.op, (Add, Sub)) and \
     69              isinstance(node.right, Num) and \
     70              isinstance(node.right.n, complex) and \
     71              isinstance(node.left, Num) and \
     72              isinstance(node.left.n, (int, long, float)):
     73             left = node.left.n
     74             right = node.right.n
     75             if isinstance(node.op, Add):
     76                 return left + right
     77             else:
     78                 return left - right
     79         raise ValueError('malformed string')
     80     return _convert(node_or_string)
     81 
     82 
     83 def dump(node, annotate_fields=True, include_attributes=False):
     84     """
     85     Return a formatted dump of the tree in *node*.  This is mainly useful for
     86     debugging purposes.  The returned string will show the names and the values
     87     for fields.  This makes the code impossible to evaluate, so if evaluation is
     88     wanted *annotate_fields* must be set to False.  Attributes such as line
     89     numbers and column offsets are not dumped by default.  If this is wanted,
     90     *include_attributes* can be set to True.
     91     """
     92     def _format(node):
     93         if isinstance(node, AST):
     94             fields = [(a, _format(b)) for a, b in iter_fields(node)]
     95             rv = '%s(%s' % (node.__class__.__name__, ', '.join(
     96                 ('%s=%s' % field for field in fields)
     97                 if annotate_fields else
     98                 (b for a, b in fields)
     99             ))
    100             if include_attributes and node._attributes:
    101                 rv += fields and ', ' or ' '
    102                 rv += ', '.join('%s=%s' % (a, _format(getattr(node, a)))
    103                                 for a in node._attributes)
    104             return rv + ')'
    105         elif isinstance(node, list):
    106             return '[%s]' % ', '.join(_format(x) for x in node)
    107         return repr(node)
    108     if not isinstance(node, AST):
    109         raise TypeError('expected AST, got %r' % node.__class__.__name__)
    110     return _format(node)
    111 
    112 
    113 def copy_location(new_node, old_node):
    114     """
    115     Copy source location (`lineno` and `col_offset` attributes) from
    116     *old_node* to *new_node* if possible, and return *new_node*.
    117     """
    118     for attr in 'lineno', 'col_offset':
    119         if attr in old_node._attributes and attr in new_node._attributes \
    120            and hasattr(old_node, attr):
    121             setattr(new_node, attr, getattr(old_node, attr))
    122     return new_node
    123 
    124 
    125 def fix_missing_locations(node):
    126     """
    127     When you compile a node tree with compile(), the compiler expects lineno and
    128     col_offset attributes for every node that supports them.  This is rather
    129     tedious to fill in for generated nodes, so this helper adds these attributes
    130     recursively where not already set, by setting them to the values of the
    131     parent node.  It works recursively starting at *node*.
    132     """
    133     def _fix(node, lineno, col_offset):
    134         if 'lineno' in node._attributes:
    135             if not hasattr(node, 'lineno'):
    136                 node.lineno = lineno
    137             else:
    138                 lineno = node.lineno
    139         if 'col_offset' in node._attributes:
    140             if not hasattr(node, 'col_offset'):
    141                 node.col_offset = col_offset
    142             else:
    143                 col_offset = node.col_offset
    144         for child in iter_child_nodes(node):
    145             _fix(child, lineno, col_offset)
    146     _fix(node, 1, 0)
    147     return node
    148 
    149 
    150 def increment_lineno(node, n=1):
    151     """
    152     Increment the line number of each node in the tree starting at *node* by *n*.
    153     This is useful to "move code" to a different location in a file.
    154     """
    155     for child in walk(node):
    156         if 'lineno' in child._attributes:
    157             child.lineno = getattr(child, 'lineno', 0) + n
    158     return node
    159 
    160 
    161 def iter_fields(node):
    162     """
    163     Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields``
    164     that is present on *node*.
    165     """
    166     for field in node._fields:
    167         try:
    168             yield field, getattr(node, field)
    169         except AttributeError:
    170             pass
    171 
    172 
    173 def iter_child_nodes(node):
    174     """
    175     Yield all direct child nodes of *node*, that is, all fields that are nodes
    176     and all items of fields that are lists of nodes.
    177     """
    178     for name, field in iter_fields(node):
    179         if isinstance(field, AST):
    180             yield field
    181         elif isinstance(field, list):
    182             for item in field:
    183                 if isinstance(item, AST):
    184                     yield item
    185 
    186 
    187 def get_docstring(node, clean=True):
    188     """
    189     Return the docstring for the given node or None if no docstring can
    190     be found.  If the node provided does not have docstrings a TypeError
    191     will be raised.
    192     """
    193     if not isinstance(node, (FunctionDef, ClassDef, Module)):
    194         raise TypeError("%r can't have docstrings" % node.__class__.__name__)
    195     if node.body and isinstance(node.body[0], Expr) and \
    196        isinstance(node.body[0].value, Str):
    197         if clean:
    198             import inspect
    199             return inspect.cleandoc(node.body[0].value.s)
    200         return node.body[0].value.s
    201 
    202 
    203 def walk(node):
    204     """
    205     Recursively yield all descendant nodes in the tree starting at *node*
    206     (including *node* itself), in no specified order.  This is useful if you
    207     only want to modify nodes in place and don't care about the context.
    208     """
    209     from collections import deque
    210     todo = deque([node])
    211     while todo:
    212         node = todo.popleft()
    213         todo.extend(iter_child_nodes(node))
    214         yield node
    215 
    216 
    217 class NodeVisitor(object):
    218     """
    219     A node visitor base class that walks the abstract syntax tree and calls a
    220     visitor function for every node found.  This function may return a value
    221     which is forwarded by the `visit` method.
    222 
    223     This class is meant to be subclassed, with the subclass adding visitor
    224     methods.
    225 
    226     Per default the visitor functions for the nodes are ``'visit_'`` +
    227     class name of the node.  So a `TryFinally` node visit function would
    228     be `visit_TryFinally`.  This behavior can be changed by overriding
    229     the `visit` method.  If no visitor function exists for a node
    230     (return value `None`) the `generic_visit` visitor is used instead.
    231 
    232     Don't use the `NodeVisitor` if you want to apply changes to nodes during
    233     traversing.  For this a special visitor exists (`NodeTransformer`) that
    234     allows modifications.
    235     """
    236 
    237     def visit(self, node):
    238         """Visit a node."""
    239         method = 'visit_' + node.__class__.__name__
    240         visitor = getattr(self, method, self.generic_visit)
    241         return visitor(node)
    242 
    243     def generic_visit(self, node):
    244         """Called if no explicit visitor function exists for a node."""
    245         for field, value in iter_fields(node):
    246             if isinstance(value, list):
    247                 for item in value:
    248                     if isinstance(item, AST):
    249                         self.visit(item)
    250             elif isinstance(value, AST):
    251                 self.visit(value)
    252 
    253 
    254 class NodeTransformer(NodeVisitor):
    255     """
    256     A :class:`NodeVisitor` subclass that walks the abstract syntax tree and
    257     allows modification of nodes.
    258 
    259     The `NodeTransformer` will walk the AST and use the return value of the
    260     visitor methods to replace or remove the old node.  If the return value of
    261     the visitor method is ``None``, the node will be removed from its location,
    262     otherwise it is replaced with the return value.  The return value may be the
    263     original node in which case no replacement takes place.
    264 
    265     Here is an example transformer that rewrites all occurrences of name lookups
    266     (``foo``) to ``data['foo']``::
    267 
    268        class RewriteName(NodeTransformer):
    269 
    270            def visit_Name(self, node):
    271                return copy_location(Subscript(
    272                    value=Name(id='data', ctx=Load()),
    273                    slice=Index(value=Str(s=node.id)),
    274                    ctx=node.ctx
    275                ), node)
    276 
    277     Keep in mind that if the node you're operating on has child nodes you must
    278     either transform the child nodes yourself or call the :meth:`generic_visit`
    279     method for the node first.
    280 
    281     For nodes that were part of a collection of statements (that applies to all
    282     statement nodes), the visitor may also return a list of nodes rather than
    283     just a single node.
    284 
    285     Usually you use the transformer like this::
    286 
    287        node = YourTransformer().visit(node)
    288     """
    289 
    290     def generic_visit(self, node):
    291         for field, old_value in iter_fields(node):
    292             old_value = getattr(node, field, None)
    293             if isinstance(old_value, list):
    294                 new_values = []
    295                 for value in old_value:
    296                     if isinstance(value, AST):
    297                         value = self.visit(value)
    298                         if value is None:
    299                             continue
    300                         elif not isinstance(value, AST):
    301                             new_values.extend(value)
    302                             continue
    303                     new_values.append(value)
    304                 old_value[:] = new_values
    305             elif isinstance(old_value, AST):
    306                 new_node = self.visit(old_value)
    307                 if new_node is None:
    308                     delattr(node, field)
    309                 else:
    310                     setattr(node, field, new_node)
    311         return node
    312