1 import unittest 2 from weakref import WeakSet 3 import string 4 from collections import UserString as ustr 5 import gc 6 import contextlib 7 8 9 class Foo: 10 pass 11 12 class RefCycle: 13 def __init__(self): 14 self.cycle = self 15 16 17 class TestWeakSet(unittest.TestCase): 18 19 def setUp(self): 20 # need to keep references to them 21 self.items = [ustr(c) for c in ('a', 'b', 'c')] 22 self.items2 = [ustr(c) for c in ('x', 'y', 'z')] 23 self.ab_items = [ustr(c) for c in 'ab'] 24 self.abcde_items = [ustr(c) for c in 'abcde'] 25 self.def_items = [ustr(c) for c in 'def'] 26 self.ab_weakset = WeakSet(self.ab_items) 27 self.abcde_weakset = WeakSet(self.abcde_items) 28 self.def_weakset = WeakSet(self.def_items) 29 self.letters = [ustr(c) for c in string.ascii_letters] 30 self.s = WeakSet(self.items) 31 self.d = dict.fromkeys(self.items) 32 self.obj = ustr('F') 33 self.fs = WeakSet([self.obj]) 34 35 def test_methods(self): 36 weaksetmethods = dir(WeakSet) 37 for method in dir(set): 38 if method == 'test_c_api' or method.startswith('_'): 39 continue 40 self.assertIn(method, weaksetmethods, 41 "WeakSet missing method " + method) 42 43 def test_new_or_init(self): 44 self.assertRaises(TypeError, WeakSet, [], 2) 45 46 def test_len(self): 47 self.assertEqual(len(self.s), len(self.d)) 48 self.assertEqual(len(self.fs), 1) 49 del self.obj 50 self.assertEqual(len(self.fs), 0) 51 52 def test_contains(self): 53 for c in self.letters: 54 self.assertEqual(c in self.s, c in self.d) 55 # 1 is not weakref'able, but that TypeError is caught by __contains__ 56 self.assertNotIn(1, self.s) 57 self.assertIn(self.obj, self.fs) 58 del self.obj 59 self.assertNotIn(ustr('F'), self.fs) 60 61 def test_union(self): 62 u = self.s.union(self.items2) 63 for c in self.letters: 64 self.assertEqual(c in u, c in self.d or c in self.items2) 65 self.assertEqual(self.s, WeakSet(self.items)) 66 self.assertEqual(type(u), WeakSet) 67 self.assertRaises(TypeError, self.s.union, [[]]) 68 for C in set, frozenset, dict.fromkeys, list, tuple: 69 x = WeakSet(self.items + self.items2) 70 c = C(self.items2) 71 self.assertEqual(self.s.union(c), x) 72 del c 73 self.assertEqual(len(u), len(self.items) + len(self.items2)) 74 self.items2.pop() 75 gc.collect() 76 self.assertEqual(len(u), len(self.items) + len(self.items2)) 77 78 def test_or(self): 79 i = self.s.union(self.items2) 80 self.assertEqual(self.s | set(self.items2), i) 81 self.assertEqual(self.s | frozenset(self.items2), i) 82 83 def test_intersection(self): 84 s = WeakSet(self.letters) 85 i = s.intersection(self.items2) 86 for c in self.letters: 87 self.assertEqual(c in i, c in self.items2 and c in self.letters) 88 self.assertEqual(s, WeakSet(self.letters)) 89 self.assertEqual(type(i), WeakSet) 90 for C in set, frozenset, dict.fromkeys, list, tuple: 91 x = WeakSet([]) 92 self.assertEqual(i.intersection(C(self.items)), x) 93 self.assertEqual(len(i), len(self.items2)) 94 self.items2.pop() 95 gc.collect() 96 self.assertEqual(len(i), len(self.items2)) 97 98 def test_isdisjoint(self): 99 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2))) 100 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters))) 101 102 def test_and(self): 103 i = self.s.intersection(self.items2) 104 self.assertEqual(self.s & set(self.items2), i) 105 self.assertEqual(self.s & frozenset(self.items2), i) 106 107 def test_difference(self): 108 i = self.s.difference(self.items2) 109 for c in self.letters: 110 self.assertEqual(c in i, c in self.d and c not in self.items2) 111 self.assertEqual(self.s, WeakSet(self.items)) 112 self.assertEqual(type(i), WeakSet) 113 self.assertRaises(TypeError, self.s.difference, [[]]) 114 115 def test_sub(self): 116 i = self.s.difference(self.items2) 117 self.assertEqual(self.s - set(self.items2), i) 118 self.assertEqual(self.s - frozenset(self.items2), i) 119 120 def test_symmetric_difference(self): 121 i = self.s.symmetric_difference(self.items2) 122 for c in self.letters: 123 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2)) 124 self.assertEqual(self.s, WeakSet(self.items)) 125 self.assertEqual(type(i), WeakSet) 126 self.assertRaises(TypeError, self.s.symmetric_difference, [[]]) 127 self.assertEqual(len(i), len(self.items) + len(self.items2)) 128 self.items2.pop() 129 gc.collect() 130 self.assertEqual(len(i), len(self.items) + len(self.items2)) 131 132 def test_xor(self): 133 i = self.s.symmetric_difference(self.items2) 134 self.assertEqual(self.s ^ set(self.items2), i) 135 self.assertEqual(self.s ^ frozenset(self.items2), i) 136 137 def test_sub_and_super(self): 138 self.assertTrue(self.ab_weakset <= self.abcde_weakset) 139 self.assertTrue(self.abcde_weakset <= self.abcde_weakset) 140 self.assertTrue(self.abcde_weakset >= self.ab_weakset) 141 self.assertFalse(self.abcde_weakset <= self.def_weakset) 142 self.assertFalse(self.abcde_weakset >= self.def_weakset) 143 self.assertTrue(set('a').issubset('abc')) 144 self.assertTrue(set('abc').issuperset('a')) 145 self.assertFalse(set('a').issubset('cbs')) 146 self.assertFalse(set('cbs').issuperset('a')) 147 148 def test_lt(self): 149 self.assertTrue(self.ab_weakset < self.abcde_weakset) 150 self.assertFalse(self.abcde_weakset < self.def_weakset) 151 self.assertFalse(self.ab_weakset < self.ab_weakset) 152 self.assertFalse(WeakSet() < WeakSet()) 153 154 def test_gt(self): 155 self.assertTrue(self.abcde_weakset > self.ab_weakset) 156 self.assertFalse(self.abcde_weakset > self.def_weakset) 157 self.assertFalse(self.ab_weakset > self.ab_weakset) 158 self.assertFalse(WeakSet() > WeakSet()) 159 160 def test_gc(self): 161 # Create a nest of cycles to exercise overall ref count check 162 s = WeakSet(Foo() for i in range(1000)) 163 for elem in s: 164 elem.cycle = s 165 elem.sub = elem 166 elem.set = WeakSet([elem]) 167 168 def test_subclass_with_custom_hash(self): 169 # Bug #1257731 170 class H(WeakSet): 171 def __hash__(self): 172 return int(id(self) & 0x7fffffff) 173 s=H() 174 f=set() 175 f.add(s) 176 self.assertIn(s, f) 177 f.remove(s) 178 f.add(s) 179 f.discard(s) 180 181 def test_init(self): 182 s = WeakSet() 183 s.__init__(self.items) 184 self.assertEqual(s, self.s) 185 s.__init__(self.items2) 186 self.assertEqual(s, WeakSet(self.items2)) 187 self.assertRaises(TypeError, s.__init__, s, 2); 188 self.assertRaises(TypeError, s.__init__, 1); 189 190 def test_constructor_identity(self): 191 s = WeakSet(self.items) 192 t = WeakSet(s) 193 self.assertNotEqual(id(s), id(t)) 194 195 def test_hash(self): 196 self.assertRaises(TypeError, hash, self.s) 197 198 def test_clear(self): 199 self.s.clear() 200 self.assertEqual(self.s, WeakSet([])) 201 self.assertEqual(len(self.s), 0) 202 203 def test_copy(self): 204 dup = self.s.copy() 205 self.assertEqual(self.s, dup) 206 self.assertNotEqual(id(self.s), id(dup)) 207 208 def test_add(self): 209 x = ustr('Q') 210 self.s.add(x) 211 self.assertIn(x, self.s) 212 dup = self.s.copy() 213 self.s.add(x) 214 self.assertEqual(self.s, dup) 215 self.assertRaises(TypeError, self.s.add, []) 216 self.fs.add(Foo()) 217 self.assertTrue(len(self.fs) == 1) 218 self.fs.add(self.obj) 219 self.assertTrue(len(self.fs) == 1) 220 221 def test_remove(self): 222 x = ustr('a') 223 self.s.remove(x) 224 self.assertNotIn(x, self.s) 225 self.assertRaises(KeyError, self.s.remove, x) 226 self.assertRaises(TypeError, self.s.remove, []) 227 228 def test_discard(self): 229 a, q = ustr('a'), ustr('Q') 230 self.s.discard(a) 231 self.assertNotIn(a, self.s) 232 self.s.discard(q) 233 self.assertRaises(TypeError, self.s.discard, []) 234 235 def test_pop(self): 236 for i in range(len(self.s)): 237 elem = self.s.pop() 238 self.assertNotIn(elem, self.s) 239 self.assertRaises(KeyError, self.s.pop) 240 241 def test_update(self): 242 retval = self.s.update(self.items2) 243 self.assertEqual(retval, None) 244 for c in (self.items + self.items2): 245 self.assertIn(c, self.s) 246 self.assertRaises(TypeError, self.s.update, [[]]) 247 248 def test_update_set(self): 249 self.s.update(set(self.items2)) 250 for c in (self.items + self.items2): 251 self.assertIn(c, self.s) 252 253 def test_ior(self): 254 self.s |= set(self.items2) 255 for c in (self.items + self.items2): 256 self.assertIn(c, self.s) 257 258 def test_intersection_update(self): 259 retval = self.s.intersection_update(self.items2) 260 self.assertEqual(retval, None) 261 for c in (self.items + self.items2): 262 if c in self.items2 and c in self.items: 263 self.assertIn(c, self.s) 264 else: 265 self.assertNotIn(c, self.s) 266 self.assertRaises(TypeError, self.s.intersection_update, [[]]) 267 268 def test_iand(self): 269 self.s &= set(self.items2) 270 for c in (self.items + self.items2): 271 if c in self.items2 and c in self.items: 272 self.assertIn(c, self.s) 273 else: 274 self.assertNotIn(c, self.s) 275 276 def test_difference_update(self): 277 retval = self.s.difference_update(self.items2) 278 self.assertEqual(retval, None) 279 for c in (self.items + self.items2): 280 if c in self.items and c not in self.items2: 281 self.assertIn(c, self.s) 282 else: 283 self.assertNotIn(c, self.s) 284 self.assertRaises(TypeError, self.s.difference_update, [[]]) 285 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 286 287 def test_isub(self): 288 self.s -= set(self.items2) 289 for c in (self.items + self.items2): 290 if c in self.items and c not in self.items2: 291 self.assertIn(c, self.s) 292 else: 293 self.assertNotIn(c, self.s) 294 295 def test_symmetric_difference_update(self): 296 retval = self.s.symmetric_difference_update(self.items2) 297 self.assertEqual(retval, None) 298 for c in (self.items + self.items2): 299 if (c in self.items) ^ (c in self.items2): 300 self.assertIn(c, self.s) 301 else: 302 self.assertNotIn(c, self.s) 303 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]]) 304 305 def test_ixor(self): 306 self.s ^= set(self.items2) 307 for c in (self.items + self.items2): 308 if (c in self.items) ^ (c in self.items2): 309 self.assertIn(c, self.s) 310 else: 311 self.assertNotIn(c, self.s) 312 313 def test_inplace_on_self(self): 314 t = self.s.copy() 315 t |= t 316 self.assertEqual(t, self.s) 317 t &= t 318 self.assertEqual(t, self.s) 319 t -= t 320 self.assertEqual(t, WeakSet()) 321 t = self.s.copy() 322 t ^= t 323 self.assertEqual(t, WeakSet()) 324 325 def test_eq(self): 326 # issue 5964 327 self.assertTrue(self.s == self.s) 328 self.assertTrue(self.s == WeakSet(self.items)) 329 self.assertFalse(self.s == set(self.items)) 330 self.assertFalse(self.s == list(self.items)) 331 self.assertFalse(self.s == tuple(self.items)) 332 self.assertFalse(self.s == WeakSet([Foo])) 333 self.assertFalse(self.s == 1) 334 335 def test_ne(self): 336 self.assertTrue(self.s != set(self.items)) 337 s1 = WeakSet() 338 s2 = WeakSet() 339 self.assertFalse(s1 != s2) 340 341 def test_weak_destroy_while_iterating(self): 342 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 343 # Create new items to be sure no-one else holds a reference 344 items = [ustr(c) for c in ('a', 'b', 'c')] 345 s = WeakSet(items) 346 it = iter(s) 347 next(it) # Trigger internal iteration 348 # Destroy an item 349 del items[-1] 350 gc.collect() # just in case 351 # We have removed either the first consumed items, or another one 352 self.assertIn(len(list(it)), [len(items), len(items) - 1]) 353 del it 354 # The removal has been committed 355 self.assertEqual(len(s), len(items)) 356 357 def test_weak_destroy_and_mutate_while_iterating(self): 358 # Issue #7105: iterators shouldn't crash when a key is implicitly removed 359 items = [ustr(c) for c in string.ascii_letters] 360 s = WeakSet(items) 361 @contextlib.contextmanager 362 def testcontext(): 363 try: 364 it = iter(s) 365 # Start iterator 366 yielded = ustr(str(next(it))) 367 # Schedule an item for removal and recreate it 368 u = ustr(str(items.pop())) 369 if yielded == u: 370 # The iterator still has a reference to the removed item, 371 # advance it (issue #20006). 372 next(it) 373 gc.collect() # just in case 374 yield u 375 finally: 376 it = None # should commit all removals 377 378 with testcontext() as u: 379 self.assertNotIn(u, s) 380 with testcontext() as u: 381 self.assertRaises(KeyError, s.remove, u) 382 self.assertNotIn(u, s) 383 with testcontext() as u: 384 s.add(u) 385 self.assertIn(u, s) 386 t = s.copy() 387 with testcontext() as u: 388 s.update(t) 389 self.assertEqual(len(s), len(t)) 390 with testcontext() as u: 391 s.clear() 392 self.assertEqual(len(s), 0) 393 394 def test_len_cycles(self): 395 N = 20 396 items = [RefCycle() for i in range(N)] 397 s = WeakSet(items) 398 del items 399 it = iter(s) 400 try: 401 next(it) 402 except StopIteration: 403 pass 404 gc.collect() 405 n1 = len(s) 406 del it 407 gc.collect() 408 n2 = len(s) 409 # one item may be kept alive inside the iterator 410 self.assertIn(n1, (0, 1)) 411 self.assertEqual(n2, 0) 412 413 def test_len_race(self): 414 # Extended sanity checks for len() in the face of cyclic collection 415 self.addCleanup(gc.set_threshold, *gc.get_threshold()) 416 for th in range(1, 100): 417 N = 20 418 gc.collect(0) 419 gc.set_threshold(th, th, th) 420 items = [RefCycle() for i in range(N)] 421 s = WeakSet(items) 422 del items 423 # All items will be collected at next garbage collection pass 424 it = iter(s) 425 try: 426 next(it) 427 except StopIteration: 428 pass 429 n1 = len(s) 430 del it 431 n2 = len(s) 432 self.assertGreaterEqual(n1, 0) 433 self.assertLessEqual(n1, N) 434 self.assertGreaterEqual(n2, 0) 435 self.assertLessEqual(n2, n1) 436 437 438 if __name__ == "__main__": 439 unittest.main() 440