Home | History | Annotate | Download | only in test
      1 """Unit tests for the with statement specified in PEP 343."""
      2 
      3 
      4 __author__ = "Mike Bland"
      5 __email__ = "mbland at acm dot org"
      6 
      7 import sys
      8 import unittest
      9 from collections import deque
     10 from contextlib import GeneratorContextManager, contextmanager
     11 from test.test_support import run_unittest
     12 
     13 
     14 class MockContextManager(GeneratorContextManager):
     15     def __init__(self, gen):
     16         GeneratorContextManager.__init__(self, gen)
     17         self.enter_called = False
     18         self.exit_called = False
     19         self.exit_args = None
     20 
     21     def __enter__(self):
     22         self.enter_called = True
     23         return GeneratorContextManager.__enter__(self)
     24 
     25     def __exit__(self, type, value, traceback):
     26         self.exit_called = True
     27         self.exit_args = (type, value, traceback)
     28         return GeneratorContextManager.__exit__(self, type,
     29                                                 value, traceback)
     30 
     31 
     32 def mock_contextmanager(func):
     33     def helper(*args, **kwds):
     34         return MockContextManager(func(*args, **kwds))
     35     return helper
     36 
     37 
     38 class MockResource(object):
     39     def __init__(self):
     40         self.yielded = False
     41         self.stopped = False
     42 
     43 
     44 @mock_contextmanager
     45 def mock_contextmanager_generator():
     46     mock = MockResource()
     47     try:
     48         mock.yielded = True
     49         yield mock
     50     finally:
     51         mock.stopped = True
     52 
     53 
     54 class Nested(object):
     55 
     56     def __init__(self, *managers):
     57         self.managers = managers
     58         self.entered = None
     59 
     60     def __enter__(self):
     61         if self.entered is not None:
     62             raise RuntimeError("Context is not reentrant")
     63         self.entered = deque()
     64         vars = []
     65         try:
     66             for mgr in self.managers:
     67                 vars.append(mgr.__enter__())
     68                 self.entered.appendleft(mgr)
     69         except:
     70             if not self.__exit__(*sys.exc_info()):
     71                 raise
     72         return vars
     73 
     74     def __exit__(self, *exc_info):
     75         # Behave like nested with statements
     76         # first in, last out
     77         # New exceptions override old ones
     78         ex = exc_info
     79         for mgr in self.entered:
     80             try:
     81                 if mgr.__exit__(*ex):
     82                     ex = (None, None, None)
     83             except:
     84                 ex = sys.exc_info()
     85         self.entered = None
     86         if ex is not exc_info:
     87             raise ex[0], ex[1], ex[2]
     88 
     89 
     90 class MockNested(Nested):
     91     def __init__(self, *managers):
     92         Nested.__init__(self, *managers)
     93         self.enter_called = False
     94         self.exit_called = False
     95         self.exit_args = None
     96 
     97     def __enter__(self):
     98         self.enter_called = True
     99         return Nested.__enter__(self)
    100 
    101     def __exit__(self, *exc_info):
    102         self.exit_called = True
    103         self.exit_args = exc_info
    104         return Nested.__exit__(self, *exc_info)
    105 
    106 
    107 class FailureTestCase(unittest.TestCase):
    108     def testNameError(self):
    109         def fooNotDeclared():
    110             with foo: pass
    111         self.assertRaises(NameError, fooNotDeclared)
    112 
    113     def testEnterAttributeError(self):
    114         class LacksEnter(object):
    115             def __exit__(self, type, value, traceback):
    116                 pass
    117 
    118         def fooLacksEnter():
    119             foo = LacksEnter()
    120             with foo: pass
    121         self.assertRaises(AttributeError, fooLacksEnter)
    122 
    123     def testExitAttributeError(self):
    124         class LacksExit(object):
    125             def __enter__(self):
    126                 pass
    127 
    128         def fooLacksExit():
    129             foo = LacksExit()
    130             with foo: pass
    131         self.assertRaises(AttributeError, fooLacksExit)
    132 
    133     def assertRaisesSyntaxError(self, codestr):
    134         def shouldRaiseSyntaxError(s):
    135             compile(s, '', 'single')
    136         self.assertRaises(SyntaxError, shouldRaiseSyntaxError, codestr)
    137 
    138     def testAssignmentToNoneError(self):
    139         self.assertRaisesSyntaxError('with mock as None:\n  pass')
    140         self.assertRaisesSyntaxError(
    141             'with mock as (None):\n'
    142             '  pass')
    143 
    144     def testAssignmentToEmptyTupleError(self):
    145         self.assertRaisesSyntaxError(
    146             'with mock as ():\n'
    147             '  pass')
    148 
    149     def testAssignmentToTupleOnlyContainingNoneError(self):
    150         self.assertRaisesSyntaxError('with mock as None,:\n  pass')
    151         self.assertRaisesSyntaxError(
    152             'with mock as (None,):\n'
    153             '  pass')
    154 
    155     def testAssignmentToTupleContainingNoneError(self):
    156         self.assertRaisesSyntaxError(
    157             'with mock as (foo, None, bar):\n'
    158             '  pass')
    159 
    160     def testEnterThrows(self):
    161         class EnterThrows(object):
    162             def __enter__(self):
    163                 raise RuntimeError("Enter threw")
    164             def __exit__(self, *args):
    165                 pass
    166 
    167         def shouldThrow():
    168             ct = EnterThrows()
    169             self.foo = None
    170             with ct as self.foo:
    171                 pass
    172         self.assertRaises(RuntimeError, shouldThrow)
    173         self.assertEqual(self.foo, None)
    174 
    175     def testExitThrows(self):
    176         class ExitThrows(object):
    177             def __enter__(self):
    178                 return
    179             def __exit__(self, *args):
    180                 raise RuntimeError(42)
    181         def shouldThrow():
    182             with ExitThrows():
    183                 pass
    184         self.assertRaises(RuntimeError, shouldThrow)
    185 
    186 class ContextmanagerAssertionMixin(object):
    187     TEST_EXCEPTION = RuntimeError("test exception")
    188 
    189     def assertInWithManagerInvariants(self, mock_manager):
    190         self.assertTrue(mock_manager.enter_called)
    191         self.assertFalse(mock_manager.exit_called)
    192         self.assertEqual(mock_manager.exit_args, None)
    193 
    194     def assertAfterWithManagerInvariants(self, mock_manager, exit_args):
    195         self.assertTrue(mock_manager.enter_called)
    196         self.assertTrue(mock_manager.exit_called)
    197         self.assertEqual(mock_manager.exit_args, exit_args)
    198 
    199     def assertAfterWithManagerInvariantsNoError(self, mock_manager):
    200         self.assertAfterWithManagerInvariants(mock_manager,
    201             (None, None, None))
    202 
    203     def assertInWithGeneratorInvariants(self, mock_generator):
    204         self.assertTrue(mock_generator.yielded)
    205         self.assertFalse(mock_generator.stopped)
    206 
    207     def assertAfterWithGeneratorInvariantsNoError(self, mock_generator):
    208         self.assertTrue(mock_generator.yielded)
    209         self.assertTrue(mock_generator.stopped)
    210 
    211     def raiseTestException(self):
    212         raise self.TEST_EXCEPTION
    213 
    214     def assertAfterWithManagerInvariantsWithError(self, mock_manager,
    215                                                   exc_type=None):
    216         self.assertTrue(mock_manager.enter_called)
    217         self.assertTrue(mock_manager.exit_called)
    218         if exc_type is None:
    219             self.assertEqual(mock_manager.exit_args[1], self.TEST_EXCEPTION)
    220             exc_type = type(self.TEST_EXCEPTION)
    221         self.assertEqual(mock_manager.exit_args[0], exc_type)
    222         # Test the __exit__ arguments. Issue #7853
    223         self.assertIsInstance(mock_manager.exit_args[1], exc_type)
    224         self.assertIsNot(mock_manager.exit_args[2], None)
    225 
    226     def assertAfterWithGeneratorInvariantsWithError(self, mock_generator):
    227         self.assertTrue(mock_generator.yielded)
    228         self.assertTrue(mock_generator.stopped)
    229 
    230 
    231 class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
    232     def testInlineGeneratorSyntax(self):
    233         with mock_contextmanager_generator():
    234             pass
    235 
    236     def testUnboundGenerator(self):
    237         mock = mock_contextmanager_generator()
    238         with mock:
    239             pass
    240         self.assertAfterWithManagerInvariantsNoError(mock)
    241 
    242     def testInlineGeneratorBoundSyntax(self):
    243         with mock_contextmanager_generator() as foo:
    244             self.assertInWithGeneratorInvariants(foo)
    245         # FIXME: In the future, we'll try to keep the bound names from leaking
    246         self.assertAfterWithGeneratorInvariantsNoError(foo)
    247 
    248     def testInlineGeneratorBoundToExistingVariable(self):
    249         foo = None
    250         with mock_contextmanager_generator() as foo:
    251             self.assertInWithGeneratorInvariants(foo)
    252         self.assertAfterWithGeneratorInvariantsNoError(foo)
    253 
    254     def testInlineGeneratorBoundToDottedVariable(self):
    255         with mock_contextmanager_generator() as self.foo:
    256             self.assertInWithGeneratorInvariants(self.foo)
    257         self.assertAfterWithGeneratorInvariantsNoError(self.foo)
    258 
    259     def testBoundGenerator(self):
    260         mock = mock_contextmanager_generator()
    261         with mock as foo:
    262             self.assertInWithGeneratorInvariants(foo)
    263             self.assertInWithManagerInvariants(mock)
    264         self.assertAfterWithGeneratorInvariantsNoError(foo)
    265         self.assertAfterWithManagerInvariantsNoError(mock)
    266 
    267     def testNestedSingleStatements(self):
    268         mock_a = mock_contextmanager_generator()
    269         with mock_a as foo:
    270             mock_b = mock_contextmanager_generator()
    271             with mock_b as bar:
    272                 self.assertInWithManagerInvariants(mock_a)
    273                 self.assertInWithManagerInvariants(mock_b)
    274                 self.assertInWithGeneratorInvariants(foo)
    275                 self.assertInWithGeneratorInvariants(bar)
    276             self.assertAfterWithManagerInvariantsNoError(mock_b)
    277             self.assertAfterWithGeneratorInvariantsNoError(bar)
    278             self.assertInWithManagerInvariants(mock_a)
    279             self.assertInWithGeneratorInvariants(foo)
    280         self.assertAfterWithManagerInvariantsNoError(mock_a)
    281         self.assertAfterWithGeneratorInvariantsNoError(foo)
    282 
    283 
    284 class NestedNonexceptionalTestCase(unittest.TestCase,
    285     ContextmanagerAssertionMixin):
    286     def testSingleArgInlineGeneratorSyntax(self):
    287         with Nested(mock_contextmanager_generator()):
    288             pass
    289 
    290     def testSingleArgBoundToNonTuple(self):
    291         m = mock_contextmanager_generator()
    292         # This will bind all the arguments to nested() into a single list
    293         # assigned to foo.
    294         with Nested(m) as foo:
    295             self.assertInWithManagerInvariants(m)
    296         self.assertAfterWithManagerInvariantsNoError(m)
    297 
    298     def testSingleArgBoundToSingleElementParenthesizedList(self):
    299         m = mock_contextmanager_generator()
    300         # This will bind all the arguments to nested() into a single list
    301         # assigned to foo.
    302         with Nested(m) as (foo):
    303             self.assertInWithManagerInvariants(m)
    304         self.assertAfterWithManagerInvariantsNoError(m)
    305 
    306     def testSingleArgBoundToMultipleElementTupleError(self):
    307         def shouldThrowValueError():
    308             with Nested(mock_contextmanager_generator()) as (foo, bar):
    309                 pass
    310         self.assertRaises(ValueError, shouldThrowValueError)
    311 
    312     def testSingleArgUnbound(self):
    313         mock_contextmanager = mock_contextmanager_generator()
    314         mock_nested = MockNested(mock_contextmanager)
    315         with mock_nested:
    316             self.assertInWithManagerInvariants(mock_contextmanager)
    317             self.assertInWithManagerInvariants(mock_nested)
    318         self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
    319         self.assertAfterWithManagerInvariantsNoError(mock_nested)
    320 
    321     def testMultipleArgUnbound(self):
    322         m = mock_contextmanager_generator()
    323         n = mock_contextmanager_generator()
    324         o = mock_contextmanager_generator()
    325         mock_nested = MockNested(m, n, o)
    326         with mock_nested:
    327             self.assertInWithManagerInvariants(m)
    328             self.assertInWithManagerInvariants(n)
    329             self.assertInWithManagerInvariants(o)
    330             self.assertInWithManagerInvariants(mock_nested)
    331         self.assertAfterWithManagerInvariantsNoError(m)
    332         self.assertAfterWithManagerInvariantsNoError(n)
    333         self.assertAfterWithManagerInvariantsNoError(o)
    334         self.assertAfterWithManagerInvariantsNoError(mock_nested)
    335 
    336     def testMultipleArgBound(self):
    337         mock_nested = MockNested(mock_contextmanager_generator(),
    338             mock_contextmanager_generator(), mock_contextmanager_generator())
    339         with mock_nested as (m, n, o):
    340             self.assertInWithGeneratorInvariants(m)
    341             self.assertInWithGeneratorInvariants(n)
    342             self.assertInWithGeneratorInvariants(o)
    343             self.assertInWithManagerInvariants(mock_nested)
    344         self.assertAfterWithGeneratorInvariantsNoError(m)
    345         self.assertAfterWithGeneratorInvariantsNoError(n)
    346         self.assertAfterWithGeneratorInvariantsNoError(o)
    347         self.assertAfterWithManagerInvariantsNoError(mock_nested)
    348 
    349 
    350 class ExceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
    351     def testSingleResource(self):
    352         cm = mock_contextmanager_generator()
    353         def shouldThrow():
    354             with cm as self.resource:
    355                 self.assertInWithManagerInvariants(cm)
    356                 self.assertInWithGeneratorInvariants(self.resource)
    357                 self.raiseTestException()
    358         self.assertRaises(RuntimeError, shouldThrow)
    359         self.assertAfterWithManagerInvariantsWithError(cm)
    360         self.assertAfterWithGeneratorInvariantsWithError(self.resource)
    361 
    362     def testExceptionNormalized(self):
    363         cm = mock_contextmanager_generator()
    364         def shouldThrow():
    365             with cm as self.resource:
    366                 # Note this relies on the fact that 1 // 0 produces an exception
    367                 # that is not normalized immediately.
    368                 1 // 0
    369         self.assertRaises(ZeroDivisionError, shouldThrow)
    370         self.assertAfterWithManagerInvariantsWithError(cm, ZeroDivisionError)
    371 
    372     def testNestedSingleStatements(self):
    373         mock_a = mock_contextmanager_generator()
    374         mock_b = mock_contextmanager_generator()
    375         def shouldThrow():
    376             with mock_a as self.foo:
    377                 with mock_b as self.bar:
    378                     self.assertInWithManagerInvariants(mock_a)
    379                     self.assertInWithManagerInvariants(mock_b)
    380                     self.assertInWithGeneratorInvariants(self.foo)
    381                     self.assertInWithGeneratorInvariants(self.bar)
    382                     self.raiseTestException()
    383         self.assertRaises(RuntimeError, shouldThrow)
    384         self.assertAfterWithManagerInvariantsWithError(mock_a)
    385         self.assertAfterWithManagerInvariantsWithError(mock_b)
    386         self.assertAfterWithGeneratorInvariantsWithError(self.foo)
    387         self.assertAfterWithGeneratorInvariantsWithError(self.bar)
    388 
    389     def testMultipleResourcesInSingleStatement(self):
    390         cm_a = mock_contextmanager_generator()
    391         cm_b = mock_contextmanager_generator()
    392         mock_nested = MockNested(cm_a, cm_b)
    393         def shouldThrow():
    394             with mock_nested as (self.resource_a, self.resource_b):
    395                 self.assertInWithManagerInvariants(cm_a)
    396                 self.assertInWithManagerInvariants(cm_b)
    397                 self.assertInWithManagerInvariants(mock_nested)
    398                 self.assertInWithGeneratorInvariants(self.resource_a)
    399                 self.assertInWithGeneratorInvariants(self.resource_b)
    400                 self.raiseTestException()
    401         self.assertRaises(RuntimeError, shouldThrow)
    402         self.assertAfterWithManagerInvariantsWithError(cm_a)
    403         self.assertAfterWithManagerInvariantsWithError(cm_b)
    404         self.assertAfterWithManagerInvariantsWithError(mock_nested)
    405         self.assertAfterWithGeneratorInvariantsWithError(self.resource_a)
    406         self.assertAfterWithGeneratorInvariantsWithError(self.resource_b)
    407 
    408     def testNestedExceptionBeforeInnerStatement(self):
    409         mock_a = mock_contextmanager_generator()
    410         mock_b = mock_contextmanager_generator()
    411         self.bar = None
    412         def shouldThrow():
    413             with mock_a as self.foo:
    414                 self.assertInWithManagerInvariants(mock_a)
    415                 self.assertInWithGeneratorInvariants(self.foo)
    416                 self.raiseTestException()
    417                 with mock_b as self.bar:
    418                     pass
    419         self.assertRaises(RuntimeError, shouldThrow)
    420         self.assertAfterWithManagerInvariantsWithError(mock_a)
    421         self.assertAfterWithGeneratorInvariantsWithError(self.foo)
    422 
    423         # The inner statement stuff should never have been touched
    424         self.assertEqual(self.bar, None)
    425         self.assertFalse(mock_b.enter_called)
    426         self.assertFalse(mock_b.exit_called)
    427         self.assertEqual(mock_b.exit_args, None)
    428 
    429     def testNestedExceptionAfterInnerStatement(self):
    430         mock_a = mock_contextmanager_generator()
    431         mock_b = mock_contextmanager_generator()
    432         def shouldThrow():
    433             with mock_a as self.foo:
    434                 with mock_b as self.bar:
    435                     self.assertInWithManagerInvariants(mock_a)
    436                     self.assertInWithManagerInvariants(mock_b)
    437                     self.assertInWithGeneratorInvariants(self.foo)
    438                     self.assertInWithGeneratorInvariants(self.bar)
    439                 self.raiseTestException()
    440         self.assertRaises(RuntimeError, shouldThrow)
    441         self.assertAfterWithManagerInvariantsWithError(mock_a)
    442         self.assertAfterWithManagerInvariantsNoError(mock_b)
    443         self.assertAfterWithGeneratorInvariantsWithError(self.foo)
    444         self.assertAfterWithGeneratorInvariantsNoError(self.bar)
    445 
    446     def testRaisedStopIteration1(self):
    447         # From bug 1462485
    448         @contextmanager
    449         def cm():
    450             yield
    451 
    452         def shouldThrow():
    453             with cm():
    454                 raise StopIteration("from with")
    455 
    456         self.assertRaises(StopIteration, shouldThrow)
    457 
    458     def testRaisedStopIteration2(self):
    459         # From bug 1462485
    460         class cm(object):
    461             def __enter__(self):
    462                 pass
    463             def __exit__(self, type, value, traceback):
    464                 pass
    465 
    466         def shouldThrow():
    467             with cm():
    468                 raise StopIteration("from with")
    469 
    470         self.assertRaises(StopIteration, shouldThrow)
    471 
    472     def testRaisedStopIteration3(self):
    473         # Another variant where the exception hasn't been instantiated
    474         # From bug 1705170
    475         @contextmanager
    476         def cm():
    477             yield
    478 
    479         def shouldThrow():
    480             with cm():
    481                 raise iter([]).next()
    482 
    483         self.assertRaises(StopIteration, shouldThrow)
    484 
    485     def testRaisedGeneratorExit1(self):
    486         # From bug 1462485
    487         @contextmanager
    488         def cm():
    489             yield
    490 
    491         def shouldThrow():
    492             with cm():
    493                 raise GeneratorExit("from with")
    494 
    495         self.assertRaises(GeneratorExit, shouldThrow)
    496 
    497     def testRaisedGeneratorExit2(self):
    498         # From bug 1462485
    499         class cm (object):
    500             def __enter__(self):
    501                 pass
    502             def __exit__(self, type, value, traceback):
    503                 pass
    504 
    505         def shouldThrow():
    506             with cm():
    507                 raise GeneratorExit("from with")
    508 
    509         self.assertRaises(GeneratorExit, shouldThrow)
    510 
    511     def testErrorsInBool(self):
    512         # issue4589: __exit__ return code may raise an exception
    513         # when looking at its truth value.
    514 
    515         class cm(object):
    516             def __init__(self, bool_conversion):
    517                 class Bool:
    518                     def __nonzero__(self):
    519                         return bool_conversion()
    520                 self.exit_result = Bool()
    521             def __enter__(self):
    522                 return 3
    523             def __exit__(self, a, b, c):
    524                 return self.exit_result
    525 
    526         def trueAsBool():
    527             with cm(lambda: True):
    528                 self.fail("Should NOT see this")
    529         trueAsBool()
    530 
    531         def falseAsBool():
    532             with cm(lambda: False):
    533                 self.fail("Should raise")
    534         self.assertRaises(AssertionError, falseAsBool)
    535 
    536         def failAsBool():
    537             with cm(lambda: 1 // 0):
    538                 self.fail("Should NOT see this")
    539         self.assertRaises(ZeroDivisionError, failAsBool)
    540 
    541 
    542 class NonLocalFlowControlTestCase(unittest.TestCase):
    543 
    544     def testWithBreak(self):
    545         counter = 0
    546         while True:
    547             counter += 1
    548             with mock_contextmanager_generator():
    549                 counter += 10
    550                 break
    551             counter += 100 # Not reached
    552         self.assertEqual(counter, 11)
    553 
    554     def testWithContinue(self):
    555         counter = 0
    556         while True:
    557             counter += 1
    558             if counter > 2:
    559                 break
    560             with mock_contextmanager_generator():
    561                 counter += 10
    562                 continue
    563             counter += 100 # Not reached
    564         self.assertEqual(counter, 12)
    565 
    566     def testWithReturn(self):
    567         def foo():
    568             counter = 0
    569             while True:
    570                 counter += 1
    571                 with mock_contextmanager_generator():
    572                     counter += 10
    573                     return counter
    574                 counter += 100 # Not reached
    575         self.assertEqual(foo(), 11)
    576 
    577     def testWithYield(self):
    578         def gen():
    579             with mock_contextmanager_generator():
    580                 yield 12
    581                 yield 13
    582         x = list(gen())
    583         self.assertEqual(x, [12, 13])
    584 
    585     def testWithRaise(self):
    586         counter = 0
    587         try:
    588             counter += 1
    589             with mock_contextmanager_generator():
    590                 counter += 10
    591                 raise RuntimeError
    592             counter += 100 # Not reached
    593         except RuntimeError:
    594             self.assertEqual(counter, 11)
    595         else:
    596             self.fail("Didn't raise RuntimeError")
    597 
    598 
    599 class AssignmentTargetTestCase(unittest.TestCase):
    600 
    601     def testSingleComplexTarget(self):
    602         targets = {1: [0, 1, 2]}
    603         with mock_contextmanager_generator() as targets[1][0]:
    604             self.assertEqual(targets.keys(), [1])
    605             self.assertEqual(targets[1][0].__class__, MockResource)
    606         with mock_contextmanager_generator() as targets.values()[0][1]:
    607             self.assertEqual(targets.keys(), [1])
    608             self.assertEqual(targets[1][1].__class__, MockResource)
    609         with mock_contextmanager_generator() as targets[2]:
    610             keys = targets.keys()
    611             keys.sort()
    612             self.assertEqual(keys, [1, 2])
    613         class C: pass
    614         blah = C()
    615         with mock_contextmanager_generator() as blah.foo:
    616             self.assertEqual(hasattr(blah, "foo"), True)
    617 
    618     def testMultipleComplexTargets(self):
    619         class C:
    620             def __enter__(self): return 1, 2, 3
    621             def __exit__(self, t, v, tb): pass
    622         targets = {1: [0, 1, 2]}
    623         with C() as (targets[1][0], targets[1][1], targets[1][2]):
    624             self.assertEqual(targets, {1: [1, 2, 3]})
    625         with C() as (targets.values()[0][2], targets.values()[0][1], targets.values()[0][0]):
    626             self.assertEqual(targets, {1: [3, 2, 1]})
    627         with C() as (targets[1], targets[2], targets[3]):
    628             self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
    629         class B: pass
    630         blah = B()
    631         with C() as (blah.one, blah.two, blah.three):
    632             self.assertEqual(blah.one, 1)
    633             self.assertEqual(blah.two, 2)
    634             self.assertEqual(blah.three, 3)
    635 
    636 
    637 class ExitSwallowsExceptionTestCase(unittest.TestCase):
    638 
    639     def testExitTrueSwallowsException(self):
    640         class AfricanSwallow:
    641             def __enter__(self): pass
    642             def __exit__(self, t, v, tb): return True
    643         try:
    644             with AfricanSwallow():
    645                 1 // 0
    646         except ZeroDivisionError:
    647             self.fail("ZeroDivisionError should have been swallowed")
    648 
    649     def testExitFalseDoesntSwallowException(self):
    650         class EuropeanSwallow:
    651             def __enter__(self): pass
    652             def __exit__(self, t, v, tb): return False
    653         try:
    654             with EuropeanSwallow():
    655                 1 // 0
    656         except ZeroDivisionError:
    657             pass
    658         else:
    659             self.fail("ZeroDivisionError should have been raised")
    660 
    661 
    662 class NestedWith(unittest.TestCase):
    663 
    664     class Dummy(object):
    665         def __init__(self, value=None, gobble=False):
    666             if value is None:
    667                 value = self
    668             self.value = value
    669             self.gobble = gobble
    670             self.enter_called = False
    671             self.exit_called = False
    672 
    673         def __enter__(self):
    674             self.enter_called = True
    675             return self.value
    676 
    677         def __exit__(self, *exc_info):
    678             self.exit_called = True
    679             self.exc_info = exc_info
    680             if self.gobble:
    681                 return True
    682 
    683     class InitRaises(object):
    684         def __init__(self): raise RuntimeError()
    685 
    686     class EnterRaises(object):
    687         def __enter__(self): raise RuntimeError()
    688         def __exit__(self, *exc_info): pass
    689 
    690     class ExitRaises(object):
    691         def __enter__(self): pass
    692         def __exit__(self, *exc_info): raise RuntimeError()
    693 
    694     def testNoExceptions(self):
    695         with self.Dummy() as a, self.Dummy() as b:
    696             self.assertTrue(a.enter_called)
    697             self.assertTrue(b.enter_called)
    698         self.assertTrue(a.exit_called)
    699         self.assertTrue(b.exit_called)
    700 
    701     def testExceptionInExprList(self):
    702         try:
    703             with self.Dummy() as a, self.InitRaises():
    704                 pass
    705         except:
    706             pass
    707         self.assertTrue(a.enter_called)
    708         self.assertTrue(a.exit_called)
    709 
    710     def testExceptionInEnter(self):
    711         try:
    712             with self.Dummy() as a, self.EnterRaises():
    713                 self.fail('body of bad with executed')
    714         except RuntimeError:
    715             pass
    716         else:
    717             self.fail('RuntimeError not reraised')
    718         self.assertTrue(a.enter_called)
    719         self.assertTrue(a.exit_called)
    720 
    721     def testExceptionInExit(self):
    722         body_executed = False
    723         with self.Dummy(gobble=True) as a, self.ExitRaises():
    724             body_executed = True
    725         self.assertTrue(a.enter_called)
    726         self.assertTrue(a.exit_called)
    727         self.assertTrue(body_executed)
    728         self.assertNotEqual(a.exc_info[0], None)
    729 
    730     def testEnterReturnsTuple(self):
    731         with self.Dummy(value=(1,2)) as (a1, a2), \
    732              self.Dummy(value=(10, 20)) as (b1, b2):
    733             self.assertEqual(1, a1)
    734             self.assertEqual(2, a2)
    735             self.assertEqual(10, b1)
    736             self.assertEqual(20, b2)
    737 
    738 def test_main():
    739     run_unittest(FailureTestCase, NonexceptionalTestCase,
    740                  NestedNonexceptionalTestCase, ExceptionalTestCase,
    741                  NonLocalFlowControlTestCase,
    742                  AssignmentTargetTestCase,
    743                  ExitSwallowsExceptionTestCase,
    744                  NestedWith)
    745 
    746 
    747 if __name__ == '__main__':
    748     test_main()
    749