1 import unittest 2 from test import test_support 3 from weakref import proxy, ref, WeakSet 4 import operator 5 import copy 6 import string 7 import os 8 from random import randrange, shuffle 9 import sys 10 import warnings 11 import collections 12 import gc 13 import contextlib 14 15 16 class Foo: 17 pass 18 19 class SomeClass(object): 20 def __init__(self, value): 21 self.value = value 22 def __eq__(self, other): 23 if type(other) != type(self): 24 return False 25 return other.value == self.value 26 27 def __ne__(self, other): 28 return not self.__eq__(other) 29 30 def __hash__(self): 31 return hash((SomeClass, self.value)) 32 33 class RefCycle(object): 34 def __init__(self): 35 self.cycle = self 36 37 class TestWeakSet(unittest.TestCase): 38 39 def setUp(self): 40 # need to keep references to them 41 self.items = [SomeClass(c) for c in ('a', 'b', 'c')] 42 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')] 43 self.letters = [SomeClass(c) for c in string.ascii_letters] 44 self.ab_items = [SomeClass(c) for c in 'ab'] 45 self.abcde_items = [SomeClass(c) for c in 'abcde'] 46 self.def_items = [SomeClass(c) for c in 'def'] 47 self.ab_weakset = WeakSet(self.ab_items) 48 self.abcde_weakset = WeakSet(self.abcde_items) 49 self.def_weakset = WeakSet(self.def_items) 50 self.s = WeakSet(self.items) 51 self.d = dict.fromkeys(self.items) 52 self.obj = SomeClass('F') 53 self.fs = WeakSet([self.obj]) 54 55 def test_methods(self): 56 weaksetmethods = dir(WeakSet) 57 for method in dir(set): 58 if method == 'test_c_api' or method.startswith('_'): 59 continue 60 self.assertIn(method, weaksetmethods, 61 "WeakSet missing method " + method) 62 63 def test_new_or_init(self): 64 self.assertRaises(TypeError, WeakSet, [], 2) 65 66 def test_len(self): 67 self.assertEqual(len(self.s), len(self.d)) 68 self.assertEqual(len(self.fs), 1) 69 del self.obj 70 self.assertEqual(len(self.fs), 0) 71 72 def test_contains(self): 73 for c in self.letters: 74 self.assertEqual(c in self.s, c in self.d) 75 # 1 is not weakref'able, but that TypeError is caught by __contains__ 76 self.assertNotIn(1, self.s) 77 self.assertIn(self.obj, self.fs) 78 del self.obj 79 self.assertNotIn(SomeClass('F'), self.fs) 80 81 def test_union(self): 82 u = self.s.union(self.items2) 83 for c in self.letters: 84 self.assertEqual(c in u, c in self.d or c in self.items2) 85 self.assertEqual(self.s, WeakSet(self.items)) 86 self.assertEqual(type(u), WeakSet) 87 self.assertRaises(TypeError, self.s.union, [[]]) 88 for C in set, frozenset, dict.fromkeys, list, tuple: 89 x = WeakSet(self.items + self.items2) 90 c = C(self.items2) 91 self.assertEqual(self.s.union(c), x) 92 del c 93 self.assertEqual(len(u), len(self.items) + len(self.items2)) 94 self.items2.pop() 95 gc.collect() 96 self.assertEqual(len(u), len(self.items) + len(self.items2)) 97 98 def test_or(self): 99 i = self.s.union(self.items2) 100 self.assertEqual(self.s | set(self.items2), i) 101 self.assertEqual(self.s | frozenset(self.items2), i) 102 103 def test_intersection(self): 104 s = WeakSet(self.letters) 105 i = s.intersection(self.items2) 106 for c in self.letters: 107 self.assertEqual(c in i, c in self.items2 and c in self.letters) 108 self.assertEqual(s, WeakSet(self.letters)) 109 self.assertEqual(type(i), WeakSet) 110 for C in set, frozenset, dict.fromkeys, list, tuple: 111 x = WeakSet([]) 112 self.assertEqual(i.intersection(C(self.items)), x) 113 self.assertEqual(len(i), len(self.items2)) 114 self.items2.pop() 115 gc.collect() 116 self.assertEqual(len(i), len(self.items2)) 117 118 def test_isdisjoint(self): 119 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) 120 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters))) 121 122 def test_and(self): 123 i = self.s.intersection(self.items2) 124 self.assertEqual(self.s & set(self.items2), i) 125 self.assertEqual(self.s & frozenset(self.items2), i) 126 127 def test_difference(self): 128 i = self.s.difference(self.items2) 129 for c in self.letters: 130 self.assertEqual(c in i, c in self.d and c not in self.items2) 131 self.assertEqual(self.s, WeakSet(self.items)) 132 self.assertEqual(type(i), WeakSet) 133 self.assertRaises(TypeError, self.s.difference, [[]]) 134 135 def test_sub(self): 136 i = self.s.difference(self.items2) 137 self.assertEqual(self.s - set(self.items2), i) 138 self.assertEqual(self.s - frozenset(self.items2), i) 139 140 def test_symmetric_difference(self): 141 i = self.s.symmetric_difference(self.items2) 142 for c in self.letters: 143 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2)) 144 self.assertEqual(self.s, WeakSet(self.items)) 145 self.assertEqual(type(i), WeakSet) 146 self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) 147 self.assertEqual(len(i), len(self.items) + len(self.items2)) 148 self.items2.pop() 149 gc.collect() 150 self.assertEqual(len(i), len(self.items) + len(self.items2)) 151 152 def test_xor(self): 153 i = self.s.symmetric_difference(self.items2) 154 self.assertEqual(self.s ^ set(self.items2), i) 155 self.assertEqual(self.s ^ frozenset(self.items2), i) 156 157 def test_sub_and_super(self): 158 self.assertTrue(self.ab_weakset <= self.abcde_weakset) 159 self.assertTrue(self.abcde_weakset <= self.abcde_weakset) 160 self.assertTrue(self.abcde_weakset >= self.ab_weakset) 161 self.assertFalse(self.abcde_weakset <= self.def_weakset) 162 self.assertFalse(self.abcde_weakset >= self.def_weakset) 163 self.assertTrue(set('a').issubset('abc')) 164 self.assertTrue(set('abc').issuperset('a')) 165 self.assertFalse(set('a').issubset('cbs')) 166 self.assertFalse(set('cbs').issuperset('a')) 167 168 def test_lt(self): 169 self.assertTrue(self.ab_weakset < self.abcde_weakset) 170 self.assertFalse(self.abcde_weakset < self.def_weakset) 171 self.assertFalse(self.ab_weakset < self.ab_weakset) 172 self.assertFalse(WeakSet() < WeakSet()) 173 174 def test_gt(self): 175 self.assertTrue(self.abcde_weakset > self.ab_weakset) 176 self.assertFalse(self.abcde_weakset > self.def_weakset) 177 self.assertFalse(self.ab_weakset > self.ab_weakset) 178 self.assertFalse(WeakSet() > WeakSet()) 179 180 def test_gc(self): 181 # Create a nest of cycles to exercise overall ref count check 182 s = WeakSet(Foo() for i in range(1000)) 183 for elem in s: 184 elem.cycle = s 185 elem.sub = elem 186 elem.set = WeakSet([elem]) 187 188 def test_subclass_with_custom_hash(self): 189 # Bug #1257731 190 class H(WeakSet): 191 def __hash__(self): 192 return int(id(self) & 0x7fffffff) 193 s=H() 194 f=set() 195 f.add(s) 196 self.assertIn(s, f) 197 f.remove(s) 198 f.add(s) 199 f.discard(s) 200 201 def test_init(self): 202 s = WeakSet() 203 s.__init__(self.items) 204 self.assertEqual(s, self.s) 205 s.__init__(self.items2) 206 self.assertEqual(s, WeakSet(self.items2)) 207 self.assertRaises(TypeError, s.__init__, s, 2); 208 self.assertRaises(TypeError, s.__init__, 1); 209 210 def test_constructor_identity(self): 211 s = WeakSet(self.items) 212 t = WeakSet(s) 213 self.assertNotEqual(id(s), id(t)) 214 215 def test_hash(self): 216 self.assertRaises(TypeError, hash, self.s) 217 218 def test_clear(self): 219 self.s.clear() 220 self.assertEqual(self.s, WeakSet([])) 221 self.assertEqual(len(self.s), 0) 222 223 def test_copy(self): 224 dup = self.s.copy() 225 self.assertEqual(self.s, dup) 226 self.assertNotEqual(id(self.s), id(dup)) 227 228 def test_add(self): 229 x = SomeClass('Q') 230 self.s.add(x) 231 self.assertIn(x, self.s) 232 dup = self.s.copy() 233 self.s.add(x) 234 self.assertEqual(self.s, dup) 235 self.assertRaises(TypeError, self.s.add, []) 236 self.fs.add(Foo()) 237 self.assertTrue(len(self.fs) == 1) 238 self.fs.add(self.obj) 239 self.assertTrue(len(self.fs) == 1) 240 241 def test_remove(self): 242 x = SomeClass('a') 243 self.s.remove(x) 244 self.assertNotIn(x, self.s) 245 self.assertRaises(KeyError, self.s.remove, x) 246 self.assertRaises(TypeError, self.s.remove, []) 247 248 def test_discard(self): 249 a, q = SomeClass('a'), SomeClass('Q') 250 self.s.discard(a) 251 self.assertNotIn(a, self.s) 252 self.s.discard(q) 253 self.assertRaises(TypeError, self.s.discard, []) 254 255 def test_pop(self): 256 for i in range(len(self.s)): 257 elem = self.s.pop() 258 self.assertNotIn(elem, self.s) 259 self.assertRaises(KeyError, self.s.pop) 260 261 def test_update(self): 262 retval = self.s.update(self.items2) 263 self.assertEqual(retval, None) 264 for c in (self.items + self.items2): 265 self.assertIn(c, self.s) 266 self.assertRaises(TypeError, self.s.update, [[]]) 267 268 def test_update_set(self): 269 self.s.update(set(self.items2)) 270 for c in (self.items + self.items2): 271 self.assertIn(c, self.s) 272 273 def test_ior(self): 274 self.s |= set(self.items2) 275 for c in (self.items + self.items2): 276 self.assertIn(c, self.s) 277 278 def test_intersection_update(self): 279 retval = self.s.intersection_update(self.items2) 280 self.assertEqual(retval, None) 281 for c in (self.items + self.items2): 282 if c in self.items2 and c in self.items: 283 self.assertIn(c, self.s) 284 else: 285 self.assertNotIn(c, self.s) 286 self.assertRaises(TypeError, self.s.intersection_update, [[]]) 287 288 def test_iand(self): 289 self.s &= set(self.items2) 290 for c in (self.items + self.items2): 291 if c in self.items2 and c in self.items: 292 self.assertIn(c, self.s) 293 else: 294 self.assertNotIn(c, self.s) 295 296 def test_difference_update(self): 297 retval = self.s.difference_update(self.items2) 298 self.assertEqual(retval, None) 299 for c in (self.items + self.items2): 300 if c in self.items and c not in self.items2: 301 self.assertIn(c, self.s) 302 else: 303 self.assertNotIn(c, self.s) 304 self.assertRaises(TypeError, self.s.difference_update, [[]]) 305 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 306 307 def test_isub(self): 308 self.s -= set(self.items2) 309 for c in (self.items + self.items2): 310 if c in self.items and c not in self.items2: 311 self.assertIn(c, self.s) 312 else: 313 self.assertNotIn(c, self.s) 314 315 def test_symmetric_difference_update(self): 316 retval = self.s.symmetric_difference_update(self.items2) 317 self.assertEqual(retval, None) 318 for c in (self.items + self.items2): 319 if (c in self.items) ^ (c in self.items2): 320 self.assertIn(c, self.s) 321 else: 322 self.assertNotIn(c, self.s) 323 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 324 325 def test_ixor(self): 326 self.s ^= set(self.items2) 327 for c in (self.items + self.items2): 328 if (c in self.items) ^ (c in self.items2): 329 self.assertIn(c, self.s) 330 else: 331 self.assertNotIn(c, self.s) 332 333 def test_inplace_on_self(self): 334 t = self.s.copy() 335 t |= t 336 self.assertEqual(t, self.s) 337 t &= t 338 self.assertEqual(t, self.s) 339 t -= t 340 self.assertEqual(t, WeakSet()) 341 t = self.s.copy() 342 t ^= t 343 self.assertEqual(t, WeakSet()) 344 345 def test_eq(self): 346 # issue 5964 347 self.assertTrue(self.s == self.s) 348 self.assertTrue(self.s == WeakSet(self.items)) 349 self.assertFalse(self.s == set(self.items)) 350 self.assertFalse(self.s == list(self.items)) 351 self.assertFalse(self.s == tuple(self.items)) 352 self.assertFalse(self.s == 1) 353 354 def test_weak_destroy_while_iterating(self): 355 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 356 # Create new items to be sure no-one else holds a reference 357 items = [SomeClass(c) for c in ('a', 'b', 'c')] 358 s = WeakSet(items) 359 it = iter(s) 360 next(it) # Trigger internal iteration 361 # Destroy an item 362 del items[-1] 363 gc.collect() # just in case 364 # We have removed either the first consumed items, or another one 365 self.assertIn(len(list(it)), [len(items), len(items) - 1]) 366 del it 367 # The removal has been committed 368 self.assertEqual(len(s), len(items)) 369 370 def test_weak_destroy_and_mutate_while_iterating(self): 371 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 372 items = [SomeClass(c) for c in string.ascii_letters] 373 s = WeakSet(items) 374 @contextlib.contextmanager 375 def testcontext(): 376 try: 377 it = iter(s) 378 next(it) 379 # Schedule an item for removal and recreate it 380 u = SomeClass(str(items.pop())) 381 gc.collect() # just in case 382 yield u 383 finally: 384 it = None # should commit all removals 385 386 with testcontext() as u: 387 self.assertNotIn(u, s) 388 with testcontext() as u: 389 self.assertRaises(KeyError, s.remove, u) 390 self.assertNotIn(u, s) 391 with testcontext() as u: 392 s.add(u) 393 self.assertIn(u, s) 394 t = s.copy() 395 with testcontext() as u: 396 s.update(t) 397 self.assertEqual(len(s), len(t)) 398 with testcontext() as u: 399 s.clear() 400 self.assertEqual(len(s), 0) 401 402 def test_len_cycles(self): 403 N = 20 404 items = [RefCycle() for i in range(N)] 405 s = WeakSet(items) 406 del items 407 it = iter(s) 408 try: 409 next(it) 410 except StopIteration: 411 pass 412 gc.collect() 413 n1 = len(s) 414 del it 415 gc.collect() 416 n2 = len(s) 417 # one item may be kept alive inside the iterator 418 self.assertIn(n1, (0, 1)) 419 self.assertEqual(n2, 0) 420 421 def test_len_race(self): 422 # Extended sanity checks for len() in the face of cyclic collection 423 self.addCleanup(gc.set_threshold, *gc.get_threshold()) 424 for th in range(1, 100): 425 N = 20 426 gc.collect(0) 427 gc.set_threshold(th, th, th) 428 items = [RefCycle() for i in range(N)] 429 s = WeakSet(items) 430 del items 431 # All items will be collected at next garbage collection pass 432 it = iter(s) 433 try: 434 next(it) 435 except StopIteration: 436 pass 437 n1 = len(s) 438 del it 439 n2 = len(s) 440 self.assertGreaterEqual(n1, 0) 441 self.assertLessEqual(n1, N) 442 self.assertGreaterEqual(n2, 0) 443 self.assertLessEqual(n2, n1) 444 445 446 def test_main(verbose=None): 447 test_support.run_unittest(TestWeakSet) 448 449 if __name__ == "__main__": 450 test_main(verbose=True) 451