Home | History | Annotate | Download | only in unittest2
      1 """Loading unittests."""
      2 
      3 import os
      4 import re
      5 import sys
      6 import traceback
      7 import types
      8 import unittest
      9 
     10 from fnmatch import fnmatch
     11 
     12 from unittest2 import case, suite
     13 
     14 try:
     15     from os.path import relpath
     16 except ImportError:
     17     from unittest2.compatibility import relpath
     18 
     19 __unittest = True
     20 
     21 
     22 def _CmpToKey(mycmp):
     23     'Convert a cmp= function into a key= function'
     24     class K(object):
     25         def __init__(self, obj):
     26             self.obj = obj
     27         def __lt__(self, other):
     28             return mycmp(self.obj, other.obj) == -1
     29     return K
     30 
     31 
     32 # what about .pyc or .pyo (etc)
     33 # we would need to avoid loading the same tests multiple times
     34 # from '.py', '.pyc' *and* '.pyo'
     35 VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
     36 
     37 
     38 def _make_failed_import_test(name, suiteClass):
     39     message = 'Failed to import test module: %s' % name
     40     if hasattr(traceback, 'format_exc'):
     41         # Python 2.3 compatibility
     42         # format_exc returns two frames of discover.py as well
     43         message += '\n%s' % traceback.format_exc()
     44     return _make_failed_test('ModuleImportFailure', name, ImportError(message),
     45                              suiteClass)
     46 
     47 def _make_failed_load_tests(name, exception, suiteClass):
     48     return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
     49 
     50 def _make_failed_test(classname, methodname, exception, suiteClass):
     51     def testFailure(self):
     52         raise exception
     53     attrs = {methodname: testFailure}
     54     TestClass = type(classname, (case.TestCase,), attrs)
     55     return suiteClass((TestClass(methodname),))
     56 
     57 
     58 class TestLoader(unittest.TestLoader):
     59     """
     60     This class is responsible for loading tests according to various criteria
     61     and returning them wrapped in a TestSuite
     62     """
     63     testMethodPrefix = 'test'
     64     sortTestMethodsUsing = cmp
     65     suiteClass = suite.TestSuite
     66     _top_level_dir = None
     67 
     68     def loadTestsFromTestCase(self, testCaseClass):
     69         """Return a suite of all tests cases contained in testCaseClass"""
     70         if issubclass(testCaseClass, suite.TestSuite):
     71             raise TypeError("Test cases should not be derived from TestSuite."
     72                             " Maybe you meant to derive from TestCase?")
     73         testCaseNames = self.getTestCaseNames(testCaseClass)
     74         if not testCaseNames and hasattr(testCaseClass, 'runTest'):
     75             testCaseNames = ['runTest']
     76         loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
     77         return loaded_suite
     78 
     79     def loadTestsFromModule(self, module, use_load_tests=True):
     80         """Return a suite of all tests cases contained in the given module"""
     81         tests = []
     82         for name in dir(module):
     83             obj = getattr(module, name)
     84             if isinstance(obj, type) and issubclass(obj, unittest.TestCase):
     85                 tests.append(self.loadTestsFromTestCase(obj))
     86 
     87         load_tests = getattr(module, 'load_tests', None)
     88         tests = self.suiteClass(tests)
     89         if use_load_tests and load_tests is not None:
     90             try:
     91                 return load_tests(self, tests, None)
     92             except Exception, e:
     93                 return _make_failed_load_tests(module.__name__, e,
     94                                                self.suiteClass)
     95         return tests
     96 
     97     def loadTestsFromName(self, name, module=None):
     98         """Return a suite of all tests cases given a string specifier.
     99 
    100         The name may resolve either to a module, a test case class, a
    101         test method within a test case class, or a callable object which
    102         returns a TestCase or TestSuite instance.
    103 
    104         The method optionally resolves the names relative to a given module.
    105         """
    106         parts = name.split('.')
    107         if module is None:
    108             parts_copy = parts[:]
    109             while parts_copy:
    110                 try:
    111                     module = __import__('.'.join(parts_copy))
    112                     break
    113                 except ImportError:
    114                     del parts_copy[-1]
    115                     if not parts_copy:
    116                         raise
    117             parts = parts[1:]
    118         obj = module
    119         for part in parts:
    120             parent, obj = obj, getattr(obj, part)
    121 
    122         if isinstance(obj, types.ModuleType):
    123             return self.loadTestsFromModule(obj)
    124         elif isinstance(obj, type) and issubclass(obj, unittest.TestCase):
    125             return self.loadTestsFromTestCase(obj)
    126         elif (isinstance(obj, types.UnboundMethodType) and
    127               isinstance(parent, type) and
    128               issubclass(parent, case.TestCase)):
    129             return self.suiteClass([parent(obj.__name__)])
    130         elif isinstance(obj, unittest.TestSuite):
    131             return obj
    132         elif hasattr(obj, '__call__'):
    133             test = obj()
    134             if isinstance(test, unittest.TestSuite):
    135                 return test
    136             elif isinstance(test, unittest.TestCase):
    137                 return self.suiteClass([test])
    138             else:
    139                 raise TypeError("calling %s returned %s, not a test" %
    140                                 (obj, test))
    141         else:
    142             raise TypeError("don't know how to make test from: %s" % obj)
    143 
    144     def loadTestsFromNames(self, names, module=None):
    145         """Return a suite of all tests cases found using the given sequence
    146         of string specifiers. See 'loadTestsFromName()'.
    147         """
    148         suites = [self.loadTestsFromName(name, module) for name in names]
    149         return self.suiteClass(suites)
    150 
    151     def getTestCaseNames(self, testCaseClass):
    152         """Return a sorted sequence of method names found within testCaseClass
    153         """
    154         def isTestMethod(attrname, testCaseClass=testCaseClass,
    155                          prefix=self.testMethodPrefix):
    156             return attrname.startswith(prefix) and \
    157                 hasattr(getattr(testCaseClass, attrname), '__call__')
    158         testFnNames = filter(isTestMethod, dir(testCaseClass))
    159         if self.sortTestMethodsUsing:
    160             testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
    161         return testFnNames
    162 
    163     def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
    164         """Find and return all test modules from the specified start
    165         directory, recursing into subdirectories to find them. Only test files
    166         that match the pattern will be loaded. (Using shell style pattern
    167         matching.)
    168 
    169         All test modules must be importable from the top level of the project.
    170         If the start directory is not the top level directory then the top
    171         level directory must be specified separately.
    172 
    173         If a test package name (directory with '__init__.py') matches the
    174         pattern then the package will be checked for a 'load_tests' function. If
    175         this exists then it will be called with loader, tests, pattern.
    176 
    177         If load_tests exists then discovery does  *not* recurse into the package,
    178         load_tests is responsible for loading all tests in the package.
    179 
    180         The pattern is deliberately not stored as a loader attribute so that
    181         packages can continue discovery themselves. top_level_dir is stored so
    182         load_tests does not need to pass this argument in to loader.discover().
    183         """
    184         set_implicit_top = False
    185         if top_level_dir is None and self._top_level_dir is not None:
    186             # make top_level_dir optional if called from load_tests in a package
    187             top_level_dir = self._top_level_dir
    188         elif top_level_dir is None:
    189             set_implicit_top = True
    190             top_level_dir = start_dir
    191 
    192         top_level_dir = os.path.abspath(top_level_dir)
    193 
    194         if not top_level_dir in sys.path:
    195             # all test modules must be importable from the top level directory
    196             # should we *unconditionally* put the start directory in first
    197             # in sys.path to minimise likelihood of conflicts between installed
    198             # modules and development versions?
    199             sys.path.insert(0, top_level_dir)
    200         self._top_level_dir = top_level_dir
    201 
    202         is_not_importable = False
    203         if os.path.isdir(os.path.abspath(start_dir)):
    204             start_dir = os.path.abspath(start_dir)
    205             if start_dir != top_level_dir:
    206                 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
    207         else:
    208             # support for discovery from dotted module names
    209             try:
    210                 __import__(start_dir)
    211             except ImportError:
    212                 is_not_importable = True
    213             else:
    214                 the_module = sys.modules[start_dir]
    215                 top_part = start_dir.split('.')[0]
    216                 start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
    217                 if set_implicit_top:
    218                     self._top_level_dir = os.path.abspath(os.path.dirname(os.path.dirname(sys.modules[top_part].__file__)))
    219                     sys.path.remove(top_level_dir)
    220 
    221         if is_not_importable:
    222             raise ImportError('Start directory is not importable: %r' % start_dir)
    223 
    224         tests = list(self._find_tests(start_dir, pattern))
    225         return self.suiteClass(tests)
    226 
    227     def _get_name_from_path(self, path):
    228         path = os.path.splitext(os.path.normpath(path))[0]
    229 
    230         _relpath = relpath(path, self._top_level_dir)
    231         assert not os.path.isabs(_relpath), "Path must be within the project"
    232         assert not _relpath.startswith('..'), "Path must be within the project"
    233 
    234         name = _relpath.replace(os.path.sep, '.')
    235         return name
    236 
    237     def _get_module_from_name(self, name):
    238         __import__(name)
    239         return sys.modules[name]
    240 
    241     def _match_path(self, path, full_path, pattern):
    242         # override this method to use alternative matching strategy
    243         return fnmatch(path, pattern)
    244 
    245     def _find_tests(self, start_dir, pattern):
    246         """Used by discovery. Yields test suites it loads."""
    247         paths = os.listdir(start_dir)
    248 
    249         for path in paths:
    250             full_path = os.path.join(start_dir, path)
    251             if os.path.isfile(full_path):
    252                 if not VALID_MODULE_NAME.match(path):
    253                     # valid Python identifiers only
    254                     continue
    255                 if not self._match_path(path, full_path, pattern):
    256                     continue
    257                 # if the test file matches, load it
    258                 name = self._get_name_from_path(full_path)
    259                 try:
    260                     module = self._get_module_from_name(name)
    261                 except:
    262                     yield _make_failed_import_test(name, self.suiteClass)
    263                 else:
    264                     mod_file = os.path.abspath(getattr(module, '__file__', full_path))
    265                     realpath = os.path.splitext(mod_file)[0]
    266                     fullpath_noext = os.path.splitext(full_path)[0]
    267                     if realpath.lower() != fullpath_noext.lower():
    268                         module_dir = os.path.dirname(realpath)
    269                         mod_name = os.path.splitext(os.path.basename(full_path))[0]
    270                         expected_dir = os.path.dirname(full_path)
    271                         msg = ("%r module incorrectly imported from %r. Expected %r. "
    272                                "Is this module globally installed?")
    273                         raise ImportError(msg % (mod_name, module_dir, expected_dir))
    274                     yield self.loadTestsFromModule(module)
    275             elif os.path.isdir(full_path):
    276                 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
    277                     continue
    278 
    279                 load_tests = None
    280                 tests = None
    281                 if fnmatch(path, pattern):
    282                     # only check load_tests if the package directory itself matches the filter
    283                     name = self._get_name_from_path(full_path)
    284                     package = self._get_module_from_name(name)
    285                     load_tests = getattr(package, 'load_tests', None)
    286                     tests = self.loadTestsFromModule(package, use_load_tests=False)
    287 
    288                 if load_tests is None:
    289                     if tests is not None:
    290                         # tests loaded from package file
    291                         yield tests
    292                     # recurse into the package
    293                     for test in self._find_tests(full_path, pattern):
    294                         yield test
    295                 else:
    296                     try:
    297                         yield load_tests(self, tests, pattern)
    298                     except Exception, e:
    299                         yield _make_failed_load_tests(package.__name__, e,
    300                                                       self.suiteClass)
    301 
    302 defaultTestLoader = TestLoader()
    303 
    304 
    305 def _makeLoader(prefix, sortUsing, suiteClass=None):
    306     loader = TestLoader()
    307     loader.sortTestMethodsUsing = sortUsing
    308     loader.testMethodPrefix = prefix
    309     if suiteClass:
    310         loader.suiteClass = suiteClass
    311     return loader
    312 
    313 def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
    314     return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
    315 
    316 def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
    317               suiteClass=suite.TestSuite):
    318     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
    319 
    320 def findTestCases(module, prefix='test', sortUsing=cmp,
    321                   suiteClass=suite.TestSuite):
    322     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
    323