Home | History | Annotate | Download | only in test
      1 """Unit tests for contextlib.py, and other context managers."""
      2 
      3 import sys
      4 import tempfile
      5 import unittest
      6 from contextlib import *  # Tests __all__
      7 from test import test_support
      8 try:
      9     import threading
     10 except ImportError:
     11     threading = None
     12 
     13 
     14 class ContextManagerTestCase(unittest.TestCase):
     15 
     16     def test_contextmanager_plain(self):
     17         state = []
     18         @contextmanager
     19         def woohoo():
     20             state.append(1)
     21             yield 42
     22             state.append(999)
     23         with woohoo() as x:
     24             self.assertEqual(state, [1])
     25             self.assertEqual(x, 42)
     26             state.append(x)
     27         self.assertEqual(state, [1, 42, 999])
     28 
     29     def test_contextmanager_finally(self):
     30         state = []
     31         @contextmanager
     32         def woohoo():
     33             state.append(1)
     34             try:
     35                 yield 42
     36             finally:
     37                 state.append(999)
     38         with self.assertRaises(ZeroDivisionError):
     39             with woohoo() as x:
     40                 self.assertEqual(state, [1])
     41                 self.assertEqual(x, 42)
     42                 state.append(x)
     43                 raise ZeroDivisionError()
     44         self.assertEqual(state, [1, 42, 999])
     45 
     46     def test_contextmanager_no_reraise(self):
     47         @contextmanager
     48         def whee():
     49             yield
     50         ctx = whee()
     51         ctx.__enter__()
     52         # Calling __exit__ should not result in an exception
     53         self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
     54 
     55     def test_contextmanager_trap_yield_after_throw(self):
     56         @contextmanager
     57         def whoo():
     58             try:
     59                 yield
     60             except:
     61                 yield
     62         ctx = whoo()
     63         ctx.__enter__()
     64         self.assertRaises(
     65             RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
     66         )
     67 
     68     def test_contextmanager_except(self):
     69         state = []
     70         @contextmanager
     71         def woohoo():
     72             state.append(1)
     73             try:
     74                 yield 42
     75             except ZeroDivisionError, e:
     76                 state.append(e.args[0])
     77                 self.assertEqual(state, [1, 42, 999])
     78         with woohoo() as x:
     79             self.assertEqual(state, [1])
     80             self.assertEqual(x, 42)
     81             state.append(x)
     82             raise ZeroDivisionError(999)
     83         self.assertEqual(state, [1, 42, 999])
     84 
     85     def _create_contextmanager_attribs(self):
     86         def attribs(**kw):
     87             def decorate(func):
     88                 for k,v in kw.items():
     89                     setattr(func,k,v)
     90                 return func
     91             return decorate
     92         @contextmanager
     93         @attribs(foo='bar')
     94         def baz(spam):
     95             """Whee!"""
     96         return baz
     97 
     98     def test_contextmanager_attribs(self):
     99         baz = self._create_contextmanager_attribs()
    100         self.assertEqual(baz.__name__,'baz')
    101         self.assertEqual(baz.foo, 'bar')
    102 
    103     @unittest.skipIf(sys.flags.optimize >= 2,
    104                      "Docstrings are omitted with -O2 and above")
    105     def test_contextmanager_doc_attrib(self):
    106         baz = self._create_contextmanager_attribs()
    107         self.assertEqual(baz.__doc__, "Whee!")
    108 
    109 class NestedTestCase(unittest.TestCase):
    110 
    111     # XXX This needs more work
    112 
    113     def test_nested(self):
    114         @contextmanager
    115         def a():
    116             yield 1
    117         @contextmanager
    118         def b():
    119             yield 2
    120         @contextmanager
    121         def c():
    122             yield 3
    123         with nested(a(), b(), c()) as (x, y, z):
    124             self.assertEqual(x, 1)
    125             self.assertEqual(y, 2)
    126             self.assertEqual(z, 3)
    127 
    128     def test_nested_cleanup(self):
    129         state = []
    130         @contextmanager
    131         def a():
    132             state.append(1)
    133             try:
    134                 yield 2
    135             finally:
    136                 state.append(3)
    137         @contextmanager
    138         def b():
    139             state.append(4)
    140             try:
    141                 yield 5
    142             finally:
    143                 state.append(6)
    144         with self.assertRaises(ZeroDivisionError):
    145             with nested(a(), b()) as (x, y):
    146                 state.append(x)
    147                 state.append(y)
    148                 1 // 0
    149         self.assertEqual(state, [1, 4, 2, 5, 6, 3])
    150 
    151     def test_nested_right_exception(self):
    152         @contextmanager
    153         def a():
    154             yield 1
    155         class b(object):
    156             def __enter__(self):
    157                 return 2
    158             def __exit__(self, *exc_info):
    159                 try:
    160                     raise Exception()
    161                 except:
    162                     pass
    163         with self.assertRaises(ZeroDivisionError):
    164             with nested(a(), b()) as (x, y):
    165                 1 // 0
    166         self.assertEqual((x, y), (1, 2))
    167 
    168     def test_nested_b_swallows(self):
    169         @contextmanager
    170         def a():
    171             yield
    172         @contextmanager
    173         def b():
    174             try:
    175                 yield
    176             except:
    177                 # Swallow the exception
    178                 pass
    179         try:
    180             with nested(a(), b()):
    181                 1 // 0
    182         except ZeroDivisionError:
    183             self.fail("Didn't swallow ZeroDivisionError")
    184 
    185     def test_nested_break(self):
    186         @contextmanager
    187         def a():
    188             yield
    189         state = 0
    190         while True:
    191             state += 1
    192             with nested(a(), a()):
    193                 break
    194             state += 10
    195         self.assertEqual(state, 1)
    196 
    197     def test_nested_continue(self):
    198         @contextmanager
    199         def a():
    200             yield
    201         state = 0
    202         while state < 3:
    203             state += 1
    204             with nested(a(), a()):
    205                 continue
    206             state += 10
    207         self.assertEqual(state, 3)
    208 
    209     def test_nested_return(self):
    210         @contextmanager
    211         def a():
    212             try:
    213                 yield
    214             except:
    215                 pass
    216         def foo():
    217             with nested(a(), a()):
    218                 return 1
    219             return 10
    220         self.assertEqual(foo(), 1)
    221 
    222 class ClosingTestCase(unittest.TestCase):
    223 
    224     # XXX This needs more work
    225 
    226     def test_closing(self):
    227         state = []
    228         class C:
    229             def close(self):
    230                 state.append(1)
    231         x = C()
    232         self.assertEqual(state, [])
    233         with closing(x) as y:
    234             self.assertEqual(x, y)
    235         self.assertEqual(state, [1])
    236 
    237     def test_closing_error(self):
    238         state = []
    239         class C:
    240             def close(self):
    241                 state.append(1)
    242         x = C()
    243         self.assertEqual(state, [])
    244         with self.assertRaises(ZeroDivisionError):
    245             with closing(x) as y:
    246                 self.assertEqual(x, y)
    247                 1 // 0
    248         self.assertEqual(state, [1])
    249 
    250 class FileContextTestCase(unittest.TestCase):
    251 
    252     def testWithOpen(self):
    253         tfn = tempfile.mktemp()
    254         try:
    255             f = None
    256             with open(tfn, "w") as f:
    257                 self.assertFalse(f.closed)
    258                 f.write("Booh\n")
    259             self.assertTrue(f.closed)
    260             f = None
    261             with self.assertRaises(ZeroDivisionError):
    262                 with open(tfn, "r") as f:
    263                     self.assertFalse(f.closed)
    264                     self.assertEqual(f.read(), "Booh\n")
    265                     1 // 0
    266             self.assertTrue(f.closed)
    267         finally:
    268             test_support.unlink(tfn)
    269 
    270 @unittest.skipUnless(threading, 'Threading required for this test.')
    271 class LockContextTestCase(unittest.TestCase):
    272 
    273     def boilerPlate(self, lock, locked):
    274         self.assertFalse(locked())
    275         with lock:
    276             self.assertTrue(locked())
    277         self.assertFalse(locked())
    278         with self.assertRaises(ZeroDivisionError):
    279             with lock:
    280                 self.assertTrue(locked())
    281                 1 // 0
    282         self.assertFalse(locked())
    283 
    284     def testWithLock(self):
    285         lock = threading.Lock()
    286         self.boilerPlate(lock, lock.locked)
    287 
    288     def testWithRLock(self):
    289         lock = threading.RLock()
    290         self.boilerPlate(lock, lock._is_owned)
    291 
    292     def testWithCondition(self):
    293         lock = threading.Condition()
    294         def locked():
    295             return lock._is_owned()
    296         self.boilerPlate(lock, locked)
    297 
    298     def testWithSemaphore(self):
    299         lock = threading.Semaphore()
    300         def locked():
    301             if lock.acquire(False):
    302                 lock.release()
    303                 return False
    304             else:
    305                 return True
    306         self.boilerPlate(lock, locked)
    307 
    308     def testWithBoundedSemaphore(self):
    309         lock = threading.BoundedSemaphore()
    310         def locked():
    311             if lock.acquire(False):
    312                 lock.release()
    313                 return False
    314             else:
    315                 return True
    316         self.boilerPlate(lock, locked)
    317 
    318 # This is needed to make the test actually run under regrtest.py!
    319 def test_main():
    320     with test_support.check_warnings(("With-statements now directly support "
    321                                       "multiple context managers",
    322                                       DeprecationWarning)):
    323         test_support.run_unittest(__name__)
    324 
    325 if __name__ == "__main__":
    326     test_main()
    327