Home | History | Annotate | Download | only in test
      1 """
      2 Various tests for synchronization primitives.
      3 """
      4 
      5 import sys
      6 import time
      7 from thread import start_new_thread, get_ident
      8 import threading
      9 import unittest
     10 
     11 from test import test_support as support
     12 
     13 
     14 def _wait():
     15     # A crude wait/yield function not relying on synchronization primitives.
     16     time.sleep(0.01)
     17 
     18 class Bunch(object):
     19     """
     20     A bunch of threads.
     21     """
     22     def __init__(self, f, n, wait_before_exit=False):
     23         """
     24         Construct a bunch of `n` threads running the same function `f`.
     25         If `wait_before_exit` is True, the threads won't terminate until
     26         do_finish() is called.
     27         """
     28         self.f = f
     29         self.n = n
     30         self.started = []
     31         self.finished = []
     32         self._can_exit = not wait_before_exit
     33         def task():
     34             tid = get_ident()
     35             self.started.append(tid)
     36             try:
     37                 f()
     38             finally:
     39                 self.finished.append(tid)
     40                 while not self._can_exit:
     41                     _wait()
     42         for i in range(n):
     43             start_new_thread(task, ())
     44 
     45     def wait_for_started(self):
     46         while len(self.started) < self.n:
     47             _wait()
     48 
     49     def wait_for_finished(self):
     50         while len(self.finished) < self.n:
     51             _wait()
     52 
     53     def do_finish(self):
     54         self._can_exit = True
     55 
     56 
     57 class BaseTestCase(unittest.TestCase):
     58     def setUp(self):
     59         self._threads = support.threading_setup()
     60 
     61     def tearDown(self):
     62         support.threading_cleanup(*self._threads)
     63         support.reap_children()
     64 
     65 
     66 class BaseLockTests(BaseTestCase):
     67     """
     68     Tests for both recursive and non-recursive locks.
     69     """
     70 
     71     def test_constructor(self):
     72         lock = self.locktype()
     73         del lock
     74 
     75     def test_acquire_destroy(self):
     76         lock = self.locktype()
     77         lock.acquire()
     78         del lock
     79 
     80     def test_acquire_release(self):
     81         lock = self.locktype()
     82         lock.acquire()
     83         lock.release()
     84         del lock
     85 
     86     def test_try_acquire(self):
     87         lock = self.locktype()
     88         self.assertTrue(lock.acquire(False))
     89         lock.release()
     90 
     91     def test_try_acquire_contended(self):
     92         lock = self.locktype()
     93         lock.acquire()
     94         result = []
     95         def f():
     96             result.append(lock.acquire(False))
     97         Bunch(f, 1).wait_for_finished()
     98         self.assertFalse(result[0])
     99         lock.release()
    100 
    101     def test_acquire_contended(self):
    102         lock = self.locktype()
    103         lock.acquire()
    104         N = 5
    105         def f():
    106             lock.acquire()
    107             lock.release()
    108 
    109         b = Bunch(f, N)
    110         b.wait_for_started()
    111         _wait()
    112         self.assertEqual(len(b.finished), 0)
    113         lock.release()
    114         b.wait_for_finished()
    115         self.assertEqual(len(b.finished), N)
    116 
    117     def test_with(self):
    118         lock = self.locktype()
    119         def f():
    120             lock.acquire()
    121             lock.release()
    122         def _with(err=None):
    123             with lock:
    124                 if err is not None:
    125                     raise err
    126         _with()
    127         # Check the lock is unacquired
    128         Bunch(f, 1).wait_for_finished()
    129         self.assertRaises(TypeError, _with, TypeError)
    130         # Check the lock is unacquired
    131         Bunch(f, 1).wait_for_finished()
    132 
    133     def test_thread_leak(self):
    134         # The lock shouldn't leak a Thread instance when used from a foreign
    135         # (non-threading) thread.
    136         lock = self.locktype()
    137         def f():
    138             lock.acquire()
    139             lock.release()
    140         n = len(threading.enumerate())
    141         # We run many threads in the hope that existing threads ids won't
    142         # be recycled.
    143         Bunch(f, 15).wait_for_finished()
    144         self.assertEqual(n, len(threading.enumerate()))
    145 
    146 
    147 class LockTests(BaseLockTests):
    148     """
    149     Tests for non-recursive, weak locks
    150     (which can be acquired and released from different threads).
    151     """
    152     def test_reacquire(self):
    153         # Lock needs to be released before re-acquiring.
    154         lock = self.locktype()
    155         phase = []
    156         def f():
    157             lock.acquire()
    158             phase.append(None)
    159             lock.acquire()
    160             phase.append(None)
    161         start_new_thread(f, ())
    162         while len(phase) == 0:
    163             _wait()
    164         _wait()
    165         self.assertEqual(len(phase), 1)
    166         lock.release()
    167         while len(phase) == 1:
    168             _wait()
    169         self.assertEqual(len(phase), 2)
    170 
    171     def test_different_thread(self):
    172         # Lock can be released from a different thread.
    173         lock = self.locktype()
    174         lock.acquire()
    175         def f():
    176             lock.release()
    177         b = Bunch(f, 1)
    178         b.wait_for_finished()
    179         lock.acquire()
    180         lock.release()
    181 
    182 
    183 class RLockTests(BaseLockTests):
    184     """
    185     Tests for recursive locks.
    186     """
    187     def test_reacquire(self):
    188         lock = self.locktype()
    189         lock.acquire()
    190         lock.acquire()
    191         lock.release()
    192         lock.acquire()
    193         lock.release()
    194         lock.release()
    195 
    196     def test_release_unacquired(self):
    197         # Cannot release an unacquired lock
    198         lock = self.locktype()
    199         self.assertRaises(RuntimeError, lock.release)
    200         lock.acquire()
    201         lock.acquire()
    202         lock.release()
    203         lock.acquire()
    204         lock.release()
    205         lock.release()
    206         self.assertRaises(RuntimeError, lock.release)
    207 
    208     def test_different_thread(self):
    209         # Cannot release from a different thread
    210         lock = self.locktype()
    211         def f():
    212             lock.acquire()
    213         b = Bunch(f, 1, True)
    214         try:
    215             self.assertRaises(RuntimeError, lock.release)
    216         finally:
    217             b.do_finish()
    218 
    219     def test__is_owned(self):
    220         lock = self.locktype()
    221         self.assertFalse(lock._is_owned())
    222         lock.acquire()
    223         self.assertTrue(lock._is_owned())
    224         lock.acquire()
    225         self.assertTrue(lock._is_owned())
    226         result = []
    227         def f():
    228             result.append(lock._is_owned())
    229         Bunch(f, 1).wait_for_finished()
    230         self.assertFalse(result[0])
    231         lock.release()
    232         self.assertTrue(lock._is_owned())
    233         lock.release()
    234         self.assertFalse(lock._is_owned())
    235 
    236 
    237 class EventTests(BaseTestCase):
    238     """
    239     Tests for Event objects.
    240     """
    241 
    242     def test_is_set(self):
    243         evt = self.eventtype()
    244         self.assertFalse(evt.is_set())
    245         evt.set()
    246         self.assertTrue(evt.is_set())
    247         evt.set()
    248         self.assertTrue(evt.is_set())
    249         evt.clear()
    250         self.assertFalse(evt.is_set())
    251         evt.clear()
    252         self.assertFalse(evt.is_set())
    253 
    254     def _check_notify(self, evt):
    255         # All threads get notified
    256         N = 5
    257         results1 = []
    258         results2 = []
    259         def f():
    260             results1.append(evt.wait())
    261             results2.append(evt.wait())
    262         b = Bunch(f, N)
    263         b.wait_for_started()
    264         _wait()
    265         self.assertEqual(len(results1), 0)
    266         evt.set()
    267         b.wait_for_finished()
    268         self.assertEqual(results1, [True] * N)
    269         self.assertEqual(results2, [True] * N)
    270 
    271     def test_notify(self):
    272         evt = self.eventtype()
    273         self._check_notify(evt)
    274         # Another time, after an explicit clear()
    275         evt.set()
    276         evt.clear()
    277         self._check_notify(evt)
    278 
    279     def test_timeout(self):
    280         evt = self.eventtype()
    281         results1 = []
    282         results2 = []
    283         N = 5
    284         def f():
    285             results1.append(evt.wait(0.0))
    286             t1 = time.time()
    287             r = evt.wait(0.2)
    288             t2 = time.time()
    289             results2.append((r, t2 - t1))
    290         Bunch(f, N).wait_for_finished()
    291         self.assertEqual(results1, [False] * N)
    292         for r, dt in results2:
    293             self.assertFalse(r)
    294             self.assertTrue(dt >= 0.2, dt)
    295         # The event is set
    296         results1 = []
    297         results2 = []
    298         evt.set()
    299         Bunch(f, N).wait_for_finished()
    300         self.assertEqual(results1, [True] * N)
    301         for r, dt in results2:
    302             self.assertTrue(r)
    303 
    304 
    305 class ConditionTests(BaseTestCase):
    306     """
    307     Tests for condition variables.
    308     """
    309 
    310     def test_acquire(self):
    311         cond = self.condtype()
    312         # Be default we have an RLock: the condition can be acquired multiple
    313         # times.
    314         cond.acquire()
    315         cond.acquire()
    316         cond.release()
    317         cond.release()
    318         lock = threading.Lock()
    319         cond = self.condtype(lock)
    320         cond.acquire()
    321         self.assertFalse(lock.acquire(False))
    322         cond.release()
    323         self.assertTrue(lock.acquire(False))
    324         self.assertFalse(cond.acquire(False))
    325         lock.release()
    326         with cond:
    327             self.assertFalse(lock.acquire(False))
    328 
    329     def test_unacquired_wait(self):
    330         cond = self.condtype()
    331         self.assertRaises(RuntimeError, cond.wait)
    332 
    333     def test_unacquired_notify(self):
    334         cond = self.condtype()
    335         self.assertRaises(RuntimeError, cond.notify)
    336 
    337     def _check_notify(self, cond):
    338         N = 5
    339         results1 = []
    340         results2 = []
    341         phase_num = 0
    342         def f():
    343             cond.acquire()
    344             cond.wait()
    345             cond.release()
    346             results1.append(phase_num)
    347             cond.acquire()
    348             cond.wait()
    349             cond.release()
    350             results2.append(phase_num)
    351         b = Bunch(f, N)
    352         b.wait_for_started()
    353         _wait()
    354         self.assertEqual(results1, [])
    355         # Notify 3 threads at first
    356         cond.acquire()
    357         cond.notify(3)
    358         _wait()
    359         phase_num = 1
    360         cond.release()
    361         while len(results1) < 3:
    362             _wait()
    363         self.assertEqual(results1, [1] * 3)
    364         self.assertEqual(results2, [])
    365         # Notify 5 threads: they might be in their first or second wait
    366         cond.acquire()
    367         cond.notify(5)
    368         _wait()
    369         phase_num = 2
    370         cond.release()
    371         while len(results1) + len(results2) < 8:
    372             _wait()
    373         self.assertEqual(results1, [1] * 3 + [2] * 2)
    374         self.assertEqual(results2, [2] * 3)
    375         # Notify all threads: they are all in their second wait
    376         cond.acquire()
    377         cond.notify_all()
    378         _wait()
    379         phase_num = 3
    380         cond.release()
    381         while len(results2) < 5:
    382             _wait()
    383         self.assertEqual(results1, [1] * 3 + [2] * 2)
    384         self.assertEqual(results2, [2] * 3 + [3] * 2)
    385         b.wait_for_finished()
    386 
    387     def test_notify(self):
    388         cond = self.condtype()
    389         self._check_notify(cond)
    390         # A second time, to check internal state is still ok.
    391         self._check_notify(cond)
    392 
    393     def test_timeout(self):
    394         cond = self.condtype()
    395         results = []
    396         N = 5
    397         def f():
    398             cond.acquire()
    399             t1 = time.time()
    400             cond.wait(0.2)
    401             t2 = time.time()
    402             cond.release()
    403             results.append(t2 - t1)
    404         Bunch(f, N).wait_for_finished()
    405         self.assertEqual(len(results), 5)
    406         for dt in results:
    407             self.assertTrue(dt >= 0.2, dt)
    408 
    409 
    410 class BaseSemaphoreTests(BaseTestCase):
    411     """
    412     Common tests for {bounded, unbounded} semaphore objects.
    413     """
    414 
    415     def test_constructor(self):
    416         self.assertRaises(ValueError, self.semtype, value = -1)
    417         self.assertRaises(ValueError, self.semtype, value = -sys.maxint)
    418 
    419     def test_acquire(self):
    420         sem = self.semtype(1)
    421         sem.acquire()
    422         sem.release()
    423         sem = self.semtype(2)
    424         sem.acquire()
    425         sem.acquire()
    426         sem.release()
    427         sem.release()
    428 
    429     def test_acquire_destroy(self):
    430         sem = self.semtype()
    431         sem.acquire()
    432         del sem
    433 
    434     def test_acquire_contended(self):
    435         sem = self.semtype(7)
    436         sem.acquire()
    437         N = 10
    438         results1 = []
    439         results2 = []
    440         phase_num = 0
    441         def f():
    442             sem.acquire()
    443             results1.append(phase_num)
    444             sem.acquire()
    445             results2.append(phase_num)
    446         b = Bunch(f, 10)
    447         b.wait_for_started()
    448         while len(results1) + len(results2) < 6:
    449             _wait()
    450         self.assertEqual(results1 + results2, [0] * 6)
    451         phase_num = 1
    452         for i in range(7):
    453             sem.release()
    454         while len(results1) + len(results2) < 13:
    455             _wait()
    456         self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7)
    457         phase_num = 2
    458         for i in range(6):
    459             sem.release()
    460         while len(results1) + len(results2) < 19:
    461             _wait()
    462         self.assertEqual(sorted(results1 + results2), [0] * 6 + [1] * 7 + [2] * 6)
    463         # The semaphore is still locked
    464         self.assertFalse(sem.acquire(False))
    465         # Final release, to let the last thread finish
    466         sem.release()
    467         b.wait_for_finished()
    468 
    469     def test_try_acquire(self):
    470         sem = self.semtype(2)
    471         self.assertTrue(sem.acquire(False))
    472         self.assertTrue(sem.acquire(False))
    473         self.assertFalse(sem.acquire(False))
    474         sem.release()
    475         self.assertTrue(sem.acquire(False))
    476 
    477     def test_try_acquire_contended(self):
    478         sem = self.semtype(4)
    479         sem.acquire()
    480         results = []
    481         def f():
    482             results.append(sem.acquire(False))
    483             results.append(sem.acquire(False))
    484         Bunch(f, 5).wait_for_finished()
    485         # There can be a thread switch between acquiring the semaphore and
    486         # appending the result, therefore results will not necessarily be
    487         # ordered.
    488         self.assertEqual(sorted(results), [False] * 7 + [True] *  3 )
    489 
    490     def test_default_value(self):
    491         # The default initial value is 1.
    492         sem = self.semtype()
    493         sem.acquire()
    494         def f():
    495             sem.acquire()
    496             sem.release()
    497         b = Bunch(f, 1)
    498         b.wait_for_started()
    499         _wait()
    500         self.assertFalse(b.finished)
    501         sem.release()
    502         b.wait_for_finished()
    503 
    504     def test_with(self):
    505         sem = self.semtype(2)
    506         def _with(err=None):
    507             with sem:
    508                 self.assertTrue(sem.acquire(False))
    509                 sem.release()
    510                 with sem:
    511                     self.assertFalse(sem.acquire(False))
    512                     if err:
    513                         raise err
    514         _with()
    515         self.assertTrue(sem.acquire(False))
    516         sem.release()
    517         self.assertRaises(TypeError, _with, TypeError)
    518         self.assertTrue(sem.acquire(False))
    519         sem.release()
    520 
    521 class SemaphoreTests(BaseSemaphoreTests):
    522     """
    523     Tests for unbounded semaphores.
    524     """
    525 
    526     def test_release_unacquired(self):
    527         # Unbounded releases are allowed and increment the semaphore's value
    528         sem = self.semtype(1)
    529         sem.release()
    530         sem.acquire()
    531         sem.acquire()
    532         sem.release()
    533 
    534 
    535 class BoundedSemaphoreTests(BaseSemaphoreTests):
    536     """
    537     Tests for bounded semaphores.
    538     """
    539 
    540     def test_release_unacquired(self):
    541         # Cannot go past the initial value
    542         sem = self.semtype()
    543         self.assertRaises(ValueError, sem.release)
    544         sem.acquire()
    545         sem.release()
    546         self.assertRaises(ValueError, sem.release)
    547