Home | History | Annotate | Download | only in asyncio
      1 """Synchronization primitives."""
      2 
      3 __all__ = ['Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore']
      4 
      5 import collections
      6 
      7 from . import compat
      8 from . import events
      9 from . import futures
     10 from .coroutines import coroutine
     11 
     12 
     13 class _ContextManager:
     14     """Context manager.
     15 
     16     This enables the following idiom for acquiring and releasing a
     17     lock around a block:
     18 
     19         with (yield from lock):
     20             <block>
     21 
     22     while failing loudly when accidentally using:
     23 
     24         with lock:
     25             <block>
     26     """
     27 
     28     def __init__(self, lock):
     29         self._lock = lock
     30 
     31     def __enter__(self):
     32         # We have no use for the "as ..."  clause in the with
     33         # statement for locks.
     34         return None
     35 
     36     def __exit__(self, *args):
     37         try:
     38             self._lock.release()
     39         finally:
     40             self._lock = None  # Crudely prevent reuse.
     41 
     42 
     43 class _ContextManagerMixin:
     44     def __enter__(self):
     45         raise RuntimeError(
     46             '"yield from" should be used as context manager expression')
     47 
     48     def __exit__(self, *args):
     49         # This must exist because __enter__ exists, even though that
     50         # always raises; that's how the with-statement works.
     51         pass
     52 
     53     @coroutine
     54     def __iter__(self):
     55         # This is not a coroutine.  It is meant to enable the idiom:
     56         #
     57         #     with (yield from lock):
     58         #         <block>
     59         #
     60         # as an alternative to:
     61         #
     62         #     yield from lock.acquire()
     63         #     try:
     64         #         <block>
     65         #     finally:
     66         #         lock.release()
     67         yield from self.acquire()
     68         return _ContextManager(self)
     69 
     70     if compat.PY35:
     71 
     72         def __await__(self):
     73             # To make "with await lock" work.
     74             yield from self.acquire()
     75             return _ContextManager(self)
     76 
     77         @coroutine
     78         def __aenter__(self):
     79             yield from self.acquire()
     80             # We have no use for the "as ..."  clause in the with
     81             # statement for locks.
     82             return None
     83 
     84         @coroutine
     85         def __aexit__(self, exc_type, exc, tb):
     86             self.release()
     87 
     88 
     89 class Lock(_ContextManagerMixin):
     90     """Primitive lock objects.
     91 
     92     A primitive lock is a synchronization primitive that is not owned
     93     by a particular coroutine when locked.  A primitive lock is in one
     94     of two states, 'locked' or 'unlocked'.
     95 
     96     It is created in the unlocked state.  It has two basic methods,
     97     acquire() and release().  When the state is unlocked, acquire()
     98     changes the state to locked and returns immediately.  When the
     99     state is locked, acquire() blocks until a call to release() in
    100     another coroutine changes it to unlocked, then the acquire() call
    101     resets it to locked and returns.  The release() method should only
    102     be called in the locked state; it changes the state to unlocked
    103     and returns immediately.  If an attempt is made to release an
    104     unlocked lock, a RuntimeError will be raised.
    105 
    106     When more than one coroutine is blocked in acquire() waiting for
    107     the state to turn to unlocked, only one coroutine proceeds when a
    108     release() call resets the state to unlocked; first coroutine which
    109     is blocked in acquire() is being processed.
    110 
    111     acquire() is a coroutine and should be called with 'yield from'.
    112 
    113     Locks also support the context management protocol.  '(yield from lock)'
    114     should be used as the context manager expression.
    115 
    116     Usage:
    117 
    118         lock = Lock()
    119         ...
    120         yield from lock
    121         try:
    122             ...
    123         finally:
    124             lock.release()
    125 
    126     Context manager usage:
    127 
    128         lock = Lock()
    129         ...
    130         with (yield from lock):
    131              ...
    132 
    133     Lock objects can be tested for locking state:
    134 
    135         if not lock.locked():
    136            yield from lock
    137         else:
    138            # lock is acquired
    139            ...
    140 
    141     """
    142 
    143     def __init__(self, *, loop=None):
    144         self._waiters = collections.deque()
    145         self._locked = False
    146         if loop is not None:
    147             self._loop = loop
    148         else:
    149             self._loop = events.get_event_loop()
    150 
    151     def __repr__(self):
    152         res = super().__repr__()
    153         extra = 'locked' if self._locked else 'unlocked'
    154         if self._waiters:
    155             extra = '{},waiters:{}'.format(extra, len(self._waiters))
    156         return '<{} [{}]>'.format(res[1:-1], extra)
    157 
    158     def locked(self):
    159         """Return True if lock is acquired."""
    160         return self._locked
    161 
    162     @coroutine
    163     def acquire(self):
    164         """Acquire a lock.
    165 
    166         This method blocks until the lock is unlocked, then sets it to
    167         locked and returns True.
    168         """
    169         if not self._locked and all(w.cancelled() for w in self._waiters):
    170             self._locked = True
    171             return True
    172 
    173         fut = self._loop.create_future()
    174         self._waiters.append(fut)
    175         try:
    176             yield from fut
    177             self._locked = True
    178             return True
    179         finally:
    180             self._waiters.remove(fut)
    181 
    182     def release(self):
    183         """Release a lock.
    184 
    185         When the lock is locked, reset it to unlocked, and return.
    186         If any other coroutines are blocked waiting for the lock to become
    187         unlocked, allow exactly one of them to proceed.
    188 
    189         When invoked on an unlocked lock, a RuntimeError is raised.
    190 
    191         There is no return value.
    192         """
    193         if self._locked:
    194             self._locked = False
    195             # Wake up the first waiter who isn't cancelled.
    196             for fut in self._waiters:
    197                 if not fut.done():
    198                     fut.set_result(True)
    199                     break
    200         else:
    201             raise RuntimeError('Lock is not acquired.')
    202 
    203 
    204 class Event:
    205     """Asynchronous equivalent to threading.Event.
    206 
    207     Class implementing event objects. An event manages a flag that can be set
    208     to true with the set() method and reset to false with the clear() method.
    209     The wait() method blocks until the flag is true. The flag is initially
    210     false.
    211     """
    212 
    213     def __init__(self, *, loop=None):
    214         self._waiters = collections.deque()
    215         self._value = False
    216         if loop is not None:
    217             self._loop = loop
    218         else:
    219             self._loop = events.get_event_loop()
    220 
    221     def __repr__(self):
    222         res = super().__repr__()
    223         extra = 'set' if self._value else 'unset'
    224         if self._waiters:
    225             extra = '{},waiters:{}'.format(extra, len(self._waiters))
    226         return '<{} [{}]>'.format(res[1:-1], extra)
    227 
    228     def is_set(self):
    229         """Return True if and only if the internal flag is true."""
    230         return self._value
    231 
    232     def set(self):
    233         """Set the internal flag to true. All coroutines waiting for it to
    234         become true are awakened. Coroutine that call wait() once the flag is
    235         true will not block at all.
    236         """
    237         if not self._value:
    238             self._value = True
    239 
    240             for fut in self._waiters:
    241                 if not fut.done():
    242                     fut.set_result(True)
    243 
    244     def clear(self):
    245         """Reset the internal flag to false. Subsequently, coroutines calling
    246         wait() will block until set() is called to set the internal flag
    247         to true again."""
    248         self._value = False
    249 
    250     @coroutine
    251     def wait(self):
    252         """Block until the internal flag is true.
    253 
    254         If the internal flag is true on entry, return True
    255         immediately.  Otherwise, block until another coroutine calls
    256         set() to set the flag to true, then return True.
    257         """
    258         if self._value:
    259             return True
    260 
    261         fut = self._loop.create_future()
    262         self._waiters.append(fut)
    263         try:
    264             yield from fut
    265             return True
    266         finally:
    267             self._waiters.remove(fut)
    268 
    269 
    270 class Condition(_ContextManagerMixin):
    271     """Asynchronous equivalent to threading.Condition.
    272 
    273     This class implements condition variable objects. A condition variable
    274     allows one or more coroutines to wait until they are notified by another
    275     coroutine.
    276 
    277     A new Lock object is created and used as the underlying lock.
    278     """
    279 
    280     def __init__(self, lock=None, *, loop=None):
    281         if loop is not None:
    282             self._loop = loop
    283         else:
    284             self._loop = events.get_event_loop()
    285 
    286         if lock is None:
    287             lock = Lock(loop=self._loop)
    288         elif lock._loop is not self._loop:
    289             raise ValueError("loop argument must agree with lock")
    290 
    291         self._lock = lock
    292         # Export the lock's locked(), acquire() and release() methods.
    293         self.locked = lock.locked
    294         self.acquire = lock.acquire
    295         self.release = lock.release
    296 
    297         self._waiters = collections.deque()
    298 
    299     def __repr__(self):
    300         res = super().__repr__()
    301         extra = 'locked' if self.locked() else 'unlocked'
    302         if self._waiters:
    303             extra = '{},waiters:{}'.format(extra, len(self._waiters))
    304         return '<{} [{}]>'.format(res[1:-1], extra)
    305 
    306     @coroutine
    307     def wait(self):
    308         """Wait until notified.
    309 
    310         If the calling coroutine has not acquired the lock when this
    311         method is called, a RuntimeError is raised.
    312 
    313         This method releases the underlying lock, and then blocks
    314         until it is awakened by a notify() or notify_all() call for
    315         the same condition variable in another coroutine.  Once
    316         awakened, it re-acquires the lock and returns True.
    317         """
    318         if not self.locked():
    319             raise RuntimeError('cannot wait on un-acquired lock')
    320 
    321         self.release()
    322         try:
    323             fut = self._loop.create_future()
    324             self._waiters.append(fut)
    325             try:
    326                 yield from fut
    327                 return True
    328             finally:
    329                 self._waiters.remove(fut)
    330 
    331         finally:
    332             # Must reacquire lock even if wait is cancelled
    333             while True:
    334                 try:
    335                     yield from self.acquire()
    336                     break
    337                 except futures.CancelledError:
    338                     pass
    339 
    340     @coroutine
    341     def wait_for(self, predicate):
    342         """Wait until a predicate becomes true.
    343 
    344         The predicate should be a callable which result will be
    345         interpreted as a boolean value.  The final predicate value is
    346         the return value.
    347         """
    348         result = predicate()
    349         while not result:
    350             yield from self.wait()
    351             result = predicate()
    352         return result
    353 
    354     def notify(self, n=1):
    355         """By default, wake up one coroutine waiting on this condition, if any.
    356         If the calling coroutine has not acquired the lock when this method
    357         is called, a RuntimeError is raised.
    358 
    359         This method wakes up at most n of the coroutines waiting for the
    360         condition variable; it is a no-op if no coroutines are waiting.
    361 
    362         Note: an awakened coroutine does not actually return from its
    363         wait() call until it can reacquire the lock. Since notify() does
    364         not release the lock, its caller should.
    365         """
    366         if not self.locked():
    367             raise RuntimeError('cannot notify on un-acquired lock')
    368 
    369         idx = 0
    370         for fut in self._waiters:
    371             if idx >= n:
    372                 break
    373 
    374             if not fut.done():
    375                 idx += 1
    376                 fut.set_result(False)
    377 
    378     def notify_all(self):
    379         """Wake up all threads waiting on this condition. This method acts
    380         like notify(), but wakes up all waiting threads instead of one. If the
    381         calling thread has not acquired the lock when this method is called,
    382         a RuntimeError is raised.
    383         """
    384         self.notify(len(self._waiters))
    385 
    386 
    387 class Semaphore(_ContextManagerMixin):
    388     """A Semaphore implementation.
    389 
    390     A semaphore manages an internal counter which is decremented by each
    391     acquire() call and incremented by each release() call. The counter
    392     can never go below zero; when acquire() finds that it is zero, it blocks,
    393     waiting until some other thread calls release().
    394 
    395     Semaphores also support the context management protocol.
    396 
    397     The optional argument gives the initial value for the internal
    398     counter; it defaults to 1. If the value given is less than 0,
    399     ValueError is raised.
    400     """
    401 
    402     def __init__(self, value=1, *, loop=None):
    403         if value < 0:
    404             raise ValueError("Semaphore initial value must be >= 0")
    405         self._value = value
    406         self._waiters = collections.deque()
    407         if loop is not None:
    408             self._loop = loop
    409         else:
    410             self._loop = events.get_event_loop()
    411 
    412     def __repr__(self):
    413         res = super().__repr__()
    414         extra = 'locked' if self.locked() else 'unlocked,value:{}'.format(
    415             self._value)
    416         if self._waiters:
    417             extra = '{},waiters:{}'.format(extra, len(self._waiters))
    418         return '<{} [{}]>'.format(res[1:-1], extra)
    419 
    420     def _wake_up_next(self):
    421         while self._waiters:
    422             waiter = self._waiters.popleft()
    423             if not waiter.done():
    424                 waiter.set_result(None)
    425                 return
    426 
    427     def locked(self):
    428         """Returns True if semaphore can not be acquired immediately."""
    429         return self._value == 0
    430 
    431     @coroutine
    432     def acquire(self):
    433         """Acquire a semaphore.
    434 
    435         If the internal counter is larger than zero on entry,
    436         decrement it by one and return True immediately.  If it is
    437         zero on entry, block, waiting until some other coroutine has
    438         called release() to make it larger than 0, and then return
    439         True.
    440         """
    441         while self._value <= 0:
    442             fut = self._loop.create_future()
    443             self._waiters.append(fut)
    444             try:
    445                 yield from fut
    446             except:
    447                 # See the similar code in Queue.get.
    448                 fut.cancel()
    449                 if self._value > 0 and not fut.cancelled():
    450                     self._wake_up_next()
    451                 raise
    452         self._value -= 1
    453         return True
    454 
    455     def release(self):
    456         """Release a semaphore, incrementing the internal counter by one.
    457         When it was zero on entry and another coroutine is waiting for it to
    458         become larger than zero again, wake up that coroutine.
    459         """
    460         self._value += 1
    461         self._wake_up_next()
    462 
    463 
    464 class BoundedSemaphore(Semaphore):
    465     """A bounded semaphore implementation.
    466 
    467     This raises ValueError in release() if it would increase the value
    468     above the initial value.
    469     """
    470 
    471     def __init__(self, value=1, *, loop=None):
    472         self._bound_value = value
    473         super().__init__(value, loop=loop)
    474 
    475     def release(self):
    476         if self._value >= self._bound_value:
    477             raise ValueError('BoundedSemaphore released too many times')
    478         super().release()
    479