Home | History | Annotate | Download | only in Lib
      1 # Access WeakSet through the weakref module.
      2 # This code is separated-out because it is needed
      3 # by abc.py to load everything else at startup.
      4 
      5 from _weakref import ref
      6 
      7 __all__ = ['WeakSet']
      8 
      9 
     10 class _IterationGuard:
     11     # This context manager registers itself in the current iterators of the
     12     # weak container, such as to delay all removals until the context manager
     13     # exits.
     14     # This technique should be relatively thread-safe (since sets are).
     15 
     16     def __init__(self, weakcontainer):
     17         # Don't create cycles
     18         self.weakcontainer = ref(weakcontainer)
     19 
     20     def __enter__(self):
     21         w = self.weakcontainer()
     22         if w is not None:
     23             w._iterating.add(self)
     24         return self
     25 
     26     def __exit__(self, e, t, b):
     27         w = self.weakcontainer()
     28         if w is not None:
     29             s = w._iterating
     30             s.remove(self)
     31             if not s:
     32                 w._commit_removals()
     33 
     34 
     35 class WeakSet:
     36     def __init__(self, data=None):
     37         self.data = set()
     38         def _remove(item, selfref=ref(self)):
     39             self = selfref()
     40             if self is not None:
     41                 if self._iterating:
     42                     self._pending_removals.append(item)
     43                 else:
     44                     self.data.discard(item)
     45         self._remove = _remove
     46         # A list of keys to be removed
     47         self._pending_removals = []
     48         self._iterating = set()
     49         if data is not None:
     50             self.update(data)
     51 
     52     def _commit_removals(self):
     53         l = self._pending_removals
     54         discard = self.data.discard
     55         while l:
     56             discard(l.pop())
     57 
     58     def __iter__(self):
     59         with _IterationGuard(self):
     60             for itemref in self.data:
     61                 item = itemref()
     62                 if item is not None:
     63                     # Caveat: the iterator will keep a strong reference to
     64                     # `item` until it is resumed or closed.
     65                     yield item
     66 
     67     def __len__(self):
     68         return len(self.data) - len(self._pending_removals)
     69 
     70     def __contains__(self, item):
     71         try:
     72             wr = ref(item)
     73         except TypeError:
     74             return False
     75         return wr in self.data
     76 
     77     def __reduce__(self):
     78         return (self.__class__, (list(self),),
     79                 getattr(self, '__dict__', None))
     80 
     81     def add(self, item):
     82         if self._pending_removals:
     83             self._commit_removals()
     84         self.data.add(ref(item, self._remove))
     85 
     86     def clear(self):
     87         if self._pending_removals:
     88             self._commit_removals()
     89         self.data.clear()
     90 
     91     def copy(self):
     92         return self.__class__(self)
     93 
     94     def pop(self):
     95         if self._pending_removals:
     96             self._commit_removals()
     97         while True:
     98             try:
     99                 itemref = self.data.pop()
    100             except KeyError:
    101                 raise KeyError('pop from empty WeakSet')
    102             item = itemref()
    103             if item is not None:
    104                 return item
    105 
    106     def remove(self, item):
    107         if self._pending_removals:
    108             self._commit_removals()
    109         self.data.remove(ref(item))
    110 
    111     def discard(self, item):
    112         if self._pending_removals:
    113             self._commit_removals()
    114         self.data.discard(ref(item))
    115 
    116     def update(self, other):
    117         if self._pending_removals:
    118             self._commit_removals()
    119         for element in other:
    120             self.add(element)
    121 
    122     def __ior__(self, other):
    123         self.update(other)
    124         return self
    125 
    126     def difference(self, other):
    127         newset = self.copy()
    128         newset.difference_update(other)
    129         return newset
    130     __sub__ = difference
    131 
    132     def difference_update(self, other):
    133         self.__isub__(other)
    134     def __isub__(self, other):
    135         if self._pending_removals:
    136             self._commit_removals()
    137         if self is other:
    138             self.data.clear()
    139         else:
    140             self.data.difference_update(ref(item) for item in other)
    141         return self
    142 
    143     def intersection(self, other):
    144         return self.__class__(item for item in other if item in self)
    145     __and__ = intersection
    146 
    147     def intersection_update(self, other):
    148         self.__iand__(other)
    149     def __iand__(self, other):
    150         if self._pending_removals:
    151             self._commit_removals()
    152         self.data.intersection_update(ref(item) for item in other)
    153         return self
    154 
    155     def issubset(self, other):
    156         return self.data.issubset(ref(item) for item in other)
    157     __le__ = issubset
    158 
    159     def __lt__(self, other):
    160         return self.data < set(ref(item) for item in other)
    161 
    162     def issuperset(self, other):
    163         return self.data.issuperset(ref(item) for item in other)
    164     __ge__ = issuperset
    165 
    166     def __gt__(self, other):
    167         return self.data > set(ref(item) for item in other)
    168 
    169     def __eq__(self, other):
    170         if not isinstance(other, self.__class__):
    171             return NotImplemented
    172         return self.data == set(ref(item) for item in other)
    173 
    174     def symmetric_difference(self, other):
    175         newset = self.copy()
    176         newset.symmetric_difference_update(other)
    177         return newset
    178     __xor__ = symmetric_difference
    179 
    180     def symmetric_difference_update(self, other):
    181         self.__ixor__(other)
    182     def __ixor__(self, other):
    183         if self._pending_removals:
    184             self._commit_removals()
    185         if self is other:
    186             self.data.clear()
    187         else:
    188             self.data.symmetric_difference_update(ref(item, self._remove) for item in other)
    189         return self
    190 
    191     def union(self, other):
    192         return self.__class__(e for s in (self, other) for e in s)
    193     __or__ = union
    194 
    195     def isdisjoint(self, other):
    196         return len(self.intersection(other)) == 0
    197