Home | History | Annotate | Download | only in test
      1 import sys
      2 import imp
      3 import os
      4 import unittest
      5 from test import test_support
      6 
      7 
      8 test_src = """\
      9 def get_name():
     10     return __name__
     11 def get_file():
     12     return __file__
     13 """
     14 
     15 absimp = "import sub\n"
     16 relimp = "from . import sub\n"
     17 deeprelimp = "from .... import sub\n"
     18 futimp = "from __future__ import absolute_import\n"
     19 
     20 reload_src = test_src+"""\
     21 reloaded = True
     22 """
     23 
     24 test_co = compile(test_src, "<???>", "exec")
     25 reload_co = compile(reload_src, "<???>", "exec")
     26 
     27 test2_oldabs_co = compile(absimp + test_src, "<???>", "exec")
     28 test2_newabs_co = compile(futimp + absimp + test_src, "<???>", "exec")
     29 test2_newrel_co = compile(relimp + test_src, "<???>", "exec")
     30 test2_deeprel_co = compile(deeprelimp + test_src, "<???>", "exec")
     31 test2_futrel_co = compile(futimp + relimp + test_src, "<???>", "exec")
     32 
     33 test_path = "!!!_test_!!!"
     34 
     35 
     36 class TestImporter:
     37 
     38     modules = {
     39         "hooktestmodule": (False, test_co),
     40         "hooktestpackage": (True, test_co),
     41         "hooktestpackage.sub": (True, test_co),
     42         "hooktestpackage.sub.subber": (True, test_co),
     43         "hooktestpackage.oldabs": (False, test2_oldabs_co),
     44         "hooktestpackage.newabs": (False, test2_newabs_co),
     45         "hooktestpackage.newrel": (False, test2_newrel_co),
     46         "hooktestpackage.sub.subber.subest": (True, test2_deeprel_co),
     47         "hooktestpackage.futrel": (False, test2_futrel_co),
     48         "sub": (False, test_co),
     49         "reloadmodule": (False, test_co),
     50     }
     51 
     52     def __init__(self, path=test_path):
     53         if path != test_path:
     54             # if out class is on sys.path_hooks, we must raise
     55             # ImportError for any path item that we can't handle.
     56             raise ImportError
     57         self.path = path
     58 
     59     def _get__path__(self):
     60         raise NotImplementedError
     61 
     62     def find_module(self, fullname, path=None):
     63         if fullname in self.modules:
     64             return self
     65         else:
     66             return None
     67 
     68     def load_module(self, fullname):
     69         ispkg, code = self.modules[fullname]
     70         mod = sys.modules.setdefault(fullname,imp.new_module(fullname))
     71         mod.__file__ = "<%s>" % self.__class__.__name__
     72         mod.__loader__ = self
     73         if ispkg:
     74             mod.__path__ = self._get__path__()
     75         exec code in mod.__dict__
     76         return mod
     77 
     78 
     79 class MetaImporter(TestImporter):
     80     def _get__path__(self):
     81         return []
     82 
     83 class PathImporter(TestImporter):
     84     def _get__path__(self):
     85         return [self.path]
     86 
     87 
     88 class ImportBlocker:
     89     """Place an ImportBlocker instance on sys.meta_path and you
     90     can be sure the modules you specified can't be imported, even
     91     if it's a builtin."""
     92     def __init__(self, *namestoblock):
     93         self.namestoblock = dict.fromkeys(namestoblock)
     94     def find_module(self, fullname, path=None):
     95         if fullname in self.namestoblock:
     96             return self
     97         return None
     98     def load_module(self, fullname):
     99         raise ImportError, "I dare you"
    100 
    101 
    102 class ImpWrapper:
    103 
    104     def __init__(self, path=None):
    105         if path is not None and not os.path.isdir(path):
    106             raise ImportError
    107         self.path = path
    108 
    109     def find_module(self, fullname, path=None):
    110         subname = fullname.split(".")[-1]
    111         if subname != fullname and self.path is None:
    112             return None
    113         if self.path is None:
    114             path = None
    115         else:
    116             path = [self.path]
    117         try:
    118             file, filename, stuff = imp.find_module(subname, path)
    119         except ImportError:
    120             return None
    121         return ImpLoader(file, filename, stuff)
    122 
    123 
    124 class ImpLoader:
    125 
    126     def __init__(self, file, filename, stuff):
    127         self.file = file
    128         self.filename = filename
    129         self.stuff = stuff
    130 
    131     def load_module(self, fullname):
    132         mod = imp.load_module(fullname, self.file, self.filename, self.stuff)
    133         if self.file:
    134             self.file.close()
    135         mod.__loader__ = self  # for introspection
    136         return mod
    137 
    138 
    139 class ImportHooksBaseTestCase(unittest.TestCase):
    140 
    141     def setUp(self):
    142         self.path = sys.path[:]
    143         self.meta_path = sys.meta_path[:]
    144         self.path_hooks = sys.path_hooks[:]
    145         sys.path_importer_cache.clear()
    146         self.modules_before = sys.modules.copy()
    147 
    148     def tearDown(self):
    149         sys.path[:] = self.path
    150         sys.meta_path[:] = self.meta_path
    151         sys.path_hooks[:] = self.path_hooks
    152         sys.path_importer_cache.clear()
    153         sys.modules.clear()
    154         sys.modules.update(self.modules_before)
    155 
    156 
    157 class ImportHooksTestCase(ImportHooksBaseTestCase):
    158 
    159     def doTestImports(self, importer=None):
    160         import hooktestmodule
    161         import hooktestpackage
    162         import hooktestpackage.sub
    163         import hooktestpackage.sub.subber
    164         self.assertEqual(hooktestmodule.get_name(),
    165                          "hooktestmodule")
    166         self.assertEqual(hooktestpackage.get_name(),
    167                          "hooktestpackage")
    168         self.assertEqual(hooktestpackage.sub.get_name(),
    169                          "hooktestpackage.sub")
    170         self.assertEqual(hooktestpackage.sub.subber.get_name(),
    171                          "hooktestpackage.sub.subber")
    172         if importer:
    173             self.assertEqual(hooktestmodule.__loader__, importer)
    174             self.assertEqual(hooktestpackage.__loader__, importer)
    175             self.assertEqual(hooktestpackage.sub.__loader__, importer)
    176             self.assertEqual(hooktestpackage.sub.subber.__loader__, importer)
    177 
    178         TestImporter.modules['reloadmodule'] = (False, test_co)
    179         import reloadmodule
    180         self.assertFalse(hasattr(reloadmodule,'reloaded'))
    181 
    182         TestImporter.modules['reloadmodule'] = (False, reload_co)
    183         imp.reload(reloadmodule)
    184         self.assertTrue(hasattr(reloadmodule,'reloaded'))
    185 
    186         import hooktestpackage.oldabs
    187         self.assertEqual(hooktestpackage.oldabs.get_name(),
    188                          "hooktestpackage.oldabs")
    189         self.assertEqual(hooktestpackage.oldabs.sub,
    190                          hooktestpackage.sub)
    191 
    192         import hooktestpackage.newrel
    193         self.assertEqual(hooktestpackage.newrel.get_name(),
    194                          "hooktestpackage.newrel")
    195         self.assertEqual(hooktestpackage.newrel.sub,
    196                          hooktestpackage.sub)
    197 
    198         import hooktestpackage.sub.subber.subest as subest
    199         self.assertEqual(subest.get_name(),
    200                          "hooktestpackage.sub.subber.subest")
    201         self.assertEqual(subest.sub,
    202                          hooktestpackage.sub)
    203 
    204         import hooktestpackage.futrel
    205         self.assertEqual(hooktestpackage.futrel.get_name(),
    206                          "hooktestpackage.futrel")
    207         self.assertEqual(hooktestpackage.futrel.sub,
    208                          hooktestpackage.sub)
    209 
    210         import sub
    211         self.assertEqual(sub.get_name(), "sub")
    212 
    213         import hooktestpackage.newabs
    214         self.assertEqual(hooktestpackage.newabs.get_name(),
    215                          "hooktestpackage.newabs")
    216         self.assertEqual(hooktestpackage.newabs.sub, sub)
    217 
    218     def testMetaPath(self):
    219         i = MetaImporter()
    220         sys.meta_path.append(i)
    221         self.doTestImports(i)
    222 
    223     def testPathHook(self):
    224         sys.path_hooks.append(PathImporter)
    225         sys.path.append(test_path)
    226         self.doTestImports()
    227 
    228     def testBlocker(self):
    229         mname = "exceptions"  # an arbitrary harmless builtin module
    230         test_support.unload(mname)
    231         sys.meta_path.append(ImportBlocker(mname))
    232         self.assertRaises(ImportError, __import__, mname)
    233 
    234     def testImpWrapper(self):
    235         i = ImpWrapper()
    236         sys.meta_path.append(i)
    237         sys.path_hooks.append(ImpWrapper)
    238         mnames = ("colorsys", "urlparse", "distutils.core", "compiler.misc")
    239         for mname in mnames:
    240             parent = mname.split(".")[0]
    241             for n in sys.modules.keys():
    242                 if n.startswith(parent):
    243                     del sys.modules[n]
    244         with test_support.check_warnings(("The compiler package is deprecated "
    245                                           "and removed", DeprecationWarning)):
    246             for mname in mnames:
    247                 m = __import__(mname, globals(), locals(), ["__dummy__"])
    248                 m.__loader__  # to make sure we actually handled the import
    249 
    250 
    251 def test_main():
    252     test_support.run_unittest(ImportHooksTestCase)
    253 
    254 if __name__ == "__main__":
    255     test_main()
    256