Home | History | Annotate | Download | only in antlr3
      1 """ @package antlr3.tree
      2 @brief ANTLR3 runtime package, treewizard module
      3 
      4 A utility module to create ASTs at runtime.
      5 See <http://www.antlr.org/wiki/display/~admin/2007/07/02/Exploring+Concept+of+TreeWizard> for an overview. Note that the API of the Python implementation is slightly different.
      6 
      7 """
      8 
      9 # begin[licence]
     10 #
     11 # [The "BSD licence"]
     12 # Copyright (c) 2005-2008 Terence Parr
     13 # All rights reserved.
     14 #
     15 # Redistribution and use in source and binary forms, with or without
     16 # modification, are permitted provided that the following conditions
     17 # are met:
     18 # 1. Redistributions of source code must retain the above copyright
     19 #    notice, this list of conditions and the following disclaimer.
     20 # 2. Redistributions in binary form must reproduce the above copyright
     21 #    notice, this list of conditions and the following disclaimer in the
     22 #    documentation and/or other materials provided with the distribution.
     23 # 3. The name of the author may not be used to endorse or promote products
     24 #    derived from this software without specific prior written permission.
     25 #
     26 # THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
     27 # IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
     28 # OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
     29 # IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
     30 # INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
     31 # NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     32 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     33 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     34 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
     35 # THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     36 #
     37 # end[licence]
     38 
     39 from antlr3.constants import INVALID_TOKEN_TYPE
     40 from antlr3.tokens import CommonToken
     41 from antlr3.tree import CommonTree, CommonTreeAdaptor
     42 
     43 
     44 def computeTokenTypes(tokenNames):
     45     """
     46     Compute a dict that is an inverted index of
     47     tokenNames (which maps int token types to names).
     48     """
     49 
     50     if tokenNames is None:
     51         return {}
     52 
     53     return dict((name, type) for type, name in enumerate(tokenNames))
     54 
     55 
     56 ## token types for pattern parser
     57 EOF = -1
     58 BEGIN = 1
     59 END = 2
     60 ID = 3
     61 ARG = 4
     62 PERCENT = 5
     63 COLON = 6
     64 DOT = 7
     65 
     66 class TreePatternLexer(object):
     67     def __init__(self, pattern):
     68         ## The tree pattern to lex like "(A B C)"
     69         self.pattern = pattern
     70 
     71 	## Index into input string
     72         self.p = -1
     73 
     74 	## Current char
     75         self.c = None
     76 
     77 	## How long is the pattern in char?
     78         self.n = len(pattern)
     79 
     80 	## Set when token type is ID or ARG
     81         self.sval = None
     82 
     83         self.error = False
     84 
     85         self.consume()
     86 
     87 
     88     __idStartChar = frozenset(
     89         'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ_'
     90         )
     91     __idChar = __idStartChar | frozenset('0123456789')
     92 
     93     def nextToken(self):
     94         self.sval = ""
     95         while self.c != EOF:
     96             if self.c in (' ', '\n', '\r', '\t'):
     97                 self.consume()
     98                 continue
     99 
    100             if self.c in self.__idStartChar:
    101                 self.sval += self.c
    102                 self.consume()
    103                 while self.c in self.__idChar:
    104                     self.sval += self.c
    105                     self.consume()
    106 
    107                 return ID
    108 
    109             if self.c == '(':
    110                 self.consume()
    111                 return BEGIN
    112 
    113             if self.c == ')':
    114                 self.consume()
    115                 return END
    116 
    117             if self.c == '%':
    118                 self.consume()
    119                 return PERCENT
    120 
    121             if self.c == ':':
    122                 self.consume()
    123                 return COLON
    124 
    125             if self.c == '.':
    126                 self.consume()
    127                 return DOT
    128 
    129             if self.c == '[': # grab [x] as a string, returning x
    130                 self.consume()
    131                 while self.c != ']':
    132                     if self.c == '\\':
    133                         self.consume()
    134                         if self.c != ']':
    135                             self.sval += '\\'
    136 
    137                         self.sval += self.c
    138 
    139                     else:
    140                         self.sval += self.c
    141 
    142                     self.consume()
    143 
    144                 self.consume()
    145                 return ARG
    146 
    147             self.consume()
    148             self.error = True
    149             return EOF
    150 
    151         return EOF
    152 
    153 
    154     def consume(self):
    155         self.p += 1
    156         if self.p >= self.n:
    157             self.c = EOF
    158 
    159         else:
    160             self.c = self.pattern[self.p]
    161 
    162 
    163 class TreePatternParser(object):
    164     def __init__(self, tokenizer, wizard, adaptor):
    165         self.tokenizer = tokenizer
    166         self.wizard = wizard
    167         self.adaptor = adaptor
    168         self.ttype = tokenizer.nextToken() # kickstart
    169 
    170 
    171     def pattern(self):
    172         if self.ttype == BEGIN:
    173             return self.parseTree()
    174 
    175         elif self.ttype == ID:
    176             node = self.parseNode()
    177             if self.ttype == EOF:
    178                 return node
    179 
    180             return None # extra junk on end
    181 
    182         return None
    183 
    184 
    185     def parseTree(self):
    186         if self.ttype != BEGIN:
    187             return None
    188 
    189         self.ttype = self.tokenizer.nextToken()
    190         root = self.parseNode()
    191         if root is None:
    192             return None
    193 
    194         while self.ttype in (BEGIN, ID, PERCENT, DOT):
    195             if self.ttype == BEGIN:
    196                 subtree = self.parseTree()
    197                 self.adaptor.addChild(root, subtree)
    198 
    199             else:
    200                 child = self.parseNode()
    201                 if child is None:
    202                     return None
    203 
    204                 self.adaptor.addChild(root, child)
    205 
    206         if self.ttype != END:
    207             return None
    208 
    209         self.ttype = self.tokenizer.nextToken()
    210         return root
    211 
    212 
    213     def parseNode(self):
    214         # "%label:" prefix
    215         label = None
    216 
    217         if self.ttype == PERCENT:
    218             self.ttype = self.tokenizer.nextToken()
    219             if self.ttype != ID:
    220                 return None
    221 
    222             label = self.tokenizer.sval
    223             self.ttype = self.tokenizer.nextToken()
    224             if self.ttype != COLON:
    225                 return None
    226 
    227             self.ttype = self.tokenizer.nextToken() # move to ID following colon
    228 
    229         # Wildcard?
    230         if self.ttype == DOT:
    231             self.ttype = self.tokenizer.nextToken()
    232             wildcardPayload = CommonToken(0, ".")
    233             node = WildcardTreePattern(wildcardPayload)
    234             if label is not None:
    235                 node.label = label
    236             return node
    237 
    238         # "ID" or "ID[arg]"
    239         if self.ttype != ID:
    240             return None
    241 
    242         tokenName = self.tokenizer.sval
    243         self.ttype = self.tokenizer.nextToken()
    244 
    245         if tokenName == "nil":
    246             return self.adaptor.nil()
    247 
    248         text = tokenName
    249         # check for arg
    250         arg = None
    251         if self.ttype == ARG:
    252             arg = self.tokenizer.sval
    253             text = arg
    254             self.ttype = self.tokenizer.nextToken()
    255 
    256         # create node
    257         treeNodeType = self.wizard.getTokenType(tokenName)
    258         if treeNodeType == INVALID_TOKEN_TYPE:
    259             return None
    260 
    261         node = self.adaptor.createFromType(treeNodeType, text)
    262         if label is not None and isinstance(node, TreePattern):
    263             node.label = label
    264 
    265         if arg is not None and isinstance(node, TreePattern):
    266             node.hasTextArg = True
    267 
    268         return node
    269 
    270 
    271 class TreePattern(CommonTree):
    272     """
    273     When using %label:TOKENNAME in a tree for parse(), we must
    274     track the label.
    275     """
    276 
    277     def __init__(self, payload):
    278         CommonTree.__init__(self, payload)
    279 
    280         self.label = None
    281         self.hasTextArg = None
    282 
    283 
    284     def toString(self):
    285         if self.label is not None:
    286             return '%' + self.label + ':' + CommonTree.toString(self)
    287 
    288         else:
    289             return CommonTree.toString(self)
    290 
    291 
    292 class WildcardTreePattern(TreePattern):
    293     pass
    294 
    295 
    296 class TreePatternTreeAdaptor(CommonTreeAdaptor):
    297     """This adaptor creates TreePattern objects for use during scan()"""
    298 
    299     def createWithPayload(self, payload):
    300         return TreePattern(payload)
    301 
    302 
    303 class TreeWizard(object):
    304     """
    305     Build and navigate trees with this object.  Must know about the names
    306     of tokens so you have to pass in a map or array of token names (from which
    307     this class can build the map).  I.e., Token DECL means nothing unless the
    308     class can translate it to a token type.
    309 
    310     In order to create nodes and navigate, this class needs a TreeAdaptor.
    311 
    312     This class can build a token type -> node index for repeated use or for
    313     iterating over the various nodes with a particular type.
    314 
    315     This class works in conjunction with the TreeAdaptor rather than moving
    316     all this functionality into the adaptor.  An adaptor helps build and
    317     navigate trees using methods.  This class helps you do it with string
    318     patterns like "(A B C)".  You can create a tree from that pattern or
    319     match subtrees against it.
    320     """
    321 
    322     def __init__(self, adaptor=None, tokenNames=None, typeMap=None):
    323         if adaptor is None:
    324             self.adaptor = CommonTreeAdaptor()
    325 
    326         else:
    327             self.adaptor = adaptor
    328 
    329         if typeMap is None:
    330             self.tokenNameToTypeMap = computeTokenTypes(tokenNames)
    331 
    332         else:
    333             if tokenNames is not None:
    334                 raise ValueError("Can't have both tokenNames and typeMap")
    335 
    336             self.tokenNameToTypeMap = typeMap
    337 
    338 
    339     def getTokenType(self, tokenName):
    340         """Using the map of token names to token types, return the type."""
    341 
    342         try:
    343             return self.tokenNameToTypeMap[tokenName]
    344         except KeyError:
    345             return INVALID_TOKEN_TYPE
    346 
    347 
    348     def create(self, pattern):
    349         """
    350         Create a tree or node from the indicated tree pattern that closely
    351         follows ANTLR tree grammar tree element syntax:
    352 
    353         (root child1 ... child2).
    354 
    355         You can also just pass in a node: ID
    356 
    357         Any node can have a text argument: ID[foo]
    358         (notice there are no quotes around foo--it's clear it's a string).
    359 
    360         nil is a special name meaning "give me a nil node".  Useful for
    361         making lists: (nil A B C) is a list of A B C.
    362         """
    363 
    364         tokenizer = TreePatternLexer(pattern)
    365         parser = TreePatternParser(tokenizer, self, self.adaptor)
    366         return parser.pattern()
    367 
    368 
    369     def index(self, tree):
    370         """Walk the entire tree and make a node name to nodes mapping.
    371 
    372         For now, use recursion but later nonrecursive version may be
    373         more efficient.  Returns a dict int -> list where the list is
    374         of your AST node type.  The int is the token type of the node.
    375         """
    376 
    377         m = {}
    378         self._index(tree, m)
    379         return m
    380 
    381 
    382     def _index(self, t, m):
    383         """Do the work for index"""
    384 
    385         if t is None:
    386             return
    387 
    388         ttype = self.adaptor.getType(t)
    389         elements = m.get(ttype)
    390         if elements is None:
    391             m[ttype] = elements = []
    392 
    393         elements.append(t)
    394         for i in range(self.adaptor.getChildCount(t)):
    395             child = self.adaptor.getChild(t, i)
    396             self._index(child, m)
    397 
    398 
    399     def find(self, tree, what):
    400         """Return a list of matching token.
    401 
    402         what may either be an integer specifzing the token type to find or
    403         a string with a pattern that must be matched.
    404 
    405         """
    406 
    407         if isinstance(what, (int, long)):
    408             return self._findTokenType(tree, what)
    409 
    410         elif isinstance(what, basestring):
    411             return self._findPattern(tree, what)
    412 
    413         else:
    414             raise TypeError("'what' must be string or integer")
    415 
    416 
    417     def _findTokenType(self, t, ttype):
    418         """Return a List of tree nodes with token type ttype"""
    419 
    420         nodes = []
    421 
    422         def visitor(tree, parent, childIndex, labels):
    423             nodes.append(tree)
    424 
    425         self.visit(t, ttype, visitor)
    426 
    427         return nodes
    428 
    429 
    430     def _findPattern(self, t, pattern):
    431         """Return a List of subtrees matching pattern."""
    432 
    433         subtrees = []
    434 
    435         # Create a TreePattern from the pattern
    436         tokenizer = TreePatternLexer(pattern)
    437         parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
    438         tpattern = parser.pattern()
    439 
    440         # don't allow invalid patterns
    441         if (tpattern is None or tpattern.isNil()
    442             or isinstance(tpattern, WildcardTreePattern)):
    443             return None
    444 
    445         rootTokenType = tpattern.getType()
    446 
    447         def visitor(tree, parent, childIndex, label):
    448             if self._parse(tree, tpattern, None):
    449                 subtrees.append(tree)
    450 
    451         self.visit(t, rootTokenType, visitor)
    452 
    453         return subtrees
    454 
    455 
    456     def visit(self, tree, what, visitor):
    457         """Visit every node in tree matching what, invoking the visitor.
    458 
    459         If what is a string, it is parsed as a pattern and only matching
    460         subtrees will be visited.
    461         The implementation uses the root node of the pattern in combination
    462         with visit(t, ttype, visitor) so nil-rooted patterns are not allowed.
    463         Patterns with wildcard roots are also not allowed.
    464 
    465         If what is an integer, it is used as a token type and visit will match
    466         all nodes of that type (this is faster than the pattern match).
    467         The labels arg of the visitor action method is never set (it's None)
    468         since using a token type rather than a pattern doesn't let us set a
    469         label.
    470         """
    471 
    472         if isinstance(what, (int, long)):
    473             self._visitType(tree, None, 0, what, visitor)
    474 
    475         elif isinstance(what, basestring):
    476             self._visitPattern(tree, what, visitor)
    477 
    478         else:
    479             raise TypeError("'what' must be string or integer")
    480 
    481 
    482     def _visitType(self, t, parent, childIndex, ttype, visitor):
    483         """Do the recursive work for visit"""
    484 
    485         if t is None:
    486             return
    487 
    488         if self.adaptor.getType(t) == ttype:
    489             visitor(t, parent, childIndex, None)
    490 
    491         for i in range(self.adaptor.getChildCount(t)):
    492             child = self.adaptor.getChild(t, i)
    493             self._visitType(child, t, i, ttype, visitor)
    494 
    495 
    496     def _visitPattern(self, tree, pattern, visitor):
    497         """
    498         For all subtrees that match the pattern, execute the visit action.
    499         """
    500 
    501         # Create a TreePattern from the pattern
    502         tokenizer = TreePatternLexer(pattern)
    503         parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
    504         tpattern = parser.pattern()
    505 
    506         # don't allow invalid patterns
    507         if (tpattern is None or tpattern.isNil()
    508             or isinstance(tpattern, WildcardTreePattern)):
    509             return
    510 
    511         rootTokenType = tpattern.getType()
    512 
    513         def rootvisitor(tree, parent, childIndex, labels):
    514             labels = {}
    515             if self._parse(tree, tpattern, labels):
    516                 visitor(tree, parent, childIndex, labels)
    517 
    518         self.visit(tree, rootTokenType, rootvisitor)
    519 
    520 
    521     def parse(self, t, pattern, labels=None):
    522         """
    523         Given a pattern like (ASSIGN %lhs:ID %rhs:.) with optional labels
    524         on the various nodes and '.' (dot) as the node/subtree wildcard,
    525         return true if the pattern matches and fill the labels Map with
    526         the labels pointing at the appropriate nodes.  Return false if
    527         the pattern is malformed or the tree does not match.
    528 
    529         If a node specifies a text arg in pattern, then that must match
    530         for that node in t.
    531         """
    532 
    533         tokenizer = TreePatternLexer(pattern)
    534         parser = TreePatternParser(tokenizer, self, TreePatternTreeAdaptor())
    535         tpattern = parser.pattern()
    536 
    537         return self._parse(t, tpattern, labels)
    538 
    539 
    540     def _parse(self, t1, tpattern, labels):
    541         """
    542         Do the work for parse. Check to see if the tpattern fits the
    543         structure and token types in t1.  Check text if the pattern has
    544         text arguments on nodes.  Fill labels map with pointers to nodes
    545         in tree matched against nodes in pattern with labels.
    546 	"""
    547 
    548         # make sure both are non-null
    549         if t1 is None or tpattern is None:
    550             return False
    551 
    552         # check roots (wildcard matches anything)
    553         if not isinstance(tpattern, WildcardTreePattern):
    554             if self.adaptor.getType(t1) != tpattern.getType():
    555                 return False
    556 
    557             # if pattern has text, check node text
    558             if (tpattern.hasTextArg
    559                 and self.adaptor.getText(t1) != tpattern.getText()):
    560                 return False
    561 
    562         if tpattern.label is not None and labels is not None:
    563             # map label in pattern to node in t1
    564             labels[tpattern.label] = t1
    565 
    566         # check children
    567         n1 = self.adaptor.getChildCount(t1)
    568         n2 = tpattern.getChildCount()
    569         if n1 != n2:
    570             return False
    571 
    572         for i in range(n1):
    573             child1 = self.adaptor.getChild(t1, i)
    574             child2 = tpattern.getChild(i)
    575             if not self._parse(child1, child2, labels):
    576                 return False
    577 
    578         return True
    579 
    580 
    581     def equals(self, t1, t2, adaptor=None):
    582         """
    583         Compare t1 and t2; return true if token types/text, structure match
    584         exactly.
    585         The trees are examined in their entirety so that (A B) does not match
    586         (A B C) nor (A (B C)).
    587         """
    588 
    589         if adaptor is None:
    590             adaptor = self.adaptor
    591 
    592         return self._equals(t1, t2, adaptor)
    593 
    594 
    595     def _equals(self, t1, t2, adaptor):
    596         # make sure both are non-null
    597         if t1 is None or t2 is None:
    598             return False
    599 
    600         # check roots
    601         if adaptor.getType(t1) != adaptor.getType(t2):
    602             return False
    603 
    604         if adaptor.getText(t1) != adaptor.getText(t2):
    605             return False
    606 
    607         # check children
    608         n1 = adaptor.getChildCount(t1)
    609         n2 = adaptor.getChildCount(t2)
    610         if n1 != n2:
    611             return False
    612 
    613         for i in range(n1):
    614             child1 = adaptor.getChild(t1, i)
    615             child2 = adaptor.getChild(t2, i)
    616             if not self._equals(child1, child2, adaptor):
    617                 return False
    618 
    619         return True
    620