Home | History | Annotate | Download | only in webpagereplay
      1 #!/usr/bin/env python
      2 # Copyright (c) 2014 The Chromium Authors. All rights reserved.
      3 # Use of this source code is governed by a BSD-style license that can be
      4 # found in the LICENSE file.
      5 import unittest
      6 import sys
      7 import os
      8 import optparse
      9 
     10 __all__ = []
     11 
     12 def FilterSuite(suite, predicate):
     13   new_suite = suite.__class__()
     14 
     15   for x in suite:
     16     if isinstance(x, unittest.TestSuite):
     17       subsuite = FilterSuite(x, predicate)
     18       if subsuite.countTestCases() == 0:
     19         continue
     20 
     21       new_suite.addTest(subsuite)
     22       continue
     23 
     24     assert isinstance(x, unittest.TestCase)
     25     if predicate(x):
     26       new_suite.addTest(x)
     27 
     28   return new_suite
     29 
     30 class _TestLoader(unittest.TestLoader):
     31   def __init__(self, *args):
     32     super(_TestLoader, self).__init__(*args)
     33     self.discover_calls = []
     34 
     35   def loadTestsFromModule(self, module, use_load_tests=True):
     36     if module.__file__ != __file__:
     37       return super(_TestLoader, self).loadTestsFromModule(
     38           module, use_load_tests)
     39 
     40     suite = unittest.TestSuite()
     41     for discover_args in self.discover_calls:
     42       subsuite = self.discover(*discover_args)
     43       suite.addTest(subsuite)
     44     return suite
     45 
     46 class _RunnerImpl(unittest.TextTestRunner):
     47   def __init__(self, filters):
     48     super(_RunnerImpl, self).__init__(verbosity=2)
     49     self.filters = filters
     50 
     51   def ShouldTestRun(self, test):
     52     return not self.filters or any(name in test.id() for name in self.filters)
     53 
     54   def run(self, suite):
     55     filtered_test = FilterSuite(suite, self.ShouldTestRun)
     56     return super(_RunnerImpl, self).run(filtered_test)
     57 
     58 
     59 class TestRunner(object):
     60   def __init__(self):
     61     self._loader = _TestLoader()
     62 
     63   def AddDirectory(self, dir_path, test_file_pattern="*test.py"):
     64     assert os.path.isdir(dir_path)
     65 
     66     self._loader.discover_calls.append((dir_path, test_file_pattern, dir_path))
     67 
     68   def Main(self, argv=None):
     69     if argv is None:
     70       argv = sys.argv
     71 
     72     parser = optparse.OptionParser()
     73     options, args = parser.parse_args(argv[1:])
     74 
     75     runner = _RunnerImpl(filters=args)
     76     return unittest.main(module=__name__, argv=[sys.argv[0]],
     77                          testLoader=self._loader,
     78                          testRunner=runner)
     79