Home | History | Annotate | Download | only in lib2to3
      1 """Utility functions, node construction macros, etc."""
      2 # Author: Collin Winter
      3 
      4 from itertools import islice
      5 
      6 # Local imports
      7 from .pgen2 import token
      8 from .pytree import Leaf, Node
      9 from .pygram import python_symbols as syms
     10 from . import patcomp
     11 
     12 
     13 ###########################################################
     14 ### Common node-construction "macros"
     15 ###########################################################
     16 
     17 def KeywordArg(keyword, value):
     18     return Node(syms.argument,
     19                 [keyword, Leaf(token.EQUAL, u"="), value])
     20 
     21 def LParen():
     22     return Leaf(token.LPAR, u"(")
     23 
     24 def RParen():
     25     return Leaf(token.RPAR, u")")
     26 
     27 def Assign(target, source):
     28     """Build an assignment statement"""
     29     if not isinstance(target, list):
     30         target = [target]
     31     if not isinstance(source, list):
     32         source.prefix = u" "
     33         source = [source]
     34 
     35     return Node(syms.atom,
     36                 target + [Leaf(token.EQUAL, u"=", prefix=u" ")] + source)
     37 
     38 def Name(name, prefix=None):
     39     """Return a NAME leaf"""
     40     return Leaf(token.NAME, name, prefix=prefix)
     41 
     42 def Attr(obj, attr):
     43     """A node tuple for obj.attr"""
     44     return [obj, Node(syms.trailer, [Dot(), attr])]
     45 
     46 def Comma():
     47     """A comma leaf"""
     48     return Leaf(token.COMMA, u",")
     49 
     50 def Dot():
     51     """A period (.) leaf"""
     52     return Leaf(token.DOT, u".")
     53 
     54 def ArgList(args, lparen=LParen(), rparen=RParen()):
     55     """A parenthesised argument list, used by Call()"""
     56     node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
     57     if args:
     58         node.insert_child(1, Node(syms.arglist, args))
     59     return node
     60 
     61 def Call(func_name, args=None, prefix=None):
     62     """A function call"""
     63     node = Node(syms.power, [func_name, ArgList(args)])
     64     if prefix is not None:
     65         node.prefix = prefix
     66     return node
     67 
     68 def Newline():
     69     """A newline literal"""
     70     return Leaf(token.NEWLINE, u"\n")
     71 
     72 def BlankLine():
     73     """A blank line"""
     74     return Leaf(token.NEWLINE, u"")
     75 
     76 def Number(n, prefix=None):
     77     return Leaf(token.NUMBER, n, prefix=prefix)
     78 
     79 def Subscript(index_node):
     80     """A numeric or string subscript"""
     81     return Node(syms.trailer, [Leaf(token.LBRACE, u"["),
     82                                index_node,
     83                                Leaf(token.RBRACE, u"]")])
     84 
     85 def String(string, prefix=None):
     86     """A string leaf"""
     87     return Leaf(token.STRING, string, prefix=prefix)
     88 
     89 def ListComp(xp, fp, it, test=None):
     90     """A list comprehension of the form [xp for fp in it if test].
     91 
     92     If test is None, the "if test" part is omitted.
     93     """
     94     xp.prefix = u""
     95     fp.prefix = u" "
     96     it.prefix = u" "
     97     for_leaf = Leaf(token.NAME, u"for")
     98     for_leaf.prefix = u" "
     99     in_leaf = Leaf(token.NAME, u"in")
    100     in_leaf.prefix = u" "
    101     inner_args = [for_leaf, fp, in_leaf, it]
    102     if test:
    103         test.prefix = u" "
    104         if_leaf = Leaf(token.NAME, u"if")
    105         if_leaf.prefix = u" "
    106         inner_args.append(Node(syms.comp_if, [if_leaf, test]))
    107     inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
    108     return Node(syms.atom,
    109                        [Leaf(token.LBRACE, u"["),
    110                         inner,
    111                         Leaf(token.RBRACE, u"]")])
    112 
    113 def FromImport(package_name, name_leafs):
    114     """ Return an import statement in the form:
    115         from package import name_leafs"""
    116     # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
    117     #assert package_name == '.' or '.' not in package_name, "FromImport has "\
    118     #       "not been tested with dotted package names -- use at your own "\
    119     #       "peril!"
    120 
    121     for leaf in name_leafs:
    122         # Pull the leaves out of their old tree
    123         leaf.remove()
    124 
    125     children = [Leaf(token.NAME, u"from"),
    126                 Leaf(token.NAME, package_name, prefix=u" "),
    127                 Leaf(token.NAME, u"import", prefix=u" "),
    128                 Node(syms.import_as_names, name_leafs)]
    129     imp = Node(syms.import_from, children)
    130     return imp
    131 
    132 
    133 ###########################################################
    134 ### Determine whether a node represents a given literal
    135 ###########################################################
    136 
    137 def is_tuple(node):
    138     """Does the node represent a tuple literal?"""
    139     if isinstance(node, Node) and node.children == [LParen(), RParen()]:
    140         return True
    141     return (isinstance(node, Node)
    142             and len(node.children) == 3
    143             and isinstance(node.children[0], Leaf)
    144             and isinstance(node.children[1], Node)
    145             and isinstance(node.children[2], Leaf)
    146             and node.children[0].value == u"("
    147             and node.children[2].value == u")")
    148 
    149 def is_list(node):
    150     """Does the node represent a list literal?"""
    151     return (isinstance(node, Node)
    152             and len(node.children) > 1
    153             and isinstance(node.children[0], Leaf)
    154             and isinstance(node.children[-1], Leaf)
    155             and node.children[0].value == u"["
    156             and node.children[-1].value == u"]")
    157 
    158 
    159 ###########################################################
    160 ### Misc
    161 ###########################################################
    162 
    163 def parenthesize(node):
    164     return Node(syms.atom, [LParen(), node, RParen()])
    165 
    166 
    167 consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum",
    168                        "min", "max", "enumerate"])
    169 
    170 def attr_chain(obj, attr):
    171     """Follow an attribute chain.
    172 
    173     If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
    174     use this to iterate over all objects in the chain. Iteration is
    175     terminated by getattr(x, attr) is None.
    176 
    177     Args:
    178         obj: the starting object
    179         attr: the name of the chaining attribute
    180 
    181     Yields:
    182         Each successive object in the chain.
    183     """
    184     next = getattr(obj, attr)
    185     while next:
    186         yield next
    187         next = getattr(next, attr)
    188 
    189 p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
    190         | comp_for< 'for' any 'in' node=any any* >
    191      """
    192 p1 = """
    193 power<
    194     ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
    195       'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
    196     trailer< '(' node=any ')' >
    197     any*
    198 >
    199 """
    200 p2 = """
    201 power<
    202     ( 'sorted' | 'enumerate' )
    203     trailer< '(' arglist<node=any any*> ')' >
    204     any*
    205 >
    206 """
    207 pats_built = False
    208 def in_special_context(node):
    209     """ Returns true if node is in an environment where all that is required
    210         of it is being iterable (ie, it doesn't matter if it returns a list
    211         or an iterator).
    212         See test_map_nochange in test_fixers.py for some examples and tests.
    213         """
    214     global p0, p1, p2, pats_built
    215     if not pats_built:
    216         p0 = patcomp.compile_pattern(p0)
    217         p1 = patcomp.compile_pattern(p1)
    218         p2 = patcomp.compile_pattern(p2)
    219         pats_built = True
    220     patterns = [p0, p1, p2]
    221     for pattern, parent in zip(patterns, attr_chain(node, "parent")):
    222         results = {}
    223         if pattern.match(parent, results) and results["node"] is node:
    224             return True
    225     return False
    226 
    227 def is_probably_builtin(node):
    228     """
    229     Check that something isn't an attribute or function name etc.
    230     """
    231     prev = node.prev_sibling
    232     if prev is not None and prev.type == token.DOT:
    233         # Attribute lookup.
    234         return False
    235     parent = node.parent
    236     if parent.type in (syms.funcdef, syms.classdef):
    237         return False
    238     if parent.type == syms.expr_stmt and parent.children[0] is node:
    239         # Assignment.
    240         return False
    241     if parent.type == syms.parameters or \
    242             (parent.type == syms.typedargslist and (
    243             (prev is not None and prev.type == token.COMMA) or
    244             parent.children[0] is node
    245             )):
    246         # The name of an argument.
    247         return False
    248     return True
    249 
    250 def find_indentation(node):
    251     """Find the indentation of *node*."""
    252     while node is not None:
    253         if node.type == syms.suite and len(node.children) > 2:
    254             indent = node.children[1]
    255             if indent.type == token.INDENT:
    256                 return indent.value
    257         node = node.parent
    258     return u""
    259 
    260 ###########################################################
    261 ### The following functions are to find bindings in a suite
    262 ###########################################################
    263 
    264 def make_suite(node):
    265     if node.type == syms.suite:
    266         return node
    267     node = node.clone()
    268     parent, node.parent = node.parent, None
    269     suite = Node(syms.suite, [node])
    270     suite.parent = parent
    271     return suite
    272 
    273 def find_root(node):
    274     """Find the top level namespace."""
    275     # Scamper up to the top level namespace
    276     while node.type != syms.file_input:
    277         node = node.parent
    278         if not node:
    279             raise ValueError("root found before file_input node was found.")
    280     return node
    281 
    282 def does_tree_import(package, name, node):
    283     """ Returns true if name is imported from package at the
    284         top level of the tree which node belongs to.
    285         To cover the case of an import like 'import foo', use
    286         None for the package and 'foo' for the name. """
    287     binding = find_binding(name, find_root(node), package)
    288     return bool(binding)
    289 
    290 def is_import(node):
    291     """Returns true if the node is an import statement."""
    292     return node.type in (syms.import_name, syms.import_from)
    293 
    294 def touch_import(package, name, node):
    295     """ Works like `does_tree_import` but adds an import statement
    296         if it was not imported. """
    297     def is_import_stmt(node):
    298         return (node.type == syms.simple_stmt and node.children and
    299                 is_import(node.children[0]))
    300 
    301     root = find_root(node)
    302 
    303     if does_tree_import(package, name, root):
    304         return
    305 
    306     # figure out where to insert the new import.  First try to find
    307     # the first import and then skip to the last one.
    308     insert_pos = offset = 0
    309     for idx, node in enumerate(root.children):
    310         if not is_import_stmt(node):
    311             continue
    312         for offset, node2 in enumerate(root.children[idx:]):
    313             if not is_import_stmt(node2):
    314                 break
    315         insert_pos = idx + offset
    316         break
    317 
    318     # if there are no imports where we can insert, find the docstring.
    319     # if that also fails, we stick to the beginning of the file
    320     if insert_pos == 0:
    321         for idx, node in enumerate(root.children):
    322             if (node.type == syms.simple_stmt and node.children and
    323                node.children[0].type == token.STRING):
    324                 insert_pos = idx + 1
    325                 break
    326 
    327     if package is None:
    328         import_ = Node(syms.import_name, [
    329             Leaf(token.NAME, u"import"),
    330             Leaf(token.NAME, name, prefix=u" ")
    331         ])
    332     else:
    333         import_ = FromImport(package, [Leaf(token.NAME, name, prefix=u" ")])
    334 
    335     children = [import_, Newline()]
    336     root.insert_child(insert_pos, Node(syms.simple_stmt, children))
    337 
    338 
    339 _def_syms = set([syms.classdef, syms.funcdef])
    340 def find_binding(name, node, package=None):
    341     """ Returns the node which binds variable name, otherwise None.
    342         If optional argument package is supplied, only imports will
    343         be returned.
    344         See test cases for examples."""
    345     for child in node.children:
    346         ret = None
    347         if child.type == syms.for_stmt:
    348             if _find(name, child.children[1]):
    349                 return child
    350             n = find_binding(name, make_suite(child.children[-1]), package)
    351             if n: ret = n
    352         elif child.type in (syms.if_stmt, syms.while_stmt):
    353             n = find_binding(name, make_suite(child.children[-1]), package)
    354             if n: ret = n
    355         elif child.type == syms.try_stmt:
    356             n = find_binding(name, make_suite(child.children[2]), package)
    357             if n:
    358                 ret = n
    359             else:
    360                 for i, kid in enumerate(child.children[3:]):
    361                     if kid.type == token.COLON and kid.value == ":":
    362                         # i+3 is the colon, i+4 is the suite
    363                         n = find_binding(name, make_suite(child.children[i+4]), package)
    364                         if n: ret = n
    365         elif child.type in _def_syms and child.children[1].value == name:
    366             ret = child
    367         elif _is_import_binding(child, name, package):
    368             ret = child
    369         elif child.type == syms.simple_stmt:
    370             ret = find_binding(name, child, package)
    371         elif child.type == syms.expr_stmt:
    372             if _find(name, child.children[0]):
    373                 ret = child
    374 
    375         if ret:
    376             if not package:
    377                 return ret
    378             if is_import(ret):
    379                 return ret
    380     return None
    381 
    382 _block_syms = set([syms.funcdef, syms.classdef, syms.trailer])
    383 def _find(name, node):
    384     nodes = [node]
    385     while nodes:
    386         node = nodes.pop()
    387         if node.type > 256 and node.type not in _block_syms:
    388             nodes.extend(node.children)
    389         elif node.type == token.NAME and node.value == name:
    390             return node
    391     return None
    392 
    393 def _is_import_binding(node, name, package=None):
    394     """ Will reuturn node if node will import name, or node
    395         will import * from package.  None is returned otherwise.
    396         See test cases for examples. """
    397 
    398     if node.type == syms.import_name and not package:
    399         imp = node.children[1]
    400         if imp.type == syms.dotted_as_names:
    401             for child in imp.children:
    402                 if child.type == syms.dotted_as_name:
    403                     if child.children[2].value == name:
    404                         return node
    405                 elif child.type == token.NAME and child.value == name:
    406                     return node
    407         elif imp.type == syms.dotted_as_name:
    408             last = imp.children[-1]
    409             if last.type == token.NAME and last.value == name:
    410                 return node
    411         elif imp.type == token.NAME and imp.value == name:
    412             return node
    413     elif node.type == syms.import_from:
    414         # unicode(...) is used to make life easier here, because
    415         # from a.b import parses to ['import', ['a', '.', 'b'], ...]
    416         if package and unicode(node.children[1]).strip() != package:
    417             return None
    418         n = node.children[3]
    419         if package and _find(u"as", n):
    420             # See test_from_import_as for explanation
    421             return None
    422         elif n.type == syms.import_as_names and _find(name, n):
    423             return node
    424         elif n.type == syms.import_as_name:
    425             child = n.children[2]
    426             if child.type == token.NAME and child.value == name:
    427                 return node
    428         elif n.type == token.NAME and n.value == name:
    429             return node
    430         elif package and n.type == token.STAR:
    431             return node
    432     return None
    433