1 """Unit tests for contextlib.py, and other context managers.""" 2 3 import sys 4 import tempfile 5 import unittest 6 from contextlib import * # Tests __all__ 7 from test import test_support 8 try: 9 import threading 10 except ImportError: 11 threading = None 12 13 14 class ContextManagerTestCase(unittest.TestCase): 15 16 def test_contextmanager_plain(self): 17 state = [] 18 @contextmanager 19 def woohoo(): 20 state.append(1) 21 yield 42 22 state.append(999) 23 with woohoo() as x: 24 self.assertEqual(state, [1]) 25 self.assertEqual(x, 42) 26 state.append(x) 27 self.assertEqual(state, [1, 42, 999]) 28 29 def test_contextmanager_finally(self): 30 state = [] 31 @contextmanager 32 def woohoo(): 33 state.append(1) 34 try: 35 yield 42 36 finally: 37 state.append(999) 38 with self.assertRaises(ZeroDivisionError): 39 with woohoo() as x: 40 self.assertEqual(state, [1]) 41 self.assertEqual(x, 42) 42 state.append(x) 43 raise ZeroDivisionError() 44 self.assertEqual(state, [1, 42, 999]) 45 46 def test_contextmanager_no_reraise(self): 47 @contextmanager 48 def whee(): 49 yield 50 ctx = whee() 51 ctx.__enter__() 52 # Calling __exit__ should not result in an exception 53 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None)) 54 55 def test_contextmanager_trap_yield_after_throw(self): 56 @contextmanager 57 def whoo(): 58 try: 59 yield 60 except: 61 yield 62 ctx = whoo() 63 ctx.__enter__() 64 self.assertRaises( 65 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None 66 ) 67 68 def test_contextmanager_except(self): 69 state = [] 70 @contextmanager 71 def woohoo(): 72 state.append(1) 73 try: 74 yield 42 75 except ZeroDivisionError, e: 76 state.append(e.args[0]) 77 self.assertEqual(state, [1, 42, 999]) 78 with woohoo() as x: 79 self.assertEqual(state, [1]) 80 self.assertEqual(x, 42) 81 state.append(x) 82 raise ZeroDivisionError(999) 83 self.assertEqual(state, [1, 42, 999]) 84 85 def _create_contextmanager_attribs(self): 86 def attribs(**kw): 87 def decorate(func): 88 for k,v in kw.items(): 89 setattr(func,k,v) 90 return func 91 return decorate 92 @contextmanager 93 @attribs(foo='bar') 94 def baz(spam): 95 """Whee!""" 96 return baz 97 98 def test_contextmanager_attribs(self): 99 baz = self._create_contextmanager_attribs() 100 self.assertEqual(baz.__name__,'baz') 101 self.assertEqual(baz.foo, 'bar') 102 103 @unittest.skipIf(sys.flags.optimize >= 2, 104 "Docstrings are omitted with -O2 and above") 105 def test_contextmanager_doc_attrib(self): 106 baz = self._create_contextmanager_attribs() 107 self.assertEqual(baz.__doc__, "Whee!") 108 109 class NestedTestCase(unittest.TestCase): 110 111 # XXX This needs more work 112 113 def test_nested(self): 114 @contextmanager 115 def a(): 116 yield 1 117 @contextmanager 118 def b(): 119 yield 2 120 @contextmanager 121 def c(): 122 yield 3 123 with nested(a(), b(), c()) as (x, y, z): 124 self.assertEqual(x, 1) 125 self.assertEqual(y, 2) 126 self.assertEqual(z, 3) 127 128 def test_nested_cleanup(self): 129 state = [] 130 @contextmanager 131 def a(): 132 state.append(1) 133 try: 134 yield 2 135 finally: 136 state.append(3) 137 @contextmanager 138 def b(): 139 state.append(4) 140 try: 141 yield 5 142 finally: 143 state.append(6) 144 with self.assertRaises(ZeroDivisionError): 145 with nested(a(), b()) as (x, y): 146 state.append(x) 147 state.append(y) 148 1 // 0 149 self.assertEqual(state, [1, 4, 2, 5, 6, 3]) 150 151 def test_nested_right_exception(self): 152 @contextmanager 153 def a(): 154 yield 1 155 class b(object): 156 def __enter__(self): 157 return 2 158 def __exit__(self, *exc_info): 159 try: 160 raise Exception() 161 except: 162 pass 163 with self.assertRaises(ZeroDivisionError): 164 with nested(a(), b()) as (x, y): 165 1 // 0 166 self.assertEqual((x, y), (1, 2)) 167 168 def test_nested_b_swallows(self): 169 @contextmanager 170 def a(): 171 yield 172 @contextmanager 173 def b(): 174 try: 175 yield 176 except: 177 # Swallow the exception 178 pass 179 try: 180 with nested(a(), b()): 181 1 // 0 182 except ZeroDivisionError: 183 self.fail("Didn't swallow ZeroDivisionError") 184 185 def test_nested_break(self): 186 @contextmanager 187 def a(): 188 yield 189 state = 0 190 while True: 191 state += 1 192 with nested(a(), a()): 193 break 194 state += 10 195 self.assertEqual(state, 1) 196 197 def test_nested_continue(self): 198 @contextmanager 199 def a(): 200 yield 201 state = 0 202 while state < 3: 203 state += 1 204 with nested(a(), a()): 205 continue 206 state += 10 207 self.assertEqual(state, 3) 208 209 def test_nested_return(self): 210 @contextmanager 211 def a(): 212 try: 213 yield 214 except: 215 pass 216 def foo(): 217 with nested(a(), a()): 218 return 1 219 return 10 220 self.assertEqual(foo(), 1) 221 222 class ClosingTestCase(unittest.TestCase): 223 224 # XXX This needs more work 225 226 def test_closing(self): 227 state = [] 228 class C: 229 def close(self): 230 state.append(1) 231 x = C() 232 self.assertEqual(state, []) 233 with closing(x) as y: 234 self.assertEqual(x, y) 235 self.assertEqual(state, [1]) 236 237 def test_closing_error(self): 238 state = [] 239 class C: 240 def close(self): 241 state.append(1) 242 x = C() 243 self.assertEqual(state, []) 244 with self.assertRaises(ZeroDivisionError): 245 with closing(x) as y: 246 self.assertEqual(x, y) 247 1 // 0 248 self.assertEqual(state, [1]) 249 250 class FileContextTestCase(unittest.TestCase): 251 252 def testWithOpen(self): 253 tfn = tempfile.mktemp() 254 try: 255 f = None 256 with open(tfn, "w") as f: 257 self.assertFalse(f.closed) 258 f.write("Booh\n") 259 self.assertTrue(f.closed) 260 f = None 261 with self.assertRaises(ZeroDivisionError): 262 with open(tfn, "r") as f: 263 self.assertFalse(f.closed) 264 self.assertEqual(f.read(), "Booh\n") 265 1 // 0 266 self.assertTrue(f.closed) 267 finally: 268 test_support.unlink(tfn) 269 270 @unittest.skipUnless(threading, 'Threading required for this test.') 271 class LockContextTestCase(unittest.TestCase): 272 273 def boilerPlate(self, lock, locked): 274 self.assertFalse(locked()) 275 with lock: 276 self.assertTrue(locked()) 277 self.assertFalse(locked()) 278 with self.assertRaises(ZeroDivisionError): 279 with lock: 280 self.assertTrue(locked()) 281 1 // 0 282 self.assertFalse(locked()) 283 284 def testWithLock(self): 285 lock = threading.Lock() 286 self.boilerPlate(lock, lock.locked) 287 288 def testWithRLock(self): 289 lock = threading.RLock() 290 self.boilerPlate(lock, lock._is_owned) 291 292 def testWithCondition(self): 293 lock = threading.Condition() 294 def locked(): 295 return lock._is_owned() 296 self.boilerPlate(lock, locked) 297 298 def testWithSemaphore(self): 299 lock = threading.Semaphore() 300 def locked(): 301 if lock.acquire(False): 302 lock.release() 303 return False 304 else: 305 return True 306 self.boilerPlate(lock, locked) 307 308 def testWithBoundedSemaphore(self): 309 lock = threading.BoundedSemaphore() 310 def locked(): 311 if lock.acquire(False): 312 lock.release() 313 return False 314 else: 315 return True 316 self.boilerPlate(lock, locked) 317 318 # This is needed to make the test actually run under regrtest.py! 319 def test_main(): 320 with test_support.check_warnings(("With-statements now directly support " 321 "multiple context managers", 322 DeprecationWarning)): 323 test_support.run_unittest(__name__) 324 325 if __name__ == "__main__": 326 test_main() 327