Home | History | Annotate | Download | only in test
      1 import contextlib
      2 import imp
      3 import importlib
      4 import sys
      5 import unittest
      6 
      7 
      8 @contextlib.contextmanager
      9 def uncache(*names):
     10     """Uncache a module from sys.modules.
     11 
     12     A basic sanity check is performed to prevent uncaching modules that either
     13     cannot/shouldn't be uncached.
     14 
     15     """
     16     for name in names:
     17         if name in ('sys', 'marshal', 'imp'):
     18             raise ValueError(
     19                 "cannot uncache {0} as it will break _importlib".format(name))
     20         try:
     21             del sys.modules[name]
     22         except KeyError:
     23             pass
     24     try:
     25         yield
     26     finally:
     27         for name in names:
     28             try:
     29                 del sys.modules[name]
     30             except KeyError:
     31                 pass
     32 
     33 
     34 @contextlib.contextmanager
     35 def import_state(**kwargs):
     36     """Context manager to manage the various importers and stored state in the
     37     sys module.
     38 
     39     The 'modules' attribute is not supported as the interpreter state stores a
     40     pointer to the dict that the interpreter uses internally;
     41     reassigning to sys.modules does not have the desired effect.
     42 
     43     """
     44     originals = {}
     45     try:
     46         for attr, default in (('meta_path', []), ('path', []),
     47                               ('path_hooks', []),
     48                               ('path_importer_cache', {})):
     49             originals[attr] = getattr(sys, attr)
     50             if attr in kwargs:
     51                 new_value = kwargs[attr]
     52                 del kwargs[attr]
     53             else:
     54                 new_value = default
     55             setattr(sys, attr, new_value)
     56         if len(kwargs):
     57             raise ValueError(
     58                     'unrecognized arguments: {0}'.format(kwargs.keys()))
     59         yield
     60     finally:
     61         for attr, value in originals.items():
     62             setattr(sys, attr, value)
     63 
     64 
     65 class mock_modules(object):
     66 
     67     """A mock importer/loader."""
     68 
     69     def __init__(self, *names):
     70         self.modules = {}
     71         for name in names:
     72             if not name.endswith('.__init__'):
     73                 import_name = name
     74             else:
     75                 import_name = name[:-len('.__init__')]
     76             if '.' not in name:
     77                 package = None
     78             elif import_name == name:
     79                 package = name.rsplit('.', 1)[0]
     80             else:
     81                 package = import_name
     82             module = imp.new_module(import_name)
     83             module.__loader__ = self
     84             module.__file__ = '<mock __file__>'
     85             module.__package__ = package
     86             module.attr = name
     87             if import_name != name:
     88                 module.__path__ = ['<mock __path__>']
     89             self.modules[import_name] = module
     90 
     91     def __getitem__(self, name):
     92         return self.modules[name]
     93 
     94     def find_module(self, fullname, path=None):
     95         if fullname not in self.modules:
     96             return None
     97         else:
     98             return self
     99 
    100     def load_module(self, fullname):
    101         if fullname not in self.modules:
    102             raise ImportError
    103         else:
    104             sys.modules[fullname] = self.modules[fullname]
    105             return self.modules[fullname]
    106 
    107     def __enter__(self):
    108         self._uncache = uncache(*self.modules.keys())
    109         self._uncache.__enter__()
    110         return self
    111 
    112     def __exit__(self, *exc_info):
    113         self._uncache.__exit__(None, None, None)
    114 
    115 
    116 
    117 class ImportModuleTests(unittest.TestCase):
    118 
    119     """Test importlib.import_module."""
    120 
    121     def test_module_import(self):
    122         # Test importing a top-level module.

    123         with mock_modules('top_level') as mock:
    124             with import_state(meta_path=[mock]):
    125                 module = importlib.import_module('top_level')
    126                 self.assertEqual(module.__name__, 'top_level')
    127 
    128     def test_absolute_package_import(self):
    129         # Test importing a module from a package with an absolute name.

    130         pkg_name = 'pkg'
    131         pkg_long_name = '{0}.__init__'.format(pkg_name)
    132         name = '{0}.mod'.format(pkg_name)
    133         with mock_modules(pkg_long_name, name) as mock:
    134             with import_state(meta_path=[mock]):
    135                 module = importlib.import_module(name)
    136                 self.assertEqual(module.__name__, name)
    137 
    138     def test_shallow_relative_package_import(self):
    139         modules = ['a.__init__', 'a.b.__init__', 'a.b.c.__init__', 'a.b.c.d']
    140         with mock_modules(*modules) as mock:
    141             with import_state(meta_path=[mock]):
    142                 module = importlib.import_module('.d', 'a.b.c')
    143                 self.assertEqual(module.__name__, 'a.b.c.d')
    144 
    145     def test_deep_relative_package_import(self):
    146         # Test importing a module from a package through a relatve import.

    147         modules = ['a.__init__', 'a.b.__init__', 'a.c']
    148         with mock_modules(*modules) as mock:
    149             with import_state(meta_path=[mock]):
    150                 module = importlib.import_module('..c', 'a.b')
    151                 self.assertEqual(module.__name__, 'a.c')
    152 
    153     def test_absolute_import_with_package(self):
    154         # Test importing a module from a package with an absolute name with

    155         # the 'package' argument given.

    156         pkg_name = 'pkg'
    157         pkg_long_name = '{0}.__init__'.format(pkg_name)
    158         name = '{0}.mod'.format(pkg_name)
    159         with mock_modules(pkg_long_name, name) as mock:
    160             with import_state(meta_path=[mock]):
    161                 module = importlib.import_module(name, pkg_name)
    162                 self.assertEqual(module.__name__, name)
    163 
    164     def test_relative_import_wo_package(self):
    165         # Relative imports cannot happen without the 'package' argument being

    166         # set.

    167         self.assertRaises(TypeError, importlib.import_module, '.support')
    168 
    169 
    170 def test_main():
    171     from test.test_support import run_unittest
    172     run_unittest(ImportModuleTests)
    173 
    174 
    175 if __name__ == '__main__':
    176     test_main()
    177