Home | History | Annotate | Download | only in unittest2
      1 """TestSuite"""
      2 
      3 import sys
      4 import unittest
      5 from unittest2 import case, util
      6 
      7 __unittest = True
      8 
      9 
     10 class BaseTestSuite(unittest.TestSuite):
     11     """A simple test suite that doesn't provide class or module shared fixtures.
     12     """
     13     def __init__(self, tests=()):
     14         self._tests = []
     15         self.addTests(tests)
     16 
     17     def __repr__(self):
     18         return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
     19 
     20     def __eq__(self, other):
     21         if not isinstance(other, self.__class__):
     22             return NotImplemented
     23         return list(self) == list(other)
     24 
     25     def __ne__(self, other):
     26         return not self == other
     27 
     28     # Can't guarantee hash invariant, so flag as unhashable
     29     __hash__ = None
     30 
     31     def __iter__(self):
     32         return iter(self._tests)
     33 
     34     def countTestCases(self):
     35         cases = 0
     36         for test in self:
     37             cases += test.countTestCases()
     38         return cases
     39 
     40     def addTest(self, test):
     41         # sanity checks
     42         if not hasattr(test, '__call__'):
     43             raise TypeError("%r is not callable" % (repr(test),))
     44         if isinstance(test, type) and issubclass(test,
     45                                                  (case.TestCase, TestSuite)):
     46             raise TypeError("TestCases and TestSuites must be instantiated "
     47                             "before passing them to addTest()")
     48         self._tests.append(test)
     49 
     50     def addTests(self, tests):
     51         if isinstance(tests, basestring):
     52             raise TypeError("tests must be an iterable of tests, not a string")
     53         for test in tests:
     54             self.addTest(test)
     55 
     56     def run(self, result):
     57         for test in self:
     58             if result.shouldStop:
     59                 break
     60             test(result)
     61         return result
     62 
     63     def __call__(self, *args, **kwds):
     64         return self.run(*args, **kwds)
     65 
     66     def debug(self):
     67         """Run the tests without collecting errors in a TestResult"""
     68         for test in self:
     69             test.debug()
     70 
     71 
     72 class TestSuite(BaseTestSuite):
     73     """A test suite is a composite test consisting of a number of TestCases.
     74 
     75     For use, create an instance of TestSuite, then add test case instances.
     76     When all tests have been added, the suite can be passed to a test
     77     runner, such as TextTestRunner. It will run the individual test cases
     78     in the order in which they were added, aggregating the results. When
     79     subclassing, do not forget to call the base class constructor.
     80     """
     81     
     82 
     83     def run(self, result):
     84         self._wrapped_run(result)
     85         self._tearDownPreviousClass(None, result)
     86         self._handleModuleTearDown(result)
     87         return result
     88 
     89     def debug(self):
     90         """Run the tests without collecting errors in a TestResult"""
     91         debug = _DebugResult()
     92         self._wrapped_run(debug, True)
     93         self._tearDownPreviousClass(None, debug)
     94         self._handleModuleTearDown(debug)
     95 
     96     ################################
     97     # private methods
     98     def _wrapped_run(self, result, debug=False):
     99         for test in self:
    100             if result.shouldStop:
    101                 break
    102             
    103             if _isnotsuite(test):
    104                 self._tearDownPreviousClass(test, result)
    105                 self._handleModuleFixture(test, result)
    106                 self._handleClassSetUp(test, result)
    107                 result._previousTestClass = test.__class__
    108                 
    109                 if (getattr(test.__class__, '_classSetupFailed', False) or 
    110                     getattr(result, '_moduleSetUpFailed', False)):
    111                     continue
    112             
    113             if hasattr(test, '_wrapped_run'):
    114                 test._wrapped_run(result, debug)
    115             elif not debug:
    116                 test(result)
    117             else:
    118                 test.debug()
    119     
    120     def _handleClassSetUp(self, test, result):
    121         previousClass = getattr(result, '_previousTestClass', None)
    122         currentClass = test.__class__
    123         if currentClass == previousClass:
    124             return
    125         if result._moduleSetUpFailed:
    126             return
    127         if getattr(currentClass, "__unittest_skip__", False):
    128             return
    129         
    130         try:
    131             currentClass._classSetupFailed = False
    132         except TypeError:
    133             # test may actually be a function
    134             # so its class will be a builtin-type
    135             pass
    136             
    137         setUpClass = getattr(currentClass, 'setUpClass', None)
    138         if setUpClass is not None:
    139             try:
    140                 setUpClass()
    141             except Exception, e:
    142                 if isinstance(result, _DebugResult):
    143                     raise
    144                 currentClass._classSetupFailed = True
    145                 className = util.strclass(currentClass)
    146                 errorName = 'setUpClass (%s)' % className
    147                 self._addClassOrModuleLevelException(result, e, errorName)
    148     
    149     def _get_previous_module(self, result):
    150         previousModule = None
    151         previousClass = getattr(result, '_previousTestClass', None)
    152         if previousClass is not None:
    153             previousModule = previousClass.__module__
    154         return previousModule
    155         
    156         
    157     def _handleModuleFixture(self, test, result):
    158         previousModule = self._get_previous_module(result)
    159         currentModule = test.__class__.__module__
    160         if currentModule == previousModule:
    161             return
    162         
    163         self._handleModuleTearDown(result)
    164 
    165         
    166         result._moduleSetUpFailed = False
    167         try:
    168             module = sys.modules[currentModule]
    169         except KeyError:
    170             return
    171         setUpModule = getattr(module, 'setUpModule', None)
    172         if setUpModule is not None:
    173             try:
    174                 setUpModule()
    175             except Exception, e:
    176                 if isinstance(result, _DebugResult):
    177                     raise
    178                 result._moduleSetUpFailed = True
    179                 errorName = 'setUpModule (%s)' % currentModule
    180                 self._addClassOrModuleLevelException(result, e, errorName)
    181 
    182     def _addClassOrModuleLevelException(self, result, exception, errorName):
    183         error = _ErrorHolder(errorName)
    184         addSkip = getattr(result, 'addSkip', None)
    185         if addSkip is not None and isinstance(exception, case.SkipTest):
    186             addSkip(error, str(exception))
    187         else:
    188             result.addError(error, sys.exc_info())
    189     
    190     def _handleModuleTearDown(self, result):
    191         previousModule = self._get_previous_module(result)
    192         if previousModule is None:
    193             return
    194         if result._moduleSetUpFailed:
    195             return
    196             
    197         try:
    198             module = sys.modules[previousModule]
    199         except KeyError:
    200             return
    201 
    202         tearDownModule = getattr(module, 'tearDownModule', None)
    203         if tearDownModule is not None:
    204             try:
    205                 tearDownModule()
    206             except Exception, e:
    207                 if isinstance(result, _DebugResult):
    208                     raise
    209                 errorName = 'tearDownModule (%s)' % previousModule
    210                 self._addClassOrModuleLevelException(result, e, errorName)
    211     
    212     def _tearDownPreviousClass(self, test, result):
    213         previousClass = getattr(result, '_previousTestClass', None)
    214         currentClass = test.__class__
    215         if currentClass == previousClass:
    216             return
    217         if getattr(previousClass, '_classSetupFailed', False):
    218             return
    219         if getattr(result, '_moduleSetUpFailed', False):
    220             return
    221         if getattr(previousClass, "__unittest_skip__", False):
    222             return
    223         
    224         tearDownClass = getattr(previousClass, 'tearDownClass', None)
    225         if tearDownClass is not None:
    226             try:
    227                 tearDownClass()
    228             except Exception, e:
    229                 if isinstance(result, _DebugResult):
    230                     raise
    231                 className = util.strclass(previousClass)
    232                 errorName = 'tearDownClass (%s)' % className
    233                 self._addClassOrModuleLevelException(result, e, errorName)
    234 
    235 
    236 class _ErrorHolder(object):
    237     """
    238     Placeholder for a TestCase inside a result. As far as a TestResult
    239     is concerned, this looks exactly like a unit test. Used to insert
    240     arbitrary errors into a test suite run.
    241     """
    242     # Inspired by the ErrorHolder from Twisted:
    243     # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
    244 
    245     # attribute used by TestResult._exc_info_to_string
    246     failureException = None
    247 
    248     def __init__(self, description):
    249         self.description = description
    250 
    251     def id(self):
    252         return self.description
    253 
    254     def shortDescription(self):
    255         return None
    256 
    257     def __repr__(self):
    258         return "<ErrorHolder description=%r>" % (self.description,)
    259 
    260     def __str__(self):
    261         return self.id()
    262 
    263     def run(self, result):
    264         # could call result.addError(...) - but this test-like object
    265         # shouldn't be run anyway
    266         pass
    267 
    268     def __call__(self, result):
    269         return self.run(result)
    270 
    271     def countTestCases(self):
    272         return 0
    273 
    274 def _isnotsuite(test):
    275     "A crude way to tell apart testcases and suites with duck-typing"
    276     try:
    277         iter(test)
    278     except TypeError:
    279         return True
    280     return False
    281 
    282 
    283 class _DebugResult(object):
    284     "Used by the TestSuite to hold previous class when running in debug."
    285     _previousTestClass = None
    286     _moduleSetUpFailed = False
    287     shouldStop = False
    288