Home | History | Annotate | Download | only in Tests
      1 import os
      2 
      3 from Cython.Compiler import CmdLine
      4 from Cython.TestUtils import TransformTest
      5 from Cython.Compiler.ParseTreeTransforms import *
      6 from Cython.Compiler.Nodes import *
      7 from Cython.Compiler import Main, Symtab
      8 
      9 
     10 class TestNormalizeTree(TransformTest):
     11     def test_parserbehaviour_is_what_we_coded_for(self):
     12         t = self.fragment(u"if x: y").root
     13         self.assertLines(u"""
     14 (root): StatListNode
     15   stats[0]: IfStatNode
     16     if_clauses[0]: IfClauseNode
     17       condition: NameNode
     18       body: ExprStatNode
     19         expr: NameNode
     20 """, self.treetypes(t))
     21 
     22     def test_wrap_singlestat(self):
     23         t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
     24         self.assertLines(u"""
     25 (root): StatListNode
     26   stats[0]: IfStatNode
     27     if_clauses[0]: IfClauseNode
     28       condition: NameNode
     29       body: StatListNode
     30         stats[0]: ExprStatNode
     31           expr: NameNode
     32 """, self.treetypes(t))
     33 
     34     def test_wrap_multistat(self):
     35         t = self.run_pipeline([NormalizeTree(None)], u"""
     36             if z:
     37                 x
     38                 y
     39         """)
     40         self.assertLines(u"""
     41 (root): StatListNode
     42   stats[0]: IfStatNode
     43     if_clauses[0]: IfClauseNode
     44       condition: NameNode
     45       body: StatListNode
     46         stats[0]: ExprStatNode
     47           expr: NameNode
     48         stats[1]: ExprStatNode
     49           expr: NameNode
     50 """, self.treetypes(t))
     51 
     52     def test_statinexpr(self):
     53         t = self.run_pipeline([NormalizeTree(None)], u"""
     54             a, b = x, y
     55         """)
     56         self.assertLines(u"""
     57 (root): StatListNode
     58   stats[0]: SingleAssignmentNode
     59     lhs: TupleNode
     60       args[0]: NameNode
     61       args[1]: NameNode
     62     rhs: TupleNode
     63       args[0]: NameNode
     64       args[1]: NameNode
     65 """, self.treetypes(t))
     66 
     67     def test_wrap_offagain(self):
     68         t = self.run_pipeline([NormalizeTree(None)], u"""
     69             x
     70             y
     71             if z:
     72                 x
     73         """)
     74         self.assertLines(u"""
     75 (root): StatListNode
     76   stats[0]: ExprStatNode
     77     expr: NameNode
     78   stats[1]: ExprStatNode
     79     expr: NameNode
     80   stats[2]: IfStatNode
     81     if_clauses[0]: IfClauseNode
     82       condition: NameNode
     83       body: StatListNode
     84         stats[0]: ExprStatNode
     85           expr: NameNode
     86 """, self.treetypes(t))
     87 
     88 
     89     def test_pass_eliminated(self):
     90         t = self.run_pipeline([NormalizeTree(None)], u"pass")
     91         self.assert_(len(t.stats) == 0)
     92 
     93 class TestWithTransform(object): # (TransformTest): # Disabled!
     94 
     95     def test_simplified(self):
     96         t = self.run_pipeline([WithTransform(None)], u"""
     97         with x:
     98             y = z ** 3
     99         """)
    100 
    101         self.assertCode(u"""
    102 
    103         $0_0 = x
    104         $0_2 = $0_0.__exit__
    105         $0_0.__enter__()
    106         $0_1 = True
    107         try:
    108             try:
    109                 $1_0 = None
    110                 y = z ** 3
    111             except:
    112                 $0_1 = False
    113                 if (not $0_2($1_0)):
    114                     raise
    115         finally:
    116             if $0_1:
    117                 $0_2(None, None, None)
    118 
    119         """, t)
    120 
    121     def test_basic(self):
    122         t = self.run_pipeline([WithTransform(None)], u"""
    123         with x as y:
    124             y = z ** 3
    125         """)
    126         self.assertCode(u"""
    127 
    128         $0_0 = x
    129         $0_2 = $0_0.__exit__
    130         $0_3 = $0_0.__enter__()
    131         $0_1 = True
    132         try:
    133             try:
    134                 $1_0 = None
    135                 y = $0_3
    136                 y = z ** 3
    137             except:
    138                 $0_1 = False
    139                 if (not $0_2($1_0)):
    140                     raise
    141         finally:
    142             if $0_1:
    143                 $0_2(None, None, None)
    144 
    145         """, t)
    146 
    147 
    148 class TestInterpretCompilerDirectives(TransformTest):
    149     """
    150     This class tests the parallel directives AST-rewriting and importing.
    151     """
    152 
    153     # Test the parallel directives (c)importing
    154 
    155     import_code = u"""
    156         cimport cython.parallel
    157         cimport cython.parallel as par
    158         from cython cimport parallel as par2
    159         from cython cimport parallel
    160 
    161         from cython.parallel cimport threadid as tid
    162         from cython.parallel cimport threadavailable as tavail
    163         from cython.parallel cimport prange
    164     """
    165 
    166     expected_directives_dict = {
    167         u'cython.parallel': u'cython.parallel',
    168         u'par': u'cython.parallel',
    169         u'par2': u'cython.parallel',
    170         u'parallel': u'cython.parallel',
    171 
    172         u"tid": u"cython.parallel.threadid",
    173         u"tavail": u"cython.parallel.threadavailable",
    174         u"prange": u"cython.parallel.prange",
    175     }
    176 
    177 
    178     def setUp(self):
    179         super(TestInterpretCompilerDirectives, self).setUp()
    180 
    181         compilation_options = Main.CompilationOptions(Main.default_options)
    182         ctx = compilation_options.create_context()
    183 
    184         transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives)
    185         transform.module_scope = Symtab.ModuleScope('__main__', None, ctx)
    186         self.pipeline = [transform]
    187 
    188         self.debug_exception_on_error = DebugFlags.debug_exception_on_error
    189 
    190     def tearDown(self):
    191         DebugFlags.debug_exception_on_error = self.debug_exception_on_error
    192 
    193     def test_parallel_directives_cimports(self):
    194         self.run_pipeline(self.pipeline, self.import_code)
    195         parallel_directives = self.pipeline[0].parallel_directives
    196         self.assertEqual(parallel_directives, self.expected_directives_dict)
    197 
    198     def test_parallel_directives_imports(self):
    199         self.run_pipeline(self.pipeline,
    200                           self.import_code.replace(u'cimport', u'import'))
    201         parallel_directives = self.pipeline[0].parallel_directives
    202         self.assertEqual(parallel_directives, self.expected_directives_dict)
    203 
    204 
    205 # TODO: Re-enable once they're more robust.
    206 if sys.version_info[:2] >= (2, 5) and False:
    207     from Cython.Debugger import DebugWriter
    208     from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase
    209 else:
    210     # skip test, don't let it inherit unittest.TestCase
    211     DebuggerTestCase = object
    212 
    213 class TestDebugTransform(DebuggerTestCase):
    214 
    215     def elem_hasattrs(self, elem, attrs):
    216         # we shall supporteth python 2.3 !
    217         return all([attr in elem.attrib for attr in attrs])
    218 
    219     def test_debug_info(self):
    220         try:
    221             assert os.path.exists(self.debug_dest)
    222 
    223             t = DebugWriter.etree.parse(self.debug_dest)
    224             # the xpath of the standard ElementTree is primitive, don't use
    225             # anything fancy
    226             L = list(t.find('/Module/Globals'))
    227             # assertTrue is retarded, use the normal assert statement
    228             assert L
    229             xml_globals = dict(
    230                             [(e.attrib['name'], e.attrib['type']) for e in L])
    231             self.assertEqual(len(L), len(xml_globals))
    232 
    233             L = list(t.find('/Module/Functions'))
    234             assert L
    235             xml_funcs = dict([(e.attrib['qualified_name'], e) for e in L])
    236             self.assertEqual(len(L), len(xml_funcs))
    237 
    238             # test globals
    239             self.assertEqual('CObject', xml_globals.get('c_var'))
    240             self.assertEqual('PythonObject', xml_globals.get('python_var'))
    241 
    242             # test functions
    243             funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs',
    244                          'codefile.closure', 'codefile.inner')
    245             required_xml_attrs = 'name', 'cname', 'qualified_name'
    246             assert all([f in xml_funcs for f in funcnames])
    247             spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames]
    248 
    249             self.assertEqual(spam.attrib['name'], 'spam')
    250             self.assertNotEqual('spam', spam.attrib['cname'])
    251             assert self.elem_hasattrs(spam, required_xml_attrs)
    252 
    253             # test locals of functions
    254             spam_locals = list(spam.find('Locals'))
    255             assert spam_locals
    256             spam_locals.sort(key=lambda e: e.attrib['name'])
    257             names = [e.attrib['name'] for e in spam_locals]
    258             self.assertEqual(list('abcd'), names)
    259             assert self.elem_hasattrs(spam_locals[0], required_xml_attrs)
    260 
    261             # test arguments of functions
    262             spam_arguments = list(spam.find('Arguments'))
    263             assert spam_arguments
    264             self.assertEqual(1, len(list(spam_arguments)))
    265 
    266             # test step-into functions
    267             step_into = spam.find('StepIntoFunctions')
    268             spam_stepinto = [x.attrib['name'] for x in step_into]
    269             assert spam_stepinto
    270             self.assertEqual(2, len(spam_stepinto))
    271             assert 'puts' in spam_stepinto
    272             assert 'some_c_function' in spam_stepinto
    273         except:
    274             f = open(self.debug_dest)
    275             try:
    276                 print(f.read())
    277             finally:
    278                 f.close()
    279             raise
    280 
    281 
    282 
    283 if __name__ == "__main__":
    284     import unittest
    285     unittest.main()
    286