Home | History | Annotate | Download | only in Cython
      1 import Cython.Compiler.Errors as Errors
      2 from Cython.CodeWriter import CodeWriter
      3 from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
      4 from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
      5 from Cython.Compiler import TreePath
      6 
      7 import unittest
      8 import os, sys
      9 import tempfile
     10 
     11 
     12 class NodeTypeWriter(TreeVisitor):
     13     def __init__(self):
     14         super(NodeTypeWriter, self).__init__()
     15         self._indents = 0
     16         self.result = []
     17 
     18     def visit_Node(self, node):
     19         if not self.access_path:
     20             name = u"(root)"
     21         else:
     22             tip = self.access_path[-1]
     23             if tip[2] is not None:
     24                 name = u"%s[%d]" % tip[1:3]
     25             else:
     26                 name = tip[1]
     27 
     28         self.result.append(u"  " * self._indents +
     29                            u"%s: %s" % (name, node.__class__.__name__))
     30         self._indents += 1
     31         self.visitchildren(node)
     32         self._indents -= 1
     33 
     34 
     35 def treetypes(root):
     36     """Returns a string representing the tree by class names.
     37     There's a leading and trailing whitespace so that it can be
     38     compared by simple string comparison while still making test
     39     cases look ok."""
     40     w = NodeTypeWriter()
     41     w.visit(root)
     42     return u"\n".join([u""] + w.result + [u""])
     43 
     44 
     45 class CythonTest(unittest.TestCase):
     46 
     47     def setUp(self):
     48         self.listing_file = Errors.listing_file
     49         self.echo_file = Errors.echo_file
     50         Errors.listing_file = Errors.echo_file = None
     51 
     52     def tearDown(self):
     53         Errors.listing_file = self.listing_file
     54         Errors.echo_file = self.echo_file
     55 
     56     def assertLines(self, expected, result):
     57         "Checks that the given strings or lists of strings are equal line by line"
     58         if not isinstance(expected, list): expected = expected.split(u"\n")
     59         if not isinstance(result, list): result = result.split(u"\n")
     60         for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
     61             self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
     62         self.assertEqual(len(expected), len(result),
     63             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
     64 
     65     def codeToLines(self, tree):
     66         writer = CodeWriter()
     67         writer.write(tree)
     68         return writer.result.lines
     69 
     70     def codeToString(self, tree):
     71         return "\n".join(self.codeToLines(tree))
     72 
     73     def assertCode(self, expected, result_tree):
     74         result_lines = self.codeToLines(result_tree)
     75 
     76         expected_lines = strip_common_indent(expected.split("\n"))
     77 
     78         for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
     79             self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
     80         self.assertEqual(len(result_lines), len(expected_lines),
     81             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
     82 
     83     def assertNodeExists(self, path, result_tree):
     84         self.assertNotEqual(TreePath.find_first(result_tree, path), None,
     85                             "Path '%s' not found in result tree" % path)
     86 
     87     def fragment(self, code, pxds={}, pipeline=[]):
     88         "Simply create a tree fragment using the name of the test-case in parse errors."
     89         name = self.id()
     90         if name.startswith("__main__."): name = name[len("__main__."):]
     91         name = name.replace(".", "_")
     92         return TreeFragment(code, name, pxds, pipeline=pipeline)
     93 
     94     def treetypes(self, root):
     95         return treetypes(root)
     96 
     97     def should_fail(self, func, exc_type=Exception):
     98         """Calls "func" and fails if it doesn't raise the right exception
     99         (any exception by default). Also returns the exception in question.
    100         """
    101         try:
    102             func()
    103             self.fail("Expected an exception of type %r" % exc_type)
    104         except exc_type, e:
    105             self.assert_(isinstance(e, exc_type))
    106             return e
    107 
    108     def should_not_fail(self, func):
    109         """Calls func and succeeds if and only if no exception is raised
    110         (i.e. converts exception raising into a failed testcase). Returns
    111         the return value of func."""
    112         try:
    113             return func()
    114         except:
    115             self.fail(str(sys.exc_info()[1]))
    116 
    117 
    118 class TransformTest(CythonTest):
    119     """
    120     Utility base class for transform unit tests. It is based around constructing
    121     test trees (either explicitly or by parsing a Cython code string); running
    122     the transform, serialize it using a customized Cython serializer (with
    123     special markup for nodes that cannot be represented in Cython),
    124     and do a string-comparison line-by-line of the result.
    125 
    126     To create a test case:
    127      - Call run_pipeline. The pipeline should at least contain the transform you
    128        are testing; pyx should be either a string (passed to the parser to
    129        create a post-parse tree) or a node representing input to pipeline.
    130        The result will be a transformed result.
    131 
    132      - Check that the tree is correct. If wanted, assertCode can be used, which
    133        takes a code string as expected, and a ModuleNode in result_tree
    134        (it serializes the ModuleNode to a string and compares line-by-line).
    135 
    136     All code strings are first stripped for whitespace lines and then common
    137     indentation.
    138 
    139     Plans: One could have a pxd dictionary parameter to run_pipeline.
    140     """
    141 
    142     def run_pipeline(self, pipeline, pyx, pxds={}):
    143         tree = self.fragment(pyx, pxds).root
    144         # Run pipeline
    145         for T in pipeline:
    146             tree = T(tree)
    147         return tree
    148 
    149 
    150 class TreeAssertVisitor(VisitorTransform):
    151     # actually, a TreeVisitor would be enough, but this needs to run
    152     # as part of the compiler pipeline
    153 
    154     def visit_CompilerDirectivesNode(self, node):
    155         directives = node.directives
    156         if 'test_assert_path_exists' in directives:
    157             for path in directives['test_assert_path_exists']:
    158                 if TreePath.find_first(node, path) is None:
    159                     Errors.error(
    160                         node.pos,
    161                         "Expected path '%s' not found in result tree" % path)
    162         if 'test_fail_if_path_exists' in directives:
    163             for path in directives['test_fail_if_path_exists']:
    164                 if TreePath.find_first(node, path) is not None:
    165                     Errors.error(
    166                         node.pos,
    167                         "Unexpected path '%s' found in result tree" %  path)
    168         self.visitchildren(node)
    169         return node
    170 
    171     visit_Node = VisitorTransform.recurse_to_children
    172 
    173 
    174 def unpack_source_tree(tree_file, dir=None):
    175     if dir is None:
    176         dir = tempfile.mkdtemp()
    177     header = []
    178     cur_file = None
    179     f = open(tree_file)
    180     try:
    181         lines = f.readlines()
    182     finally:
    183         f.close()
    184     del f
    185     try:
    186         for line in lines:
    187             if line[:5] == '#####':
    188                 filename = line.strip().strip('#').strip().replace('/', os.path.sep)
    189                 path = os.path.join(dir, filename)
    190                 if not os.path.exists(os.path.dirname(path)):
    191                     os.makedirs(os.path.dirname(path))
    192                 if cur_file is not None:
    193                     f, cur_file = cur_file, None
    194                     f.close()
    195                 cur_file = open(path, 'w')
    196             elif cur_file is not None:
    197                 cur_file.write(line)
    198             elif line.strip() and not line.lstrip().startswith('#'):
    199                 if line.strip() not in ('"""', "'''"):
    200                     header.append(line)
    201     finally:
    202         if cur_file is not None:
    203             cur_file.close()
    204     return dir, ''.join(header)
    205