1 # cython: infer_types=True 2 3 # 4 # Tree visitor and transform framework 5 # 6 import inspect 7 8 from Cython.Compiler import TypeSlots 9 from Cython.Compiler import Builtin 10 from Cython.Compiler import Nodes 11 from Cython.Compiler import ExprNodes 12 from Cython.Compiler import Errors 13 from Cython.Compiler import DebugFlags 14 15 import cython 16 17 18 class TreeVisitor(object): 19 """ 20 Base class for writing visitors for a Cython tree, contains utilities for 21 recursing such trees using visitors. Each node is 22 expected to have a child_attrs iterable containing the names of attributes 23 containing child nodes or lists of child nodes. Lists are not considered 24 part of the tree structure (i.e. contained nodes are considered direct 25 children of the parent node). 26 27 visit_children visits each of the children of a given node (see the visit_children 28 documentation). When recursing the tree using visit_children, an attribute 29 access_path is maintained which gives information about the current location 30 in the tree as a stack of tuples: (parent_node, attrname, index), representing 31 the node, attribute and optional list index that was taken in each step in the path to 32 the current node. 33 34 Example: 35 36 >>> class SampleNode(object): 37 ... child_attrs = ["head", "body"] 38 ... def __init__(self, value, head=None, body=None): 39 ... self.value = value 40 ... self.head = head 41 ... self.body = body 42 ... def __repr__(self): return "SampleNode(%s)" % self.value 43 ... 44 >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)]) 45 >>> class MyVisitor(TreeVisitor): 46 ... def visit_SampleNode(self, node): 47 ... print "in", node.value, self.access_path 48 ... self.visitchildren(node) 49 ... print "out", node.value 50 ... 51 >>> MyVisitor().visit(tree) 52 in 0 [] 53 in 1 [(SampleNode(0), 'head', None)] 54 out 1 55 in 2 [(SampleNode(0), 'body', 0)] 56 out 2 57 in 3 [(SampleNode(0), 'body', 1)] 58 out 3 59 out 0 60 """ 61 def __init__(self): 62 super(TreeVisitor, self).__init__() 63 self.dispatch_table = {} 64 self.access_path = [] 65 66 def dump_node(self, node, indent=0): 67 ignored = list(node.child_attrs or []) + [u'child_attrs', u'pos', 68 u'gil_message', u'cpp_message', 69 u'subexprs'] 70 values = [] 71 pos = getattr(node, 'pos', None) 72 if pos: 73 source = pos[0] 74 if source: 75 import os.path 76 source = os.path.basename(source.get_description()) 77 values.append(u'%s:%s:%s' % (source, pos[1], pos[2])) 78 attribute_names = dir(node) 79 attribute_names.sort() 80 for attr in attribute_names: 81 if attr in ignored: 82 continue 83 if attr.startswith(u'_') or attr.endswith(u'_'): 84 continue 85 try: 86 value = getattr(node, attr) 87 except AttributeError: 88 continue 89 if value is None or value == 0: 90 continue 91 elif isinstance(value, list): 92 value = u'[...]/%d' % len(value) 93 elif not isinstance(value, (str, unicode, long, int, float)): 94 continue 95 else: 96 value = repr(value) 97 values.append(u'%s = %s' % (attr, value)) 98 return u'%s(%s)' % (node.__class__.__name__, 99 u',\n '.join(values)) 100 101 def _find_node_path(self, stacktrace): 102 import os.path 103 last_traceback = stacktrace 104 nodes = [] 105 while hasattr(stacktrace, 'tb_frame'): 106 frame = stacktrace.tb_frame 107 node = frame.f_locals.get(u'self') 108 if isinstance(node, Nodes.Node): 109 code = frame.f_code 110 method_name = code.co_name 111 pos = (os.path.basename(code.co_filename), 112 frame.f_lineno) 113 nodes.append((node, method_name, pos)) 114 last_traceback = stacktrace 115 stacktrace = stacktrace.tb_next 116 return (last_traceback, nodes) 117 118 def _raise_compiler_error(self, child, e): 119 import sys 120 trace = [''] 121 for parent, attribute, index in self.access_path: 122 node = getattr(parent, attribute) 123 if index is None: 124 index = '' 125 else: 126 node = node[index] 127 index = u'[%d]' % index 128 trace.append(u'%s.%s%s = %s' % ( 129 parent.__class__.__name__, attribute, index, 130 self.dump_node(node))) 131 stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2]) 132 last_node = child 133 for node, method_name, pos in called_nodes: 134 last_node = node 135 trace.append(u"File '%s', line %d, in %s: %s" % ( 136 pos[0], pos[1], method_name, self.dump_node(node))) 137 raise Errors.CompilerCrash( 138 getattr(last_node, 'pos', None), self.__class__.__name__, 139 u'\n'.join(trace), e, stacktrace) 140 141 @cython.final 142 def find_handler(self, obj): 143 # to resolve, try entire hierarchy 144 cls = type(obj) 145 pattern = "visit_%s" 146 mro = inspect.getmro(cls) 147 handler_method = None 148 for mro_cls in mro: 149 handler_method = getattr(self, pattern % mro_cls.__name__, None) 150 if handler_method is not None: 151 return handler_method 152 print type(self), cls 153 if self.access_path: 154 print self.access_path 155 print self.access_path[-1][0].pos 156 print self.access_path[-1][0].__dict__ 157 raise RuntimeError("Visitor %r does not accept object: %s" % (self, obj)) 158 159 def visit(self, obj): 160 return self._visit(obj) 161 162 @cython.final 163 def _visit(self, obj): 164 try: 165 try: 166 handler_method = self.dispatch_table[type(obj)] 167 except KeyError: 168 handler_method = self.find_handler(obj) 169 self.dispatch_table[type(obj)] = handler_method 170 return handler_method(obj) 171 except Errors.CompileError: 172 raise 173 except Errors.AbortError: 174 raise 175 except Exception, e: 176 if DebugFlags.debug_no_exception_intercept: 177 raise 178 self._raise_compiler_error(obj, e) 179 180 @cython.final 181 def _visitchild(self, child, parent, attrname, idx): 182 self.access_path.append((parent, attrname, idx)) 183 result = self._visit(child) 184 self.access_path.pop() 185 return result 186 187 def visitchildren(self, parent, attrs=None): 188 return self._visitchildren(parent, attrs) 189 190 @cython.final 191 @cython.locals(idx=int) 192 def _visitchildren(self, parent, attrs): 193 """ 194 Visits the children of the given parent. If parent is None, returns 195 immediately (returning None). 196 197 The return value is a dictionary giving the results for each 198 child (mapping the attribute name to either the return value 199 or a list of return values (in the case of multiple children 200 in an attribute)). 201 """ 202 if parent is None: return None 203 result = {} 204 for attr in parent.child_attrs: 205 if attrs is not None and attr not in attrs: continue 206 child = getattr(parent, attr) 207 if child is not None: 208 if type(child) is list: 209 childretval = [self._visitchild(x, parent, attr, idx) for idx, x in enumerate(child)] 210 else: 211 childretval = self._visitchild(child, parent, attr, None) 212 assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent) 213 result[attr] = childretval 214 return result 215 216 217 class VisitorTransform(TreeVisitor): 218 """ 219 A tree transform is a base class for visitors that wants to do stream 220 processing of the structure (rather than attributes etc.) of a tree. 221 222 It implements __call__ to simply visit the argument node. 223 224 It requires the visitor methods to return the nodes which should take 225 the place of the visited node in the result tree (which can be the same 226 or one or more replacement). Specifically, if the return value from 227 a visitor method is: 228 229 - [] or None; the visited node will be removed (set to None if an attribute and 230 removed if in a list) 231 - A single node; the visited node will be replaced by the returned node. 232 - A list of nodes; the visited nodes will be replaced by all the nodes in the 233 list. This will only work if the node was already a member of a list; if it 234 was not, an exception will be raised. (Typically you want to ensure that you 235 are within a StatListNode or similar before doing this.) 236 """ 237 def visitchildren(self, parent, attrs=None): 238 result = self._visitchildren(parent, attrs) 239 for attr, newnode in result.iteritems(): 240 if type(newnode) is not list: 241 setattr(parent, attr, newnode) 242 else: 243 # Flatten the list one level and remove any None 244 newlist = [] 245 for x in newnode: 246 if x is not None: 247 if type(x) is list: 248 newlist += x 249 else: 250 newlist.append(x) 251 setattr(parent, attr, newlist) 252 return result 253 254 def recurse_to_children(self, node): 255 self.visitchildren(node) 256 return node 257 258 def __call__(self, root): 259 return self._visit(root) 260 261 class CythonTransform(VisitorTransform): 262 """ 263 Certain common conventions and utilities for Cython transforms. 264 265 - Sets up the context of the pipeline in self.context 266 - Tracks directives in effect in self.current_directives 267 """ 268 def __init__(self, context): 269 super(CythonTransform, self).__init__() 270 self.context = context 271 272 def __call__(self, node): 273 import ModuleNode 274 if isinstance(node, ModuleNode.ModuleNode): 275 self.current_directives = node.directives 276 return super(CythonTransform, self).__call__(node) 277 278 def visit_CompilerDirectivesNode(self, node): 279 old = self.current_directives 280 self.current_directives = node.directives 281 self.visitchildren(node) 282 self.current_directives = old 283 return node 284 285 def visit_Node(self, node): 286 self.visitchildren(node) 287 return node 288 289 class ScopeTrackingTransform(CythonTransform): 290 # Keeps track of type of scopes 291 #scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct' 292 #scope_node: the node that owns the current scope 293 294 def visit_ModuleNode(self, node): 295 self.scope_type = 'module' 296 self.scope_node = node 297 self.visitchildren(node) 298 return node 299 300 def visit_scope(self, node, scope_type): 301 prev = self.scope_type, self.scope_node 302 self.scope_type = scope_type 303 self.scope_node = node 304 self.visitchildren(node) 305 self.scope_type, self.scope_node = prev 306 return node 307 308 def visit_CClassDefNode(self, node): 309 return self.visit_scope(node, 'cclass') 310 311 def visit_PyClassDefNode(self, node): 312 return self.visit_scope(node, 'pyclass') 313 314 def visit_FuncDefNode(self, node): 315 return self.visit_scope(node, 'function') 316 317 def visit_CStructOrUnionDefNode(self, node): 318 return self.visit_scope(node, 'struct') 319 320 321 class EnvTransform(CythonTransform): 322 """ 323 This transformation keeps a stack of the environments. 324 """ 325 def __call__(self, root): 326 self.env_stack = [] 327 self.enter_scope(root, root.scope) 328 return super(EnvTransform, self).__call__(root) 329 330 def current_env(self): 331 return self.env_stack[-1][1] 332 333 def current_scope_node(self): 334 return self.env_stack[-1][0] 335 336 def global_scope(self): 337 return self.current_env().global_scope() 338 339 def enter_scope(self, node, scope): 340 self.env_stack.append((node, scope)) 341 342 def exit_scope(self): 343 self.env_stack.pop() 344 345 def visit_FuncDefNode(self, node): 346 self.enter_scope(node, node.local_scope) 347 self.visitchildren(node) 348 self.exit_scope() 349 return node 350 351 def visit_GeneratorBodyDefNode(self, node): 352 self.visitchildren(node) 353 return node 354 355 def visit_ClassDefNode(self, node): 356 self.enter_scope(node, node.scope) 357 self.visitchildren(node) 358 self.exit_scope() 359 return node 360 361 def visit_CStructOrUnionDefNode(self, node): 362 self.enter_scope(node, node.scope) 363 self.visitchildren(node) 364 self.exit_scope() 365 return node 366 367 def visit_ScopedExprNode(self, node): 368 if node.expr_scope: 369 self.enter_scope(node, node.expr_scope) 370 self.visitchildren(node) 371 self.exit_scope() 372 else: 373 self.visitchildren(node) 374 return node 375 376 def visit_CArgDeclNode(self, node): 377 # default arguments are evaluated in the outer scope 378 if node.default: 379 attrs = [ attr for attr in node.child_attrs if attr != 'default' ] 380 self.visitchildren(node, attrs) 381 self.enter_scope(node, self.current_env().outer_scope) 382 self.visitchildren(node, ('default',)) 383 self.exit_scope() 384 else: 385 self.visitchildren(node) 386 return node 387 388 389 class NodeRefCleanupMixin(object): 390 """ 391 Clean up references to nodes that were replaced. 392 393 NOTE: this implementation assumes that the replacement is 394 done first, before hitting any further references during 395 normal tree traversal. This needs to be arranged by calling 396 "self.visitchildren()" at a proper place in the transform 397 and by ordering the "child_attrs" of nodes appropriately. 398 """ 399 def __init__(self, *args): 400 super(NodeRefCleanupMixin, self).__init__(*args) 401 self._replacements = {} 402 403 def visit_CloneNode(self, node): 404 arg = node.arg 405 if arg not in self._replacements: 406 self.visitchildren(node) 407 arg = node.arg 408 node.arg = self._replacements.get(arg, arg) 409 return node 410 411 def visit_ResultRefNode(self, node): 412 expr = node.expression 413 if expr is None or expr not in self._replacements: 414 self.visitchildren(node) 415 expr = node.expression 416 if expr is not None: 417 node.expression = self._replacements.get(expr, expr) 418 return node 419 420 def replace(self, node, replacement): 421 self._replacements[node] = replacement 422 return replacement 423 424 425 find_special_method_for_binary_operator = { 426 '<': '__lt__', 427 '<=': '__le__', 428 '==': '__eq__', 429 '!=': '__ne__', 430 '>=': '__ge__', 431 '>': '__gt__', 432 '+': '__add__', 433 '&': '__and__', 434 '/': '__truediv__', 435 '//': '__floordiv__', 436 '<<': '__lshift__', 437 '%': '__mod__', 438 '*': '__mul__', 439 '|': '__or__', 440 '**': '__pow__', 441 '>>': '__rshift__', 442 '-': '__sub__', 443 '^': '__xor__', 444 'in': '__contains__', 445 }.get 446 447 448 find_special_method_for_unary_operator = { 449 'not': '__not__', 450 '~': '__inv__', 451 '-': '__neg__', 452 '+': '__pos__', 453 }.get 454 455 456 class MethodDispatcherTransform(EnvTransform): 457 """ 458 Base class for transformations that want to intercept on specific 459 builtin functions or methods of builtin types, including special 460 methods triggered by Python operators. Must run after declaration 461 analysis when entries were assigned. 462 463 Naming pattern for handler methods is as follows: 464 465 * builtin functions: _handle_(general|simple|any)_function_NAME 466 467 * builtin methods: _handle_(general|simple|any)_method_TYPENAME_METHODNAME 468 """ 469 # only visit call nodes and Python operations 470 def visit_GeneralCallNode(self, node): 471 self.visitchildren(node) 472 function = node.function 473 if not function.type.is_pyobject: 474 return node 475 arg_tuple = node.positional_args 476 if not isinstance(arg_tuple, ExprNodes.TupleNode): 477 return node 478 keyword_args = node.keyword_args 479 if keyword_args and not isinstance(keyword_args, ExprNodes.DictNode): 480 # can't handle **kwargs 481 return node 482 args = arg_tuple.args 483 return self._dispatch_to_handler(node, function, args, keyword_args) 484 485 def visit_SimpleCallNode(self, node): 486 self.visitchildren(node) 487 function = node.function 488 if function.type.is_pyobject: 489 arg_tuple = node.arg_tuple 490 if not isinstance(arg_tuple, ExprNodes.TupleNode): 491 return node 492 args = arg_tuple.args 493 else: 494 args = node.args 495 return self._dispatch_to_handler(node, function, args, None) 496 497 def visit_PrimaryCmpNode(self, node): 498 if node.cascade: 499 # not currently handled below 500 self.visitchildren(node) 501 return node 502 return self._visit_binop_node(node) 503 504 def visit_BinopNode(self, node): 505 return self._visit_binop_node(node) 506 507 def _visit_binop_node(self, node): 508 self.visitchildren(node) 509 # FIXME: could special case 'not_in' 510 special_method_name = find_special_method_for_binary_operator(node.operator) 511 if special_method_name: 512 operand1, operand2 = node.operand1, node.operand2 513 if special_method_name == '__contains__': 514 operand1, operand2 = operand2, operand1 515 obj_type = operand1.type 516 if obj_type.is_builtin_type: 517 type_name = obj_type.name 518 else: 519 type_name = "object" # safety measure 520 node = self._dispatch_to_method_handler( 521 special_method_name, None, False, type_name, 522 node, None, [operand1, operand2], None) 523 return node 524 525 def visit_UnopNode(self, node): 526 self.visitchildren(node) 527 special_method_name = find_special_method_for_unary_operator(node.operator) 528 if special_method_name: 529 operand = node.operand 530 obj_type = operand.type 531 if obj_type.is_builtin_type: 532 type_name = obj_type.name 533 else: 534 type_name = "object" # safety measure 535 node = self._dispatch_to_method_handler( 536 special_method_name, None, False, type_name, 537 node, None, [operand], None) 538 return node 539 540 ### dispatch to specific handlers 541 542 def _find_handler(self, match_name, has_kwargs): 543 call_type = has_kwargs and 'general' or 'simple' 544 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None) 545 if handler is None: 546 handler = getattr(self, '_handle_any_%s' % match_name, None) 547 return handler 548 549 def _delegate_to_assigned_value(self, node, function, arg_list, kwargs): 550 assignment = function.cf_state[0] 551 value = assignment.rhs 552 if value.is_name: 553 if not value.entry or len(value.entry.cf_assignments) > 1: 554 # the variable might have been reassigned => play safe 555 return node 556 elif value.is_attribute and value.obj.is_name: 557 if not value.obj.entry or len(value.obj.entry.cf_assignments) > 1: 558 # the underlying variable might have been reassigned => play safe 559 return node 560 else: 561 return node 562 return self._dispatch_to_handler( 563 node, value, arg_list, kwargs) 564 565 def _dispatch_to_handler(self, node, function, arg_list, kwargs): 566 if function.is_name: 567 # we only consider functions that are either builtin 568 # Python functions or builtins that were already replaced 569 # into a C function call (defined in the builtin scope) 570 if not function.entry: 571 return node 572 is_builtin = ( 573 function.entry.is_builtin or 574 function.entry is self.current_env().builtin_scope().lookup_here(function.name)) 575 if not is_builtin: 576 if function.cf_state and function.cf_state.is_single: 577 # we know the value of the variable 578 # => see if it's usable instead 579 return self._delegate_to_assigned_value( 580 node, function, arg_list, kwargs) 581 return node 582 function_handler = self._find_handler( 583 "function_%s" % function.name, kwargs) 584 if function_handler is None: 585 return self._handle_function(node, function.name, function, arg_list, kwargs) 586 if kwargs: 587 return function_handler(node, function, arg_list, kwargs) 588 else: 589 return function_handler(node, function, arg_list) 590 elif function.is_attribute and function.type.is_pyobject: 591 attr_name = function.attribute 592 self_arg = function.obj 593 obj_type = self_arg.type 594 is_unbound_method = False 595 if obj_type.is_builtin_type: 596 if (obj_type is Builtin.type_type and self_arg.is_name and 597 arg_list and arg_list[0].type.is_pyobject): 598 # calling an unbound method like 'list.append(L,x)' 599 # (ignoring 'type.mro()' here ...) 600 type_name = self_arg.name 601 self_arg = None 602 is_unbound_method = True 603 else: 604 type_name = obj_type.name 605 else: 606 type_name = "object" # safety measure 607 return self._dispatch_to_method_handler( 608 attr_name, self_arg, is_unbound_method, type_name, 609 node, function, arg_list, kwargs) 610 else: 611 return node 612 613 def _dispatch_to_method_handler(self, attr_name, self_arg, 614 is_unbound_method, type_name, 615 node, function, arg_list, kwargs): 616 method_handler = self._find_handler( 617 "method_%s_%s" % (type_name, attr_name), kwargs) 618 if method_handler is None: 619 if (attr_name in TypeSlots.method_name_to_slot 620 or attr_name == '__new__'): 621 method_handler = self._find_handler( 622 "slot%s" % attr_name, kwargs) 623 if method_handler is None: 624 return self._handle_method( 625 node, type_name, attr_name, function, 626 arg_list, is_unbound_method, kwargs) 627 if self_arg is not None: 628 arg_list = [self_arg] + list(arg_list) 629 if kwargs: 630 return method_handler( 631 node, function, arg_list, is_unbound_method, kwargs) 632 else: 633 return method_handler( 634 node, function, arg_list, is_unbound_method) 635 636 def _handle_function(self, node, function_name, function, arg_list, kwargs): 637 """Fallback handler""" 638 return node 639 640 def _handle_method(self, node, type_name, attr_name, function, 641 arg_list, is_unbound_method, kwargs): 642 """Fallback handler""" 643 return node 644 645 646 class RecursiveNodeReplacer(VisitorTransform): 647 """ 648 Recursively replace all occurrences of a node in a subtree by 649 another node. 650 """ 651 def __init__(self, orig_node, new_node): 652 super(RecursiveNodeReplacer, self).__init__() 653 self.orig_node, self.new_node = orig_node, new_node 654 655 def visit_Node(self, node): 656 self.visitchildren(node) 657 if node is self.orig_node: 658 return self.new_node 659 else: 660 return node 661 662 def recursively_replace_node(tree, old_node, new_node): 663 replace_in = RecursiveNodeReplacer(old_node, new_node) 664 replace_in(tree) 665 666 667 class NodeFinder(TreeVisitor): 668 """ 669 Find out if a node appears in a subtree. 670 """ 671 def __init__(self, node): 672 super(NodeFinder, self).__init__() 673 self.node = node 674 self.found = False 675 676 def visit_Node(self, node): 677 if self.found: 678 pass # short-circuit 679 elif node is self.node: 680 self.found = True 681 else: 682 self._visitchildren(node, None) 683 684 def tree_contains(tree, node): 685 finder = NodeFinder(node) 686 finder.visit(tree) 687 return finder.found 688 689 690 # Utils 691 def replace_node(ptr, value): 692 """Replaces a node. ptr is of the form used on the access path stack 693 (parent, attrname, listidx|None) 694 """ 695 parent, attrname, listidx = ptr 696 if listidx is None: 697 setattr(parent, attrname, value) 698 else: 699 getattr(parent, attrname)[listidx] = value 700 701 class PrintTree(TreeVisitor): 702 """Prints a representation of the tree to standard output. 703 Subclass and override repr_of to provide more information 704 about nodes. """ 705 def __init__(self): 706 TreeVisitor.__init__(self) 707 self._indent = "" 708 709 def indent(self): 710 self._indent += " " 711 def unindent(self): 712 self._indent = self._indent[:-2] 713 714 def __call__(self, tree, phase=None): 715 print("Parse tree dump at phase '%s'" % phase) 716 self.visit(tree) 717 return tree 718 719 # Don't do anything about process_list, the defaults gives 720 # nice-looking name[idx] nodes which will visually appear 721 # under the parent-node, not displaying the list itself in 722 # the hierarchy. 723 def visit_Node(self, node): 724 if len(self.access_path) == 0: 725 name = "(root)" 726 else: 727 parent, attr, idx = self.access_path[-1] 728 if idx is not None: 729 name = "%s[%d]" % (attr, idx) 730 else: 731 name = attr 732 print("%s- %s: %s" % (self._indent, name, self.repr_of(node))) 733 self.indent() 734 self.visitchildren(node) 735 self.unindent() 736 return node 737 738 def repr_of(self, node): 739 if node is None: 740 return "(none)" 741 else: 742 result = node.__class__.__name__ 743 if isinstance(node, ExprNodes.NameNode): 744 result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name) 745 elif isinstance(node, Nodes.DefNode): 746 result += "(name=\"%s\")" % node.name 747 elif isinstance(node, ExprNodes.ExprNode): 748 t = node.type 749 result += "(type=%s)" % repr(t) 750 elif node.pos: 751 pos = node.pos 752 path = pos[0].get_description() 753 if '/' in path: 754 path = path.split('/')[-1] 755 if '\\' in path: 756 path = path.split('\\')[-1] 757 result += "(pos=(%s:%s:%s))" % (path, pos[1], pos[2]) 758 759 return result 760 761 if __name__ == "__main__": 762 import doctest 763 doctest.testmod() 764