1 """Utility functions, node construction macros, etc.""" 2 # Author: Collin Winter 3 4 from itertools import islice 5 6 # Local imports 7 from .pgen2 import token 8 from .pytree import Leaf, Node 9 from .pygram import python_symbols as syms 10 from . import patcomp 11 12 13 ########################################################### 14 ### Common node-construction "macros" 15 ########################################################### 16 17 def KeywordArg(keyword, value): 18 return Node(syms.argument, 19 [keyword, Leaf(token.EQUAL, u"="), value]) 20 21 def LParen(): 22 return Leaf(token.LPAR, u"(") 23 24 def RParen(): 25 return Leaf(token.RPAR, u")") 26 27 def Assign(target, source): 28 """Build an assignment statement""" 29 if not isinstance(target, list): 30 target = [target] 31 if not isinstance(source, list): 32 source.prefix = u" " 33 source = [source] 34 35 return Node(syms.atom, 36 target + [Leaf(token.EQUAL, u"=", prefix=u" ")] + source) 37 38 def Name(name, prefix=None): 39 """Return a NAME leaf""" 40 return Leaf(token.NAME, name, prefix=prefix) 41 42 def Attr(obj, attr): 43 """A node tuple for obj.attr""" 44 return [obj, Node(syms.trailer, [Dot(), attr])] 45 46 def Comma(): 47 """A comma leaf""" 48 return Leaf(token.COMMA, u",") 49 50 def Dot(): 51 """A period (.) leaf""" 52 return Leaf(token.DOT, u".") 53 54 def ArgList(args, lparen=LParen(), rparen=RParen()): 55 """A parenthesised argument list, used by Call()""" 56 node = Node(syms.trailer, [lparen.clone(), rparen.clone()]) 57 if args: 58 node.insert_child(1, Node(syms.arglist, args)) 59 return node 60 61 def Call(func_name, args=None, prefix=None): 62 """A function call""" 63 node = Node(syms.power, [func_name, ArgList(args)]) 64 if prefix is not None: 65 node.prefix = prefix 66 return node 67 68 def Newline(): 69 """A newline literal""" 70 return Leaf(token.NEWLINE, u"\n") 71 72 def BlankLine(): 73 """A blank line""" 74 return Leaf(token.NEWLINE, u"") 75 76 def Number(n, prefix=None): 77 return Leaf(token.NUMBER, n, prefix=prefix) 78 79 def Subscript(index_node): 80 """A numeric or string subscript""" 81 return Node(syms.trailer, [Leaf(token.LBRACE, u"["), 82 index_node, 83 Leaf(token.RBRACE, u"]")]) 84 85 def String(string, prefix=None): 86 """A string leaf""" 87 return Leaf(token.STRING, string, prefix=prefix) 88 89 def ListComp(xp, fp, it, test=None): 90 """A list comprehension of the form [xp for fp in it if test]. 91 92 If test is None, the "if test" part is omitted. 93 """ 94 xp.prefix = u"" 95 fp.prefix = u" " 96 it.prefix = u" " 97 for_leaf = Leaf(token.NAME, u"for") 98 for_leaf.prefix = u" " 99 in_leaf = Leaf(token.NAME, u"in") 100 in_leaf.prefix = u" " 101 inner_args = [for_leaf, fp, in_leaf, it] 102 if test: 103 test.prefix = u" " 104 if_leaf = Leaf(token.NAME, u"if") 105 if_leaf.prefix = u" " 106 inner_args.append(Node(syms.comp_if, [if_leaf, test])) 107 inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)]) 108 return Node(syms.atom, 109 [Leaf(token.LBRACE, u"["), 110 inner, 111 Leaf(token.RBRACE, u"]")]) 112 113 def FromImport(package_name, name_leafs): 114 """ Return an import statement in the form: 115 from package import name_leafs""" 116 # XXX: May not handle dotted imports properly (eg, package_name='foo.bar') 117 #assert package_name == '.' or '.' not in package_name, "FromImport has "\ 118 # "not been tested with dotted package names -- use at your own "\ 119 # "peril!" 120 121 for leaf in name_leafs: 122 # Pull the leaves out of their old tree 123 leaf.remove() 124 125 children = [Leaf(token.NAME, u"from"), 126 Leaf(token.NAME, package_name, prefix=u" "), 127 Leaf(token.NAME, u"import", prefix=u" "), 128 Node(syms.import_as_names, name_leafs)] 129 imp = Node(syms.import_from, children) 130 return imp 131 132 133 ########################################################### 134 ### Determine whether a node represents a given literal 135 ########################################################### 136 137 def is_tuple(node): 138 """Does the node represent a tuple literal?""" 139 if isinstance(node, Node) and node.children == [LParen(), RParen()]: 140 return True 141 return (isinstance(node, Node) 142 and len(node.children) == 3 143 and isinstance(node.children[0], Leaf) 144 and isinstance(node.children[1], Node) 145 and isinstance(node.children[2], Leaf) 146 and node.children[0].value == u"(" 147 and node.children[2].value == u")") 148 149 def is_list(node): 150 """Does the node represent a list literal?""" 151 return (isinstance(node, Node) 152 and len(node.children) > 1 153 and isinstance(node.children[0], Leaf) 154 and isinstance(node.children[-1], Leaf) 155 and node.children[0].value == u"[" 156 and node.children[-1].value == u"]") 157 158 159 ########################################################### 160 ### Misc 161 ########################################################### 162 163 def parenthesize(node): 164 return Node(syms.atom, [LParen(), node, RParen()]) 165 166 167 consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum", 168 "min", "max", "enumerate"]) 169 170 def attr_chain(obj, attr): 171 """Follow an attribute chain. 172 173 If you have a chain of objects where a.foo -> b, b.foo-> c, etc, 174 use this to iterate over all objects in the chain. Iteration is 175 terminated by getattr(x, attr) is None. 176 177 Args: 178 obj: the starting object 179 attr: the name of the chaining attribute 180 181 Yields: 182 Each successive object in the chain. 183 """ 184 next = getattr(obj, attr) 185 while next: 186 yield next 187 next = getattr(next, attr) 188 189 p0 = """for_stmt< 'for' any 'in' node=any ':' any* > 190 | comp_for< 'for' any 'in' node=any any* > 191 """ 192 p1 = """ 193 power< 194 ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' | 195 'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) ) 196 trailer< '(' node=any ')' > 197 any* 198 > 199 """ 200 p2 = """ 201 power< 202 ( 'sorted' | 'enumerate' ) 203 trailer< '(' arglist<node=any any*> ')' > 204 any* 205 > 206 """ 207 pats_built = False 208 def in_special_context(node): 209 """ Returns true if node is in an environment where all that is required 210 of it is being iterable (ie, it doesn't matter if it returns a list 211 or an iterator). 212 See test_map_nochange in test_fixers.py for some examples and tests. 213 """ 214 global p0, p1, p2, pats_built 215 if not pats_built: 216 p0 = patcomp.compile_pattern(p0) 217 p1 = patcomp.compile_pattern(p1) 218 p2 = patcomp.compile_pattern(p2) 219 pats_built = True 220 patterns = [p0, p1, p2] 221 for pattern, parent in zip(patterns, attr_chain(node, "parent")): 222 results = {} 223 if pattern.match(parent, results) and results["node"] is node: 224 return True 225 return False 226 227 def is_probably_builtin(node): 228 """ 229 Check that something isn't an attribute or function name etc. 230 """ 231 prev = node.prev_sibling 232 if prev is not None and prev.type == token.DOT: 233 # Attribute lookup. 234 return False 235 parent = node.parent 236 if parent.type in (syms.funcdef, syms.classdef): 237 return False 238 if parent.type == syms.expr_stmt and parent.children[0] is node: 239 # Assignment. 240 return False 241 if parent.type == syms.parameters or \ 242 (parent.type == syms.typedargslist and ( 243 (prev is not None and prev.type == token.COMMA) or 244 parent.children[0] is node 245 )): 246 # The name of an argument. 247 return False 248 return True 249 250 def find_indentation(node): 251 """Find the indentation of *node*.""" 252 while node is not None: 253 if node.type == syms.suite and len(node.children) > 2: 254 indent = node.children[1] 255 if indent.type == token.INDENT: 256 return indent.value 257 node = node.parent 258 return u"" 259 260 ########################################################### 261 ### The following functions are to find bindings in a suite 262 ########################################################### 263 264 def make_suite(node): 265 if node.type == syms.suite: 266 return node 267 node = node.clone() 268 parent, node.parent = node.parent, None 269 suite = Node(syms.suite, [node]) 270 suite.parent = parent 271 return suite 272 273 def find_root(node): 274 """Find the top level namespace.""" 275 # Scamper up to the top level namespace 276 while node.type != syms.file_input: 277 node = node.parent 278 if not node: 279 raise ValueError("root found before file_input node was found.") 280 return node 281 282 def does_tree_import(package, name, node): 283 """ Returns true if name is imported from package at the 284 top level of the tree which node belongs to. 285 To cover the case of an import like 'import foo', use 286 None for the package and 'foo' for the name. """ 287 binding = find_binding(name, find_root(node), package) 288 return bool(binding) 289 290 def is_import(node): 291 """Returns true if the node is an import statement.""" 292 return node.type in (syms.import_name, syms.import_from) 293 294 def touch_import(package, name, node): 295 """ Works like `does_tree_import` but adds an import statement 296 if it was not imported. """ 297 def is_import_stmt(node): 298 return (node.type == syms.simple_stmt and node.children and 299 is_import(node.children[0])) 300 301 root = find_root(node) 302 303 if does_tree_import(package, name, root): 304 return 305 306 # figure out where to insert the new import. First try to find 307 # the first import and then skip to the last one. 308 insert_pos = offset = 0 309 for idx, node in enumerate(root.children): 310 if not is_import_stmt(node): 311 continue 312 for offset, node2 in enumerate(root.children[idx:]): 313 if not is_import_stmt(node2): 314 break 315 insert_pos = idx + offset 316 break 317 318 # if there are no imports where we can insert, find the docstring. 319 # if that also fails, we stick to the beginning of the file 320 if insert_pos == 0: 321 for idx, node in enumerate(root.children): 322 if (node.type == syms.simple_stmt and node.children and 323 node.children[0].type == token.STRING): 324 insert_pos = idx + 1 325 break 326 327 if package is None: 328 import_ = Node(syms.import_name, [ 329 Leaf(token.NAME, u"import"), 330 Leaf(token.NAME, name, prefix=u" ") 331 ]) 332 else: 333 import_ = FromImport(package, [Leaf(token.NAME, name, prefix=u" ")]) 334 335 children = [import_, Newline()] 336 root.insert_child(insert_pos, Node(syms.simple_stmt, children)) 337 338 339 _def_syms = set([syms.classdef, syms.funcdef]) 340 def find_binding(name, node, package=None): 341 """ Returns the node which binds variable name, otherwise None. 342 If optional argument package is supplied, only imports will 343 be returned. 344 See test cases for examples.""" 345 for child in node.children: 346 ret = None 347 if child.type == syms.for_stmt: 348 if _find(name, child.children[1]): 349 return child 350 n = find_binding(name, make_suite(child.children[-1]), package) 351 if n: ret = n 352 elif child.type in (syms.if_stmt, syms.while_stmt): 353 n = find_binding(name, make_suite(child.children[-1]), package) 354 if n: ret = n 355 elif child.type == syms.try_stmt: 356 n = find_binding(name, make_suite(child.children[2]), package) 357 if n: 358 ret = n 359 else: 360 for i, kid in enumerate(child.children[3:]): 361 if kid.type == token.COLON and kid.value == ":": 362 # i+3 is the colon, i+4 is the suite 363 n = find_binding(name, make_suite(child.children[i+4]), package) 364 if n: ret = n 365 elif child.type in _def_syms and child.children[1].value == name: 366 ret = child 367 elif _is_import_binding(child, name, package): 368 ret = child 369 elif child.type == syms.simple_stmt: 370 ret = find_binding(name, child, package) 371 elif child.type == syms.expr_stmt: 372 if _find(name, child.children[0]): 373 ret = child 374 375 if ret: 376 if not package: 377 return ret 378 if is_import(ret): 379 return ret 380 return None 381 382 _block_syms = set([syms.funcdef, syms.classdef, syms.trailer]) 383 def _find(name, node): 384 nodes = [node] 385 while nodes: 386 node = nodes.pop() 387 if node.type > 256 and node.type not in _block_syms: 388 nodes.extend(node.children) 389 elif node.type == token.NAME and node.value == name: 390 return node 391 return None 392 393 def _is_import_binding(node, name, package=None): 394 """ Will reuturn node if node will import name, or node 395 will import * from package. None is returned otherwise. 396 See test cases for examples. """ 397 398 if node.type == syms.import_name and not package: 399 imp = node.children[1] 400 if imp.type == syms.dotted_as_names: 401 for child in imp.children: 402 if child.type == syms.dotted_as_name: 403 if child.children[2].value == name: 404 return node 405 elif child.type == token.NAME and child.value == name: 406 return node 407 elif imp.type == syms.dotted_as_name: 408 last = imp.children[-1] 409 if last.type == token.NAME and last.value == name: 410 return node 411 elif imp.type == token.NAME and imp.value == name: 412 return node 413 elif node.type == syms.import_from: 414 # unicode(...) is used to make life easier here, because 415 # from a.b import parses to ['import', ['a', '.', 'b'], ...] 416 if package and unicode(node.children[1]).strip() != package: 417 return None 418 n = node.children[3] 419 if package and _find(u"as", n): 420 # See test_from_import_as for explanation 421 return None 422 elif n.type == syms.import_as_names and _find(name, n): 423 return node 424 elif n.type == syms.import_as_name: 425 child = n.children[2] 426 if child.type == token.NAME and child.value == name: 427 return node 428 elif n.type == token.NAME and n.value == name: 429 return node 430 elif package and n.type == token.STAR: 431 return node 432 return None 433