Home | History | Annotate | Download | only in test
      1 """Unit tests for contextlib.py, and other context managers."""
      2 
      3 import sys
      4 import tempfile
      5 import unittest
      6 from contextlib import *  # Tests __all__
      7 from test import test_support
      8 try:
      9     import threading
     10 except ImportError:
     11     threading = None
     12 
     13 
     14 class ContextManagerTestCase(unittest.TestCase):
     15 
     16     def test_contextmanager_plain(self):
     17         state = []
     18         @contextmanager
     19         def woohoo():
     20             state.append(1)
     21             yield 42
     22             state.append(999)
     23         with woohoo() as x:
     24             self.assertEqual(state, [1])
     25             self.assertEqual(x, 42)
     26             state.append(x)
     27         self.assertEqual(state, [1, 42, 999])
     28 
     29     def test_contextmanager_finally(self):
     30         state = []
     31         @contextmanager
     32         def woohoo():
     33             state.append(1)
     34             try:
     35                 yield 42
     36             finally:
     37                 state.append(999)
     38         with self.assertRaises(ZeroDivisionError):
     39             with woohoo() as x:
     40                 self.assertEqual(state, [1])
     41                 self.assertEqual(x, 42)
     42                 state.append(x)
     43                 raise ZeroDivisionError()
     44         self.assertEqual(state, [1, 42, 999])
     45 
     46     def test_contextmanager_no_reraise(self):
     47         @contextmanager
     48         def whee():
     49             yield
     50         ctx = whee()
     51         ctx.__enter__()
     52         # Calling __exit__ should not result in an exception
     53         self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
     54 
     55     def test_contextmanager_trap_yield_after_throw(self):
     56         @contextmanager
     57         def whoo():
     58             try:
     59                 yield
     60             except:
     61                 yield
     62         ctx = whoo()
     63         ctx.__enter__()
     64         self.assertRaises(
     65             RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
     66         )
     67 
     68     def test_contextmanager_except(self):
     69         state = []
     70         @contextmanager
     71         def woohoo():
     72             state.append(1)
     73             try:
     74                 yield 42
     75             except ZeroDivisionError, e:
     76                 state.append(e.args[0])
     77                 self.assertEqual(state, [1, 42, 999])
     78         with woohoo() as x:
     79             self.assertEqual(state, [1])
     80             self.assertEqual(x, 42)
     81             state.append(x)
     82             raise ZeroDivisionError(999)
     83         self.assertEqual(state, [1, 42, 999])
     84 
     85     def _create_contextmanager_attribs(self):
     86         def attribs(**kw):
     87             def decorate(func):
     88                 for k,v in kw.items():
     89                     setattr(func,k,v)
     90                 return func
     91             return decorate
     92         @contextmanager
     93         @attribs(foo='bar')
     94         def baz(spam):
     95             """Whee!"""
     96         return baz
     97 
     98     def test_contextmanager_attribs(self):
     99         baz = self._create_contextmanager_attribs()
    100         self.assertEqual(baz.__name__,'baz')
    101         self.assertEqual(baz.foo, 'bar')
    102 
    103     @unittest.skipIf(sys.flags.optimize >= 2,
    104                      "Docstrings are omitted with -O2 and above")
    105     def test_contextmanager_doc_attrib(self):
    106         baz = self._create_contextmanager_attribs()
    107         self.assertEqual(baz.__doc__, "Whee!")
    108 
    109     def test_keywords(self):
    110         # Ensure no keyword arguments are inhibited
    111         @contextmanager
    112         def woohoo(self, func, args, kwds):
    113             yield (self, func, args, kwds)
    114         with woohoo(self=11, func=22, args=33, kwds=44) as target:
    115             self.assertEqual(target, (11, 22, 33, 44))
    116 
    117 class NestedTestCase(unittest.TestCase):
    118 
    119     # XXX This needs more work
    120 
    121     def test_nested(self):
    122         @contextmanager
    123         def a():
    124             yield 1
    125         @contextmanager
    126         def b():
    127             yield 2
    128         @contextmanager
    129         def c():
    130             yield 3
    131         with nested(a(), b(), c()) as (x, y, z):
    132             self.assertEqual(x, 1)
    133             self.assertEqual(y, 2)
    134             self.assertEqual(z, 3)
    135 
    136     def test_nested_cleanup(self):
    137         state = []
    138         @contextmanager
    139         def a():
    140             state.append(1)
    141             try:
    142                 yield 2
    143             finally:
    144                 state.append(3)
    145         @contextmanager
    146         def b():
    147             state.append(4)
    148             try:
    149                 yield 5
    150             finally:
    151                 state.append(6)
    152         with self.assertRaises(ZeroDivisionError):
    153             with nested(a(), b()) as (x, y):
    154                 state.append(x)
    155                 state.append(y)
    156                 1 // 0
    157         self.assertEqual(state, [1, 4, 2, 5, 6, 3])
    158 
    159     def test_nested_right_exception(self):
    160         @contextmanager
    161         def a():
    162             yield 1
    163         class b(object):
    164             def __enter__(self):
    165                 return 2
    166             def __exit__(self, *exc_info):
    167                 try:
    168                     raise Exception()
    169                 except:
    170                     pass
    171         with self.assertRaises(ZeroDivisionError):
    172             with nested(a(), b()) as (x, y):
    173                 1 // 0
    174         self.assertEqual((x, y), (1, 2))
    175 
    176     def test_nested_b_swallows(self):
    177         @contextmanager
    178         def a():
    179             yield
    180         @contextmanager
    181         def b():
    182             try:
    183                 yield
    184             except:
    185                 # Swallow the exception
    186                 pass
    187         try:
    188             with nested(a(), b()):
    189                 1 // 0
    190         except ZeroDivisionError:
    191             self.fail("Didn't swallow ZeroDivisionError")
    192 
    193     def test_nested_break(self):
    194         @contextmanager
    195         def a():
    196             yield
    197         state = 0
    198         while True:
    199             state += 1
    200             with nested(a(), a()):
    201                 break
    202             state += 10
    203         self.assertEqual(state, 1)
    204 
    205     def test_nested_continue(self):
    206         @contextmanager
    207         def a():
    208             yield
    209         state = 0
    210         while state < 3:
    211             state += 1
    212             with nested(a(), a()):
    213                 continue
    214             state += 10
    215         self.assertEqual(state, 3)
    216 
    217     def test_nested_return(self):
    218         @contextmanager
    219         def a():
    220             try:
    221                 yield
    222             except:
    223                 pass
    224         def foo():
    225             with nested(a(), a()):
    226                 return 1
    227             return 10
    228         self.assertEqual(foo(), 1)
    229 
    230 class ClosingTestCase(unittest.TestCase):
    231 
    232     # XXX This needs more work
    233 
    234     def test_closing(self):
    235         state = []
    236         class C:
    237             def close(self):
    238                 state.append(1)
    239         x = C()
    240         self.assertEqual(state, [])
    241         with closing(x) as y:
    242             self.assertEqual(x, y)
    243         self.assertEqual(state, [1])
    244 
    245     def test_closing_error(self):
    246         state = []
    247         class C:
    248             def close(self):
    249                 state.append(1)
    250         x = C()
    251         self.assertEqual(state, [])
    252         with self.assertRaises(ZeroDivisionError):
    253             with closing(x) as y:
    254                 self.assertEqual(x, y)
    255                 1 // 0
    256         self.assertEqual(state, [1])
    257 
    258 class FileContextTestCase(unittest.TestCase):
    259 
    260     def testWithOpen(self):
    261         tfn = tempfile.mktemp()
    262         try:
    263             f = None
    264             with open(tfn, "w") as f:
    265                 self.assertFalse(f.closed)
    266                 f.write("Booh\n")
    267             self.assertTrue(f.closed)
    268             f = None
    269             with self.assertRaises(ZeroDivisionError):
    270                 with open(tfn, "r") as f:
    271                     self.assertFalse(f.closed)
    272                     self.assertEqual(f.read(), "Booh\n")
    273                     1 // 0
    274             self.assertTrue(f.closed)
    275         finally:
    276             test_support.unlink(tfn)
    277 
    278 @unittest.skipUnless(threading, 'Threading required for this test.')
    279 class LockContextTestCase(unittest.TestCase):
    280 
    281     def boilerPlate(self, lock, locked):
    282         self.assertFalse(locked())
    283         with lock:
    284             self.assertTrue(locked())
    285         self.assertFalse(locked())
    286         with self.assertRaises(ZeroDivisionError):
    287             with lock:
    288                 self.assertTrue(locked())
    289                 1 // 0
    290         self.assertFalse(locked())
    291 
    292     def testWithLock(self):
    293         lock = threading.Lock()
    294         self.boilerPlate(lock, lock.locked)
    295 
    296     def testWithRLock(self):
    297         lock = threading.RLock()
    298         self.boilerPlate(lock, lock._is_owned)
    299 
    300     def testWithCondition(self):
    301         lock = threading.Condition()
    302         def locked():
    303             return lock._is_owned()
    304         self.boilerPlate(lock, locked)
    305 
    306     def testWithSemaphore(self):
    307         lock = threading.Semaphore()
    308         def locked():
    309             if lock.acquire(False):
    310                 lock.release()
    311                 return False
    312             else:
    313                 return True
    314         self.boilerPlate(lock, locked)
    315 
    316     def testWithBoundedSemaphore(self):
    317         lock = threading.BoundedSemaphore()
    318         def locked():
    319             if lock.acquire(False):
    320                 lock.release()
    321                 return False
    322             else:
    323                 return True
    324         self.boilerPlate(lock, locked)
    325 
    326 # This is needed to make the test actually run under regrtest.py!
    327 def test_main():
    328     with test_support.check_warnings(("With-statements now directly support "
    329                                       "multiple context managers",
    330                                       DeprecationWarning)):
    331         test_support.run_unittest(__name__)
    332 
    333 if __name__ == "__main__":
    334     test_main()
    335