Home | History | Annotate | Download | only in lib2to3
      1 """Utility functions, node construction macros, etc."""
      2 # Author: Collin Winter
      3 
      4 # Local imports
      5 from .pgen2 import token
      6 from .pytree import Leaf, Node
      7 from .pygram import python_symbols as syms
      8 from . import patcomp
      9 
     10 
     11 ###########################################################
     12 ### Common node-construction "macros"
     13 ###########################################################
     14 
     15 def KeywordArg(keyword, value):
     16     return Node(syms.argument,
     17                 [keyword, Leaf(token.EQUAL, "="), value])
     18 
     19 def LParen():
     20     return Leaf(token.LPAR, "(")
     21 
     22 def RParen():
     23     return Leaf(token.RPAR, ")")
     24 
     25 def Assign(target, source):
     26     """Build an assignment statement"""
     27     if not isinstance(target, list):
     28         target = [target]
     29     if not isinstance(source, list):
     30         source.prefix = " "
     31         source = [source]
     32 
     33     return Node(syms.atom,
     34                 target + [Leaf(token.EQUAL, "=", prefix=" ")] + source)
     35 
     36 def Name(name, prefix=None):
     37     """Return a NAME leaf"""
     38     return Leaf(token.NAME, name, prefix=prefix)
     39 
     40 def Attr(obj, attr):
     41     """A node tuple for obj.attr"""
     42     return [obj, Node(syms.trailer, [Dot(), attr])]
     43 
     44 def Comma():
     45     """A comma leaf"""
     46     return Leaf(token.COMMA, ",")
     47 
     48 def Dot():
     49     """A period (.) leaf"""
     50     return Leaf(token.DOT, ".")
     51 
     52 def ArgList(args, lparen=LParen(), rparen=RParen()):
     53     """A parenthesised argument list, used by Call()"""
     54     node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
     55     if args:
     56         node.insert_child(1, Node(syms.arglist, args))
     57     return node
     58 
     59 def Call(func_name, args=None, prefix=None):
     60     """A function call"""
     61     node = Node(syms.power, [func_name, ArgList(args)])
     62     if prefix is not None:
     63         node.prefix = prefix
     64     return node
     65 
     66 def Newline():
     67     """A newline literal"""
     68     return Leaf(token.NEWLINE, "\n")
     69 
     70 def BlankLine():
     71     """A blank line"""
     72     return Leaf(token.NEWLINE, "")
     73 
     74 def Number(n, prefix=None):
     75     return Leaf(token.NUMBER, n, prefix=prefix)
     76 
     77 def Subscript(index_node):
     78     """A numeric or string subscript"""
     79     return Node(syms.trailer, [Leaf(token.LBRACE, "["),
     80                                index_node,
     81                                Leaf(token.RBRACE, "]")])
     82 
     83 def String(string, prefix=None):
     84     """A string leaf"""
     85     return Leaf(token.STRING, string, prefix=prefix)
     86 
     87 def ListComp(xp, fp, it, test=None):
     88     """A list comprehension of the form [xp for fp in it if test].
     89 
     90     If test is None, the "if test" part is omitted.
     91     """
     92     xp.prefix = ""
     93     fp.prefix = " "
     94     it.prefix = " "
     95     for_leaf = Leaf(token.NAME, "for")
     96     for_leaf.prefix = " "
     97     in_leaf = Leaf(token.NAME, "in")
     98     in_leaf.prefix = " "
     99     inner_args = [for_leaf, fp, in_leaf, it]
    100     if test:
    101         test.prefix = " "
    102         if_leaf = Leaf(token.NAME, "if")
    103         if_leaf.prefix = " "
    104         inner_args.append(Node(syms.comp_if, [if_leaf, test]))
    105     inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
    106     return Node(syms.atom,
    107                        [Leaf(token.LBRACE, "["),
    108                         inner,
    109                         Leaf(token.RBRACE, "]")])
    110 
    111 def FromImport(package_name, name_leafs):
    112     """ Return an import statement in the form:
    113         from package import name_leafs"""
    114     # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
    115     #assert package_name == '.' or '.' not in package_name, "FromImport has "\
    116     #       "not been tested with dotted package names -- use at your own "\
    117     #       "peril!"
    118 
    119     for leaf in name_leafs:
    120         # Pull the leaves out of their old tree
    121         leaf.remove()
    122 
    123     children = [Leaf(token.NAME, "from"),
    124                 Leaf(token.NAME, package_name, prefix=" "),
    125                 Leaf(token.NAME, "import", prefix=" "),
    126                 Node(syms.import_as_names, name_leafs)]
    127     imp = Node(syms.import_from, children)
    128     return imp
    129 
    130 def ImportAndCall(node, results, names):
    131     """Returns an import statement and calls a method
    132     of the module:
    133 
    134     import module
    135     module.name()"""
    136     obj = results["obj"].clone()
    137     if obj.type == syms.arglist:
    138         newarglist = obj.clone()
    139     else:
    140         newarglist = Node(syms.arglist, [obj.clone()])
    141     after = results["after"]
    142     if after:
    143         after = [n.clone() for n in after]
    144     new = Node(syms.power,
    145                Attr(Name(names[0]), Name(names[1])) +
    146                [Node(syms.trailer,
    147                      [results["lpar"].clone(),
    148                       newarglist,
    149                       results["rpar"].clone()])] + after)
    150     new.prefix = node.prefix
    151     return new
    152 
    153 
    154 ###########################################################
    155 ### Determine whether a node represents a given literal
    156 ###########################################################
    157 
    158 def is_tuple(node):
    159     """Does the node represent a tuple literal?"""
    160     if isinstance(node, Node) and node.children == [LParen(), RParen()]:
    161         return True
    162     return (isinstance(node, Node)
    163             and len(node.children) == 3
    164             and isinstance(node.children[0], Leaf)
    165             and isinstance(node.children[1], Node)
    166             and isinstance(node.children[2], Leaf)
    167             and node.children[0].value == "("
    168             and node.children[2].value == ")")
    169 
    170 def is_list(node):
    171     """Does the node represent a list literal?"""
    172     return (isinstance(node, Node)
    173             and len(node.children) > 1
    174             and isinstance(node.children[0], Leaf)
    175             and isinstance(node.children[-1], Leaf)
    176             and node.children[0].value == "["
    177             and node.children[-1].value == "]")
    178 
    179 
    180 ###########################################################
    181 ### Misc
    182 ###########################################################
    183 
    184 def parenthesize(node):
    185     return Node(syms.atom, [LParen(), node, RParen()])
    186 
    187 
    188 consuming_calls = {"sorted", "list", "set", "any", "all", "tuple", "sum",
    189                    "min", "max", "enumerate"}
    190 
    191 def attr_chain(obj, attr):
    192     """Follow an attribute chain.
    193 
    194     If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
    195     use this to iterate over all objects in the chain. Iteration is
    196     terminated by getattr(x, attr) is None.
    197 
    198     Args:
    199         obj: the starting object
    200         attr: the name of the chaining attribute
    201 
    202     Yields:
    203         Each successive object in the chain.
    204     """
    205     next = getattr(obj, attr)
    206     while next:
    207         yield next
    208         next = getattr(next, attr)
    209 
    210 p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
    211         | comp_for< 'for' any 'in' node=any any* >
    212      """
    213 p1 = """
    214 power<
    215     ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
    216       'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
    217     trailer< '(' node=any ')' >
    218     any*
    219 >
    220 """
    221 p2 = """
    222 power<
    223     ( 'sorted' | 'enumerate' )
    224     trailer< '(' arglist<node=any any*> ')' >
    225     any*
    226 >
    227 """
    228 pats_built = False
    229 def in_special_context(node):
    230     """ Returns true if node is in an environment where all that is required
    231         of it is being iterable (ie, it doesn't matter if it returns a list
    232         or an iterator).
    233         See test_map_nochange in test_fixers.py for some examples and tests.
    234         """
    235     global p0, p1, p2, pats_built
    236     if not pats_built:
    237         p0 = patcomp.compile_pattern(p0)
    238         p1 = patcomp.compile_pattern(p1)
    239         p2 = patcomp.compile_pattern(p2)
    240         pats_built = True
    241     patterns = [p0, p1, p2]
    242     for pattern, parent in zip(patterns, attr_chain(node, "parent")):
    243         results = {}
    244         if pattern.match(parent, results) and results["node"] is node:
    245             return True
    246     return False
    247 
    248 def is_probably_builtin(node):
    249     """
    250     Check that something isn't an attribute or function name etc.
    251     """
    252     prev = node.prev_sibling
    253     if prev is not None and prev.type == token.DOT:
    254         # Attribute lookup.
    255         return False
    256     parent = node.parent
    257     if parent.type in (syms.funcdef, syms.classdef):
    258         return False
    259     if parent.type == syms.expr_stmt and parent.children[0] is node:
    260         # Assignment.
    261         return False
    262     if parent.type == syms.parameters or \
    263             (parent.type == syms.typedargslist and (
    264             (prev is not None and prev.type == token.COMMA) or
    265             parent.children[0] is node
    266             )):
    267         # The name of an argument.
    268         return False
    269     return True
    270 
    271 def find_indentation(node):
    272     """Find the indentation of *node*."""
    273     while node is not None:
    274         if node.type == syms.suite and len(node.children) > 2:
    275             indent = node.children[1]
    276             if indent.type == token.INDENT:
    277                 return indent.value
    278         node = node.parent
    279     return ""
    280 
    281 ###########################################################
    282 ### The following functions are to find bindings in a suite
    283 ###########################################################
    284 
    285 def make_suite(node):
    286     if node.type == syms.suite:
    287         return node
    288     node = node.clone()
    289     parent, node.parent = node.parent, None
    290     suite = Node(syms.suite, [node])
    291     suite.parent = parent
    292     return suite
    293 
    294 def find_root(node):
    295     """Find the top level namespace."""
    296     # Scamper up to the top level namespace
    297     while node.type != syms.file_input:
    298         node = node.parent
    299         if not node:
    300             raise ValueError("root found before file_input node was found.")
    301     return node
    302 
    303 def does_tree_import(package, name, node):
    304     """ Returns true if name is imported from package at the
    305         top level of the tree which node belongs to.
    306         To cover the case of an import like 'import foo', use
    307         None for the package and 'foo' for the name. """
    308     binding = find_binding(name, find_root(node), package)
    309     return bool(binding)
    310 
    311 def is_import(node):
    312     """Returns true if the node is an import statement."""
    313     return node.type in (syms.import_name, syms.import_from)
    314 
    315 def touch_import(package, name, node):
    316     """ Works like `does_tree_import` but adds an import statement
    317         if it was not imported. """
    318     def is_import_stmt(node):
    319         return (node.type == syms.simple_stmt and node.children and
    320                 is_import(node.children[0]))
    321 
    322     root = find_root(node)
    323 
    324     if does_tree_import(package, name, root):
    325         return
    326 
    327     # figure out where to insert the new import.  First try to find
    328     # the first import and then skip to the last one.
    329     insert_pos = offset = 0
    330     for idx, node in enumerate(root.children):
    331         if not is_import_stmt(node):
    332             continue
    333         for offset, node2 in enumerate(root.children[idx:]):
    334             if not is_import_stmt(node2):
    335                 break
    336         insert_pos = idx + offset
    337         break
    338 
    339     # if there are no imports where we can insert, find the docstring.
    340     # if that also fails, we stick to the beginning of the file
    341     if insert_pos == 0:
    342         for idx, node in enumerate(root.children):
    343             if (node.type == syms.simple_stmt and node.children and
    344                node.children[0].type == token.STRING):
    345                 insert_pos = idx + 1
    346                 break
    347 
    348     if package is None:
    349         import_ = Node(syms.import_name, [
    350             Leaf(token.NAME, "import"),
    351             Leaf(token.NAME, name, prefix=" ")
    352         ])
    353     else:
    354         import_ = FromImport(package, [Leaf(token.NAME, name, prefix=" ")])
    355 
    356     children = [import_, Newline()]
    357     root.insert_child(insert_pos, Node(syms.simple_stmt, children))
    358 
    359 
    360 _def_syms = {syms.classdef, syms.funcdef}
    361 def find_binding(name, node, package=None):
    362     """ Returns the node which binds variable name, otherwise None.
    363         If optional argument package is supplied, only imports will
    364         be returned.
    365         See test cases for examples."""
    366     for child in node.children:
    367         ret = None
    368         if child.type == syms.for_stmt:
    369             if _find(name, child.children[1]):
    370                 return child
    371             n = find_binding(name, make_suite(child.children[-1]), package)
    372             if n: ret = n
    373         elif child.type in (syms.if_stmt, syms.while_stmt):
    374             n = find_binding(name, make_suite(child.children[-1]), package)
    375             if n: ret = n
    376         elif child.type == syms.try_stmt:
    377             n = find_binding(name, make_suite(child.children[2]), package)
    378             if n:
    379                 ret = n
    380             else:
    381                 for i, kid in enumerate(child.children[3:]):
    382                     if kid.type == token.COLON and kid.value == ":":
    383                         # i+3 is the colon, i+4 is the suite
    384                         n = find_binding(name, make_suite(child.children[i+4]), package)
    385                         if n: ret = n
    386         elif child.type in _def_syms and child.children[1].value == name:
    387             ret = child
    388         elif _is_import_binding(child, name, package):
    389             ret = child
    390         elif child.type == syms.simple_stmt:
    391             ret = find_binding(name, child, package)
    392         elif child.type == syms.expr_stmt:
    393             if _find(name, child.children[0]):
    394                 ret = child
    395 
    396         if ret:
    397             if not package:
    398                 return ret
    399             if is_import(ret):
    400                 return ret
    401     return None
    402 
    403 _block_syms = {syms.funcdef, syms.classdef, syms.trailer}
    404 def _find(name, node):
    405     nodes = [node]
    406     while nodes:
    407         node = nodes.pop()
    408         if node.type > 256 and node.type not in _block_syms:
    409             nodes.extend(node.children)
    410         elif node.type == token.NAME and node.value == name:
    411             return node
    412     return None
    413 
    414 def _is_import_binding(node, name, package=None):
    415     """ Will reuturn node if node will import name, or node
    416         will import * from package.  None is returned otherwise.
    417         See test cases for examples. """
    418 
    419     if node.type == syms.import_name and not package:
    420         imp = node.children[1]
    421         if imp.type == syms.dotted_as_names:
    422             for child in imp.children:
    423                 if child.type == syms.dotted_as_name:
    424                     if child.children[2].value == name:
    425                         return node
    426                 elif child.type == token.NAME and child.value == name:
    427                     return node
    428         elif imp.type == syms.dotted_as_name:
    429             last = imp.children[-1]
    430             if last.type == token.NAME and last.value == name:
    431                 return node
    432         elif imp.type == token.NAME and imp.value == name:
    433             return node
    434     elif node.type == syms.import_from:
    435         # str(...) is used to make life easier here, because
    436         # from a.b import parses to ['import', ['a', '.', 'b'], ...]
    437         if package and str(node.children[1]).strip() != package:
    438             return None
    439         n = node.children[3]
    440         if package and _find("as", n):
    441             # See test_from_import_as for explanation
    442             return None
    443         elif n.type == syms.import_as_names and _find(name, n):
    444             return node
    445         elif n.type == syms.import_as_name:
    446             child = n.children[2]
    447             if child.type == token.NAME and child.value == name:
    448                 return node
    449         elif n.type == token.NAME and n.value == name:
    450             return node
    451         elif package and n.type == token.STAR:
    452             return node
    453     return None
    454