Home | History | Annotate | Download | only in test
      1 """Supporting definitions for the Python regression tests."""
      2 
      3 if __name__ != 'test.test_support':
      4     raise ImportError('test_support must be imported from the test package')
      5 
      6 import contextlib
      7 import errno
      8 import functools
      9 import gc
     10 import socket
     11 import sys
     12 import os
     13 import platform
     14 import shutil
     15 import warnings
     16 import unittest
     17 import importlib
     18 import UserDict
     19 import re
     20 import time
     21 try:
     22     import thread
     23 except ImportError:
     24     thread = None
     25 
     26 __all__ = ["Error", "TestFailed", "ResourceDenied", "import_module",
     27            "verbose", "use_resources", "max_memuse", "record_original_stdout",
     28            "get_original_stdout", "unload", "unlink", "rmtree", "forget",
     29            "is_resource_enabled", "requires", "find_unused_port", "bind_port",
     30            "fcmp", "have_unicode", "is_jython", "TESTFN", "HOST", "FUZZ",
     31            "SAVEDCWD", "temp_cwd", "findfile", "sortdict", "check_syntax_error",
     32            "open_urlresource", "check_warnings", "check_py3k_warnings",
     33            "CleanImport", "EnvironmentVarGuard", "captured_output",
     34            "captured_stdout", "TransientResource", "transient_internet",
     35            "run_with_locale", "set_memlimit", "bigmemtest", "bigaddrspacetest",
     36            "BasicTestRunner", "run_unittest", "run_doctest", "threading_setup",
     37            "threading_cleanup", "reap_children", "cpython_only",
     38            "check_impl_detail", "get_attribute", "py3k_bytes",
     39            "import_fresh_module"]
     40 
     41 
     42 class Error(Exception):
     43     """Base class for regression test exceptions."""
     44 
     45 class TestFailed(Error):
     46     """Test failed."""
     47 
     48 class ResourceDenied(unittest.SkipTest):
     49     """Test skipped because it requested a disallowed resource.
     50 
     51     This is raised when a test calls requires() for a resource that
     52     has not been enabled.  It is used to distinguish between expected
     53     and unexpected skips.
     54     """
     55 
     56 @contextlib.contextmanager
     57 def _ignore_deprecated_imports(ignore=True):
     58     """Context manager to suppress package and module deprecation
     59     warnings when importing them.
     60 
     61     If ignore is False, this context manager has no effect."""
     62     if ignore:
     63         with warnings.catch_warnings():
     64             warnings.filterwarnings("ignore", ".+ (module|package)",
     65                                     DeprecationWarning)
     66             yield
     67     else:
     68         yield
     69 
     70 
     71 def import_module(name, deprecated=False):
     72     """Import and return the module to be tested, raising SkipTest if
     73     it is not available.
     74 
     75     If deprecated is True, any module or package deprecation messages
     76     will be suppressed."""
     77     with _ignore_deprecated_imports(deprecated):
     78         try:
     79             return importlib.import_module(name)
     80         except ImportError, msg:
     81             raise unittest.SkipTest(str(msg))
     82 
     83 
     84 def _save_and_remove_module(name, orig_modules):
     85     """Helper function to save and remove a module from sys.modules
     86 
     87        Raise ImportError if the module can't be imported."""
     88     # try to import the module and raise an error if it can't be imported

     89     if name not in sys.modules:
     90         __import__(name)
     91         del sys.modules[name]
     92     for modname in list(sys.modules):
     93         if modname == name or modname.startswith(name + '.'):
     94             orig_modules[modname] = sys.modules[modname]
     95             del sys.modules[modname]
     96 
     97 def _save_and_block_module(name, orig_modules):
     98     """Helper function to save and block a module in sys.modules
     99 
    100        Return True if the module was in sys.modules, False otherwise."""
    101     saved = True
    102     try:
    103         orig_modules[name] = sys.modules[name]
    104     except KeyError:
    105         saved = False
    106     sys.modules[name] = None
    107     return saved
    108 
    109 
    110 def import_fresh_module(name, fresh=(), blocked=(), deprecated=False):
    111     """Imports and returns a module, deliberately bypassing the sys.modules cache
    112     and importing a fresh copy of the module. Once the import is complete,
    113     the sys.modules cache is restored to its original state.
    114 
    115     Modules named in fresh are also imported anew if needed by the import.
    116     If one of these modules can't be imported, None is returned.
    117 
    118     Importing of modules named in blocked is prevented while the fresh import
    119     takes place.
    120 
    121     If deprecated is True, any module or package deprecation messages
    122     will be suppressed."""
    123     # NOTE: test_heapq, test_json, and test_warnings include extra sanity

    124     # checks to make sure that this utility function is working as expected

    125     with _ignore_deprecated_imports(deprecated):
    126         # Keep track of modules saved for later restoration as well

    127         # as those which just need a blocking entry removed

    128         orig_modules = {}
    129         names_to_remove = []
    130         _save_and_remove_module(name, orig_modules)
    131         try:
    132             for fresh_name in fresh:
    133                 _save_and_remove_module(fresh_name, orig_modules)
    134             for blocked_name in blocked:
    135                 if not _save_and_block_module(blocked_name, orig_modules):
    136                     names_to_remove.append(blocked_name)
    137             fresh_module = importlib.import_module(name)
    138         except ImportError:
    139             fresh_module = None
    140         finally:
    141             for orig_name, module in orig_modules.items():
    142                 sys.modules[orig_name] = module
    143             for name_to_remove in names_to_remove:
    144                 del sys.modules[name_to_remove]
    145         return fresh_module
    146 
    147 
    148 def get_attribute(obj, name):
    149     """Get an attribute, raising SkipTest if AttributeError is raised."""
    150     try:
    151         attribute = getattr(obj, name)
    152     except AttributeError:
    153         raise unittest.SkipTest("module %s has no attribute %s" % (
    154             obj.__name__, name))
    155     else:
    156         return attribute
    157 
    158 
    159 verbose = 1              # Flag set to 0 by regrtest.py

    160 use_resources = None     # Flag set to [] by regrtest.py

    161 max_memuse = 0           # Disable bigmem tests (they will still be run with

    162                          # small sizes, to make sure they work.)

    163 real_max_memuse = 0
    164 
    165 # _original_stdout is meant to hold stdout at the time regrtest began.

    166 # This may be "the real" stdout, or IDLE's emulation of stdout, or whatever.

    167 # The point is to have some flavor of stdout the user can actually see.

    168 _original_stdout = None
    169 def record_original_stdout(stdout):
    170     global _original_stdout
    171     _original_stdout = stdout
    172 
    173 def get_original_stdout():
    174     return _original_stdout or sys.stdout
    175 
    176 def unload(name):
    177     try:
    178         del sys.modules[name]
    179     except KeyError:
    180         pass
    181 
    182 def unlink(filename):
    183     try:
    184         os.unlink(filename)
    185     except OSError:
    186         pass
    187 
    188 def rmtree(path):
    189     try:
    190         shutil.rmtree(path)
    191     except OSError, e:
    192         # Unix returns ENOENT, Windows returns ESRCH.

    193         if e.errno not in (errno.ENOENT, errno.ESRCH):
    194             raise
    195 
    196 def forget(modname):
    197     '''"Forget" a module was ever imported by removing it from sys.modules and
    198     deleting any .pyc and .pyo files.'''
    199     unload(modname)
    200     for dirname in sys.path:
    201         unlink(os.path.join(dirname, modname + os.extsep + 'pyc'))
    202         # Deleting the .pyo file cannot be within the 'try' for the .pyc since

    203         # the chance exists that there is no .pyc (and thus the 'try' statement

    204         # is exited) but there is a .pyo file.

    205         unlink(os.path.join(dirname, modname + os.extsep + 'pyo'))
    206 
    207 def is_resource_enabled(resource):
    208     """Test whether a resource is enabled.  Known resources are set by
    209     regrtest.py."""
    210     return use_resources is not None and resource in use_resources
    211 
    212 def requires(resource, msg=None):
    213     """Raise ResourceDenied if the specified resource is not available.
    214 
    215     If the caller's module is __main__ then automatically return True.  The
    216     possibility of False being returned occurs when regrtest.py is executing."""
    217     # see if the caller's module is __main__ - if so, treat as if

    218     # the resource was set

    219     if sys._getframe(1).f_globals.get("__name__") == "__main__":
    220         return
    221     if not is_resource_enabled(resource):
    222         if msg is None:
    223             msg = "Use of the `%s' resource not enabled" % resource
    224         raise ResourceDenied(msg)
    225 
    226 HOST = 'localhost'
    227 
    228 def find_unused_port(family=socket.AF_INET, socktype=socket.SOCK_STREAM):
    229     """Returns an unused port that should be suitable for binding.  This is
    230     achieved by creating a temporary socket with the same family and type as
    231     the 'sock' parameter (default is AF_INET, SOCK_STREAM), and binding it to
    232     the specified host address (defaults to 0.0.0.0) with the port set to 0,
    233     eliciting an unused ephemeral port from the OS.  The temporary socket is
    234     then closed and deleted, and the ephemeral port is returned.
    235 
    236     Either this method or bind_port() should be used for any tests where a
    237     server socket needs to be bound to a particular port for the duration of
    238     the test.  Which one to use depends on whether the calling code is creating
    239     a python socket, or if an unused port needs to be provided in a constructor
    240     or passed to an external program (i.e. the -accept argument to openssl's
    241     s_server mode).  Always prefer bind_port() over find_unused_port() where
    242     possible.  Hard coded ports should *NEVER* be used.  As soon as a server
    243     socket is bound to a hard coded port, the ability to run multiple instances
    244     of the test simultaneously on the same host is compromised, which makes the
    245     test a ticking time bomb in a buildbot environment. On Unix buildbots, this
    246     may simply manifest as a failed test, which can be recovered from without
    247     intervention in most cases, but on Windows, the entire python process can
    248     completely and utterly wedge, requiring someone to log in to the buildbot
    249     and manually kill the affected process.
    250 
    251     (This is easy to reproduce on Windows, unfortunately, and can be traced to
    252     the SO_REUSEADDR socket option having different semantics on Windows versus
    253     Unix/Linux.  On Unix, you can't have two AF_INET SOCK_STREAM sockets bind,
    254     listen and then accept connections on identical host/ports.  An EADDRINUSE
    255     socket.error will be raised at some point (depending on the platform and
    256     the order bind and listen were called on each socket).
    257 
    258     However, on Windows, if SO_REUSEADDR is set on the sockets, no EADDRINUSE
    259     will ever be raised when attempting to bind two identical host/ports. When
    260     accept() is called on each socket, the second caller's process will steal
    261     the port from the first caller, leaving them both in an awkwardly wedged
    262     state where they'll no longer respond to any signals or graceful kills, and
    263     must be forcibly killed via OpenProcess()/TerminateProcess().
    264 
    265     The solution on Windows is to use the SO_EXCLUSIVEADDRUSE socket option
    266     instead of SO_REUSEADDR, which effectively affords the same semantics as
    267     SO_REUSEADDR on Unix.  Given the propensity of Unix developers in the Open
    268     Source world compared to Windows ones, this is a common mistake.  A quick
    269     look over OpenSSL's 0.9.8g source shows that they use SO_REUSEADDR when
    270     openssl.exe is called with the 's_server' option, for example. See
    271     http://bugs.python.org/issue2550 for more info.  The following site also
    272     has a very thorough description about the implications of both REUSEADDR
    273     and EXCLUSIVEADDRUSE on Windows:
    274     http://msdn2.microsoft.com/en-us/library/ms740621(VS.85).aspx)
    275 
    276     XXX: although this approach is a vast improvement on previous attempts to
    277     elicit unused ports, it rests heavily on the assumption that the ephemeral
    278     port returned to us by the OS won't immediately be dished back out to some
    279     other process when we close and delete our temporary socket but before our
    280     calling code has a chance to bind the returned port.  We can deal with this
    281     issue if/when we come across it."""
    282     tempsock = socket.socket(family, socktype)
    283     port = bind_port(tempsock)
    284     tempsock.close()
    285     del tempsock
    286     return port
    287 
    288 def bind_port(sock, host=HOST):
    289     """Bind the socket to a free port and return the port number.  Relies on
    290     ephemeral ports in order to ensure we are using an unbound port.  This is
    291     important as many tests may be running simultaneously, especially in a
    292     buildbot environment.  This method raises an exception if the sock.family
    293     is AF_INET and sock.type is SOCK_STREAM, *and* the socket has SO_REUSEADDR
    294     or SO_REUSEPORT set on it.  Tests should *never* set these socket options
    295     for TCP/IP sockets.  The only case for setting these options is testing
    296     multicasting via multiple UDP sockets.
    297 
    298     Additionally, if the SO_EXCLUSIVEADDRUSE socket option is available (i.e.
    299     on Windows), it will be set on the socket.  This will prevent anyone else
    300     from bind()'ing to our host/port for the duration of the test.
    301     """
    302     if sock.family == socket.AF_INET and sock.type == socket.SOCK_STREAM:
    303         if hasattr(socket, 'SO_REUSEADDR'):
    304             if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR) == 1:
    305                 raise TestFailed("tests should never set the SO_REUSEADDR "   \
    306                                  "socket option on TCP/IP sockets!")
    307         if hasattr(socket, 'SO_REUSEPORT'):
    308             if sock.getsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT) == 1:
    309                 raise TestFailed("tests should never set the SO_REUSEPORT "   \
    310                                  "socket option on TCP/IP sockets!")
    311         if hasattr(socket, 'SO_EXCLUSIVEADDRUSE'):
    312             sock.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1)
    313 
    314     sock.bind((host, 0))
    315     port = sock.getsockname()[1]
    316     return port
    317 
    318 FUZZ = 1e-6
    319 
    320 def fcmp(x, y): # fuzzy comparison function
    321     if isinstance(x, float) or isinstance(y, float):
    322         try:
    323             fuzz = (abs(x) + abs(y)) * FUZZ
    324             if abs(x-y) <= fuzz:
    325                 return 0
    326         except:
    327             pass
    328     elif type(x) == type(y) and isinstance(x, (tuple, list)):
    329         for i in range(min(len(x), len(y))):
    330             outcome = fcmp(x[i], y[i])
    331             if outcome != 0:
    332                 return outcome
    333         return (len(x) > len(y)) - (len(x) < len(y))
    334     return (x > y) - (x < y)
    335 
    336 try:
    337     unicode
    338     have_unicode = True
    339 except NameError:
    340     have_unicode = False
    341 
    342 is_jython = sys.platform.startswith('java')
    343 
    344 # Filename used for testing
    345 if os.name == 'java':
    346     # Jython disallows @ in module names
    347     TESTFN = '$test'
    348 elif os.name == 'riscos':
    349     TESTFN = 'testfile'
    350 else:
    351     TESTFN = '@test'
    352     # Unicode name only used if TEST_FN_ENCODING exists for the platform.
    353     if have_unicode:
    354         # Assuming sys.getfilesystemencoding()!=sys.getdefaultencoding()
    355         # TESTFN_UNICODE is a filename that can be encoded using the
    356         # file system encoding, but *not* with the default (ascii) encoding
    357         if isinstance('', unicode):
    358             # python -U
    359             # XXX perhaps unicode() should accept Unicode strings?
    360             TESTFN_UNICODE = "@test-\xe0\xf2"
    361         else:
    362             # 2 latin characters.
    363             TESTFN_UNICODE = unicode("@test-\xe0\xf2", "latin-1")
    364         TESTFN_ENCODING = sys.getfilesystemencoding()
    365         # TESTFN_UNENCODABLE is a filename that should *not* be
    366         # able to be encoded by *either* the default or filesystem encoding.
    367         # This test really only makes sense on Windows NT platforms
    368         # which have special Unicode support in posixmodule.
    369         if (not hasattr(sys, "getwindowsversion") or
    370                 sys.getwindowsversion()[3] < 2): #  0=win32s or 1=9x/ME
    371             TESTFN_UNENCODABLE = None
    372         else:
    373             # Japanese characters (I think - from bug 846133)
    374             TESTFN_UNENCODABLE = eval('u"@test-\u5171\u6709\u3055\u308c\u308b"')
    375             try:
    376                 # XXX - Note - should be using TESTFN_ENCODING here - but for
    377                 # Windows, "mbcs" currently always operates as if in
    378                 # errors=ignore' mode - hence we get '?' characters rather than
    379                 # the exception.  'Latin1' operates as we expect - ie, fails.
    380                 # See [ 850997 ] mbcs encoding ignores errors
    381                 TESTFN_UNENCODABLE.encode("Latin1")
    382             except UnicodeEncodeError:
    383                 pass
    384             else:
    385                 print \
    386                 'WARNING: The filename %r CAN be encoded by the filesystem.  ' \
    387                 'Unicode filename tests may not be effective' \
    388                 % TESTFN_UNENCODABLE
    389 
    390 
    391 # Disambiguate TESTFN for parallel testing, while letting it remain a valid
    392 # module name.
    393 TESTFN = "{}_{}_tmp".format(TESTFN, os.getpid())
    394 
    395 # Save the initial cwd
    396 SAVEDCWD = os.getcwd()
    397 
    398 @contextlib.contextmanager
    399 def temp_cwd(name='tempcwd', quiet=False):
    400     """
    401     Context manager that creates a temporary directory and set it as CWD.
    402 
    403     The new CWD is created in the current directory and it's named *name*.
    404     If *quiet* is False (default) and it's not possible to create or change
    405     the CWD, an error is raised.  If it's True, only a warning is raised
    406     and the original CWD is used.
    407     """
    408     if isinstance(name, unicode):
    409         try:
    410             name = name.encode(sys.getfilesystemencoding() or 'ascii')
    411         except UnicodeEncodeError:
    412             if not quiet:
    413                 raise unittest.SkipTest('unable to encode the cwd name with '
    414                                         'the filesystem encoding.')
    415     saved_dir = os.getcwd()
    416     is_temporary = False
    417     try:
    418         os.mkdir(name)
    419         os.chdir(name)
    420         is_temporary = True
    421     except OSError:
    422         if not quiet:
    423             raise
    424         warnings.warn('tests may fail, unable to change the CWD to ' + name,
    425                       RuntimeWarning, stacklevel=3)
    426     try:
    427         yield os.getcwd()
    428     finally:
    429         os.chdir(saved_dir)
    430         if is_temporary:
    431             rmtree(name)
    432 
    433 
    434 def findfile(file, here=__file__, subdir=None):
    435     """Try to find a file on sys.path and the working directory.  If it is not
    436     found the argument passed to the function is returned (this does not
    437     necessarily signal failure; could still be the legitimate path)."""
    438     if os.path.isabs(file):
    439         return file
    440     if subdir is not None:
    441         file = os.path.join(subdir, file)
    442     path = sys.path
    443     path = [os.path.dirname(here)] + path
    444     for dn in path:
    445         fn = os.path.join(dn, file)
    446         if os.path.exists(fn): return fn
    447     return file
    448 
    449 def sortdict(dict):
    450     "Like repr(dict), but in sorted order."
    451     items = dict.items()
    452     items.sort()
    453     reprpairs = ["%r: %r" % pair for pair in items]
    454     withcommas = ", ".join(reprpairs)
    455     return "{%s}" % withcommas
    456 
    457 def make_bad_fd():
    458     """
    459     Create an invalid file descriptor by opening and closing a file and return
    460     its fd.
    461     """
    462     file = open(TESTFN, "wb")
    463     try:
    464         return file.fileno()
    465     finally:
    466         file.close()
    467         unlink(TESTFN)
    468 
    469 def check_syntax_error(testcase, statement):
    470     testcase.assertRaises(SyntaxError, compile, statement,
    471                           '<test string>', 'exec')
    472 
    473 def open_urlresource(url, check=None):
    474     import urlparse, urllib2
    475 
    476     filename = urlparse.urlparse(url)[2].split('/')[-1] # '/': it's URL!
    477 
    478     fn = os.path.join(os.path.dirname(__file__), "data", filename)
    479 
    480     def check_valid_file(fn):
    481         f = open(fn)
    482         if check is None:
    483             return f
    484         elif check(f):
    485             f.seek(0)
    486             return f
    487         f.close()
    488 
    489     if os.path.exists(fn):
    490         f = check_valid_file(fn)
    491         if f is not None:
    492             return f
    493         unlink(fn)
    494 
    495     # Verify the requirement before downloading the file

    496     requires('urlfetch')
    497 
    498     print >> get_original_stdout(), '\tfetching %s ...' % url
    499     f = urllib2.urlopen(url, timeout=15)
    500     try:
    501         with open(fn, "wb") as out:
    502             s = f.read()
    503             while s:
    504                 out.write(s)
    505                 s = f.read()
    506     finally:
    507         f.close()
    508 
    509     f = check_valid_file(fn)
    510     if f is not None:
    511         return f
    512     raise TestFailed('invalid resource "%s"' % fn)
    513 
    514 
    515 class WarningsRecorder(object):
    516     """Convenience wrapper for the warnings list returned on
    517        entry to the warnings.catch_warnings() context manager.
    518     """
    519     def __init__(self, warnings_list):
    520         self._warnings = warnings_list
    521         self._last = 0
    522 
    523     def __getattr__(self, attr):
    524         if len(self._warnings) > self._last:
    525             return getattr(self._warnings[-1], attr)
    526         elif attr in warnings.WarningMessage._WARNING_DETAILS:
    527             return None
    528         raise AttributeError("%r has no attribute %r" % (self, attr))
    529 
    530     @property
    531     def warnings(self):
    532         return self._warnings[self._last:]
    533 
    534     def reset(self):
    535         self._last = len(self._warnings)
    536 
    537 
    538 def _filterwarnings(filters, quiet=False):
    539     """Catch the warnings, then check if all the expected
    540     warnings have been raised and re-raise unexpected warnings.
    541     If 'quiet' is True, only re-raise the unexpected warnings.
    542     """
    543     # Clear the warning registry of the calling module

    544     # in order to re-raise the warnings.

    545     frame = sys._getframe(2)
    546     registry = frame.f_globals.get('__warningregistry__')
    547     if registry:
    548         registry.clear()
    549     with warnings.catch_warnings(record=True) as w:
    550         # Set filter "always" to record all warnings.  Because

    551         # test_warnings swap the module, we need to look up in

    552         # the sys.modules dictionary.

    553         sys.modules['warnings'].simplefilter("always")
    554         yield WarningsRecorder(w)
    555     # Filter the recorded warnings

    556     reraise = [warning.message for warning in w]
    557     missing = []
    558     for msg, cat in filters:
    559         seen = False
    560         for exc in reraise[:]:
    561             message = str(exc)
    562             # Filter out the matching messages

    563             if (re.match(msg, message, re.I) and
    564                 issubclass(exc.__class__, cat)):
    565                 seen = True
    566                 reraise.remove(exc)
    567         if not seen and not quiet:
    568             # This filter caught nothing

    569             missing.append((msg, cat.__name__))
    570     if reraise:
    571         raise AssertionError("unhandled warning %r" % reraise[0])
    572     if missing:
    573         raise AssertionError("filter (%r, %s) did not catch any warning" %
    574                              missing[0])
    575 
    576 
    577 @contextlib.contextmanager
    578 def check_warnings(*filters, **kwargs):
    579     """Context manager to silence warnings.
    580 
    581     Accept 2-tuples as positional arguments:
    582         ("message regexp", WarningCategory)
    583 
    584     Optional argument:
    585      - if 'quiet' is True, it does not fail if a filter catches nothing
    586         (default True without argument,
    587          default False if some filters are defined)
    588 
    589     Without argument, it defaults to:
    590         check_warnings(("", Warning), quiet=True)
    591     """
    592     quiet = kwargs.get('quiet')
    593     if not filters:
    594         filters = (("", Warning),)
    595         # Preserve backward compatibility

    596         if quiet is None:
    597             quiet = True
    598     return _filterwarnings(filters, quiet)
    599 
    600 
    601 @contextlib.contextmanager
    602 def check_py3k_warnings(*filters, **kwargs):
    603     """Context manager to silence py3k warnings.
    604 
    605     Accept 2-tuples as positional arguments:
    606         ("message regexp", WarningCategory)
    607 
    608     Optional argument:
    609      - if 'quiet' is True, it does not fail if a filter catches nothing
    610         (default False)
    611 
    612     Without argument, it defaults to:
    613         check_py3k_warnings(("", DeprecationWarning), quiet=False)
    614     """
    615     if sys.py3kwarning:
    616         if not filters:
    617             filters = (("", DeprecationWarning),)
    618     else:
    619         # It should not raise any py3k warning

    620         filters = ()
    621     return _filterwarnings(filters, kwargs.get('quiet'))
    622 
    623 
    624 class CleanImport(object):
    625     """Context manager to force import to return a new module reference.
    626 
    627     This is useful for testing module-level behaviours, such as
    628     the emission of a DeprecationWarning on import.
    629 
    630     Use like this:
    631 
    632         with CleanImport("foo"):
    633             importlib.import_module("foo") # new reference
    634     """
    635 
    636     def __init__(self, *module_names):
    637         self.original_modules = sys.modules.copy()
    638         for module_name in module_names:
    639             if module_name in sys.modules:
    640                 module = sys.modules[module_name]
    641                 # It is possible that module_name is just an alias for

    642                 # another module (e.g. stub for modules renamed in 3.x).

    643                 # In that case, we also need delete the real module to clear

    644                 # the import cache.

    645                 if module.__name__ != module_name:
    646                     del sys.modules[module.__name__]
    647                 del sys.modules[module_name]
    648 
    649     def __enter__(self):
    650         return self
    651 
    652     def __exit__(self, *ignore_exc):
    653         sys.modules.update(self.original_modules)
    654 
    655 
    656 class EnvironmentVarGuard(UserDict.DictMixin):
    657 
    658     """Class to help protect the environment variable properly.  Can be used as
    659     a context manager."""
    660 
    661     def __init__(self):
    662         self._environ = os.environ
    663         self._changed = {}
    664 
    665     def __getitem__(self, envvar):
    666         return self._environ[envvar]
    667 
    668     def __setitem__(self, envvar, value):
    669         # Remember the initial value on the first access

    670         if envvar not in self._changed:
    671             self._changed[envvar] = self._environ.get(envvar)
    672         self._environ[envvar] = value
    673 
    674     def __delitem__(self, envvar):
    675         # Remember the initial value on the first access

    676         if envvar not in self._changed:
    677             self._changed[envvar] = self._environ.get(envvar)
    678         if envvar in self._environ:
    679             del self._environ[envvar]
    680 
    681     def keys(self):
    682         return self._environ.keys()
    683 
    684     def set(self, envvar, value):
    685         self[envvar] = value
    686 
    687     def unset(self, envvar):
    688         del self[envvar]
    689 
    690     def __enter__(self):
    691         return self
    692 
    693     def __exit__(self, *ignore_exc):
    694         for (k, v) in self._changed.items():
    695             if v is None:
    696                 if k in self._environ:
    697                     del self._environ[k]
    698             else:
    699                 self._environ[k] = v
    700         os.environ = self._environ
    701 
    702 
    703 class DirsOnSysPath(object):
    704     """Context manager to temporarily add directories to sys.path.
    705 
    706     This makes a copy of sys.path, appends any directories given
    707     as positional arguments, then reverts sys.path to the copied
    708     settings when the context ends.
    709 
    710     Note that *all* sys.path modifications in the body of the
    711     context manager, including replacement of the object,
    712     will be reverted at the end of the block.
    713     """
    714 
    715     def __init__(self, *paths):
    716         self.original_value = sys.path[:]
    717         self.original_object = sys.path
    718         sys.path.extend(paths)
    719 
    720     def __enter__(self):
    721         return self
    722 
    723     def __exit__(self, *ignore_exc):
    724         sys.path = self.original_object
    725         sys.path[:] = self.original_value
    726 
    727 
    728 class TransientResource(object):
    729 
    730     """Raise ResourceDenied if an exception is raised while the context manager
    731     is in effect that matches the specified exception and attributes."""
    732 
    733     def __init__(self, exc, **kwargs):
    734         self.exc = exc
    735         self.attrs = kwargs
    736 
    737     def __enter__(self):
    738         return self
    739 
    740     def __exit__(self, type_=None, value=None, traceback=None):
    741         """If type_ is a subclass of self.exc and value has attributes matching
    742         self.attrs, raise ResourceDenied.  Otherwise let the exception
    743         propagate (if any)."""
    744         if type_ is not None and issubclass(self.exc, type_):
    745             for attr, attr_value in self.attrs.iteritems():
    746                 if not hasattr(value, attr):
    747                     break
    748                 if getattr(value, attr) != attr_value:
    749                     break
    750             else:
    751                 raise ResourceDenied("an optional resource is not available")
    752 
    753 
    754 @contextlib.contextmanager
    755 def transient_internet(resource_name, timeout=30.0, errnos=()):
    756     """Return a context manager that raises ResourceDenied when various issues
    757     with the Internet connection manifest themselves as exceptions."""
    758     default_errnos = [
    759         ('ECONNREFUSED', 111),
    760         ('ECONNRESET', 104),
    761         ('EHOSTUNREACH', 113),
    762         ('ENETUNREACH', 101),
    763         ('ETIMEDOUT', 110),
    764     ]
    765     default_gai_errnos = [
    766         ('EAI_NONAME', -2),
    767         ('EAI_NODATA', -5),
    768     ]
    769 
    770     denied = ResourceDenied("Resource '%s' is not available" % resource_name)
    771     captured_errnos = errnos
    772     gai_errnos = []
    773     if not captured_errnos:
    774         captured_errnos = [getattr(errno, name, num)
    775                            for (name, num) in default_errnos]
    776         gai_errnos = [getattr(socket, name, num)
    777                       for (name, num) in default_gai_errnos]
    778 
    779     def filter_error(err):
    780         n = getattr(err, 'errno', None)
    781         if (isinstance(err, socket.timeout) or
    782             (isinstance(err, socket.gaierror) and n in gai_errnos) or
    783             n in captured_errnos):
    784             if not verbose:
    785                 sys.stderr.write(denied.args[0] + "\n")
    786             raise denied
    787 
    788     old_timeout = socket.getdefaulttimeout()
    789     try:
    790         if timeout is not None:
    791             socket.setdefaulttimeout(timeout)
    792         yield
    793     except IOError as err:
    794         # urllib can wrap original socket errors multiple times (!), we must

    795         # unwrap to get at the original error.

    796         while True:
    797             a = err.args
    798             if len(a) >= 1 and isinstance(a[0], IOError):
    799                 err = a[0]
    800             # The error can also be wrapped as args[1]:

    801             #    except socket.error as msg:

    802             #        raise IOError('socket error', msg).with_traceback(sys.exc_info()[2])

    803             elif len(a) >= 2 and isinstance(a[1], IOError):
    804                 err = a[1]
    805             else:
    806                 break
    807         filter_error(err)
    808         raise
    809     # XXX should we catch generic exceptions and look for their

    810     # __cause__ or __context__?

    811     finally:
    812         socket.setdefaulttimeout(old_timeout)
    813 
    814 
    815 @contextlib.contextmanager
    816 def captured_output(stream_name):
    817     """Return a context manager used by captured_stdout and captured_stdin
    818     that temporarily replaces the sys stream *stream_name* with a StringIO."""
    819     import StringIO
    820     orig_stdout = getattr(sys, stream_name)
    821     setattr(sys, stream_name, StringIO.StringIO())
    822     try:
    823         yield getattr(sys, stream_name)
    824     finally:
    825         setattr(sys, stream_name, orig_stdout)
    826 
    827 def captured_stdout():
    828     """Capture the output of sys.stdout:
    829 
    830        with captured_stdout() as s:
    831            print "hello"
    832        self.assertEqual(s.getvalue(), "hello")
    833     """
    834     return captured_output("stdout")
    835 
    836 def captured_stdin():
    837     return captured_output("stdin")
    838 
    839 def gc_collect():
    840     """Force as many objects as possible to be collected.
    841 
    842     In non-CPython implementations of Python, this is needed because timely
    843     deallocation is not guaranteed by the garbage collector.  (Even in CPython
    844     this can be the case in case of reference cycles.)  This means that __del__
    845     methods may be called later than expected and weakrefs may remain alive for
    846     longer than expected.  This function tries its best to force all garbage
    847     objects to disappear.
    848     """
    849     gc.collect()
    850     if is_jython:
    851         time.sleep(0.1)
    852     gc.collect()
    853     gc.collect()
    854 
    855 
    856 #=======================================================================

    857 # Decorator for running a function in a different locale, correctly resetting

    858 # it afterwards.

    859 
    860 def run_with_locale(catstr, *locales):
    861     def decorator(func):
    862         def inner(*args, **kwds):
    863             try:
    864                 import locale
    865                 category = getattr(locale, catstr)
    866                 orig_locale = locale.setlocale(category)
    867             except AttributeError:
    868                 # if the test author gives us an invalid category string

    869                 raise
    870             except:
    871                 # cannot retrieve original locale, so do nothing

    872                 locale = orig_locale = None
    873             else:
    874                 for loc in locales:
    875                     try:
    876                         locale.setlocale(category, loc)
    877                         break
    878                     except:
    879                         pass
    880 
    881             # now run the function, resetting the locale on exceptions

    882             try:
    883                 return func(*args, **kwds)
    884             finally:
    885                 if locale and orig_locale:
    886                     locale.setlocale(category, orig_locale)
    887         inner.func_name = func.func_name
    888         inner.__doc__ = func.__doc__
    889         return inner
    890     return decorator
    891 
    892 #=======================================================================

    893 # Big-memory-test support. Separate from 'resources' because memory use should be configurable.

    894 
    895 # Some handy shorthands. Note that these are used for byte-limits as well

    896 # as size-limits, in the various bigmem tests

    897 _1M = 1024*1024
    898 _1G = 1024 * _1M
    899 _2G = 2 * _1G
    900 _4G = 4 * _1G
    901 
    902 MAX_Py_ssize_t = sys.maxsize
    903 
    904 def set_memlimit(limit):
    905     global max_memuse
    906     global real_max_memuse
    907     sizes = {
    908         'k': 1024,
    909         'm': _1M,
    910         'g': _1G,
    911         't': 1024*_1G,
    912     }
    913     m = re.match(r'(\d+(\.\d+)?) (K|M|G|T)b?$', limit,
    914                  re.IGNORECASE | re.VERBOSE)
    915     if m is None:
    916         raise ValueError('Invalid memory limit %r' % (limit,))
    917     memlimit = int(float(m.group(1)) * sizes[m.group(3).lower()])
    918     real_max_memuse = memlimit
    919     if memlimit > MAX_Py_ssize_t:
    920         memlimit = MAX_Py_ssize_t
    921     if memlimit < _2G - 1:
    922         raise ValueError('Memory limit %r too low to be useful' % (limit,))
    923     max_memuse = memlimit
    924 
    925 def bigmemtest(minsize, memuse, overhead=5*_1M):
    926     """Decorator for bigmem tests.
    927 
    928     'minsize' is the minimum useful size for the test (in arbitrary,
    929     test-interpreted units.) 'memuse' is the number of 'bytes per size' for
    930     the test, or a good estimate of it. 'overhead' specifies fixed overhead,
    931     independent of the testsize, and defaults to 5Mb.
    932 
    933     The decorator tries to guess a good value for 'size' and passes it to
    934     the decorated test function. If minsize * memuse is more than the
    935     allowed memory use (as defined by max_memuse), the test is skipped.
    936     Otherwise, minsize is adjusted upward to use up to max_memuse.
    937     """
    938     def decorator(f):
    939         def wrapper(self):
    940             if not max_memuse:
    941                 # If max_memuse is 0 (the default),

    942                 # we still want to run the tests with size set to a few kb,

    943                 # to make sure they work. We still want to avoid using

    944                 # too much memory, though, but we do that noisily.

    945                 maxsize = 5147
    946                 self.assertFalse(maxsize * memuse + overhead > 20 * _1M)
    947             else:
    948                 maxsize = int((max_memuse - overhead) / memuse)
    949                 if maxsize < minsize:
    950                     # Really ought to print 'test skipped' or something

    951                     if verbose:
    952                         sys.stderr.write("Skipping %s because of memory "
    953                                          "constraint\n" % (f.__name__,))
    954                     return
    955                 # Try to keep some breathing room in memory use

    956                 maxsize = max(maxsize - 50 * _1M, minsize)
    957             return f(self, maxsize)
    958         wrapper.minsize = minsize
    959         wrapper.memuse = memuse
    960         wrapper.overhead = overhead
    961         return wrapper
    962     return decorator
    963 
    964 def precisionbigmemtest(size, memuse, overhead=5*_1M):
    965     def decorator(f):
    966         def wrapper(self):
    967             if not real_max_memuse:
    968                 maxsize = 5147
    969             else:
    970                 maxsize = size
    971 
    972                 if real_max_memuse and real_max_memuse < maxsize * memuse:
    973                     if verbose:
    974                         sys.stderr.write("Skipping %s because of memory "
    975                                          "constraint\n" % (f.__name__,))
    976                     return
    977 
    978             return f(self, maxsize)
    979         wrapper.size = size
    980         wrapper.memuse = memuse
    981         wrapper.overhead = overhead
    982         return wrapper
    983     return decorator
    984 
    985 def bigaddrspacetest(f):
    986     """Decorator for tests that fill the address space."""
    987     def wrapper(self):
    988         if max_memuse < MAX_Py_ssize_t:
    989             if verbose:
    990                 sys.stderr.write("Skipping %s because of memory "
    991                                  "constraint\n" % (f.__name__,))
    992         else:
    993             return f(self)
    994     return wrapper
    995 
    996 #=======================================================================

    997 # unittest integration.

    998 
    999 class BasicTestRunner:
   1000     def run(self, test):
   1001         result = unittest.TestResult()
   1002         test(result)
   1003         return result
   1004 
   1005 def _id(obj):
   1006     return obj
   1007 
   1008 def requires_resource(resource):
   1009     if is_resource_enabled(resource):
   1010         return _id
   1011     else:
   1012         return unittest.skip("resource {0!r} is not enabled".format(resource))
   1013 
   1014 def cpython_only(test):
   1015     """
   1016     Decorator for tests only applicable on CPython.
   1017     """
   1018     return impl_detail(cpython=True)(test)
   1019 
   1020 def impl_detail(msg=None, **guards):
   1021     if check_impl_detail(**guards):
   1022         return _id
   1023     if msg is None:
   1024         guardnames, default = _parse_guards(guards)
   1025         if default:
   1026             msg = "implementation detail not available on {0}"
   1027         else:
   1028             msg = "implementation detail specific to {0}"
   1029         guardnames = sorted(guardnames.keys())
   1030         msg = msg.format(' or '.join(guardnames))
   1031     return unittest.skip(msg)
   1032 
   1033 def _parse_guards(guards):
   1034     # Returns a tuple ({platform_name: run_me}, default_value)

   1035     if not guards:
   1036         return ({'cpython': True}, False)
   1037     is_true = guards.values()[0]
   1038     assert guards.values() == [is_true] * len(guards)   # all True or all False

   1039     return (guards, not is_true)
   1040 
   1041 # Use the following check to guard CPython's implementation-specific tests --

   1042 # or to run them only on the implementation(s) guarded by the arguments.

   1043 def check_impl_detail(**guards):
   1044     """This function returns True or False depending on the host platform.
   1045        Examples:
   1046           if check_impl_detail():               # only on CPython (default)
   1047           if check_impl_detail(jython=True):    # only on Jython
   1048           if check_impl_detail(cpython=False):  # everywhere except on CPython
   1049     """
   1050     guards, default = _parse_guards(guards)
   1051     return guards.get(platform.python_implementation().lower(), default)
   1052 
   1053 
   1054 
   1055 def _run_suite(suite):
   1056     """Run tests from a unittest.TestSuite-derived class."""
   1057     if verbose:
   1058         runner = unittest.TextTestRunner(sys.stdout, verbosity=2)
   1059     else:
   1060         runner = BasicTestRunner()
   1061 
   1062     result = runner.run(suite)
   1063     if not result.wasSuccessful():
   1064         if len(result.errors) == 1 and not result.failures:
   1065             err = result.errors[0][1]
   1066         elif len(result.failures) == 1 and not result.errors:
   1067             err = result.failures[0][1]
   1068         else:
   1069             err = "multiple errors occurred"
   1070             if not verbose:
   1071                 err += "; run in verbose mode for details"
   1072         raise TestFailed(err)
   1073 
   1074 
   1075 def run_unittest(*classes):
   1076     """Run tests from unittest.TestCase-derived classes."""
   1077     valid_types = (unittest.TestSuite, unittest.TestCase)
   1078     suite = unittest.TestSuite()
   1079     for cls in classes:
   1080         if isinstance(cls, str):
   1081             if cls in sys.modules:
   1082                 suite.addTest(unittest.findTestCases(sys.modules[cls]))
   1083             else:
   1084                 raise ValueError("str arguments must be keys in sys.modules")
   1085         elif isinstance(cls, valid_types):
   1086             suite.addTest(cls)
   1087         else:
   1088             suite.addTest(unittest.makeSuite(cls))
   1089     _run_suite(suite)
   1090 
   1091 
   1092 #=======================================================================

   1093 # doctest driver.

   1094 
   1095 def run_doctest(module, verbosity=None):
   1096     """Run doctest on the given module.  Return (#failures, #tests).
   1097 
   1098     If optional argument verbosity is not specified (or is None), pass
   1099     test_support's belief about verbosity on to doctest.  Else doctest's
   1100     usual behavior is used (it searches sys.argv for -v).
   1101     """
   1102 
   1103     import doctest
   1104 
   1105     if verbosity is None:
   1106         verbosity = verbose
   1107     else:
   1108         verbosity = None
   1109 
   1110     # Direct doctest output (normally just errors) to real stdout; doctest

   1111     # output shouldn't be compared by regrtest.

   1112     save_stdout = sys.stdout
   1113     sys.stdout = get_original_stdout()
   1114     try:
   1115         f, t = doctest.testmod(module, verbose=verbosity)
   1116         if f:
   1117             raise TestFailed("%d of %d doctests failed" % (f, t))
   1118     finally:
   1119         sys.stdout = save_stdout
   1120     if verbose:
   1121         print 'doctest (%s) ... %d tests with zero failures' % (module.__name__, t)
   1122     return f, t
   1123 
   1124 #=======================================================================

   1125 # Threading support to prevent reporting refleaks when running regrtest.py -R

   1126 
   1127 # NOTE: we use thread._count() rather than threading.enumerate() (or the

   1128 # moral equivalent thereof) because a threading.Thread object is still alive

   1129 # until its __bootstrap() method has returned, even after it has been

   1130 # unregistered from the threading module.

   1131 # thread._count(), on the other hand, only gets decremented *after* the

   1132 # __bootstrap() method has returned, which gives us reliable reference counts

   1133 # at the end of a test run.

   1134 
   1135 def threading_setup():
   1136     if thread:
   1137         return thread._count(),
   1138     else:
   1139         return 1,
   1140 
   1141 def threading_cleanup(nb_threads):
   1142     if not thread:
   1143         return
   1144 
   1145     _MAX_COUNT = 10
   1146     for count in range(_MAX_COUNT):
   1147         n = thread._count()
   1148         if n == nb_threads:
   1149             break
   1150         time.sleep(0.1)
   1151     # XXX print a warning in case of failure?

   1152 
   1153 def reap_threads(func):
   1154     """Use this function when threads are being used.  This will
   1155     ensure that the threads are cleaned up even when the test fails.
   1156     If threading is unavailable this function does nothing.
   1157     """
   1158     if not thread:
   1159         return func
   1160 
   1161     @functools.wraps(func)
   1162     def decorator(*args):
   1163         key = threading_setup()
   1164         try:
   1165             return func(*args)
   1166         finally:
   1167             threading_cleanup(*key)
   1168     return decorator
   1169 
   1170 def reap_children():
   1171     """Use this function at the end of test_main() whenever sub-processes
   1172     are started.  This will help ensure that no extra children (zombies)
   1173     stick around to hog resources and create problems when looking
   1174     for refleaks.
   1175     """
   1176 
   1177     # Reap all our dead child processes so we don't leave zombies around.

   1178     # These hog resources and might be causing some of the buildbots to die.

   1179     if hasattr(os, 'waitpid'):
   1180         any_process = -1
   1181         while True:
   1182             try:
   1183                 # This will raise an exception on Windows.  That's ok.

   1184                 pid, status = os.waitpid(any_process, os.WNOHANG)
   1185                 if pid == 0:
   1186                     break
   1187             except:
   1188                 break
   1189 
   1190 def py3k_bytes(b):
   1191     """Emulate the py3k bytes() constructor.
   1192 
   1193     NOTE: This is only a best effort function.
   1194     """
   1195     try:
   1196         # memoryview?

   1197         return b.tobytes()
   1198     except AttributeError:
   1199         try:
   1200             # iterable of ints?

   1201             return b"".join(chr(x) for x in b)
   1202         except TypeError:
   1203             return bytes(b)
   1204 
   1205 def args_from_interpreter_flags():
   1206     """Return a list of command-line arguments reproducing the current
   1207     settings in sys.flags."""
   1208     flag_opt_map = {
   1209         'bytes_warning': 'b',
   1210         'dont_write_bytecode': 'B',
   1211         'ignore_environment': 'E',
   1212         'no_user_site': 's',
   1213         'no_site': 'S',
   1214         'optimize': 'O',
   1215         'py3k_warning': '3',
   1216         'verbose': 'v',
   1217     }
   1218     args = []
   1219     for flag, opt in flag_opt_map.items():
   1220         v = getattr(sys.flags, flag)
   1221         if v > 0:
   1222             args.append('-' + opt * v)
   1223     return args
   1224 
   1225 def strip_python_stderr(stderr):
   1226     """Strip the stderr of a Python process from potential debug output
   1227     emitted by the interpreter.
   1228 
   1229     This will typically be run on the result of the communicate() method
   1230     of a subprocess.Popen object.
   1231     """
   1232     stderr = re.sub(br"\[\d+ refs\]\r?\n?$", b"", stderr).strip()
   1233     return stderr
   1234