Home | History | Annotate | Download | only in unittest
      1 """TestSuite"""
      2 
      3 import sys
      4 
      5 from . import case
      6 from . import util
      7 
      8 __unittest = True
      9 
     10 
     11 def _call_if_exists(parent, attr):
     12     func = getattr(parent, attr, lambda: None)
     13     func()
     14 
     15 
     16 class BaseTestSuite(object):
     17     """A simple test suite that doesn't provide class or module shared fixtures.
     18     """
     19     _cleanup = True
     20 
     21     def __init__(self, tests=()):
     22         self._tests = []
     23         self._removed_tests = 0
     24         self.addTests(tests)
     25 
     26     def __repr__(self):
     27         return "<%s tests=%s>" % (util.strclass(self.__class__), list(self))
     28 
     29     def __eq__(self, other):
     30         if not isinstance(other, self.__class__):
     31             return NotImplemented
     32         return list(self) == list(other)
     33 
     34     def __iter__(self):
     35         return iter(self._tests)
     36 
     37     def countTestCases(self):
     38         cases = self._removed_tests
     39         for test in self:
     40             if test:
     41                 cases += test.countTestCases()
     42         return cases
     43 
     44     def addTest(self, test):
     45         # sanity checks
     46         if not callable(test):
     47             raise TypeError("{} is not callable".format(repr(test)))
     48         if isinstance(test, type) and issubclass(test,
     49                                                  (case.TestCase, TestSuite)):
     50             raise TypeError("TestCases and TestSuites must be instantiated "
     51                             "before passing them to addTest()")
     52         self._tests.append(test)
     53 
     54     def addTests(self, tests):
     55         if isinstance(tests, str):
     56             raise TypeError("tests must be an iterable of tests, not a string")
     57         for test in tests:
     58             self.addTest(test)
     59 
     60     def run(self, result):
     61         for index, test in enumerate(self):
     62             if result.shouldStop:
     63                 break
     64             test(result)
     65             if self._cleanup:
     66                 self._removeTestAtIndex(index)
     67         return result
     68 
     69     def _removeTestAtIndex(self, index):
     70         """Stop holding a reference to the TestCase at index."""
     71         try:
     72             test = self._tests[index]
     73         except TypeError:
     74             # support for suite implementations that have overridden self._tests
     75             pass
     76         else:
     77             # Some unittest tests add non TestCase/TestSuite objects to
     78             # the suite.
     79             if hasattr(test, 'countTestCases'):
     80                 self._removed_tests += test.countTestCases()
     81             self._tests[index] = None
     82 
     83     def __call__(self, *args, **kwds):
     84         return self.run(*args, **kwds)
     85 
     86     def debug(self):
     87         """Run the tests without collecting errors in a TestResult"""
     88         for test in self:
     89             test.debug()
     90 
     91 
     92 class TestSuite(BaseTestSuite):
     93     """A test suite is a composite test consisting of a number of TestCases.
     94 
     95     For use, create an instance of TestSuite, then add test case instances.
     96     When all tests have been added, the suite can be passed to a test
     97     runner, such as TextTestRunner. It will run the individual test cases
     98     in the order in which they were added, aggregating the results. When
     99     subclassing, do not forget to call the base class constructor.
    100     """
    101 
    102     def run(self, result, debug=False):
    103         topLevel = False
    104         if getattr(result, '_testRunEntered', False) is False:
    105             result._testRunEntered = topLevel = True
    106 
    107         for index, test in enumerate(self):
    108             if result.shouldStop:
    109                 break
    110 
    111             if _isnotsuite(test):
    112                 self._tearDownPreviousClass(test, result)
    113                 self._handleModuleFixture(test, result)
    114                 self._handleClassSetUp(test, result)
    115                 result._previousTestClass = test.__class__
    116 
    117                 if (getattr(test.__class__, '_classSetupFailed', False) or
    118                     getattr(result, '_moduleSetUpFailed', False)):
    119                     continue
    120 
    121             if not debug:
    122                 test(result)
    123             else:
    124                 test.debug()
    125 
    126             if self._cleanup:
    127                 self._removeTestAtIndex(index)
    128 
    129         if topLevel:
    130             self._tearDownPreviousClass(None, result)
    131             self._handleModuleTearDown(result)
    132             result._testRunEntered = False
    133         return result
    134 
    135     def debug(self):
    136         """Run the tests without collecting errors in a TestResult"""
    137         debug = _DebugResult()
    138         self.run(debug, True)
    139 
    140     ################################
    141 
    142     def _handleClassSetUp(self, test, result):
    143         previousClass = getattr(result, '_previousTestClass', None)
    144         currentClass = test.__class__
    145         if currentClass == previousClass:
    146             return
    147         if result._moduleSetUpFailed:
    148             return
    149         if getattr(currentClass, "__unittest_skip__", False):
    150             return
    151 
    152         try:
    153             currentClass._classSetupFailed = False
    154         except TypeError:
    155             # test may actually be a function
    156             # so its class will be a builtin-type
    157             pass
    158 
    159         setUpClass = getattr(currentClass, 'setUpClass', None)
    160         if setUpClass is not None:
    161             _call_if_exists(result, '_setupStdout')
    162             try:
    163                 setUpClass()
    164             except Exception as e:
    165                 if isinstance(result, _DebugResult):
    166                     raise
    167                 currentClass._classSetupFailed = True
    168                 className = util.strclass(currentClass)
    169                 errorName = 'setUpClass (%s)' % className
    170                 self._addClassOrModuleLevelException(result, e, errorName)
    171             finally:
    172                 _call_if_exists(result, '_restoreStdout')
    173 
    174     def _get_previous_module(self, result):
    175         previousModule = None
    176         previousClass = getattr(result, '_previousTestClass', None)
    177         if previousClass is not None:
    178             previousModule = previousClass.__module__
    179         return previousModule
    180 
    181 
    182     def _handleModuleFixture(self, test, result):
    183         previousModule = self._get_previous_module(result)
    184         currentModule = test.__class__.__module__
    185         if currentModule == previousModule:
    186             return
    187 
    188         self._handleModuleTearDown(result)
    189 
    190 
    191         result._moduleSetUpFailed = False
    192         try:
    193             module = sys.modules[currentModule]
    194         except KeyError:
    195             return
    196         setUpModule = getattr(module, 'setUpModule', None)
    197         if setUpModule is not None:
    198             _call_if_exists(result, '_setupStdout')
    199             try:
    200                 setUpModule()
    201             except Exception as e:
    202                 if isinstance(result, _DebugResult):
    203                     raise
    204                 result._moduleSetUpFailed = True
    205                 errorName = 'setUpModule (%s)' % currentModule
    206                 self._addClassOrModuleLevelException(result, e, errorName)
    207             finally:
    208                 _call_if_exists(result, '_restoreStdout')
    209 
    210     def _addClassOrModuleLevelException(self, result, exception, errorName):
    211         error = _ErrorHolder(errorName)
    212         addSkip = getattr(result, 'addSkip', None)
    213         if addSkip is not None and isinstance(exception, case.SkipTest):
    214             addSkip(error, str(exception))
    215         else:
    216             result.addError(error, sys.exc_info())
    217 
    218     def _handleModuleTearDown(self, result):
    219         previousModule = self._get_previous_module(result)
    220         if previousModule is None:
    221             return
    222         if result._moduleSetUpFailed:
    223             return
    224 
    225         try:
    226             module = sys.modules[previousModule]
    227         except KeyError:
    228             return
    229 
    230         tearDownModule = getattr(module, 'tearDownModule', None)
    231         if tearDownModule is not None:
    232             _call_if_exists(result, '_setupStdout')
    233             try:
    234                 tearDownModule()
    235             except Exception as e:
    236                 if isinstance(result, _DebugResult):
    237                     raise
    238                 errorName = 'tearDownModule (%s)' % previousModule
    239                 self._addClassOrModuleLevelException(result, e, errorName)
    240             finally:
    241                 _call_if_exists(result, '_restoreStdout')
    242 
    243     def _tearDownPreviousClass(self, test, result):
    244         previousClass = getattr(result, '_previousTestClass', None)
    245         currentClass = test.__class__
    246         if currentClass == previousClass:
    247             return
    248         if getattr(previousClass, '_classSetupFailed', False):
    249             return
    250         if getattr(result, '_moduleSetUpFailed', False):
    251             return
    252         if getattr(previousClass, "__unittest_skip__", False):
    253             return
    254 
    255         tearDownClass = getattr(previousClass, 'tearDownClass', None)
    256         if tearDownClass is not None:
    257             _call_if_exists(result, '_setupStdout')
    258             try:
    259                 tearDownClass()
    260             except Exception as e:
    261                 if isinstance(result, _DebugResult):
    262                     raise
    263                 className = util.strclass(previousClass)
    264                 errorName = 'tearDownClass (%s)' % className
    265                 self._addClassOrModuleLevelException(result, e, errorName)
    266             finally:
    267                 _call_if_exists(result, '_restoreStdout')
    268 
    269 
    270 class _ErrorHolder(object):
    271     """
    272     Placeholder for a TestCase inside a result. As far as a TestResult
    273     is concerned, this looks exactly like a unit test. Used to insert
    274     arbitrary errors into a test suite run.
    275     """
    276     # Inspired by the ErrorHolder from Twisted:
    277     # http://twistedmatrix.com/trac/browser/trunk/twisted/trial/runner.py
    278 
    279     # attribute used by TestResult._exc_info_to_string
    280     failureException = None
    281 
    282     def __init__(self, description):
    283         self.description = description
    284 
    285     def id(self):
    286         return self.description
    287 
    288     def shortDescription(self):
    289         return None
    290 
    291     def __repr__(self):
    292         return "<ErrorHolder description=%r>" % (self.description,)
    293 
    294     def __str__(self):
    295         return self.id()
    296 
    297     def run(self, result):
    298         # could call result.addError(...) - but this test-like object
    299         # shouldn't be run anyway
    300         pass
    301 
    302     def __call__(self, result):
    303         return self.run(result)
    304 
    305     def countTestCases(self):
    306         return 0
    307 
    308 def _isnotsuite(test):
    309     "A crude way to tell apart testcases and suites with duck-typing"
    310     try:
    311         iter(test)
    312     except TypeError:
    313         return True
    314     return False
    315 
    316 
    317 class _DebugResult(object):
    318     "Used by the TestSuite to hold previous class when running in debug."
    319     _previousTestClass = None
    320     _moduleSetUpFailed = False
    321     shouldStop = False
    322