Home | History | Annotate | Download | only in test
      1 import os, sys, unittest, getopt, time
      2 
      3 use_resources = []
      4 
      5 import ctypes
      6 ctypes_symbols = dir(ctypes)
      7 
      8 def need_symbol(name):
      9     return unittest.skipUnless(name in ctypes_symbols,
     10                                '{!r} is required'.format(name))
     11 
     12 
     13 class ResourceDenied(unittest.SkipTest):
     14     """Test skipped because it requested a disallowed resource.
     15 
     16     This is raised when a test calls requires() for a resource that
     17     has not be enabled.  Resources are defined by test modules.
     18     """
     19 
     20 def is_resource_enabled(resource):
     21     """Test whether a resource is enabled.
     22 
     23     If the caller's module is __main__ then automatically return True."""
     24     if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
     25         return True
     26     result = use_resources is not None and \
     27            (resource in use_resources or "*" in use_resources)
     28     if not result:
     29         _unavail[resource] = None
     30     return result
     31 
     32 _unavail = {}
     33 def requires(resource, msg=None):
     34     """Raise ResourceDenied if the specified resource is not available.
     35 
     36     If the caller's module is __main__ then automatically return True."""
     37     # see if the caller's module is __main__ - if so, treat as if
     38     # the resource was set
     39     if sys._getframe().f_back.f_globals.get("__name__") == "__main__":
     40         return
     41     if not is_resource_enabled(resource):
     42         if msg is None:
     43             msg = "Use of the `%s' resource not enabled" % resource
     44         raise ResourceDenied(msg)
     45 
     46 def find_package_modules(package, mask):
     47     import fnmatch
     48     if (hasattr(package, "__loader__") and
     49             hasattr(package.__loader__, '_files')):
     50         path = package.__name__.replace(".", os.path.sep)
     51         mask = os.path.join(path, mask)
     52         for fnm in package.__loader__._files.iterkeys():
     53             if fnmatch.fnmatchcase(fnm, mask):
     54                 yield os.path.splitext(fnm)[0].replace(os.path.sep, ".")
     55     else:
     56         path = package.__path__[0]
     57         for fnm in os.listdir(path):
     58             if fnmatch.fnmatchcase(fnm, mask):
     59                 yield "%s.%s" % (package.__name__, os.path.splitext(fnm)[0])
     60 
     61 def get_tests(package, mask, verbosity, exclude=()):
     62     """Return a list of skipped test modules, and a list of test cases."""
     63     tests = []
     64     skipped = []
     65     for modname in find_package_modules(package, mask):
     66         if modname.split(".")[-1] in exclude:
     67             skipped.append(modname)
     68             if verbosity > 1:
     69                 print >> sys.stderr, "Skipped %s: excluded" % modname
     70             continue
     71         try:
     72             mod = __import__(modname, globals(), locals(), ['*'])
     73         except (ResourceDenied, unittest.SkipTest) as detail:
     74             skipped.append(modname)
     75             if verbosity > 1:
     76                 print >> sys.stderr, "Skipped %s: %s" % (modname, detail)
     77             continue
     78         for name in dir(mod):
     79             if name.startswith("_"):
     80                 continue
     81             o = getattr(mod, name)
     82             if type(o) is type(unittest.TestCase) and issubclass(o, unittest.TestCase):
     83                 tests.append(o)
     84     return skipped, tests
     85 
     86 def usage():
     87     print __doc__
     88     return 1
     89 
     90 def test_with_refcounts(runner, verbosity, testcase):
     91     """Run testcase several times, tracking reference counts."""
     92     import gc
     93     import ctypes
     94     ptc = ctypes._pointer_type_cache.copy()
     95     cfc = ctypes._c_functype_cache.copy()
     96     wfc = ctypes._win_functype_cache.copy()
     97 
     98     # when searching for refcount leaks, we have to manually reset any
     99     # caches that ctypes has.
    100     def cleanup():
    101         ctypes._pointer_type_cache = ptc.copy()
    102         ctypes._c_functype_cache = cfc.copy()
    103         ctypes._win_functype_cache = wfc.copy()
    104         gc.collect()
    105 
    106     test = unittest.makeSuite(testcase)
    107     for i in range(5):
    108         rc = sys.gettotalrefcount()
    109         runner.run(test)
    110         cleanup()
    111     COUNT = 5
    112     refcounts = [None] * COUNT
    113     for i in range(COUNT):
    114         rc = sys.gettotalrefcount()
    115         runner.run(test)
    116         cleanup()
    117         refcounts[i] = sys.gettotalrefcount() - rc
    118     if filter(None, refcounts):
    119         print "%s leaks:\n\t" % testcase, refcounts
    120     elif verbosity:
    121         print "%s: ok." % testcase
    122 
    123 class TestRunner(unittest.TextTestRunner):
    124     def run(self, test, skipped):
    125         "Run the given test case or test suite."
    126         # Same as unittest.TextTestRunner.run, except that it reports
    127         # skipped tests.
    128         result = self._makeResult()
    129         startTime = time.time()
    130         test(result)
    131         stopTime = time.time()
    132         timeTaken = stopTime - startTime
    133         result.printErrors()
    134         self.stream.writeln(result.separator2)
    135         run = result.testsRun
    136         if _unavail: #skipped:
    137             requested = _unavail.keys()
    138             requested.sort()
    139             self.stream.writeln("Ran %d test%s in %.3fs (%s module%s skipped)" %
    140                                 (run, run != 1 and "s" or "", timeTaken,
    141                                  len(skipped),
    142                                  len(skipped) != 1 and "s" or ""))
    143             self.stream.writeln("Unavailable resources: %s" % ", ".join(requested))
    144         else:
    145             self.stream.writeln("Ran %d test%s in %.3fs" %
    146                                 (run, run != 1 and "s" or "", timeTaken))
    147         self.stream.writeln()
    148         if not result.wasSuccessful():
    149             self.stream.write("FAILED (")
    150             failed, errored = map(len, (result.failures, result.errors))
    151             if failed:
    152                 self.stream.write("failures=%d" % failed)
    153             if errored:
    154                 if failed: self.stream.write(", ")
    155                 self.stream.write("errors=%d" % errored)
    156             self.stream.writeln(")")
    157         else:
    158             self.stream.writeln("OK")
    159         return result
    160 
    161 
    162 def main(*packages):
    163     try:
    164         opts, args = getopt.getopt(sys.argv[1:], "rqvu:x:")
    165     except getopt.error:
    166         return usage()
    167 
    168     verbosity = 1
    169     search_leaks = False
    170     exclude = []
    171     for flag, value in opts:
    172         if flag == "-q":
    173             verbosity -= 1
    174         elif flag == "-v":
    175             verbosity += 1
    176         elif flag == "-r":
    177             try:
    178                 sys.gettotalrefcount
    179             except AttributeError:
    180                 print >> sys.stderr, "-r flag requires Python debug build"
    181                 return -1
    182             search_leaks = True
    183         elif flag == "-u":
    184             use_resources.extend(value.split(","))
    185         elif flag == "-x":
    186             exclude.extend(value.split(","))
    187 
    188     mask = "test_*.py"
    189     if args:
    190         mask = args[0]
    191 
    192     for package in packages:
    193         run_tests(package, mask, verbosity, search_leaks, exclude)
    194 
    195 
    196 def run_tests(package, mask, verbosity, search_leaks, exclude):
    197     skipped, testcases = get_tests(package, mask, verbosity, exclude)
    198     runner = TestRunner(verbosity=verbosity)
    199 
    200     suites = [unittest.makeSuite(o) for o in testcases]
    201     suite = unittest.TestSuite(suites)
    202     result = runner.run(suite, skipped)
    203 
    204     if search_leaks:
    205         # hunt for refcount leaks
    206         runner = BasicTestRunner()
    207         for t in testcases:
    208             test_with_refcounts(runner, verbosity, t)
    209 
    210     return bool(result.errors)
    211 
    212 class BasicTestRunner:
    213     def run(self, test):
    214         result = unittest.TestResult()
    215         test(result)
    216         return result
    217