Home | History | Annotate | Download | only in test
      1 import os, sys, unittest, getopt, time
      2 
      3 use_resources = []
      4 
      5 class ResourceDenied(Exception):
      6     """Test skipped because it requested a disallowed resource.
      7 
      8     This is raised when a test calls requires() for a resource that
      9     has not be enabled.  Resources are defined by test modules.
     10     """
     11 
     12 def is_resource_enabled(resource):
     13     """Test whether a resource is enabled.
     14 
     15     If the caller's module is __main__ then automatically return True."""
     16     if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
     17         return True
     18     result = use_resources is not None and \
     19            (resource in use_resources or "*" in use_resources)
     20     if not result:
     21         _unavail[resource] = None
     22     return result
     23 
     24 _unavail = {}
     25 def requires(resource, msg=None):
     26     """Raise ResourceDenied if the specified resource is not available.
     27 
     28     If the caller's module is __main__ then automatically return True."""
     29     # see if the caller's module is __main__ - if so, treat as if
     30     # the resource was set
     31     if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
     32         return
     33     if not is_resource_enabled(resource):
     34         if msg is None:
     35             msg = "Use of the `%s' resource not enabled" % resource
     36         raise ResourceDenied(msg)
     37 
     38 def find_package_modules(package, mask):
     39     import fnmatch
     40     if (hasattr(package, "__loader__") and
     41             hasattr(package.__loader__, '_files')):
     42         path = package.__name__.replace(".", os.path.sep)
     43         mask = os.path.join(path, mask)
     44         for fnm in package.__loader__._files.iterkeys():
     45             if fnmatch.fnmatchcase(fnm, mask):
     46                 yield os.path.splitext(fnm)[0].replace(os.path.sep, ".")
     47     else:
     48         path = package.__path__[0]
     49         for fnm in os.listdir(path):
     50             if fnmatch.fnmatchcase(fnm, mask):
     51                 yield "%s.%s" % (package.__name__, os.path.splitext(fnm)[0])
     52 
     53 def get_tests(package, mask, verbosity, exclude=()):
     54     """Return a list of skipped test modules, and a list of test cases."""
     55     tests = []
     56     skipped = []
     57     for modname in find_package_modules(package, mask):
     58         if modname.split(".")[-1] in exclude:
     59             skipped.append(modname)
     60             if verbosity > 1:
     61                 print >> sys.stderr, "Skipped %s: excluded" % modname
     62             continue
     63         try:
     64             mod = __import__(modname, globals(), locals(), ['*'])
     65         except (ResourceDenied, unittest.SkipTest) as detail:
     66             skipped.append(modname)
     67             if verbosity > 1:
     68                 print >> sys.stderr, "Skipped %s: %s" % (modname, detail)
     69             continue
     70         for name in dir(mod):
     71             if name.startswith("_"):
     72                 continue
     73             o = getattr(mod, name)
     74             if type(o) is type(unittest.TestCase) and issubclass(o, unittest.TestCase):
     75                 tests.append(o)
     76     return skipped, tests
     77 
     78 def usage():
     79     print __doc__
     80     return 1
     81 
     82 def test_with_refcounts(runner, verbosity, testcase):
     83     """Run testcase several times, tracking reference counts."""
     84     import gc
     85     import ctypes
     86     ptc = ctypes._pointer_type_cache.copy()
     87     cfc = ctypes._c_functype_cache.copy()
     88     wfc = ctypes._win_functype_cache.copy()
     89 
     90     # when searching for refcount leaks, we have to manually reset any
     91     # caches that ctypes has.
     92     def cleanup():
     93         ctypes._pointer_type_cache = ptc.copy()
     94         ctypes._c_functype_cache = cfc.copy()
     95         ctypes._win_functype_cache = wfc.copy()
     96         gc.collect()
     97 
     98     test = unittest.makeSuite(testcase)
     99     for i in range(5):
    100         rc = sys.gettotalrefcount()
    101         runner.run(test)
    102         cleanup()
    103     COUNT = 5
    104     refcounts = [None] * COUNT
    105     for i in range(COUNT):
    106         rc = sys.gettotalrefcount()
    107         runner.run(test)
    108         cleanup()
    109         refcounts[i] = sys.gettotalrefcount() - rc
    110     if filter(None, refcounts):
    111         print "%s leaks:\n\t" % testcase, refcounts
    112     elif verbosity:
    113         print "%s: ok." % testcase
    114 
    115 class TestRunner(unittest.TextTestRunner):
    116     def run(self, test, skipped):
    117         "Run the given test case or test suite."
    118         # Same as unittest.TextTestRunner.run, except that it reports
    119         # skipped tests.
    120         result = self._makeResult()
    121         startTime = time.time()
    122         test(result)
    123         stopTime = time.time()
    124         timeTaken = stopTime - startTime
    125         result.printErrors()
    126         self.stream.writeln(result.separator2)
    127         run = result.testsRun
    128         if _unavail: #skipped:
    129             requested = _unavail.keys()
    130             requested.sort()
    131             self.stream.writeln("Ran %d test%s in %.3fs (%s module%s skipped)" %
    132                                 (run, run != 1 and "s" or "", timeTaken,
    133                                  len(skipped),
    134                                  len(skipped) != 1 and "s" or ""))
    135             self.stream.writeln("Unavailable resources: %s" % ", ".join(requested))
    136         else:
    137             self.stream.writeln("Ran %d test%s in %.3fs" %
    138                                 (run, run != 1 and "s" or "", timeTaken))
    139         self.stream.writeln()
    140         if not result.wasSuccessful():
    141             self.stream.write("FAILED (")
    142             failed, errored = map(len, (result.failures, result.errors))
    143             if failed:
    144                 self.stream.write("failures=%d" % failed)
    145             if errored:
    146                 if failed: self.stream.write(", ")
    147                 self.stream.write("errors=%d" % errored)
    148             self.stream.writeln(")")
    149         else:
    150             self.stream.writeln("OK")
    151         return result
    152 
    153 
    154 def main(*packages):
    155     try:
    156         opts, args = getopt.getopt(sys.argv[1:], "rqvu:x:")
    157     except getopt.error:
    158         return usage()
    159 
    160     verbosity = 1
    161     search_leaks = False
    162     exclude = []
    163     for flag, value in opts:
    164         if flag == "-q":
    165             verbosity -= 1
    166         elif flag == "-v":
    167             verbosity += 1
    168         elif flag == "-r":
    169             try:
    170                 sys.gettotalrefcount
    171             except AttributeError:
    172                 print >> sys.stderr, "-r flag requires Python debug build"
    173                 return -1
    174             search_leaks = True
    175         elif flag == "-u":
    176             use_resources.extend(value.split(","))
    177         elif flag == "-x":
    178             exclude.extend(value.split(","))
    179 
    180     mask = "test_*.py"
    181     if args:
    182         mask = args[0]
    183 
    184     for package in packages:
    185         run_tests(package, mask, verbosity, search_leaks, exclude)
    186 
    187 
    188 def run_tests(package, mask, verbosity, search_leaks, exclude):
    189     skipped, testcases = get_tests(package, mask, verbosity, exclude)
    190     runner = TestRunner(verbosity=verbosity)
    191 
    192     suites = [unittest.makeSuite(o) for o in testcases]
    193     suite = unittest.TestSuite(suites)
    194     result = runner.run(suite, skipped)
    195 
    196     if search_leaks:
    197         # hunt for refcount leaks
    198         runner = BasicTestRunner()
    199         for t in testcases:
    200             test_with_refcounts(runner, verbosity, t)
    201 
    202     return bool(result.errors)
    203 
    204 class BasicTestRunner:
    205     def run(self, test):
    206         result = unittest.TestResult()
    207         test(result)
    208         return result
    209