Home | History | Annotate | Download | only in unittest
      1 """Loading unittests."""
      2 
      3 import os
      4 import re
      5 import sys
      6 import traceback
      7 import types
      8 
      9 from functools import cmp_to_key as _CmpToKey
     10 from fnmatch import fnmatch
     11 
     12 from . import case, suite
     13 
     14 __unittest = True
     15 
     16 # what about .pyc or .pyo (etc)
     17 # we would need to avoid loading the same tests multiple times
     18 # from '.py', '.pyc' *and* '.pyo'
     19 VALID_MODULE_NAME = re.compile(r'[_a-z]\w*\.py$', re.IGNORECASE)
     20 
     21 
     22 def _make_failed_import_test(name, suiteClass):
     23     message = 'Failed to import test module: %s\n%s' % (name, traceback.format_exc())
     24     return _make_failed_test('ModuleImportFailure', name, ImportError(message),
     25                              suiteClass)
     26 
     27 def _make_failed_load_tests(name, exception, suiteClass):
     28     return _make_failed_test('LoadTestsFailure', name, exception, suiteClass)
     29 
     30 def _make_failed_test(classname, methodname, exception, suiteClass):
     31     def testFailure(self):
     32         raise exception
     33     attrs = {methodname: testFailure}
     34     TestClass = type(classname, (case.TestCase,), attrs)
     35     return suiteClass((TestClass(methodname),))
     36 
     37 
     38 class TestLoader(object):
     39     """
     40     This class is responsible for loading tests according to various criteria
     41     and returning them wrapped in a TestSuite
     42     """
     43     testMethodPrefix = 'test'
     44     sortTestMethodsUsing = cmp
     45     suiteClass = suite.TestSuite
     46     _top_level_dir = None
     47 
     48     def loadTestsFromTestCase(self, testCaseClass):
     49         """Return a suite of all test cases contained in testCaseClass"""
     50         if issubclass(testCaseClass, suite.TestSuite):
     51             raise TypeError("Test cases should not be derived from TestSuite." \
     52                                 " Maybe you meant to derive from TestCase?")
     53         testCaseNames = self.getTestCaseNames(testCaseClass)
     54         if not testCaseNames and hasattr(testCaseClass, 'runTest'):
     55             testCaseNames = ['runTest']
     56         loaded_suite = self.suiteClass(map(testCaseClass, testCaseNames))
     57         return loaded_suite
     58 
     59     def loadTestsFromModule(self, module, use_load_tests=True):
     60         """Return a suite of all test cases contained in the given module"""
     61         tests = []
     62         for name in dir(module):
     63             obj = getattr(module, name)
     64             if isinstance(obj, type) and issubclass(obj, case.TestCase):
     65                 tests.append(self.loadTestsFromTestCase(obj))
     66 
     67         load_tests = getattr(module, 'load_tests', None)
     68         tests = self.suiteClass(tests)
     69         if use_load_tests and load_tests is not None:
     70             try:
     71                 return load_tests(self, tests, None)
     72             except Exception, e:
     73                 return _make_failed_load_tests(module.__name__, e,
     74                                                self.suiteClass)
     75         return tests
     76 
     77     def loadTestsFromName(self, name, module=None):
     78         """Return a suite of all test cases given a string specifier.
     79 
     80         The name may resolve either to a module, a test case class, a
     81         test method within a test case class, or a callable object which
     82         returns a TestCase or TestSuite instance.
     83 
     84         The method optionally resolves the names relative to a given module.
     85         """
     86         parts = name.split('.')
     87         if module is None:
     88             parts_copy = parts[:]
     89             while parts_copy:
     90                 try:
     91                     module = __import__('.'.join(parts_copy))
     92                     break
     93                 except ImportError:
     94                     del parts_copy[-1]
     95                     if not parts_copy:
     96                         raise
     97             parts = parts[1:]
     98         obj = module
     99         for part in parts:
    100             parent, obj = obj, getattr(obj, part)
    101 
    102         if isinstance(obj, types.ModuleType):
    103             return self.loadTestsFromModule(obj)
    104         elif isinstance(obj, type) and issubclass(obj, case.TestCase):
    105             return self.loadTestsFromTestCase(obj)
    106         elif (isinstance(obj, types.UnboundMethodType) and
    107               isinstance(parent, type) and
    108               issubclass(parent, case.TestCase)):
    109             name = parts[-1]
    110             inst = parent(name)
    111             return self.suiteClass([inst])
    112         elif isinstance(obj, suite.TestSuite):
    113             return obj
    114         elif hasattr(obj, '__call__'):
    115             test = obj()
    116             if isinstance(test, suite.TestSuite):
    117                 return test
    118             elif isinstance(test, case.TestCase):
    119                 return self.suiteClass([test])
    120             else:
    121                 raise TypeError("calling %s returned %s, not a test" %
    122                                 (obj, test))
    123         else:
    124             raise TypeError("don't know how to make test from: %s" % obj)
    125 
    126     def loadTestsFromNames(self, names, module=None):
    127         """Return a suite of all test cases found using the given sequence
    128         of string specifiers. See 'loadTestsFromName()'.
    129         """
    130         suites = [self.loadTestsFromName(name, module) for name in names]
    131         return self.suiteClass(suites)
    132 
    133     def getTestCaseNames(self, testCaseClass):
    134         """Return a sorted sequence of method names found within testCaseClass
    135         """
    136         def isTestMethod(attrname, testCaseClass=testCaseClass,
    137                          prefix=self.testMethodPrefix):
    138             return attrname.startswith(prefix) and \
    139                 hasattr(getattr(testCaseClass, attrname), '__call__')
    140         testFnNames = filter(isTestMethod, dir(testCaseClass))
    141         if self.sortTestMethodsUsing:
    142             testFnNames.sort(key=_CmpToKey(self.sortTestMethodsUsing))
    143         return testFnNames
    144 
    145     def discover(self, start_dir, pattern='test*.py', top_level_dir=None):
    146         """Find and return all test modules from the specified start
    147         directory, recursing into subdirectories to find them. Only test files
    148         that match the pattern will be loaded. (Using shell style pattern
    149         matching.)
    150 
    151         All test modules must be importable from the top level of the project.
    152         If the start directory is not the top level directory then the top
    153         level directory must be specified separately.
    154 
    155         If a test package name (directory with '__init__.py') matches the
    156         pattern then the package will be checked for a 'load_tests' function. If
    157         this exists then it will be called with loader, tests, pattern.
    158 
    159         If load_tests exists then discovery does  *not* recurse into the package,
    160         load_tests is responsible for loading all tests in the package.
    161 
    162         The pattern is deliberately not stored as a loader attribute so that
    163         packages can continue discovery themselves. top_level_dir is stored so
    164         load_tests does not need to pass this argument in to loader.discover().
    165         """
    166         set_implicit_top = False
    167         if top_level_dir is None and self._top_level_dir is not None:
    168             # make top_level_dir optional if called from load_tests in a package
    169             top_level_dir = self._top_level_dir
    170         elif top_level_dir is None:
    171             set_implicit_top = True
    172             top_level_dir = start_dir
    173 
    174         top_level_dir = os.path.abspath(top_level_dir)
    175 
    176         if not top_level_dir in sys.path:
    177             # all test modules must be importable from the top level directory
    178             # should we *unconditionally* put the start directory in first
    179             # in sys.path to minimise likelihood of conflicts between installed
    180             # modules and development versions?
    181             sys.path.insert(0, top_level_dir)
    182         self._top_level_dir = top_level_dir
    183 
    184         is_not_importable = False
    185         if os.path.isdir(os.path.abspath(start_dir)):
    186             start_dir = os.path.abspath(start_dir)
    187             if start_dir != top_level_dir:
    188                 is_not_importable = not os.path.isfile(os.path.join(start_dir, '__init__.py'))
    189         else:
    190             # support for discovery from dotted module names
    191             try:
    192                 __import__(start_dir)
    193             except ImportError:
    194                 is_not_importable = True
    195             else:
    196                 the_module = sys.modules[start_dir]
    197                 top_part = start_dir.split('.')[0]
    198                 start_dir = os.path.abspath(os.path.dirname((the_module.__file__)))
    199                 if set_implicit_top:
    200                     self._top_level_dir = self._get_directory_containing_module(top_part)
    201                     sys.path.remove(top_level_dir)
    202 
    203         if is_not_importable:
    204             raise ImportError('Start directory is not importable: %r' % start_dir)
    205 
    206         tests = list(self._find_tests(start_dir, pattern))
    207         return self.suiteClass(tests)
    208 
    209     def _get_directory_containing_module(self, module_name):
    210         module = sys.modules[module_name]
    211         full_path = os.path.abspath(module.__file__)
    212 
    213         if os.path.basename(full_path).lower().startswith('__init__.py'):
    214             return os.path.dirname(os.path.dirname(full_path))
    215         else:
    216             # here we have been given a module rather than a package - so
    217             # all we can do is search the *same* directory the module is in
    218             # should an exception be raised instead
    219             return os.path.dirname(full_path)
    220 
    221     def _get_name_from_path(self, path):
    222         path = os.path.splitext(os.path.normpath(path))[0]
    223 
    224         _relpath = os.path.relpath(path, self._top_level_dir)
    225         assert not os.path.isabs(_relpath), "Path must be within the project"
    226         assert not _relpath.startswith('..'), "Path must be within the project"
    227 
    228         name = _relpath.replace(os.path.sep, '.')
    229         return name
    230 
    231     def _get_module_from_name(self, name):
    232         __import__(name)
    233         return sys.modules[name]
    234 
    235     def _match_path(self, path, full_path, pattern):
    236         # override this method to use alternative matching strategy
    237         return fnmatch(path, pattern)
    238 
    239     def _find_tests(self, start_dir, pattern):
    240         """Used by discovery. Yields test suites it loads."""
    241         paths = os.listdir(start_dir)
    242 
    243         for path in paths:
    244             full_path = os.path.join(start_dir, path)
    245             if os.path.isfile(full_path):
    246                 if not VALID_MODULE_NAME.match(path):
    247                     # valid Python identifiers only
    248                     continue
    249                 if not self._match_path(path, full_path, pattern):
    250                     continue
    251                 # if the test file matches, load it
    252                 name = self._get_name_from_path(full_path)
    253                 try:
    254                     module = self._get_module_from_name(name)
    255                 except:
    256                     yield _make_failed_import_test(name, self.suiteClass)
    257                 else:
    258                     mod_file = os.path.abspath(getattr(module, '__file__', full_path))
    259                     realpath = os.path.splitext(os.path.realpath(mod_file))[0]
    260                     fullpath_noext = os.path.splitext(os.path.realpath(full_path))[0]
    261                     if realpath.lower() != fullpath_noext.lower():
    262                         module_dir = os.path.dirname(realpath)
    263                         mod_name = os.path.splitext(os.path.basename(full_path))[0]
    264                         expected_dir = os.path.dirname(full_path)
    265                         msg = ("%r module incorrectly imported from %r. Expected %r. "
    266                                "Is this module globally installed?")
    267                         raise ImportError(msg % (mod_name, module_dir, expected_dir))
    268                     yield self.loadTestsFromModule(module)
    269             elif os.path.isdir(full_path):
    270                 if not os.path.isfile(os.path.join(full_path, '__init__.py')):
    271                     continue
    272 
    273                 load_tests = None
    274                 tests = None
    275                 if fnmatch(path, pattern):
    276                     # only check load_tests if the package directory itself matches the filter
    277                     name = self._get_name_from_path(full_path)
    278                     package = self._get_module_from_name(name)
    279                     load_tests = getattr(package, 'load_tests', None)
    280                     tests = self.loadTestsFromModule(package, use_load_tests=False)
    281 
    282                 if load_tests is None:
    283                     if tests is not None:
    284                         # tests loaded from package file
    285                         yield tests
    286                     # recurse into the package
    287                     for test in self._find_tests(full_path, pattern):
    288                         yield test
    289                 else:
    290                     try:
    291                         yield load_tests(self, tests, pattern)
    292                     except Exception, e:
    293                         yield _make_failed_load_tests(package.__name__, e,
    294                                                       self.suiteClass)
    295 
    296 defaultTestLoader = TestLoader()
    297 
    298 
    299 def _makeLoader(prefix, sortUsing, suiteClass=None):
    300     loader = TestLoader()
    301     loader.sortTestMethodsUsing = sortUsing
    302     loader.testMethodPrefix = prefix
    303     if suiteClass:
    304         loader.suiteClass = suiteClass
    305     return loader
    306 
    307 def getTestCaseNames(testCaseClass, prefix, sortUsing=cmp):
    308     return _makeLoader(prefix, sortUsing).getTestCaseNames(testCaseClass)
    309 
    310 def makeSuite(testCaseClass, prefix='test', sortUsing=cmp,
    311               suiteClass=suite.TestSuite):
    312     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromTestCase(testCaseClass)
    313 
    314 def findTestCases(module, prefix='test', sortUsing=cmp,
    315                   suiteClass=suite.TestSuite):
    316     return _makeLoader(prefix, sortUsing, suiteClass).loadTestsFromModule(module)
    317