Home | History | Annotate | Download | only in tests
      1 """Support code for distutils test cases."""
      2 import os
      3 import sys
      4 import shutil
      5 import tempfile
      6 import unittest
      7 import sysconfig
      8 from copy import deepcopy
      9 import warnings
     10 
     11 from distutils import log
     12 from distutils.log import DEBUG, INFO, WARN, ERROR, FATAL
     13 from distutils.core import Distribution
     14 
     15 
     16 def capture_warnings(func):
     17     def _capture_warnings(*args, **kw):
     18         with warnings.catch_warnings():
     19             warnings.simplefilter("ignore")
     20             return func(*args, **kw)
     21     return _capture_warnings
     22 
     23 
     24 class LoggingSilencer(object):
     25 
     26     def setUp(self):
     27         super(LoggingSilencer, self).setUp()
     28         self.threshold = log.set_threshold(log.FATAL)
     29         # catching warnings
     30         # when log will be replaced by logging
     31         # we won't need such monkey-patch anymore
     32         self._old_log = log.Log._log
     33         log.Log._log = self._log
     34         self.logs = []
     35 
     36     def tearDown(self):
     37         log.set_threshold(self.threshold)
     38         log.Log._log = self._old_log
     39         super(LoggingSilencer, self).tearDown()
     40 
     41     def _log(self, level, msg, args):
     42         if level not in (DEBUG, INFO, WARN, ERROR, FATAL):
     43             raise ValueError('%s wrong log level' % str(level))
     44         self.logs.append((level, msg, args))
     45 
     46     def get_logs(self, *levels):
     47         def _format(msg, args):
     48             if len(args) == 0:
     49                 return msg
     50             return msg % args
     51         return [_format(msg, args) for level, msg, args
     52                 in self.logs if level in levels]
     53 
     54     def clear_logs(self):
     55         self.logs = []
     56 
     57 
     58 class TempdirManager(object):
     59     """Mix-in class that handles temporary directories for test cases.
     60 
     61     This is intended to be used with unittest.TestCase.
     62     """
     63 
     64     def setUp(self):
     65         super(TempdirManager, self).setUp()
     66         self.old_cwd = os.getcwd()
     67         self.tempdirs = []
     68 
     69     def tearDown(self):
     70         # Restore working dir, for Solaris and derivatives, where rmdir()
     71         # on the current directory fails.
     72         os.chdir(self.old_cwd)
     73         super(TempdirManager, self).tearDown()
     74         while self.tempdirs:
     75             d = self.tempdirs.pop()
     76             shutil.rmtree(d, os.name in ('nt', 'cygwin'))
     77 
     78     def mkdtemp(self):
     79         """Create a temporary directory that will be cleaned up.
     80 
     81         Returns the path of the directory.
     82         """
     83         d = tempfile.mkdtemp()
     84         self.tempdirs.append(d)
     85         return d
     86 
     87     def write_file(self, path, content='xxx'):
     88         """Writes a file in the given path.
     89 
     90 
     91         path can be a string or a sequence.
     92         """
     93         if isinstance(path, (list, tuple)):
     94             path = os.path.join(*path)
     95         f = open(path, 'w')
     96         try:
     97             f.write(content)
     98         finally:
     99             f.close()
    100 
    101     def create_dist(self, pkg_name='foo', **kw):
    102         """Will generate a test environment.
    103 
    104         This function creates:
    105          - a Distribution instance using keywords
    106          - a temporary directory with a package structure
    107 
    108         It returns the package directory and the distribution
    109         instance.
    110         """
    111         tmp_dir = self.mkdtemp()
    112         pkg_dir = os.path.join(tmp_dir, pkg_name)
    113         os.mkdir(pkg_dir)
    114         dist = Distribution(attrs=kw)
    115 
    116         return pkg_dir, dist
    117 
    118 
    119 class DummyCommand:
    120     """Class to store options for retrieval via set_undefined_options()."""
    121 
    122     def __init__(self, **kwargs):
    123         for kw, val in kwargs.items():
    124             setattr(self, kw, val)
    125 
    126     def ensure_finalized(self):
    127         pass
    128 
    129 
    130 class EnvironGuard(object):
    131 
    132     def setUp(self):
    133         super(EnvironGuard, self).setUp()
    134         self.old_environ = deepcopy(os.environ)
    135 
    136     def tearDown(self):
    137         for key, value in self.old_environ.items():
    138             if os.environ.get(key) != value:
    139                 os.environ[key] = value
    140 
    141         for key in os.environ.keys():
    142             if key not in self.old_environ:
    143                 del os.environ[key]
    144 
    145         super(EnvironGuard, self).tearDown()
    146 
    147 
    148 def copy_xxmodule_c(directory):
    149     """Helper for tests that need the xxmodule.c source file.
    150 
    151     Example use:
    152 
    153         def test_compile(self):
    154             copy_xxmodule_c(self.tmpdir)
    155             self.assertIn('xxmodule.c', os.listdir(self.tmpdir))
    156 
    157     If the source file can be found, it will be copied to *directory*.  If not,
    158     the test will be skipped.  Errors during copy are not caught.
    159     """
    160     filename = _get_xxmodule_path()
    161     if filename is None:
    162         raise unittest.SkipTest('cannot find xxmodule.c (test must run in '
    163                                 'the python build dir)')
    164     shutil.copy(filename, directory)
    165 
    166 
    167 def _get_xxmodule_path():
    168     # FIXME when run from regrtest, srcdir seems to be '.', which does not help
    169     # us find the xxmodule.c file
    170     srcdir = sysconfig.get_config_var('srcdir')
    171     candidates = [
    172         # use installed copy if available
    173         os.path.join(os.path.dirname(__file__), 'xxmodule.c'),
    174         # otherwise try using copy from build directory
    175         os.path.join(srcdir, 'Modules', 'xxmodule.c'),
    176         # srcdir mysteriously can be $srcdir/Lib/distutils/tests when
    177         # this file is run from its parent directory, so walk up the
    178         # tree to find the real srcdir
    179         os.path.join(srcdir, '..', '..', '..', 'Modules', 'xxmodule.c'),
    180     ]
    181     for path in candidates:
    182         if os.path.exists(path):
    183             return path
    184 
    185 
    186 def fixup_build_ext(cmd):
    187     """Function needed to make build_ext tests pass.
    188 
    189     When Python was build with --enable-shared on Unix, -L. is not good
    190     enough to find the libpython<blah>.so.  This is because regrtest runs
    191     it under a tempdir, not in the top level where the .so lives.  By the
    192     time we've gotten here, Python's already been chdir'd to the tempdir.
    193 
    194     When Python was built with in debug mode on Windows, build_ext commands
    195     need their debug attribute set, and it is not done automatically for
    196     some reason.
    197 
    198     This function handles both of these things.  Example use:
    199 
    200         cmd = build_ext(dist)
    201         support.fixup_build_ext(cmd)
    202         cmd.ensure_finalized()
    203 
    204     Unlike most other Unix platforms, Mac OS X embeds absolute paths
    205     to shared libraries into executables, so the fixup is not needed there.
    206     """
    207     if os.name == 'nt':
    208         cmd.debug = sys.executable.endswith('_d.exe')
    209     elif sysconfig.get_config_var('Py_ENABLE_SHARED'):
    210         # To further add to the shared builds fun on Unix, we can't just add
    211         # library_dirs to the Extension() instance because that doesn't get
    212         # plumbed through to the final compiler command.
    213         runshared = sysconfig.get_config_var('RUNSHARED')
    214         if runshared is None:
    215             cmd.library_dirs = ['.']
    216         else:
    217             if sys.platform == 'darwin':
    218                 cmd.library_dirs = []
    219             else:
    220                 name, equals, value = runshared.partition('=')
    221                 cmd.library_dirs = [d for d in value.split(os.pathsep) if d]
    222