Home | History | Annotate | Download | only in test
      1 import sys
      2 
      3 from cStringIO import StringIO
      4 
      5 import unittest
      6 
      7 
      8 def resultFactory(*_):
      9     return unittest.TestResult()
     10 
     11 
     12 class TestSetups(unittest.TestCase):
     13 
     14     def getRunner(self):
     15         return unittest.TextTestRunner(resultclass=resultFactory,
     16                                           stream=StringIO())
     17     def runTests(self, *cases):
     18         suite = unittest.TestSuite()
     19         for case in cases:
     20             tests = unittest.defaultTestLoader.loadTestsFromTestCase(case)
     21             suite.addTests(tests)
     22 
     23         runner = self.getRunner()
     24 
     25         # creating a nested suite exposes some potential bugs

     26         realSuite = unittest.TestSuite()
     27         realSuite.addTest(suite)
     28         # adding empty suites to the end exposes potential bugs

     29         suite.addTest(unittest.TestSuite())
     30         realSuite.addTest(unittest.TestSuite())
     31         return runner.run(realSuite)
     32 
     33     def test_setup_class(self):
     34         class Test(unittest.TestCase):
     35             setUpCalled = 0
     36             @classmethod
     37             def setUpClass(cls):
     38                 Test.setUpCalled += 1
     39                 unittest.TestCase.setUpClass()
     40             def test_one(self):
     41                 pass
     42             def test_two(self):
     43                 pass
     44 
     45         result = self.runTests(Test)
     46 
     47         self.assertEqual(Test.setUpCalled, 1)
     48         self.assertEqual(result.testsRun, 2)
     49         self.assertEqual(len(result.errors), 0)
     50 
     51     def test_teardown_class(self):
     52         class Test(unittest.TestCase):
     53             tearDownCalled = 0
     54             @classmethod
     55             def tearDownClass(cls):
     56                 Test.tearDownCalled += 1
     57                 unittest.TestCase.tearDownClass()
     58             def test_one(self):
     59                 pass
     60             def test_two(self):
     61                 pass
     62 
     63         result = self.runTests(Test)
     64 
     65         self.assertEqual(Test.tearDownCalled, 1)
     66         self.assertEqual(result.testsRun, 2)
     67         self.assertEqual(len(result.errors), 0)
     68 
     69     def test_teardown_class_two_classes(self):
     70         class Test(unittest.TestCase):
     71             tearDownCalled = 0
     72             @classmethod
     73             def tearDownClass(cls):
     74                 Test.tearDownCalled += 1
     75                 unittest.TestCase.tearDownClass()
     76             def test_one(self):
     77                 pass
     78             def test_two(self):
     79                 pass
     80 
     81         class Test2(unittest.TestCase):
     82             tearDownCalled = 0
     83             @classmethod
     84             def tearDownClass(cls):
     85                 Test2.tearDownCalled += 1
     86                 unittest.TestCase.tearDownClass()
     87             def test_one(self):
     88                 pass
     89             def test_two(self):
     90                 pass
     91 
     92         result = self.runTests(Test, Test2)
     93 
     94         self.assertEqual(Test.tearDownCalled, 1)
     95         self.assertEqual(Test2.tearDownCalled, 1)
     96         self.assertEqual(result.testsRun, 4)
     97         self.assertEqual(len(result.errors), 0)
     98 
     99     def test_error_in_setupclass(self):
    100         class BrokenTest(unittest.TestCase):
    101             @classmethod
    102             def setUpClass(cls):
    103                 raise TypeError('foo')
    104             def test_one(self):
    105                 pass
    106             def test_two(self):
    107                 pass
    108 
    109         result = self.runTests(BrokenTest)
    110 
    111         self.assertEqual(result.testsRun, 0)
    112         self.assertEqual(len(result.errors), 1)
    113         error, _ = result.errors[0]
    114         self.assertEqual(str(error),
    115                     'setUpClass (%s.BrokenTest)' % __name__)
    116 
    117     def test_error_in_teardown_class(self):
    118         class Test(unittest.TestCase):
    119             tornDown = 0
    120             @classmethod
    121             def tearDownClass(cls):
    122                 Test.tornDown += 1
    123                 raise TypeError('foo')
    124             def test_one(self):
    125                 pass
    126             def test_two(self):
    127                 pass
    128 
    129         class Test2(unittest.TestCase):
    130             tornDown = 0
    131             @classmethod
    132             def tearDownClass(cls):
    133                 Test2.tornDown += 1
    134                 raise TypeError('foo')
    135             def test_one(self):
    136                 pass
    137             def test_two(self):
    138                 pass
    139 
    140         result = self.runTests(Test, Test2)
    141         self.assertEqual(result.testsRun, 4)
    142         self.assertEqual(len(result.errors), 2)
    143         self.assertEqual(Test.tornDown, 1)
    144         self.assertEqual(Test2.tornDown, 1)
    145 
    146         error, _ = result.errors[0]
    147         self.assertEqual(str(error),
    148                     'tearDownClass (%s.Test)' % __name__)
    149 
    150     def test_class_not_torndown_when_setup_fails(self):
    151         class Test(unittest.TestCase):
    152             tornDown = False
    153             @classmethod
    154             def setUpClass(cls):
    155                 raise TypeError
    156             @classmethod
    157             def tearDownClass(cls):
    158                 Test.tornDown = True
    159                 raise TypeError('foo')
    160             def test_one(self):
    161                 pass
    162 
    163         self.runTests(Test)
    164         self.assertFalse(Test.tornDown)
    165 
    166     def test_class_not_setup_or_torndown_when_skipped(self):
    167         class Test(unittest.TestCase):
    168             classSetUp = False
    169             tornDown = False
    170             @classmethod
    171             def setUpClass(cls):
    172                 Test.classSetUp = True
    173             @classmethod
    174             def tearDownClass(cls):
    175                 Test.tornDown = True
    176             def test_one(self):
    177                 pass
    178 
    179         Test = unittest.skip("hop")(Test)
    180         self.runTests(Test)
    181         self.assertFalse(Test.classSetUp)
    182         self.assertFalse(Test.tornDown)
    183 
    184     def test_setup_teardown_order_with_pathological_suite(self):
    185         results = []
    186 
    187         class Module1(object):
    188             @staticmethod
    189             def setUpModule():
    190                 results.append('Module1.setUpModule')
    191             @staticmethod
    192             def tearDownModule():
    193                 results.append('Module1.tearDownModule')
    194 
    195         class Module2(object):
    196             @staticmethod
    197             def setUpModule():
    198                 results.append('Module2.setUpModule')
    199             @staticmethod
    200             def tearDownModule():
    201                 results.append('Module2.tearDownModule')
    202 
    203         class Test1(unittest.TestCase):
    204             @classmethod
    205             def setUpClass(cls):
    206                 results.append('setup 1')
    207             @classmethod
    208             def tearDownClass(cls):
    209                 results.append('teardown 1')
    210             def testOne(self):
    211                 results.append('Test1.testOne')
    212             def testTwo(self):
    213                 results.append('Test1.testTwo')
    214 
    215         class Test2(unittest.TestCase):
    216             @classmethod
    217             def setUpClass(cls):
    218                 results.append('setup 2')
    219             @classmethod
    220             def tearDownClass(cls):
    221                 results.append('teardown 2')
    222             def testOne(self):
    223                 results.append('Test2.testOne')
    224             def testTwo(self):
    225                 results.append('Test2.testTwo')
    226 
    227         class Test3(unittest.TestCase):
    228             @classmethod
    229             def setUpClass(cls):
    230                 results.append('setup 3')
    231             @classmethod
    232             def tearDownClass(cls):
    233                 results.append('teardown 3')
    234             def testOne(self):
    235                 results.append('Test3.testOne')
    236             def testTwo(self):
    237                 results.append('Test3.testTwo')
    238 
    239         Test1.__module__ = Test2.__module__ = 'Module'
    240         Test3.__module__ = 'Module2'
    241         sys.modules['Module'] = Module1
    242         sys.modules['Module2'] = Module2
    243 
    244         first = unittest.TestSuite((Test1('testOne'),))
    245         second = unittest.TestSuite((Test1('testTwo'),))
    246         third = unittest.TestSuite((Test2('testOne'),))
    247         fourth = unittest.TestSuite((Test2('testTwo'),))
    248         fifth = unittest.TestSuite((Test3('testOne'),))
    249         sixth = unittest.TestSuite((Test3('testTwo'),))
    250         suite = unittest.TestSuite((first, second, third, fourth, fifth, sixth))
    251 
    252         runner = self.getRunner()
    253         result = runner.run(suite)
    254         self.assertEqual(result.testsRun, 6)
    255         self.assertEqual(len(result.errors), 0)
    256 
    257         self.assertEqual(results,
    258                          ['Module1.setUpModule', 'setup 1',
    259                           'Test1.testOne', 'Test1.testTwo', 'teardown 1',
    260                           'setup 2', 'Test2.testOne', 'Test2.testTwo',
    261                           'teardown 2', 'Module1.tearDownModule',
    262                           'Module2.setUpModule', 'setup 3',
    263                           'Test3.testOne', 'Test3.testTwo',
    264                           'teardown 3', 'Module2.tearDownModule'])
    265 
    266     def test_setup_module(self):
    267         class Module(object):
    268             moduleSetup = 0
    269             @staticmethod
    270             def setUpModule():
    271                 Module.moduleSetup += 1
    272 
    273         class Test(unittest.TestCase):
    274             def test_one(self):
    275                 pass
    276             def test_two(self):
    277                 pass
    278         Test.__module__ = 'Module'
    279         sys.modules['Module'] = Module
    280 
    281         result = self.runTests(Test)
    282         self.assertEqual(Module.moduleSetup, 1)
    283         self.assertEqual(result.testsRun, 2)
    284         self.assertEqual(len(result.errors), 0)
    285 
    286     def test_error_in_setup_module(self):
    287         class Module(object):
    288             moduleSetup = 0
    289             moduleTornDown = 0
    290             @staticmethod
    291             def setUpModule():
    292                 Module.moduleSetup += 1
    293                 raise TypeError('foo')
    294             @staticmethod
    295             def tearDownModule():
    296                 Module.moduleTornDown += 1
    297 
    298         class Test(unittest.TestCase):
    299             classSetUp = False
    300             classTornDown = False
    301             @classmethod
    302             def setUpClass(cls):
    303                 Test.classSetUp = True
    304             @classmethod
    305             def tearDownClass(cls):
    306                 Test.classTornDown = True
    307             def test_one(self):
    308                 pass
    309             def test_two(self):
    310                 pass
    311 
    312         class Test2(unittest.TestCase):
    313             def test_one(self):
    314                 pass
    315             def test_two(self):
    316                 pass
    317         Test.__module__ = 'Module'
    318         Test2.__module__ = 'Module'
    319         sys.modules['Module'] = Module
    320 
    321         result = self.runTests(Test, Test2)
    322         self.assertEqual(Module.moduleSetup, 1)
    323         self.assertEqual(Module.moduleTornDown, 0)
    324         self.assertEqual(result.testsRun, 0)
    325         self.assertFalse(Test.classSetUp)
    326         self.assertFalse(Test.classTornDown)
    327         self.assertEqual(len(result.errors), 1)
    328         error, _ = result.errors[0]
    329         self.assertEqual(str(error), 'setUpModule (Module)')
    330 
    331     def test_testcase_with_missing_module(self):
    332         class Test(unittest.TestCase):
    333             def test_one(self):
    334                 pass
    335             def test_two(self):
    336                 pass
    337         Test.__module__ = 'Module'
    338         sys.modules.pop('Module', None)
    339 
    340         result = self.runTests(Test)
    341         self.assertEqual(result.testsRun, 2)
    342 
    343     def test_teardown_module(self):
    344         class Module(object):
    345             moduleTornDown = 0
    346             @staticmethod
    347             def tearDownModule():
    348                 Module.moduleTornDown += 1
    349 
    350         class Test(unittest.TestCase):
    351             def test_one(self):
    352                 pass
    353             def test_two(self):
    354                 pass
    355         Test.__module__ = 'Module'
    356         sys.modules['Module'] = Module
    357 
    358         result = self.runTests(Test)
    359         self.assertEqual(Module.moduleTornDown, 1)
    360         self.assertEqual(result.testsRun, 2)
    361         self.assertEqual(len(result.errors), 0)
    362 
    363     def test_error_in_teardown_module(self):
    364         class Module(object):
    365             moduleTornDown = 0
    366             @staticmethod
    367             def tearDownModule():
    368                 Module.moduleTornDown += 1
    369                 raise TypeError('foo')
    370 
    371         class Test(unittest.TestCase):
    372             classSetUp = False
    373             classTornDown = False
    374             @classmethod
    375             def setUpClass(cls):
    376                 Test.classSetUp = True
    377             @classmethod
    378             def tearDownClass(cls):
    379                 Test.classTornDown = True
    380             def test_one(self):
    381                 pass
    382             def test_two(self):
    383                 pass
    384 
    385         class Test2(unittest.TestCase):
    386             def test_one(self):
    387                 pass
    388             def test_two(self):
    389                 pass
    390         Test.__module__ = 'Module'
    391         Test2.__module__ = 'Module'
    392         sys.modules['Module'] = Module
    393 
    394         result = self.runTests(Test, Test2)
    395         self.assertEqual(Module.moduleTornDown, 1)
    396         self.assertEqual(result.testsRun, 4)
    397         self.assertTrue(Test.classSetUp)
    398         self.assertTrue(Test.classTornDown)
    399         self.assertEqual(len(result.errors), 1)
    400         error, _ = result.errors[0]
    401         self.assertEqual(str(error), 'tearDownModule (Module)')
    402 
    403     def test_skiptest_in_setupclass(self):
    404         class Test(unittest.TestCase):
    405             @classmethod
    406             def setUpClass(cls):
    407                 raise unittest.SkipTest('foo')
    408             def test_one(self):
    409                 pass
    410             def test_two(self):
    411                 pass
    412 
    413         result = self.runTests(Test)
    414         self.assertEqual(result.testsRun, 0)
    415         self.assertEqual(len(result.errors), 0)
    416         self.assertEqual(len(result.skipped), 1)
    417         skipped = result.skipped[0][0]
    418         self.assertEqual(str(skipped), 'setUpClass (%s.Test)' % __name__)
    419 
    420     def test_skiptest_in_setupmodule(self):
    421         class Test(unittest.TestCase):
    422             def test_one(self):
    423                 pass
    424             def test_two(self):
    425                 pass
    426 
    427         class Module(object):
    428             @staticmethod
    429             def setUpModule():
    430                 raise unittest.SkipTest('foo')
    431 
    432         Test.__module__ = 'Module'
    433         sys.modules['Module'] = Module
    434 
    435         result = self.runTests(Test)
    436         self.assertEqual(result.testsRun, 0)
    437         self.assertEqual(len(result.errors), 0)
    438         self.assertEqual(len(result.skipped), 1)
    439         skipped = result.skipped[0][0]
    440         self.assertEqual(str(skipped), 'setUpModule (Module)')
    441 
    442     def test_suite_debug_executes_setups_and_teardowns(self):
    443         ordering = []
    444 
    445         class Module(object):
    446             @staticmethod
    447             def setUpModule():
    448                 ordering.append('setUpModule')
    449             @staticmethod
    450             def tearDownModule():
    451                 ordering.append('tearDownModule')
    452 
    453         class Test(unittest.TestCase):
    454             @classmethod
    455             def setUpClass(cls):
    456                 ordering.append('setUpClass')
    457             @classmethod
    458             def tearDownClass(cls):
    459                 ordering.append('tearDownClass')
    460             def test_something(self):
    461                 ordering.append('test_something')
    462 
    463         Test.__module__ = 'Module'
    464         sys.modules['Module'] = Module
    465 
    466         suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
    467         suite.debug()
    468         expectedOrder = ['setUpModule', 'setUpClass', 'test_something', 'tearDownClass', 'tearDownModule']
    469         self.assertEqual(ordering, expectedOrder)
    470 
    471     def test_suite_debug_propagates_exceptions(self):
    472         class Module(object):
    473             @staticmethod
    474             def setUpModule():
    475                 if phase == 0:
    476                     raise Exception('setUpModule')
    477             @staticmethod
    478             def tearDownModule():
    479                 if phase == 1:
    480                     raise Exception('tearDownModule')
    481 
    482         class Test(unittest.TestCase):
    483             @classmethod
    484             def setUpClass(cls):
    485                 if phase == 2:
    486                     raise Exception('setUpClass')
    487             @classmethod
    488             def tearDownClass(cls):
    489                 if phase == 3:
    490                     raise Exception('tearDownClass')
    491             def test_something(self):
    492                 if phase == 4:
    493                     raise Exception('test_something')
    494 
    495         Test.__module__ = 'Module'
    496         sys.modules['Module'] = Module
    497 
    498         _suite = unittest.defaultTestLoader.loadTestsFromTestCase(Test)
    499         suite = unittest.TestSuite()
    500         suite.addTest(_suite)
    501 
    502         messages = ('setUpModule', 'tearDownModule', 'setUpClass', 'tearDownClass', 'test_something')
    503         for phase, msg in enumerate(messages):
    504             with self.assertRaisesRegexp(Exception, msg):
    505                 suite.debug()
    506 
    507 if __name__ == '__main__':
    508     unittest.main()
    509