Home | History | Annotate | Download | only in test
      1 import os
      2 import sys
      3 import shutil
      4 import string
      5 import random
      6 import tempfile
      7 import unittest
      8 
      9 from importlib.util import cache_from_source
     10 from test.support import create_empty_file
     11 
     12 class TestImport(unittest.TestCase):
     13 
     14     def __init__(self, *args, **kw):
     15         self.package_name = 'PACKAGE_'
     16         while self.package_name in sys.modules:
     17             self.package_name += random.choose(string.ascii_letters)
     18         self.module_name = self.package_name + '.foo'
     19         unittest.TestCase.__init__(self, *args, **kw)
     20 
     21     def remove_modules(self):
     22         for module_name in (self.package_name, self.module_name):
     23             if module_name in sys.modules:
     24                 del sys.modules[module_name]
     25 
     26     def setUp(self):
     27         self.test_dir = tempfile.mkdtemp()
     28         sys.path.append(self.test_dir)
     29         self.package_dir = os.path.join(self.test_dir,
     30                                         self.package_name)
     31         os.mkdir(self.package_dir)
     32         create_empty_file(os.path.join(self.package_dir, '__init__.py'))
     33         self.module_path = os.path.join(self.package_dir, 'foo.py')
     34 
     35     def tearDown(self):
     36         shutil.rmtree(self.test_dir)
     37         self.assertNotEqual(sys.path.count(self.test_dir), 0)
     38         sys.path.remove(self.test_dir)
     39         self.remove_modules()
     40 
     41     def rewrite_file(self, contents):
     42         compiled_path = cache_from_source(self.module_path)
     43         if os.path.exists(compiled_path):
     44             os.remove(compiled_path)
     45         with open(self.module_path, 'w') as f:
     46             f.write(contents)
     47 
     48     def test_package_import__semantics(self):
     49 
     50         # Generate a couple of broken modules to try importing.
     51 
     52         # ...try loading the module when there's a SyntaxError
     53         self.rewrite_file('for')
     54         try: __import__(self.module_name)
     55         except SyntaxError: pass
     56         else: raise RuntimeError('Failed to induce SyntaxError') # self.fail()?
     57         self.assertNotIn(self.module_name, sys.modules)
     58         self.assertFalse(hasattr(sys.modules[self.package_name], 'foo'))
     59 
     60         # ...make up a variable name that isn't bound in __builtins__
     61         var = 'a'
     62         while var in dir(__builtins__):
     63             var += random.choose(string.ascii_letters)
     64 
     65         # ...make a module that just contains that
     66         self.rewrite_file(var)
     67 
     68         try: __import__(self.module_name)
     69         except NameError: pass
     70         else: raise RuntimeError('Failed to induce NameError.')
     71 
     72         # ...now  change  the module  so  that  the NameError  doesn't
     73         # happen
     74         self.rewrite_file('%s = 1' % var)
     75         module = __import__(self.module_name).foo
     76         self.assertEqual(getattr(module, var), 1)
     77 
     78 
     79 if __name__ == "__main__":
     80     unittest.main()
     81