Home | History | Annotate | Download | only in Compiler
      1 #
      2 # TreeFragments - parsing of strings to trees
      3 #
      4 
      5 import re
      6 from StringIO import StringIO
      7 from Scanning import PyrexScanner, StringSourceDescriptor
      8 from Symtab import ModuleScope
      9 import PyrexTypes
     10 from Visitor import VisitorTransform
     11 from Nodes import Node, StatListNode
     12 from ExprNodes import NameNode
     13 import Parsing
     14 import Main
     15 import UtilNodes
     16 
     17 """
     18 Support for parsing strings into code trees.
     19 """
     20 
     21 class StringParseContext(Main.Context):
     22     def __init__(self, name, include_directories=None):
     23         if include_directories is None: include_directories = []
     24         Main.Context.__init__(self, include_directories, {},
     25                               create_testscope=False)
     26         self.module_name = name
     27 
     28     def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
     29         if module_name not in (self.module_name, 'cython'):
     30             raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
     31         return ModuleScope(module_name, parent_module = None, context = self)
     32 
     33 def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None,
     34                        context=None, allow_struct_enum_decorator=False):
     35     """
     36     Utility method to parse a (unicode) string of code. This is mostly
     37     used for internal Cython compiler purposes (creating code snippets
     38     that transforms should emit, as well as unit testing).
     39 
     40     code - a unicode string containing Cython (module-level) code
     41     name - a descriptive name for the code source (to use in error messages etc.)
     42 
     43     RETURNS
     44 
     45     The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
     46     set to the scope used when parsing.
     47     """
     48     if context is None:
     49         context = StringParseContext(name)
     50     # Since source files carry an encoding, it makes sense in this context
     51     # to use a unicode string so that code fragments don't have to bother
     52     # with encoding. This means that test code passed in should not have an
     53     # encoding header.
     54     assert isinstance(code, unicode), "unicode code snippets only please"
     55     encoding = "UTF-8"
     56 
     57     module_name = name
     58     if initial_pos is None:
     59         initial_pos = (name, 1, 0)
     60     code_source = StringSourceDescriptor(name, code)
     61 
     62     scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
     63 
     64     buf = StringIO(code)
     65 
     66     scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
     67                      scope = scope, context = context, initial_pos = initial_pos)
     68     ctx = Parsing.Ctx(allow_struct_enum_decorator=allow_struct_enum_decorator)
     69 
     70     if level is None:
     71         tree = Parsing.p_module(scanner, 0, module_name, ctx=ctx)
     72         tree.scope = scope
     73         tree.is_pxd = False
     74     else:
     75         tree = Parsing.p_code(scanner, level=level, ctx=ctx)
     76 
     77     tree.scope = scope
     78     return tree
     79 
     80 class TreeCopier(VisitorTransform):
     81     def visit_Node(self, node):
     82         if node is None:
     83             return node
     84         else:
     85             c = node.clone_node()
     86             self.visitchildren(c)
     87             return c
     88 
     89 class ApplyPositionAndCopy(TreeCopier):
     90     def __init__(self, pos):
     91         super(ApplyPositionAndCopy, self).__init__()
     92         self.pos = pos
     93 
     94     def visit_Node(self, node):
     95         copy = super(ApplyPositionAndCopy, self).visit_Node(node)
     96         copy.pos = self.pos
     97         return copy
     98 
     99 class TemplateTransform(VisitorTransform):
    100     """
    101     Makes a copy of a template tree while doing substitutions.
    102 
    103     A dictionary "substitutions" should be passed in when calling
    104     the transform; mapping names to replacement nodes. Then replacement
    105     happens like this:
    106      - If an ExprStatNode contains a single NameNode, whose name is
    107        a key in the substitutions dictionary, the ExprStatNode is
    108        replaced with a copy of the tree given in the dictionary.
    109        It is the responsibility of the caller that the replacement
    110        node is a valid statement.
    111      - If a single NameNode is otherwise encountered, it is replaced
    112        if its name is listed in the substitutions dictionary in the
    113        same way. It is the responsibility of the caller to make sure
    114        that the replacement nodes is a valid expression.
    115 
    116     Also a list "temps" should be passed. Any names listed will
    117     be transformed into anonymous, temporary names.
    118 
    119     Currently supported for tempnames is:
    120     NameNode
    121     (various function and class definition nodes etc. should be added to this)
    122 
    123     Each replacement node gets the position of the substituted node
    124     recursively applied to every member node.
    125     """
    126 
    127     temp_name_counter = 0
    128 
    129     def __call__(self, node, substitutions, temps, pos):
    130         self.substitutions = substitutions
    131         self.pos = pos
    132         tempmap = {}
    133         temphandles = []
    134         for temp in temps:
    135             TemplateTransform.temp_name_counter += 1
    136             handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
    137             tempmap[temp] = handle
    138             temphandles.append(handle)
    139         self.tempmap = tempmap
    140         result = super(TemplateTransform, self).__call__(node)
    141         if temps:
    142             result = UtilNodes.TempsBlockNode(self.get_pos(node),
    143                                               temps=temphandles,
    144                                               body=result)
    145         return result
    146 
    147     def get_pos(self, node):
    148         if self.pos:
    149             return self.pos
    150         else:
    151             return node.pos
    152 
    153     def visit_Node(self, node):
    154         if node is None:
    155             return None
    156         else:
    157             c = node.clone_node()
    158             if self.pos is not None:
    159                 c.pos = self.pos
    160             self.visitchildren(c)
    161             return c
    162 
    163     def try_substitution(self, node, key):
    164         sub = self.substitutions.get(key)
    165         if sub is not None:
    166             pos = self.pos
    167             if pos is None: pos = node.pos
    168             return ApplyPositionAndCopy(pos)(sub)
    169         else:
    170             return self.visit_Node(node) # make copy as usual
    171 
    172     def visit_NameNode(self, node):
    173         temphandle = self.tempmap.get(node.name)
    174         if temphandle:
    175             # Replace name with temporary
    176             return temphandle.ref(self.get_pos(node))
    177         else:
    178             return self.try_substitution(node, node.name)
    179 
    180     def visit_ExprStatNode(self, node):
    181         # If an expression-as-statement consists of only a replaceable
    182         # NameNode, we replace the entire statement, not only the NameNode
    183         if isinstance(node.expr, NameNode):
    184             return self.try_substitution(node, node.expr.name)
    185         else:
    186             return self.visit_Node(node)
    187 
    188 def copy_code_tree(node):
    189     return TreeCopier()(node)
    190 
    191 INDENT_RE = re.compile(ur"^ *")
    192 def strip_common_indent(lines):
    193     "Strips empty lines and common indentation from the list of strings given in lines"
    194     # TODO: Facilitate textwrap.indent instead
    195     lines = [x for x in lines if x.strip() != u""]
    196     minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
    197     lines = [x[minindent:] for x in lines]
    198     return lines
    199 
    200 class TreeFragment(object):
    201     def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None, initial_pos=None):
    202         if isinstance(code, unicode):
    203             def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
    204 
    205             fmt_code = fmt(code)
    206             fmt_pxds = {}
    207             for key, value in pxds.iteritems():
    208                 fmt_pxds[key] = fmt(value)
    209             mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
    210             if level is None:
    211                 t = t.body # Make sure a StatListNode is at the top
    212             if not isinstance(t, StatListNode):
    213                 t = StatListNode(pos=mod.pos, stats=[t])
    214             for transform in pipeline:
    215                 if transform is None:
    216                     continue
    217                 t = transform(t)
    218             self.root = t
    219         elif isinstance(code, Node):
    220             if pxds != {}: raise NotImplementedError()
    221             self.root = code
    222         else:
    223             raise ValueError("Unrecognized code format (accepts unicode and Node)")
    224         self.temps = temps
    225 
    226     def copy(self):
    227         return copy_code_tree(self.root)
    228 
    229     def substitute(self, nodes={}, temps=[], pos = None):
    230         return TemplateTransform()(self.root,
    231                                    substitutions = nodes,
    232                                    temps = self.temps + temps, pos = pos)
    233 
    234 class SetPosTransform(VisitorTransform):
    235     def __init__(self, pos):
    236         super(SetPosTransform, self).__init__()
    237         self.pos = pos
    238 
    239     def visit_Node(self, node):
    240         node.pos = self.pos
    241         self.visitchildren(node)
    242         return node
    243