1 import os 2 3 from Cython.Compiler import CmdLine 4 from Cython.TestUtils import TransformTest 5 from Cython.Compiler.ParseTreeTransforms import * 6 from Cython.Compiler.Nodes import * 7 from Cython.Compiler import Main, Symtab 8 9 10 class TestNormalizeTree(TransformTest): 11 def test_parserbehaviour_is_what_we_coded_for(self): 12 t = self.fragment(u"if x: y").root 13 self.assertLines(u""" 14 (root): StatListNode 15 stats[0]: IfStatNode 16 if_clauses[0]: IfClauseNode 17 condition: NameNode 18 body: ExprStatNode 19 expr: NameNode 20 """, self.treetypes(t)) 21 22 def test_wrap_singlestat(self): 23 t = self.run_pipeline([NormalizeTree(None)], u"if x: y") 24 self.assertLines(u""" 25 (root): StatListNode 26 stats[0]: IfStatNode 27 if_clauses[0]: IfClauseNode 28 condition: NameNode 29 body: StatListNode 30 stats[0]: ExprStatNode 31 expr: NameNode 32 """, self.treetypes(t)) 33 34 def test_wrap_multistat(self): 35 t = self.run_pipeline([NormalizeTree(None)], u""" 36 if z: 37 x 38 y 39 """) 40 self.assertLines(u""" 41 (root): StatListNode 42 stats[0]: IfStatNode 43 if_clauses[0]: IfClauseNode 44 condition: NameNode 45 body: StatListNode 46 stats[0]: ExprStatNode 47 expr: NameNode 48 stats[1]: ExprStatNode 49 expr: NameNode 50 """, self.treetypes(t)) 51 52 def test_statinexpr(self): 53 t = self.run_pipeline([NormalizeTree(None)], u""" 54 a, b = x, y 55 """) 56 self.assertLines(u""" 57 (root): StatListNode 58 stats[0]: SingleAssignmentNode 59 lhs: TupleNode 60 args[0]: NameNode 61 args[1]: NameNode 62 rhs: TupleNode 63 args[0]: NameNode 64 args[1]: NameNode 65 """, self.treetypes(t)) 66 67 def test_wrap_offagain(self): 68 t = self.run_pipeline([NormalizeTree(None)], u""" 69 x 70 y 71 if z: 72 x 73 """) 74 self.assertLines(u""" 75 (root): StatListNode 76 stats[0]: ExprStatNode 77 expr: NameNode 78 stats[1]: ExprStatNode 79 expr: NameNode 80 stats[2]: IfStatNode 81 if_clauses[0]: IfClauseNode 82 condition: NameNode 83 body: StatListNode 84 stats[0]: ExprStatNode 85 expr: NameNode 86 """, self.treetypes(t)) 87 88 89 def test_pass_eliminated(self): 90 t = self.run_pipeline([NormalizeTree(None)], u"pass") 91 self.assert_(len(t.stats) == 0) 92 93 class TestWithTransform(object): # (TransformTest): # Disabled! 94 95 def test_simplified(self): 96 t = self.run_pipeline([WithTransform(None)], u""" 97 with x: 98 y = z ** 3 99 """) 100 101 self.assertCode(u""" 102 103 $0_0 = x 104 $0_2 = $0_0.__exit__ 105 $0_0.__enter__() 106 $0_1 = True 107 try: 108 try: 109 $1_0 = None 110 y = z ** 3 111 except: 112 $0_1 = False 113 if (not $0_2($1_0)): 114 raise 115 finally: 116 if $0_1: 117 $0_2(None, None, None) 118 119 """, t) 120 121 def test_basic(self): 122 t = self.run_pipeline([WithTransform(None)], u""" 123 with x as y: 124 y = z ** 3 125 """) 126 self.assertCode(u""" 127 128 $0_0 = x 129 $0_2 = $0_0.__exit__ 130 $0_3 = $0_0.__enter__() 131 $0_1 = True 132 try: 133 try: 134 $1_0 = None 135 y = $0_3 136 y = z ** 3 137 except: 138 $0_1 = False 139 if (not $0_2($1_0)): 140 raise 141 finally: 142 if $0_1: 143 $0_2(None, None, None) 144 145 """, t) 146 147 148 class TestInterpretCompilerDirectives(TransformTest): 149 """ 150 This class tests the parallel directives AST-rewriting and importing. 151 """ 152 153 # Test the parallel directives (c)importing 154 155 import_code = u""" 156 cimport cython.parallel 157 cimport cython.parallel as par 158 from cython cimport parallel as par2 159 from cython cimport parallel 160 161 from cython.parallel cimport threadid as tid 162 from cython.parallel cimport threadavailable as tavail 163 from cython.parallel cimport prange 164 """ 165 166 expected_directives_dict = { 167 u'cython.parallel': u'cython.parallel', 168 u'par': u'cython.parallel', 169 u'par2': u'cython.parallel', 170 u'parallel': u'cython.parallel', 171 172 u"tid": u"cython.parallel.threadid", 173 u"tavail": u"cython.parallel.threadavailable", 174 u"prange": u"cython.parallel.prange", 175 } 176 177 178 def setUp(self): 179 super(TestInterpretCompilerDirectives, self).setUp() 180 181 compilation_options = Main.CompilationOptions(Main.default_options) 182 ctx = compilation_options.create_context() 183 184 transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives) 185 transform.module_scope = Symtab.ModuleScope('__main__', None, ctx) 186 self.pipeline = [transform] 187 188 self.debug_exception_on_error = DebugFlags.debug_exception_on_error 189 190 def tearDown(self): 191 DebugFlags.debug_exception_on_error = self.debug_exception_on_error 192 193 def test_parallel_directives_cimports(self): 194 self.run_pipeline(self.pipeline, self.import_code) 195 parallel_directives = self.pipeline[0].parallel_directives 196 self.assertEqual(parallel_directives, self.expected_directives_dict) 197 198 def test_parallel_directives_imports(self): 199 self.run_pipeline(self.pipeline, 200 self.import_code.replace(u'cimport', u'import')) 201 parallel_directives = self.pipeline[0].parallel_directives 202 self.assertEqual(parallel_directives, self.expected_directives_dict) 203 204 205 # TODO: Re-enable once they're more robust. 206 if sys.version_info[:2] >= (2, 5) and False: 207 from Cython.Debugger import DebugWriter 208 from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase 209 else: 210 # skip test, don't let it inherit unittest.TestCase 211 DebuggerTestCase = object 212 213 class TestDebugTransform(DebuggerTestCase): 214 215 def elem_hasattrs(self, elem, attrs): 216 # we shall supporteth python 2.3 ! 217 return all([attr in elem.attrib for attr in attrs]) 218 219 def test_debug_info(self): 220 try: 221 assert os.path.exists(self.debug_dest) 222 223 t = DebugWriter.etree.parse(self.debug_dest) 224 # the xpath of the standard ElementTree is primitive, don't use 225 # anything fancy 226 L = list(t.find('/Module/Globals')) 227 # assertTrue is retarded, use the normal assert statement 228 assert L 229 xml_globals = dict( 230 [(e.attrib['name'], e.attrib['type']) for e in L]) 231 self.assertEqual(len(L), len(xml_globals)) 232 233 L = list(t.find('/Module/Functions')) 234 assert L 235 xml_funcs = dict([(e.attrib['qualified_name'], e) for e in L]) 236 self.assertEqual(len(L), len(xml_funcs)) 237 238 # test globals 239 self.assertEqual('CObject', xml_globals.get('c_var')) 240 self.assertEqual('PythonObject', xml_globals.get('python_var')) 241 242 # test functions 243 funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs', 244 'codefile.closure', 'codefile.inner') 245 required_xml_attrs = 'name', 'cname', 'qualified_name' 246 assert all([f in xml_funcs for f in funcnames]) 247 spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames] 248 249 self.assertEqual(spam.attrib['name'], 'spam') 250 self.assertNotEqual('spam', spam.attrib['cname']) 251 assert self.elem_hasattrs(spam, required_xml_attrs) 252 253 # test locals of functions 254 spam_locals = list(spam.find('Locals')) 255 assert spam_locals 256 spam_locals.sort(key=lambda e: e.attrib['name']) 257 names = [e.attrib['name'] for e in spam_locals] 258 self.assertEqual(list('abcd'), names) 259 assert self.elem_hasattrs(spam_locals[0], required_xml_attrs) 260 261 # test arguments of functions 262 spam_arguments = list(spam.find('Arguments')) 263 assert spam_arguments 264 self.assertEqual(1, len(list(spam_arguments))) 265 266 # test step-into functions 267 step_into = spam.find('StepIntoFunctions') 268 spam_stepinto = [x.attrib['name'] for x in step_into] 269 assert spam_stepinto 270 self.assertEqual(2, len(spam_stepinto)) 271 assert 'puts' in spam_stepinto 272 assert 'some_c_function' in spam_stepinto 273 except: 274 f = open(self.debug_dest) 275 try: 276 print(f.read()) 277 finally: 278 f.close() 279 raise 280 281 282 283 if __name__ == "__main__": 284 import unittest 285 unittest.main() 286