Home | History | Annotate | Download | only in test
      1 import unittest
      2 from test import test_support
      3 from weakref import proxy, ref, WeakSet
      4 import operator
      5 import copy
      6 import string
      7 import os
      8 from random import randrange, shuffle
      9 import sys
     10 import warnings
     11 import collections
     12 import gc
     13 import contextlib
     14 
     15 
     16 class Foo:
     17     pass
     18 
     19 class SomeClass(object):
     20     def __init__(self, value):
     21         self.value = value
     22     def __eq__(self, other):
     23         if type(other) != type(self):
     24             return False
     25         return other.value == self.value
     26 
     27     def __ne__(self, other):
     28         return not self.__eq__(other)
     29 
     30     def __hash__(self):
     31         return hash((SomeClass, self.value))
     32 
     33 class TestWeakSet(unittest.TestCase):
     34 
     35     def setUp(self):
     36         # need to keep references to them

     37         self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
     38         self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
     39         self.letters = [SomeClass(c) for c in string.ascii_letters]
     40         self.s = WeakSet(self.items)
     41         self.d = dict.fromkeys(self.items)
     42         self.obj = SomeClass('F')
     43         self.fs = WeakSet([self.obj])
     44 
     45     def test_methods(self):
     46         weaksetmethods = dir(WeakSet)
     47         for method in dir(set):
     48             if method == 'test_c_api' or method.startswith('_'):
     49                 continue
     50             self.assertIn(method, weaksetmethods,
     51                          "WeakSet missing method " + method)
     52 
     53     def test_new_or_init(self):
     54         self.assertRaises(TypeError, WeakSet, [], 2)
     55 
     56     def test_len(self):
     57         self.assertEqual(len(self.s), len(self.d))
     58         self.assertEqual(len(self.fs), 1)
     59         del self.obj
     60         self.assertEqual(len(self.fs), 0)
     61 
     62     def test_contains(self):
     63         for c in self.letters:
     64             self.assertEqual(c in self.s, c in self.d)
     65         # 1 is not weakref'able, but that TypeError is caught by __contains__

     66         self.assertNotIn(1, self.s)
     67         self.assertIn(self.obj, self.fs)
     68         del self.obj
     69         self.assertNotIn(SomeClass('F'), self.fs)
     70 
     71     def test_union(self):
     72         u = self.s.union(self.items2)
     73         for c in self.letters:
     74             self.assertEqual(c in u, c in self.d or c in self.items2)
     75         self.assertEqual(self.s, WeakSet(self.items))
     76         self.assertEqual(type(u), WeakSet)
     77         self.assertRaises(TypeError, self.s.union, [[]])
     78         for C in set, frozenset, dict.fromkeys, list, tuple:
     79             x = WeakSet(self.items + self.items2)
     80             c = C(self.items2)
     81             self.assertEqual(self.s.union(c), x)
     82 
     83     def test_or(self):
     84         i = self.s.union(self.items2)
     85         self.assertEqual(self.s | set(self.items2), i)
     86         self.assertEqual(self.s | frozenset(self.items2), i)
     87 
     88     def test_intersection(self):
     89         i = self.s.intersection(self.items2)
     90         for c in self.letters:
     91             self.assertEqual(c in i, c in self.d and c in self.items2)
     92         self.assertEqual(self.s, WeakSet(self.items))
     93         self.assertEqual(type(i), WeakSet)
     94         for C in set, frozenset, dict.fromkeys, list, tuple:
     95             x = WeakSet([])
     96             self.assertEqual(self.s.intersection(C(self.items2)), x)
     97 
     98     def test_isdisjoint(self):
     99         self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
    100         self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
    101 
    102     def test_and(self):
    103         i = self.s.intersection(self.items2)
    104         self.assertEqual(self.s & set(self.items2), i)
    105         self.assertEqual(self.s & frozenset(self.items2), i)
    106 
    107     def test_difference(self):
    108         i = self.s.difference(self.items2)
    109         for c in self.letters:
    110             self.assertEqual(c in i, c in self.d and c not in self.items2)
    111         self.assertEqual(self.s, WeakSet(self.items))
    112         self.assertEqual(type(i), WeakSet)
    113         self.assertRaises(TypeError, self.s.difference, [[]])
    114 
    115     def test_sub(self):
    116         i = self.s.difference(self.items2)
    117         self.assertEqual(self.s - set(self.items2), i)
    118         self.assertEqual(self.s - frozenset(self.items2), i)
    119 
    120     def test_symmetric_difference(self):
    121         i = self.s.symmetric_difference(self.items2)
    122         for c in self.letters:
    123             self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
    124         self.assertEqual(self.s, WeakSet(self.items))
    125         self.assertEqual(type(i), WeakSet)
    126         self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
    127 
    128     def test_xor(self):
    129         i = self.s.symmetric_difference(self.items2)
    130         self.assertEqual(self.s ^ set(self.items2), i)
    131         self.assertEqual(self.s ^ frozenset(self.items2), i)
    132 
    133     def test_sub_and_super(self):
    134         pl, ql, rl = map(lambda s: [SomeClass(c) for c in s], ['ab', 'abcde', 'def'])
    135         p, q, r = map(WeakSet, (pl, ql, rl))
    136         self.assertTrue(p < q)
    137         self.assertTrue(p <= q)
    138         self.assertTrue(q <= q)
    139         self.assertTrue(q > p)
    140         self.assertTrue(q >= p)
    141         self.assertFalse(q < r)
    142         self.assertFalse(q <= r)
    143         self.assertFalse(q > r)
    144         self.assertFalse(q >= r)
    145         self.assertTrue(set('a').issubset('abc'))
    146         self.assertTrue(set('abc').issuperset('a'))
    147         self.assertFalse(set('a').issubset('cbs'))
    148         self.assertFalse(set('cbs').issuperset('a'))
    149 
    150     def test_gc(self):
    151         # Create a nest of cycles to exercise overall ref count check

    152         s = WeakSet(Foo() for i in range(1000))
    153         for elem in s:
    154             elem.cycle = s
    155             elem.sub = elem
    156             elem.set = WeakSet([elem])
    157 
    158     def test_subclass_with_custom_hash(self):
    159         # Bug #1257731

    160         class H(WeakSet):
    161             def __hash__(self):
    162                 return int(id(self) & 0x7fffffff)
    163         s=H()
    164         f=set()
    165         f.add(s)
    166         self.assertIn(s, f)
    167         f.remove(s)
    168         f.add(s)
    169         f.discard(s)
    170 
    171     def test_init(self):
    172         s = WeakSet()
    173         s.__init__(self.items)
    174         self.assertEqual(s, self.s)
    175         s.__init__(self.items2)
    176         self.assertEqual(s, WeakSet(self.items2))
    177         self.assertRaises(TypeError, s.__init__, s, 2);
    178         self.assertRaises(TypeError, s.__init__, 1);
    179 
    180     def test_constructor_identity(self):
    181         s = WeakSet(self.items)
    182         t = WeakSet(s)
    183         self.assertNotEqual(id(s), id(t))
    184 
    185     def test_hash(self):
    186         self.assertRaises(TypeError, hash, self.s)
    187 
    188     def test_clear(self):
    189         self.s.clear()
    190         self.assertEqual(self.s, WeakSet([]))
    191         self.assertEqual(len(self.s), 0)
    192 
    193     def test_copy(self):
    194         dup = self.s.copy()
    195         self.assertEqual(self.s, dup)
    196         self.assertNotEqual(id(self.s), id(dup))
    197 
    198     def test_add(self):
    199         x = SomeClass('Q')
    200         self.s.add(x)
    201         self.assertIn(x, self.s)
    202         dup = self.s.copy()
    203         self.s.add(x)
    204         self.assertEqual(self.s, dup)
    205         self.assertRaises(TypeError, self.s.add, [])
    206         self.fs.add(Foo())
    207         self.assertTrue(len(self.fs) == 1)
    208         self.fs.add(self.obj)
    209         self.assertTrue(len(self.fs) == 1)
    210 
    211     def test_remove(self):
    212         x = SomeClass('a')
    213         self.s.remove(x)
    214         self.assertNotIn(x, self.s)
    215         self.assertRaises(KeyError, self.s.remove, x)
    216         self.assertRaises(TypeError, self.s.remove, [])
    217 
    218     def test_discard(self):
    219         a, q = SomeClass('a'), SomeClass('Q')
    220         self.s.discard(a)
    221         self.assertNotIn(a, self.s)
    222         self.s.discard(q)
    223         self.assertRaises(TypeError, self.s.discard, [])
    224 
    225     def test_pop(self):
    226         for i in range(len(self.s)):
    227             elem = self.s.pop()
    228             self.assertNotIn(elem, self.s)
    229         self.assertRaises(KeyError, self.s.pop)
    230 
    231     def test_update(self):
    232         retval = self.s.update(self.items2)
    233         self.assertEqual(retval, None)
    234         for c in (self.items + self.items2):
    235             self.assertIn(c, self.s)
    236         self.assertRaises(TypeError, self.s.update, [[]])
    237 
    238     def test_update_set(self):
    239         self.s.update(set(self.items2))
    240         for c in (self.items + self.items2):
    241             self.assertIn(c, self.s)
    242 
    243     def test_ior(self):
    244         self.s |= set(self.items2)
    245         for c in (self.items + self.items2):
    246             self.assertIn(c, self.s)
    247 
    248     def test_intersection_update(self):
    249         retval = self.s.intersection_update(self.items2)
    250         self.assertEqual(retval, None)
    251         for c in (self.items + self.items2):
    252             if c in self.items2 and c in self.items:
    253                 self.assertIn(c, self.s)
    254             else:
    255                 self.assertNotIn(c, self.s)
    256         self.assertRaises(TypeError, self.s.intersection_update, [[]])
    257 
    258     def test_iand(self):
    259         self.s &= set(self.items2)
    260         for c in (self.items + self.items2):
    261             if c in self.items2 and c in self.items:
    262                 self.assertIn(c, self.s)
    263             else:
    264                 self.assertNotIn(c, self.s)
    265 
    266     def test_difference_update(self):
    267         retval = self.s.difference_update(self.items2)
    268         self.assertEqual(retval, None)
    269         for c in (self.items + self.items2):
    270             if c in self.items and c not in self.items2:
    271                 self.assertIn(c, self.s)
    272             else:
    273                 self.assertNotIn(c, self.s)
    274         self.assertRaises(TypeError, self.s.difference_update, [[]])
    275         self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
    276 
    277     def test_isub(self):
    278         self.s -= set(self.items2)
    279         for c in (self.items + self.items2):
    280             if c in self.items and c not in self.items2:
    281                 self.assertIn(c, self.s)
    282             else:
    283                 self.assertNotIn(c, self.s)
    284 
    285     def test_symmetric_difference_update(self):
    286         retval = self.s.symmetric_difference_update(self.items2)
    287         self.assertEqual(retval, None)
    288         for c in (self.items + self.items2):
    289             if (c in self.items) ^ (c in self.items2):
    290                 self.assertIn(c, self.s)
    291             else:
    292                 self.assertNotIn(c, self.s)
    293         self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
    294 
    295     def test_ixor(self):
    296         self.s ^= set(self.items2)
    297         for c in (self.items + self.items2):
    298             if (c in self.items) ^ (c in self.items2):
    299                 self.assertIn(c, self.s)
    300             else:
    301                 self.assertNotIn(c, self.s)
    302 
    303     def test_inplace_on_self(self):
    304         t = self.s.copy()
    305         t |= t
    306         self.assertEqual(t, self.s)
    307         t &= t
    308         self.assertEqual(t, self.s)
    309         t -= t
    310         self.assertEqual(t, WeakSet())
    311         t = self.s.copy()
    312         t ^= t
    313         self.assertEqual(t, WeakSet())
    314 
    315     def test_eq(self):
    316         # issue 5964

    317         self.assertTrue(self.s == self.s)
    318         self.assertTrue(self.s == WeakSet(self.items))
    319         self.assertFalse(self.s == set(self.items))
    320         self.assertFalse(self.s == list(self.items))
    321         self.assertFalse(self.s == tuple(self.items))
    322         self.assertFalse(self.s == 1)
    323 
    324     def test_weak_destroy_while_iterating(self):
    325         # Issue #7105: iterators shouldn't crash when a key is implicitly removed

    326         # Create new items to be sure no-one else holds a reference

    327         items = [SomeClass(c) for c in ('a', 'b', 'c')]
    328         s = WeakSet(items)
    329         it = iter(s)
    330         next(it)             # Trigger internal iteration

    331         # Destroy an item

    332         del items[-1]
    333         gc.collect()    # just in case

    334         # We have removed either the first consumed items, or another one

    335         self.assertIn(len(list(it)), [len(items), len(items) - 1])
    336         del it
    337         # The removal has been committed

    338         self.assertEqual(len(s), len(items))
    339 
    340     def test_weak_destroy_and_mutate_while_iterating(self):
    341         # Issue #7105: iterators shouldn't crash when a key is implicitly removed

    342         items = [SomeClass(c) for c in string.ascii_letters]
    343         s = WeakSet(items)
    344         @contextlib.contextmanager
    345         def testcontext():
    346             try:
    347                 it = iter(s)
    348                 next(it)
    349                 # Schedule an item for removal and recreate it

    350                 u = SomeClass(str(items.pop()))
    351                 gc.collect()      # just in case

    352                 yield u
    353             finally:
    354                 it = None           # should commit all removals

    355 
    356         with testcontext() as u:
    357             self.assertNotIn(u, s)
    358         with testcontext() as u:
    359             self.assertRaises(KeyError, s.remove, u)
    360         self.assertNotIn(u, s)
    361         with testcontext() as u:
    362             s.add(u)
    363         self.assertIn(u, s)
    364         t = s.copy()
    365         with testcontext() as u:
    366             s.update(t)
    367         self.assertEqual(len(s), len(t))
    368         with testcontext() as u:
    369             s.clear()
    370         self.assertEqual(len(s), 0)
    371 
    372 
    373 def test_main(verbose=None):
    374     test_support.run_unittest(TestWeakSet)
    375 
    376 if __name__ == "__main__":
    377     test_main(verbose=True)
    378