1 import Cython.Compiler.Errors as Errors 2 from Cython.CodeWriter import CodeWriter 3 from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent 4 from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform 5 from Cython.Compiler import TreePath 6 7 import unittest 8 import os, sys 9 import tempfile 10 11 12 class NodeTypeWriter(TreeVisitor): 13 def __init__(self): 14 super(NodeTypeWriter, self).__init__() 15 self._indents = 0 16 self.result = [] 17 18 def visit_Node(self, node): 19 if not self.access_path: 20 name = u"(root)" 21 else: 22 tip = self.access_path[-1] 23 if tip[2] is not None: 24 name = u"%s[%d]" % tip[1:3] 25 else: 26 name = tip[1] 27 28 self.result.append(u" " * self._indents + 29 u"%s: %s" % (name, node.__class__.__name__)) 30 self._indents += 1 31 self.visitchildren(node) 32 self._indents -= 1 33 34 35 def treetypes(root): 36 """Returns a string representing the tree by class names. 37 There's a leading and trailing whitespace so that it can be 38 compared by simple string comparison while still making test 39 cases look ok.""" 40 w = NodeTypeWriter() 41 w.visit(root) 42 return u"\n".join([u""] + w.result + [u""]) 43 44 45 class CythonTest(unittest.TestCase): 46 47 def setUp(self): 48 self.listing_file = Errors.listing_file 49 self.echo_file = Errors.echo_file 50 Errors.listing_file = Errors.echo_file = None 51 52 def tearDown(self): 53 Errors.listing_file = self.listing_file 54 Errors.echo_file = self.echo_file 55 56 def assertLines(self, expected, result): 57 "Checks that the given strings or lists of strings are equal line by line" 58 if not isinstance(expected, list): expected = expected.split(u"\n") 59 if not isinstance(result, list): result = result.split(u"\n") 60 for idx, (expected_line, result_line) in enumerate(zip(expected, result)): 61 self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line)) 62 self.assertEqual(len(expected), len(result), 63 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result))) 64 65 def codeToLines(self, tree): 66 writer = CodeWriter() 67 writer.write(tree) 68 return writer.result.lines 69 70 def codeToString(self, tree): 71 return "\n".join(self.codeToLines(tree)) 72 73 def assertCode(self, expected, result_tree): 74 result_lines = self.codeToLines(result_tree) 75 76 expected_lines = strip_common_indent(expected.split("\n")) 77 78 for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)): 79 self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line)) 80 self.assertEqual(len(result_lines), len(expected_lines), 81 "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) 82 83 def assertNodeExists(self, path, result_tree): 84 self.assertNotEqual(TreePath.find_first(result_tree, path), None, 85 "Path '%s' not found in result tree" % path) 86 87 def fragment(self, code, pxds={}, pipeline=[]): 88 "Simply create a tree fragment using the name of the test-case in parse errors." 89 name = self.id() 90 if name.startswith("__main__."): name = name[len("__main__."):] 91 name = name.replace(".", "_") 92 return TreeFragment(code, name, pxds, pipeline=pipeline) 93 94 def treetypes(self, root): 95 return treetypes(root) 96 97 def should_fail(self, func, exc_type=Exception): 98 """Calls "func" and fails if it doesn't raise the right exception 99 (any exception by default). Also returns the exception in question. 100 """ 101 try: 102 func() 103 self.fail("Expected an exception of type %r" % exc_type) 104 except exc_type, e: 105 self.assert_(isinstance(e, exc_type)) 106 return e 107 108 def should_not_fail(self, func): 109 """Calls func and succeeds if and only if no exception is raised 110 (i.e. converts exception raising into a failed testcase). Returns 111 the return value of func.""" 112 try: 113 return func() 114 except: 115 self.fail(str(sys.exc_info()[1])) 116 117 118 class TransformTest(CythonTest): 119 """ 120 Utility base class for transform unit tests. It is based around constructing 121 test trees (either explicitly or by parsing a Cython code string); running 122 the transform, serialize it using a customized Cython serializer (with 123 special markup for nodes that cannot be represented in Cython), 124 and do a string-comparison line-by-line of the result. 125 126 To create a test case: 127 - Call run_pipeline. The pipeline should at least contain the transform you 128 are testing; pyx should be either a string (passed to the parser to 129 create a post-parse tree) or a node representing input to pipeline. 130 The result will be a transformed result. 131 132 - Check that the tree is correct. If wanted, assertCode can be used, which 133 takes a code string as expected, and a ModuleNode in result_tree 134 (it serializes the ModuleNode to a string and compares line-by-line). 135 136 All code strings are first stripped for whitespace lines and then common 137 indentation. 138 139 Plans: One could have a pxd dictionary parameter to run_pipeline. 140 """ 141 142 def run_pipeline(self, pipeline, pyx, pxds={}): 143 tree = self.fragment(pyx, pxds).root 144 # Run pipeline 145 for T in pipeline: 146 tree = T(tree) 147 return tree 148 149 150 class TreeAssertVisitor(VisitorTransform): 151 # actually, a TreeVisitor would be enough, but this needs to run 152 # as part of the compiler pipeline 153 154 def visit_CompilerDirectivesNode(self, node): 155 directives = node.directives 156 if 'test_assert_path_exists' in directives: 157 for path in directives['test_assert_path_exists']: 158 if TreePath.find_first(node, path) is None: 159 Errors.error( 160 node.pos, 161 "Expected path '%s' not found in result tree" % path) 162 if 'test_fail_if_path_exists' in directives: 163 for path in directives['test_fail_if_path_exists']: 164 if TreePath.find_first(node, path) is not None: 165 Errors.error( 166 node.pos, 167 "Unexpected path '%s' found in result tree" % path) 168 self.visitchildren(node) 169 return node 170 171 visit_Node = VisitorTransform.recurse_to_children 172 173 174 def unpack_source_tree(tree_file, dir=None): 175 if dir is None: 176 dir = tempfile.mkdtemp() 177 header = [] 178 cur_file = None 179 f = open(tree_file) 180 try: 181 lines = f.readlines() 182 finally: 183 f.close() 184 del f 185 try: 186 for line in lines: 187 if line[:5] == '#####': 188 filename = line.strip().strip('#').strip().replace('/', os.path.sep) 189 path = os.path.join(dir, filename) 190 if not os.path.exists(os.path.dirname(path)): 191 os.makedirs(os.path.dirname(path)) 192 if cur_file is not None: 193 f, cur_file = cur_file, None 194 f.close() 195 cur_file = open(path, 'w') 196 elif cur_file is not None: 197 cur_file.write(line) 198 elif line.strip() and not line.lstrip().startswith('#'): 199 if line.strip() not in ('"""', "'''"): 200 header.append(line) 201 finally: 202 if cur_file is not None: 203 cur_file.close() 204 return dir, ''.join(header) 205