Home | History | Annotate | Download | only in test
      1 import os, sys, string, random, tempfile, unittest
      2 
      3 from test.test_support import run_unittest
      4 
      5 class TestImport(unittest.TestCase):
      6 
      7     def __init__(self, *args, **kw):
      8         self.package_name = 'PACKAGE_'
      9         while self.package_name in sys.modules:
     10             self.package_name += random.choose(string.letters)
     11         self.module_name = self.package_name + '.foo'
     12         unittest.TestCase.__init__(self, *args, **kw)
     13 
     14     def remove_modules(self):
     15         for module_name in (self.package_name, self.module_name):
     16             if module_name in sys.modules:
     17                 del sys.modules[module_name]
     18 
     19     def setUp(self):
     20         self.test_dir = tempfile.mkdtemp()
     21         sys.path.append(self.test_dir)
     22         self.package_dir = os.path.join(self.test_dir,
     23                                         self.package_name)
     24         os.mkdir(self.package_dir)
     25         open(os.path.join(
     26                 self.package_dir, '__init__'+os.extsep+'py'), 'w').close()
     27         self.module_path = os.path.join(self.package_dir, 'foo'+os.extsep+'py')
     28 
     29     def tearDown(self):
     30         for file in os.listdir(self.package_dir):
     31             os.remove(os.path.join(self.package_dir, file))
     32         os.rmdir(self.package_dir)
     33         os.rmdir(self.test_dir)
     34         self.assertNotEqual(sys.path.count(self.test_dir), 0)
     35         sys.path.remove(self.test_dir)
     36         self.remove_modules()
     37 
     38     def rewrite_file(self, contents):
     39         for extension in "co":
     40             compiled_path = self.module_path + extension
     41             if os.path.exists(compiled_path):
     42                 os.remove(compiled_path)
     43         f = open(self.module_path, 'w')
     44         f.write(contents)
     45         f.close()
     46 
     47     def test_package_import__semantics(self):
     48 
     49         # Generate a couple of broken modules to try importing.
     50 
     51         # ...try loading the module when there's a SyntaxError
     52         self.rewrite_file('for')
     53         try: __import__(self.module_name)
     54         except SyntaxError: pass
     55         else: raise RuntimeError, 'Failed to induce SyntaxError'
     56         self.assertNotIn(self.module_name, sys.modules)
     57         self.assertFalse(hasattr(sys.modules[self.package_name], 'foo'))
     58 
     59         # ...make up a variable name that isn't bound in __builtins__
     60         var = 'a'
     61         while var in dir(__builtins__):
     62             var += random.choose(string.letters)
     63 
     64         # ...make a module that just contains that
     65         self.rewrite_file(var)
     66 
     67         try: __import__(self.module_name)
     68         except NameError: pass
     69         else: raise RuntimeError, 'Failed to induce NameError.'
     70 
     71         # ...now  change  the module  so  that  the NameError  doesn't
     72         # happen
     73         self.rewrite_file('%s = 1' % var)
     74         module = __import__(self.module_name).foo
     75         self.assertEqual(getattr(module, var), 1)
     76 
     77 
     78 def test_main():
     79     run_unittest(TestImport)
     80 
     81 
     82 if __name__ == "__main__":
     83     test_main()
     84