Home | History | Annotate | Download | only in unittest
      1 """Loading unittests."""
      2 
      3 import os
      4 import re
      5 import sys
      6 import traceback
      7 import types
      8 
      9 from functools import cmp_to_key as _CmpToKey
     10 from fnmatch import fnmatch
     11 
     12 from . import case, suite
     13 
     14 __unittest = True
     15 
     16 # what about .pyc or .pyo (etc)
     17 # we would need to avoid loading the same tests multiple times
     18 # from '.py', '.pyc' *and* '.pyo'
     19 VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
     20 
     21 
     22 def _make_failed_import_test(name, suiteClass):
     23     message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
     24     return _make_failed_test('ModuleImportFailure', name, ImportError(message),
     25                              suiteClass)
     26 
     27 def _make_failed_load_tests(name, exception, suiteClass):
     28     return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
     29 
     30 def _make_failed_test(classname, methodname, exception, suiteClass):
     31     def testFailure(self):
     32         raise exception
     33     attrs = {methodname: testFailure}
     34     TestClass = type(classname, (case.TestCase,), attrs)
     35     return suiteClass((TestClass(methodname),))
     36 
     37 
     38 class TestLoader(object):
     39     """
     40     This class is responsible for loading tests according to various criteria
     41     and returning them wrapped in a TestSuite
     42     """
     43     testMethodPrefix = 'test'
     44     sortTestMethodsUsing = cmp
     45     suiteClass = suite.TestSuite
     46     _top_level_dir = None
     47 
     48     def loadTestsFromTestCase(self, testCaseClass):
     49         """Return a suite of all tests cases contained in testCaseClass"""
     50         if issubclass(testCaseClass, suite.TestSuite):
     51             raise TypeError("Test cases should not be derived from TestSuite." \
     52                                 " Maybe you meant to derive from TestCase?")
     53         testCaseNames = self.getTestCaseNames(testCaseClass)
     54         if not testCaseNames and hasattr(testCaseClass, 'runTest'):
     55             testCaseNames = ['runTest']
     56         loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
     57         return loaded_suite
     58 
     59     def loadTestsFromModule(self, module, use_load_tests=True):
     60         """Return a suite of all tests cases contained in the given module"""
     61         tests = []
     62         for name in dir(module):
     63             obj = getattr(module, name)
     64             if isinstance(obj, type) and issubclass(obj, case.TestCase):
     65                 tests.append(self.loadTestsFromTestCase(obj))
     66 
     67         load_tests = getattr(module, 'load_tests', None)
     68         tests = self.suiteClass(tests)
     69         if use_load_tests and load_tests is not None:
     70             try:
     71                 return load_tests(self, tests, None)
     72             except Exception, e:
     73                 return _make_failed_load_tests(module.__name__, e,
     74                                                self.suiteClass)
     75         return tests
     76 
     77     def loadTestsFromName(self, name, module=None):
     78         """Return a suite of all tests cases given a string specifier.
     79 
     80         The name may resolve either to a module, a test case class, a
     81         test method within a test case class, or a callable object which
     82         returns a TestCase or TestSuite instance.
     83 
     84         The method optionally resolves the names relative to a given module.
     85         """
     86         parts = name.split('.')
     87         if module is None:
     88             parts_copy = parts[:]
     89             while parts_copy:
     90                 try:
     91                     module = __import__('.'.join(parts_copy))
     92                     break
     93                 except ImportError:
     94                     del parts_copy[-1]
     95                     if not parts_copy:
     96                         raise
     97             parts = parts[1:]
     98         obj = module
     99         for part in parts:
    100             parent, obj = obj, getattr(obj, part)
    101 
    102         if isinstance(obj, types.ModuleType):
    103             return self.loadTestsFromModule(obj)
    104         elif isinstance(obj, type) and issubclass(obj, case.TestCase):
    105             return self.loadTestsFromTestCase(obj)
    106         elif (isinstance(obj, types.UnboundMethodType) and
    107               isinstance(parent, type) and
    108               issubclass(parent, case.TestCase)):
    109             return self.suiteClass([parent(obj.__name__)])
    110         elif isinstance(obj, suite.TestSuite):
    111             return obj
    112         elif hasattr(obj, '__call__'):
    113             test = obj()
    114             if isinstance(test, suite.TestSuite):
    115                 return test
    116             elif isinstance(test, case.TestCase):
    117                 return self.suiteClass([test])
    118             else:
    119                 raise TypeError("calling %s returned %s, not a test" %
    120                                 (obj, test))
    121         else:
    122             raise TypeError("don't know how to make test from: %s" % obj)
    123 
    124     def loadTestsFromNames(self, names, module=None):
    125         """Return a suite of all tests cases found using the given sequence
    126         of string specifiers. See 'loadTestsFromName()'.
    127         """
    128         suites = [self.loadTestsFromName(name, module) for name in names]
    129         return self.suiteClass(suites)
    130 
    131     def getTestCaseNames(self, testCaseClass):
    132         """Return a sorted sequence of method names found within testCaseClass
    133         """
    134         def isTestMethod(attrname, testCaseClass=testCaseClass,
    135                          prefix=self.testMethodPrefix):
    136             return attrname.startswith(prefix) and \
    137                 hasattr(getattr(testCaseClass, attrname), '__call__')
    138         testFnNames = filter(isTestMethod, dir(testCaseClass))
    139         if self.sortTestMethodsUsing:
    140             testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
    141         return testFnNames
    142 
    143     def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
    144         """Find and return all test modules from the specified start
    145         directory, recursing into subdirectories to find them. Only test files
    146         that match the pattern will be loaded. (Using shell style pattern
    147         matching.)
    148 
    149         All test modules must be importable from the top level of the project.
    150         If the start directory is not the top level directory then the top
    151         level directory must be specified separately.
    152 
    153         If a test package name (directory with '__init__.py') matches the
    154         pattern then the package will be checked for a 'load_tests' function. If
    155         this exists then it will be called with loader, tests, pattern.
    156 
    157         If load_tests exists then discovery does  *not* recurse into the package,
    158         load_tests is responsible for loading all tests in the package.
    159 
    160         The pattern is deliberately not stored as a loader attribute so that
    161         packages can continue discovery themselves. top_level_dir is stored so
    162         load_tests does not need to pass this argument in to loader.discover().
    163         """
    164         set_implicit_top = False
    165         if top_level_dir is None and self._top_level_dir is not None:
    166             # make top_level_dir optional if called from load_tests in a package
    167             top_level_dir = self._top_level_dir
    168         elif top_level_dir is None:
    169             set_implicit_top = True
    170             top_level_dir = start_dir
    171 
    172         top_level_dir = os.path.abspath(top_level_dir)
    173 
    174         if not top_level_dir in sys.path:
    175             # all test modules must be importable from the top level directory
    176             # should we *unconditionally* put the start directory in first
    177             # in sys.path to minimise likelihood of conflicts between installed
    178             # modules and development versions?
    179             sys.path.insert(0, top_level_dir)
    180         self._top_level_dir = top_level_dir
    181 
    182         is_not_importable = False
    183         if os.path.isdir(os.path.abspath(start_dir)):
    184             start_dir = os.path.abspath(start_dir)
    185             if start_dir != top_level_dir:
    186                 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
    187         else:
    188             # support for discovery from dotted module names
    189             try:
    190                 __import__(start_dir)
    191             except ImportError:
    192                 is_not_importable = True
    193             else:
    194                 the_module = sys.modules[start_dir]
    195                 top_part = start_dir.split('.')[0]
    196                 start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
    197                 if set_implicit_top:
    198                     self._top_level_dir = self._get_directory_containing_module(top_part)
    199                     sys.path.remove(top_level_dir)
    200 
    201         if is_not_importable:
    202             raise ImportError('Start directory is not importable: %r' % start_dir)
    203 
    204         tests = list(self._find_tests(start_dir, pattern))
    205         return self.suiteClass(tests)
    206 
    207     def _get_directory_containing_module(self, module_name):
    208         module = sys.modules[module_name]
    209         full_path = os.path.abspath(module.__file__)
    210 
    211         if os.path.basename(full_path).lower().startswith('__init__.py'):
    212             return os.path.dirname(os.path.dirname(full_path))
    213         else:
    214             # here we have been given a module rather than a package - so
    215             # all we can do is search the *same* directory the module is in
    216             # should an exception be raised instead
    217             return os.path.dirname(full_path)
    218 
    219     def _get_name_from_path(self, path):
    220         path = os.path.splitext(os.path.normpath(path))[0]
    221 
    222         _relpath = os.path.relpath(path, self._top_level_dir)
    223         assert not os.path.isabs(_relpath), "Path must be within the project"
    224         assert not _relpath.startswith('..'), "Path must be within the project"
    225 
    226         name = _relpath.replace(os.path.sep, '.')
    227         return name
    228 
    229     def _get_module_from_name(self, name):
    230         __import__(name)
    231         return sys.modules[name]
    232 
    233     def _match_path(self, path, full_path, pattern):
    234         # override this method to use alternative matching strategy
    235         return fnmatch(path, pattern)
    236 
    237     def _find_tests(self, start_dir, pattern):
    238         """Used by discovery. Yields test suites it loads."""
    239         paths = os.listdir(start_dir)
    240 
    241         for path in paths:
    242             full_path = os.path.join(start_dir, path)
    243             if os.path.isfile(full_path):
    244                 if not VALID_MODULE_NAME.match(path):
    245                     # valid Python identifiers only
    246                     continue
    247                 if not self._match_path(path, full_path, pattern):
    248                     continue
    249                 # if the test file matches, load it
    250                 name = self._get_name_from_path(full_path)
    251                 try:
    252                     module = self._get_module_from_name(name)
    253                 except:
    254                     yield _make_failed_import_test(name, self.suiteClass)
    255                 else:
    256                     mod_file = os.path.abspath(getattr(module, '__file__', full_path))
    257                     realpath = os.path.splitext(mod_file)[0]
    258                     fullpath_noext = os.path.splitext(full_path)[0]
    259                     if realpath.lower() != fullpath_noext.lower():
    260                         module_dir = os.path.dirname(realpath)
    261                         mod_name = os.path.splitext(os.path.basename(full_path))[0]
    262                         expected_dir = os.path.dirname(full_path)
    263                         msg = ("%r module incorrectly imported from %r. Expected %r. "
    264                                "Is this module globally installed?")
    265                         raise ImportError(msg % (mod_name, module_dir, expected_dir))
    266                     yield self.loadTestsFromModule(module)
    267             elif os.path.isdir(full_path):
    268                 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
    269                     continue
    270 
    271                 load_tests = None
    272                 tests = None
    273                 if fnmatch(path, pattern):
    274                     # only check load_tests if the package directory itself matches the filter
    275                     name = self._get_name_from_path(full_path)
    276                     package = self._get_module_from_name(name)
    277                     load_tests = getattr(package, 'load_tests', None)
    278                     tests = self.loadTestsFromModule(package, use_load_tests=False)
    279 
    280                 if load_tests is None:
    281                     if tests is not None:
    282                         # tests loaded from package file
    283                         yield tests
    284                     # recurse into the package
    285                     for test in self._find_tests(full_path, pattern):
    286                         yield test
    287                 else:
    288                     try:
    289                         yield load_tests(self, tests, pattern)
    290                     except Exception, e:
    291                         yield _make_failed_load_tests(package.__name__, e,
    292                                                       self.suiteClass)
    293 
    294 defaultTestLoader = TestLoader()
    295 
    296 
    297 def _makeLoader(prefix, sortUsing, suiteClass=None):
    298     loader = TestLoader()
    299     loader.sortTestMethodsUsing = sortUsing
    300     loader.testMethodPrefix = prefix
    301     if suiteClass:
    302         loader.suiteClass = suiteClass
    303     return loader
    304 
    305 def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
    306     return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
    307 
    308 def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
    309               suiteClass=suite.TestSuite):
    310     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
    311 
    312 def findTestCases(module, prefix='test', sortUsing=cmp,
    313                   suiteClass=suite.TestSuite):
    314     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
    315