Home | History | Annotate | Download | only in test
      1 
      2 import unittest
      3 from test import test_support
      4 import gc
      5 import weakref
      6 import operator
      7 import copy
      8 import pickle
      9 from random import randrange, shuffle
     10 import sys
     11 import collections
     12 
     13 class PassThru(Exception):
     14     pass
     15 
     16 def check_pass_thru():
     17     raise PassThru
     18     yield 1
     19 
     20 class BadCmp:
     21     def __hash__(self):
     22         return 1
     23     def __cmp__(self, other):
     24         raise RuntimeError
     25 
     26 class ReprWrapper:
     27     'Used to test self-referential repr() calls'
     28     def __repr__(self):
     29         return repr(self.value)
     30 
     31 class HashCountingInt(int):
     32     'int-like object that counts the number of times __hash__ is called'
     33     def __init__(self, *args):
     34         self.hash_count = 0
     35     def __hash__(self):
     36         self.hash_count += 1
     37         return int.__hash__(self)
     38 
     39 class TestJointOps(unittest.TestCase):
     40     # Tests common to both set and frozenset
     41 
     42     def setUp(self):
     43         self.word = word = 'simsalabim'
     44         self.otherword = 'madagascar'
     45         self.letters = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
     46         self.s = self.thetype(word)
     47         self.d = dict.fromkeys(word)
     48 
     49     def test_new_or_init(self):
     50         self.assertRaises(TypeError, self.thetype, [], 2)
     51         self.assertRaises(TypeError, set().__init__, a=1)
     52 
     53     def test_uniquification(self):
     54         actual = sorted(self.s)
     55         expected = sorted(self.d)
     56         self.assertEqual(actual, expected)
     57         self.assertRaises(PassThru, self.thetype, check_pass_thru())
     58         self.assertRaises(TypeError, self.thetype, [[]])
     59 
     60     def test_len(self):
     61         self.assertEqual(len(self.s), len(self.d))
     62 
     63     def test_contains(self):
     64         for c in self.letters:
     65             self.assertEqual(c in self.s, c in self.d)
     66         self.assertRaises(TypeError, self.s.__contains__, [[]])
     67         s = self.thetype([frozenset(self.letters)])
     68         self.assertIn(self.thetype(self.letters), s)
     69 
     70     def test_union(self):
     71         u = self.s.union(self.otherword)
     72         for c in self.letters:
     73             self.assertEqual(c in u, c in self.d or c in self.otherword)
     74         self.assertEqual(self.s, self.thetype(self.word))
     75         self.assertEqual(type(u), self.thetype)
     76         self.assertRaises(PassThru, self.s.union, check_pass_thru())
     77         self.assertRaises(TypeError, self.s.union, [[]])
     78         for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
     79             self.assertEqual(self.thetype('abcba').union(C('cdc')), set('abcd'))
     80             self.assertEqual(self.thetype('abcba').union(C('efgfe')), set('abcefg'))
     81             self.assertEqual(self.thetype('abcba').union(C('ccb')), set('abc'))
     82             self.assertEqual(self.thetype('abcba').union(C('ef')), set('abcef'))
     83             self.assertEqual(self.thetype('abcba').union(C('ef'), C('fg')), set('abcefg'))
     84 
     85         # Issue #6573
     86         x = self.thetype()
     87         self.assertEqual(x.union(set([1]), x, set([2])), self.thetype([1, 2]))
     88 
     89     def test_or(self):
     90         i = self.s.union(self.otherword)
     91         self.assertEqual(self.s | set(self.otherword), i)
     92         self.assertEqual(self.s | frozenset(self.otherword), i)
     93         try:
     94             self.s | self.otherword
     95         except TypeError:
     96             pass
     97         else:
     98             self.fail("s|t did not screen-out general iterables")
     99 
    100     def test_intersection(self):
    101         i = self.s.intersection(self.otherword)
    102         for c in self.letters:
    103             self.assertEqual(c in i, c in self.d and c in self.otherword)
    104         self.assertEqual(self.s, self.thetype(self.word))
    105         self.assertEqual(type(i), self.thetype)
    106         self.assertRaises(PassThru, self.s.intersection, check_pass_thru())
    107         for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    108             self.assertEqual(self.thetype('abcba').intersection(C('cdc')), set('cc'))
    109             self.assertEqual(self.thetype('abcba').intersection(C('efgfe')), set(''))
    110             self.assertEqual(self.thetype('abcba').intersection(C('ccb')), set('bc'))
    111             self.assertEqual(self.thetype('abcba').intersection(C('ef')), set(''))
    112             self.assertEqual(self.thetype('abcba').intersection(C('cbcf'), C('bag')), set('b'))
    113         s = self.thetype('abcba')
    114         z = s.intersection()
    115         if self.thetype == frozenset():
    116             self.assertEqual(id(s), id(z))
    117         else:
    118             self.assertNotEqual(id(s), id(z))
    119 
    120     def test_isdisjoint(self):
    121         def f(s1, s2):
    122             'Pure python equivalent of isdisjoint()'
    123             return not set(s1).intersection(s2)
    124         for larg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
    125             s1 = self.thetype(larg)
    126             for rarg in '', 'a', 'ab', 'abc', 'ababac', 'cdc', 'cc', 'efgfe', 'ccb', 'ef':
    127                 for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    128                     s2 = C(rarg)
    129                     actual = s1.isdisjoint(s2)
    130                     expected = f(s1, s2)
    131                     self.assertEqual(actual, expected)
    132                     self.assertTrue(actual is True or actual is False)
    133 
    134     def test_and(self):
    135         i = self.s.intersection(self.otherword)
    136         self.assertEqual(self.s & set(self.otherword), i)
    137         self.assertEqual(self.s & frozenset(self.otherword), i)
    138         try:
    139             self.s & self.otherword
    140         except TypeError:
    141             pass
    142         else:
    143             self.fail("s&t did not screen-out general iterables")
    144 
    145     def test_difference(self):
    146         i = self.s.difference(self.otherword)
    147         for c in self.letters:
    148             self.assertEqual(c in i, c in self.d and c not in self.otherword)
    149         self.assertEqual(self.s, self.thetype(self.word))
    150         self.assertEqual(type(i), self.thetype)
    151         self.assertRaises(PassThru, self.s.difference, check_pass_thru())
    152         self.assertRaises(TypeError, self.s.difference, [[]])
    153         for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    154             self.assertEqual(self.thetype('abcba').difference(C('cdc')), set('ab'))
    155             self.assertEqual(self.thetype('abcba').difference(C('efgfe')), set('abc'))
    156             self.assertEqual(self.thetype('abcba').difference(C('ccb')), set('a'))
    157             self.assertEqual(self.thetype('abcba').difference(C('ef')), set('abc'))
    158             self.assertEqual(self.thetype('abcba').difference(), set('abc'))
    159             self.assertEqual(self.thetype('abcba').difference(C('a'), C('b')), set('c'))
    160 
    161     def test_sub(self):
    162         i = self.s.difference(self.otherword)
    163         self.assertEqual(self.s - set(self.otherword), i)
    164         self.assertEqual(self.s - frozenset(self.otherword), i)
    165         try:
    166             self.s - self.otherword
    167         except TypeError:
    168             pass
    169         else:
    170             self.fail("s-t did not screen-out general iterables")
    171 
    172     def test_symmetric_difference(self):
    173         i = self.s.symmetric_difference(self.otherword)
    174         for c in self.letters:
    175             self.assertEqual(c in i, (c in self.d) ^ (c in self.otherword))
    176         self.assertEqual(self.s, self.thetype(self.word))
    177         self.assertEqual(type(i), self.thetype)
    178         self.assertRaises(PassThru, self.s.symmetric_difference, check_pass_thru())
    179         self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
    180         for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    181             self.assertEqual(self.thetype('abcba').symmetric_difference(C('cdc')), set('abd'))
    182             self.assertEqual(self.thetype('abcba').symmetric_difference(C('efgfe')), set('abcefg'))
    183             self.assertEqual(self.thetype('abcba').symmetric_difference(C('ccb')), set('a'))
    184             self.assertEqual(self.thetype('abcba').symmetric_difference(C('ef')), set('abcef'))
    185 
    186     def test_xor(self):
    187         i = self.s.symmetric_difference(self.otherword)
    188         self.assertEqual(self.s ^ set(self.otherword), i)
    189         self.assertEqual(self.s ^ frozenset(self.otherword), i)
    190         try:
    191             self.s ^ self.otherword
    192         except TypeError:
    193             pass
    194         else:
    195             self.fail("s^t did not screen-out general iterables")
    196 
    197     def test_equality(self):
    198         self.assertEqual(self.s, set(self.word))
    199         self.assertEqual(self.s, frozenset(self.word))
    200         self.assertEqual(self.s == self.word, False)
    201         self.assertNotEqual(self.s, set(self.otherword))
    202         self.assertNotEqual(self.s, frozenset(self.otherword))
    203         self.assertEqual(self.s != self.word, True)
    204 
    205     def test_setOfFrozensets(self):
    206         t = map(frozenset, ['abcdef', 'bcd', 'bdcb', 'fed', 'fedccba'])
    207         s = self.thetype(t)
    208         self.assertEqual(len(s), 3)
    209 
    210     def test_compare(self):
    211         self.assertRaises(TypeError, self.s.__cmp__, self.s)
    212 
    213     def test_sub_and_super(self):
    214         p, q, r = map(self.thetype, ['ab', 'abcde', 'def'])
    215         self.assertTrue(p < q)
    216         self.assertTrue(p <= q)
    217         self.assertTrue(q <= q)
    218         self.assertTrue(q > p)
    219         self.assertTrue(q >= p)
    220         self.assertFalse(q < r)
    221         self.assertFalse(q <= r)
    222         self.assertFalse(q > r)
    223         self.assertFalse(q >= r)
    224         self.assertTrue(set('a').issubset('abc'))
    225         self.assertTrue(set('abc').issuperset('a'))
    226         self.assertFalse(set('a').issubset('cbs'))
    227         self.assertFalse(set('cbs').issuperset('a'))
    228 
    229     def test_pickling(self):
    230         for i in range(pickle.HIGHEST_PROTOCOL + 1):
    231             p = pickle.dumps(self.s, i)
    232             dup = pickle.loads(p)
    233             self.assertEqual(self.s, dup, "%s != %s" % (self.s, dup))
    234             if type(self.s) not in (set, frozenset):
    235                 self.s.x = 10
    236                 p = pickle.dumps(self.s, i)
    237                 dup = pickle.loads(p)
    238                 self.assertEqual(self.s.x, dup.x)
    239 
    240     def test_deepcopy(self):
    241         class Tracer:
    242             def __init__(self, value):
    243                 self.value = value
    244             def __hash__(self):
    245                 return self.value
    246             def __deepcopy__(self, memo=None):
    247                 return Tracer(self.value + 1)
    248         t = Tracer(10)
    249         s = self.thetype([t])
    250         dup = copy.deepcopy(s)
    251         self.assertNotEqual(id(s), id(dup))
    252         for elem in dup:
    253             newt = elem
    254         self.assertNotEqual(id(t), id(newt))
    255         self.assertEqual(t.value + 1, newt.value)
    256 
    257     def test_gc(self):
    258         # Create a nest of cycles to exercise overall ref count check
    259         class A:
    260             pass
    261         s = set(A() for i in xrange(1000))
    262         for elem in s:
    263             elem.cycle = s
    264             elem.sub = elem
    265             elem.set = set([elem])
    266 
    267     def test_subclass_with_custom_hash(self):
    268         # Bug #1257731
    269         class H(self.thetype):
    270             def __hash__(self):
    271                 return int(id(self) & 0x7fffffff)
    272         s=H()
    273         f=set()
    274         f.add(s)
    275         self.assertIn(s, f)
    276         f.remove(s)
    277         f.add(s)
    278         f.discard(s)
    279 
    280     def test_badcmp(self):
    281         s = self.thetype([BadCmp()])
    282         # Detect comparison errors during insertion and lookup
    283         self.assertRaises(RuntimeError, self.thetype, [BadCmp(), BadCmp()])
    284         self.assertRaises(RuntimeError, s.__contains__, BadCmp())
    285         # Detect errors during mutating operations
    286         if hasattr(s, 'add'):
    287             self.assertRaises(RuntimeError, s.add, BadCmp())
    288             self.assertRaises(RuntimeError, s.discard, BadCmp())
    289             self.assertRaises(RuntimeError, s.remove, BadCmp())
    290 
    291     def test_cyclical_repr(self):
    292         w = ReprWrapper()
    293         s = self.thetype([w])
    294         w.value = s
    295         name = repr(s).partition('(')[0]    # strip class name from repr string
    296         self.assertEqual(repr(s), '%s([%s(...)])' % (name, name))
    297 
    298     def test_cyclical_print(self):
    299         w = ReprWrapper()
    300         s = self.thetype([w])
    301         w.value = s
    302         fo = open(test_support.TESTFN, "wb")
    303         try:
    304             print >> fo, s,
    305             fo.close()
    306             fo = open(test_support.TESTFN, "rb")
    307             self.assertEqual(fo.read(), repr(s))
    308         finally:
    309             fo.close()
    310             test_support.unlink(test_support.TESTFN)
    311 
    312     def test_do_not_rehash_dict_keys(self):
    313         n = 10
    314         d = dict.fromkeys(map(HashCountingInt, xrange(n)))
    315         self.assertEqual(sum(elem.hash_count for elem in d), n)
    316         s = self.thetype(d)
    317         self.assertEqual(sum(elem.hash_count for elem in d), n)
    318         s.difference(d)
    319         self.assertEqual(sum(elem.hash_count for elem in d), n)
    320         if hasattr(s, 'symmetric_difference_update'):
    321             s.symmetric_difference_update(d)
    322         self.assertEqual(sum(elem.hash_count for elem in d), n)
    323         d2 = dict.fromkeys(set(d))
    324         self.assertEqual(sum(elem.hash_count for elem in d), n)
    325         d3 = dict.fromkeys(frozenset(d))
    326         self.assertEqual(sum(elem.hash_count for elem in d), n)
    327         d3 = dict.fromkeys(frozenset(d), 123)
    328         self.assertEqual(sum(elem.hash_count for elem in d), n)
    329         self.assertEqual(d3, dict.fromkeys(d, 123))
    330 
    331     def test_container_iterator(self):
    332         # Bug #3680: tp_traverse was not implemented for set iterator object
    333         class C(object):
    334             pass
    335         obj = C()
    336         ref = weakref.ref(obj)
    337         container = set([obj, 1])
    338         obj.x = iter(container)
    339         del obj, container
    340         gc.collect()
    341         self.assertTrue(ref() is None, "Cycle was not collected")
    342 
    343     def test_free_after_iterating(self):
    344         test_support.check_free_after_iterating(self, iter, self.thetype)
    345 
    346 class TestSet(TestJointOps):
    347     thetype = set
    348 
    349     def test_init(self):
    350         s = self.thetype()
    351         s.__init__(self.word)
    352         self.assertEqual(s, set(self.word))
    353         s.__init__(self.otherword)
    354         self.assertEqual(s, set(self.otherword))
    355         self.assertRaises(TypeError, s.__init__, s, 2);
    356         self.assertRaises(TypeError, s.__init__, 1);
    357 
    358     def test_constructor_identity(self):
    359         s = self.thetype(range(3))
    360         t = self.thetype(s)
    361         self.assertNotEqual(id(s), id(t))
    362 
    363     def test_set_literal_insertion_order(self):
    364         # SF Issue #26020 -- Expect left to right insertion
    365         s = {1, 1.0, True}
    366         self.assertEqual(len(s), 1)
    367         stored_value = s.pop()
    368         self.assertEqual(type(stored_value), int)
    369 
    370     def test_set_literal_evaluation_order(self):
    371         # Expect left to right expression evaluation
    372         events = []
    373         def record(obj):
    374             events.append(obj)
    375         s = {record(1), record(2), record(3)}
    376         self.assertEqual(events, [1, 2, 3])
    377 
    378     def test_hash(self):
    379         self.assertRaises(TypeError, hash, self.s)
    380 
    381     def test_clear(self):
    382         self.s.clear()
    383         self.assertEqual(self.s, set())
    384         self.assertEqual(len(self.s), 0)
    385 
    386     def test_copy(self):
    387         dup = self.s.copy()
    388         self.assertEqual(self.s, dup)
    389         self.assertNotEqual(id(self.s), id(dup))
    390 
    391     def test_add(self):
    392         self.s.add('Q')
    393         self.assertIn('Q', self.s)
    394         dup = self.s.copy()
    395         self.s.add('Q')
    396         self.assertEqual(self.s, dup)
    397         self.assertRaises(TypeError, self.s.add, [])
    398 
    399     def test_remove(self):
    400         self.s.remove('a')
    401         self.assertNotIn('a', self.s)
    402         self.assertRaises(KeyError, self.s.remove, 'Q')
    403         self.assertRaises(TypeError, self.s.remove, [])
    404         s = self.thetype([frozenset(self.word)])
    405         self.assertIn(self.thetype(self.word), s)
    406         s.remove(self.thetype(self.word))
    407         self.assertNotIn(self.thetype(self.word), s)
    408         self.assertRaises(KeyError, self.s.remove, self.thetype(self.word))
    409 
    410     def test_remove_keyerror_unpacking(self):
    411         # bug:  www.python.org/sf/1576657
    412         for v1 in ['Q', (1,)]:
    413             try:
    414                 self.s.remove(v1)
    415             except KeyError, e:
    416                 v2 = e.args[0]
    417                 self.assertEqual(v1, v2)
    418             else:
    419                 self.fail()
    420 
    421     def test_remove_keyerror_set(self):
    422         key = self.thetype([3, 4])
    423         try:
    424             self.s.remove(key)
    425         except KeyError as e:
    426             self.assertTrue(e.args[0] is key,
    427                          "KeyError should be {0}, not {1}".format(key,
    428                                                                   e.args[0]))
    429         else:
    430             self.fail()
    431 
    432     def test_discard(self):
    433         self.s.discard('a')
    434         self.assertNotIn('a', self.s)
    435         self.s.discard('Q')
    436         self.assertRaises(TypeError, self.s.discard, [])
    437         s = self.thetype([frozenset(self.word)])
    438         self.assertIn(self.thetype(self.word), s)
    439         s.discard(self.thetype(self.word))
    440         self.assertNotIn(self.thetype(self.word), s)
    441         s.discard(self.thetype(self.word))
    442 
    443     def test_pop(self):
    444         for i in xrange(len(self.s)):
    445             elem = self.s.pop()
    446             self.assertNotIn(elem, self.s)
    447         self.assertRaises(KeyError, self.s.pop)
    448 
    449     def test_update(self):
    450         retval = self.s.update(self.otherword)
    451         self.assertEqual(retval, None)
    452         for c in (self.word + self.otherword):
    453             self.assertIn(c, self.s)
    454         self.assertRaises(PassThru, self.s.update, check_pass_thru())
    455         self.assertRaises(TypeError, self.s.update, [[]])
    456         for p, q in (('cdc', 'abcd'), ('efgfe', 'abcefg'), ('ccb', 'abc'), ('ef', 'abcef')):
    457             for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    458                 s = self.thetype('abcba')
    459                 self.assertEqual(s.update(C(p)), None)
    460                 self.assertEqual(s, set(q))
    461         for p in ('cdc', 'efgfe', 'ccb', 'ef', 'abcda'):
    462             q = 'ahi'
    463             for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    464                 s = self.thetype('abcba')
    465                 self.assertEqual(s.update(C(p), C(q)), None)
    466                 self.assertEqual(s, set(s) | set(p) | set(q))
    467 
    468     def test_ior(self):
    469         self.s |= set(self.otherword)
    470         for c in (self.word + self.otherword):
    471             self.assertIn(c, self.s)
    472 
    473     def test_intersection_update(self):
    474         retval = self.s.intersection_update(self.otherword)
    475         self.assertEqual(retval, None)
    476         for c in (self.word + self.otherword):
    477             if c in self.otherword and c in self.word:
    478                 self.assertIn(c, self.s)
    479             else:
    480                 self.assertNotIn(c, self.s)
    481         self.assertRaises(PassThru, self.s.intersection_update, check_pass_thru())
    482         self.assertRaises(TypeError, self.s.intersection_update, [[]])
    483         for p, q in (('cdc', 'c'), ('efgfe', ''), ('ccb', 'bc'), ('ef', '')):
    484             for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    485                 s = self.thetype('abcba')
    486                 self.assertEqual(s.intersection_update(C(p)), None)
    487                 self.assertEqual(s, set(q))
    488                 ss = 'abcba'
    489                 s = self.thetype(ss)
    490                 t = 'cbc'
    491                 self.assertEqual(s.intersection_update(C(p), C(t)), None)
    492                 self.assertEqual(s, set('abcba')&set(p)&set(t))
    493 
    494     def test_iand(self):
    495         self.s &= set(self.otherword)
    496         for c in (self.word + self.otherword):
    497             if c in self.otherword and c in self.word:
    498                 self.assertIn(c, self.s)
    499             else:
    500                 self.assertNotIn(c, self.s)
    501 
    502     def test_difference_update(self):
    503         retval = self.s.difference_update(self.otherword)
    504         self.assertEqual(retval, None)
    505         for c in (self.word + self.otherword):
    506             if c in self.word and c not in self.otherword:
    507                 self.assertIn(c, self.s)
    508             else:
    509                 self.assertNotIn(c, self.s)
    510         self.assertRaises(PassThru, self.s.difference_update, check_pass_thru())
    511         self.assertRaises(TypeError, self.s.difference_update, [[]])
    512         self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
    513         for p, q in (('cdc', 'ab'), ('efgfe', 'abc'), ('ccb', 'a'), ('ef', 'abc')):
    514             for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    515                 s = self.thetype('abcba')
    516                 self.assertEqual(s.difference_update(C(p)), None)
    517                 self.assertEqual(s, set(q))
    518 
    519                 s = self.thetype('abcdefghih')
    520                 s.difference_update()
    521                 self.assertEqual(s, self.thetype('abcdefghih'))
    522 
    523                 s = self.thetype('abcdefghih')
    524                 s.difference_update(C('aba'))
    525                 self.assertEqual(s, self.thetype('cdefghih'))
    526 
    527                 s = self.thetype('abcdefghih')
    528                 s.difference_update(C('cdc'), C('aba'))
    529                 self.assertEqual(s, self.thetype('efghih'))
    530 
    531     def test_isub(self):
    532         self.s -= set(self.otherword)
    533         for c in (self.word + self.otherword):
    534             if c in self.word and c not in self.otherword:
    535                 self.assertIn(c, self.s)
    536             else:
    537                 self.assertNotIn(c, self.s)
    538 
    539     def test_symmetric_difference_update(self):
    540         retval = self.s.symmetric_difference_update(self.otherword)
    541         self.assertEqual(retval, None)
    542         for c in (self.word + self.otherword):
    543             if (c in self.word) ^ (c in self.otherword):
    544                 self.assertIn(c, self.s)
    545             else:
    546                 self.assertNotIn(c, self.s)
    547         self.assertRaises(PassThru, self.s.symmetric_difference_update, check_pass_thru())
    548         self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
    549         for p, q in (('cdc', 'abd'), ('efgfe', 'abcefg'), ('ccb', 'a'), ('ef', 'abcef')):
    550             for C in set, frozenset, dict.fromkeys, str, unicode, list, tuple:
    551                 s = self.thetype('abcba')
    552                 self.assertEqual(s.symmetric_difference_update(C(p)), None)
    553                 self.assertEqual(s, set(q))
    554 
    555     def test_ixor(self):
    556         self.s ^= set(self.otherword)
    557         for c in (self.word + self.otherword):
    558             if (c in self.word) ^ (c in self.otherword):
    559                 self.assertIn(c, self.s)
    560             else:
    561                 self.assertNotIn(c, self.s)
    562 
    563     def test_inplace_on_self(self):
    564         t = self.s.copy()
    565         t |= t
    566         self.assertEqual(t, self.s)
    567         t &= t
    568         self.assertEqual(t, self.s)
    569         t -= t
    570         self.assertEqual(t, self.thetype())
    571         t = self.s.copy()
    572         t ^= t
    573         self.assertEqual(t, self.thetype())
    574 
    575     def test_weakref(self):
    576         s = self.thetype('gallahad')
    577         p = weakref.proxy(s)
    578         self.assertEqual(str(p), str(s))
    579         s = None
    580         self.assertRaises(ReferenceError, str, p)
    581 
    582     @unittest.skipUnless(hasattr(set, "test_c_api"),
    583                          'C API test only available in a debug build')
    584     def test_c_api(self):
    585         self.assertEqual(set().test_c_api(), True)
    586 
    587 class SetSubclass(set):
    588     pass
    589 
    590 class TestSetSubclass(TestSet):
    591     thetype = SetSubclass
    592 
    593 class SetSubclassWithKeywordArgs(set):
    594     def __init__(self, iterable=[], newarg=None):
    595         set.__init__(self, iterable)
    596 
    597 class TestSetSubclassWithKeywordArgs(TestSet):
    598 
    599     def test_keywords_in_subclass(self):
    600         'SF bug #1486663 -- this used to erroneously raise a TypeError'
    601         SetSubclassWithKeywordArgs(newarg=1)
    602 
    603 class TestFrozenSet(TestJointOps):
    604     thetype = frozenset
    605 
    606     def test_init(self):
    607         s = self.thetype(self.word)
    608         s.__init__(self.otherword)
    609         self.assertEqual(s, set(self.word))
    610 
    611     def test_singleton_empty_frozenset(self):
    612         f = frozenset()
    613         efs = [frozenset(), frozenset([]), frozenset(()), frozenset(''),
    614                frozenset(), frozenset([]), frozenset(()), frozenset(''),
    615                frozenset(xrange(0)), frozenset(frozenset()),
    616                frozenset(f), f]
    617         # All of the empty frozensets should have just one id()
    618         self.assertEqual(len(set(map(id, efs))), 1)
    619 
    620     def test_constructor_identity(self):
    621         s = self.thetype(range(3))
    622         t = self.thetype(s)
    623         self.assertEqual(id(s), id(t))
    624 
    625     def test_hash(self):
    626         self.assertEqual(hash(self.thetype('abcdeb')),
    627                          hash(self.thetype('ebecda')))
    628 
    629         # make sure that all permutations give the same hash value
    630         n = 100
    631         seq = [randrange(n) for i in xrange(n)]
    632         results = set()
    633         for i in xrange(200):
    634             shuffle(seq)
    635             results.add(hash(self.thetype(seq)))
    636         self.assertEqual(len(results), 1)
    637 
    638     def test_copy(self):
    639         dup = self.s.copy()
    640         self.assertEqual(id(self.s), id(dup))
    641 
    642     def test_frozen_as_dictkey(self):
    643         seq = range(10) + list('abcdefg') + ['apple']
    644         key1 = self.thetype(seq)
    645         key2 = self.thetype(reversed(seq))
    646         self.assertEqual(key1, key2)
    647         self.assertNotEqual(id(key1), id(key2))
    648         d = {}
    649         d[key1] = 42
    650         self.assertEqual(d[key2], 42)
    651 
    652     def test_hash_caching(self):
    653         f = self.thetype('abcdcda')
    654         self.assertEqual(hash(f), hash(f))
    655 
    656     def test_hash_effectiveness(self):
    657         n = 13
    658         hashvalues = set()
    659         addhashvalue = hashvalues.add
    660         elemmasks = [(i+1, 1<<i) for i in range(n)]
    661         for i in xrange(2**n):
    662             addhashvalue(hash(frozenset([e for e, m in elemmasks if m&i])))
    663         self.assertEqual(len(hashvalues), 2**n)
    664 
    665 class FrozenSetSubclass(frozenset):
    666     pass
    667 
    668 class TestFrozenSetSubclass(TestFrozenSet):
    669     thetype = FrozenSetSubclass
    670 
    671     def test_constructor_identity(self):
    672         s = self.thetype(range(3))
    673         t = self.thetype(s)
    674         self.assertNotEqual(id(s), id(t))
    675 
    676     def test_copy(self):
    677         dup = self.s.copy()
    678         self.assertNotEqual(id(self.s), id(dup))
    679 
    680     def test_nested_empty_constructor(self):
    681         s = self.thetype()
    682         t = self.thetype(s)
    683         self.assertEqual(s, t)
    684 
    685     def test_singleton_empty_frozenset(self):
    686         Frozenset = self.thetype
    687         f = frozenset()
    688         F = Frozenset()
    689         efs = [Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
    690                Frozenset(), Frozenset([]), Frozenset(()), Frozenset(''),
    691                Frozenset(xrange(0)), Frozenset(Frozenset()),
    692                Frozenset(frozenset()), f, F, Frozenset(f), Frozenset(F)]
    693         # All empty frozenset subclass instances should have different ids
    694         self.assertEqual(len(set(map(id, efs))), len(efs))
    695 
    696 # Tests taken from test_sets.py =============================================
    697 
    698 empty_set = set()
    699 
    700 #==============================================================================
    701 
    702 class TestBasicOps(unittest.TestCase):
    703 
    704     def test_repr(self):
    705         if self.repr is not None:
    706             self.assertEqual(repr(self.set), self.repr)
    707 
    708     def check_repr_against_values(self):
    709         text = repr(self.set)
    710         self.assertTrue(text.startswith('{'))
    711         self.assertTrue(text.endswith('}'))
    712 
    713         result = text[1:-1].split(', ')
    714         result.sort()
    715         sorted_repr_values = [repr(value) for value in self.values]
    716         sorted_repr_values.sort()
    717         self.assertEqual(result, sorted_repr_values)
    718 
    719     def test_print(self):
    720         fo = open(test_support.TESTFN, "wb")
    721         try:
    722             print >> fo, self.set,
    723             fo.close()
    724             fo = open(test_support.TESTFN, "rb")
    725             self.assertEqual(fo.read(), repr(self.set))
    726         finally:
    727             fo.close()
    728             test_support.unlink(test_support.TESTFN)
    729 
    730     def test_length(self):
    731         self.assertEqual(len(self.set), self.length)
    732 
    733     def test_self_equality(self):
    734         self.assertEqual(self.set, self.set)
    735 
    736     def test_equivalent_equality(self):
    737         self.assertEqual(self.set, self.dup)
    738 
    739     def test_copy(self):
    740         self.assertEqual(self.set.copy(), self.dup)
    741 
    742     def test_self_union(self):
    743         result = self.set | self.set
    744         self.assertEqual(result, self.dup)
    745 
    746     def test_empty_union(self):
    747         result = self.set | empty_set
    748         self.assertEqual(result, self.dup)
    749 
    750     def test_union_empty(self):
    751         result = empty_set | self.set
    752         self.assertEqual(result, self.dup)
    753 
    754     def test_self_intersection(self):
    755         result = self.set & self.set
    756         self.assertEqual(result, self.dup)
    757 
    758     def test_empty_intersection(self):
    759         result = self.set & empty_set
    760         self.assertEqual(result, empty_set)
    761 
    762     def test_intersection_empty(self):
    763         result = empty_set & self.set
    764         self.assertEqual(result, empty_set)
    765 
    766     def test_self_isdisjoint(self):
    767         result = self.set.isdisjoint(self.set)
    768         self.assertEqual(result, not self.set)
    769 
    770     def test_empty_isdisjoint(self):
    771         result = self.set.isdisjoint(empty_set)
    772         self.assertEqual(result, True)
    773 
    774     def test_isdisjoint_empty(self):
    775         result = empty_set.isdisjoint(self.set)
    776         self.assertEqual(result, True)
    777 
    778     def test_self_symmetric_difference(self):
    779         result = self.set ^ self.set
    780         self.assertEqual(result, empty_set)
    781 
    782     def test_empty_symmetric_difference(self):
    783         result = self.set ^ empty_set
    784         self.assertEqual(result, self.set)
    785 
    786     def test_self_difference(self):
    787         result = self.set - self.set
    788         self.assertEqual(result, empty_set)
    789 
    790     def test_empty_difference(self):
    791         result = self.set - empty_set
    792         self.assertEqual(result, self.dup)
    793 
    794     def test_empty_difference_rev(self):
    795         result = empty_set - self.set
    796         self.assertEqual(result, empty_set)
    797 
    798     def test_iteration(self):
    799         for v in self.set:
    800             self.assertIn(v, self.values)
    801         setiter = iter(self.set)
    802         # note: __length_hint__ is an internal undocumented API,
    803         # don't rely on it in your own programs
    804         self.assertEqual(setiter.__length_hint__(), len(self.set))
    805 
    806     def test_pickling(self):
    807         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    808             p = pickle.dumps(self.set, proto)
    809             copy = pickle.loads(p)
    810             self.assertEqual(self.set, copy,
    811                              "%s != %s" % (self.set, copy))
    812 
    813 #------------------------------------------------------------------------------
    814 
    815 class TestBasicOpsEmpty(TestBasicOps):
    816     def setUp(self):
    817         self.case   = "empty set"
    818         self.values = []
    819         self.set    = set(self.values)
    820         self.dup    = set(self.values)
    821         self.length = 0
    822         self.repr   = "set([])"
    823 
    824 #------------------------------------------------------------------------------
    825 
    826 class TestBasicOpsSingleton(TestBasicOps):
    827     def setUp(self):
    828         self.case   = "unit set (number)"
    829         self.values = [3]
    830         self.set    = set(self.values)
    831         self.dup    = set(self.values)
    832         self.length = 1
    833         self.repr   = "set([3])"
    834 
    835     def test_in(self):
    836         self.assertIn(3, self.set)
    837 
    838     def test_not_in(self):
    839         self.assertNotIn(2, self.set)
    840 
    841 #------------------------------------------------------------------------------
    842 
    843 class TestBasicOpsTuple(TestBasicOps):
    844     def setUp(self):
    845         self.case   = "unit set (tuple)"
    846         self.values = [(0, "zero")]
    847         self.set    = set(self.values)
    848         self.dup    = set(self.values)
    849         self.length = 1
    850         self.repr   = "set([(0, 'zero')])"
    851 
    852     def test_in(self):
    853         self.assertIn((0, "zero"), self.set)
    854 
    855     def test_not_in(self):
    856         self.assertNotIn(9, self.set)
    857 
    858 #------------------------------------------------------------------------------
    859 
    860 class TestBasicOpsTriple(TestBasicOps):
    861     def setUp(self):
    862         self.case   = "triple set"
    863         self.values = [0, "zero", operator.add]
    864         self.set    = set(self.values)
    865         self.dup    = set(self.values)
    866         self.length = 3
    867         self.repr   = None
    868 
    869 #------------------------------------------------------------------------------
    870 
    871 class TestBasicOpsString(TestBasicOps):
    872     def setUp(self):
    873         self.case   = "string set"
    874         self.values = ["a", "b", "c"]
    875         self.set    = set(self.values)
    876         self.dup    = set(self.values)
    877         self.length = 3
    878 
    879     def test_repr(self):
    880         self.check_repr_against_values()
    881 
    882 #------------------------------------------------------------------------------
    883 
    884 class TestBasicOpsUnicode(TestBasicOps):
    885     def setUp(self):
    886         self.case   = "unicode set"
    887         self.values = [u"a", u"b", u"c"]
    888         self.set    = set(self.values)
    889         self.dup    = set(self.values)
    890         self.length = 3
    891 
    892     def test_repr(self):
    893         self.check_repr_against_values()
    894 
    895 #------------------------------------------------------------------------------
    896 
    897 class TestBasicOpsMixedStringUnicode(TestBasicOps):
    898     def setUp(self):
    899         self.case   = "string and bytes set"
    900         self.values = ["a", "b", u"a", u"b"]
    901         self.set    = set(self.values)
    902         self.dup    = set(self.values)
    903         self.length = 4
    904 
    905     def test_repr(self):
    906         with test_support.check_warnings():
    907             self.check_repr_against_values()
    908 
    909 #==============================================================================
    910 
    911 def baditer():
    912     raise TypeError
    913     yield True
    914 
    915 def gooditer():
    916     yield True
    917 
    918 class TestExceptionPropagation(unittest.TestCase):
    919     """SF 628246:  Set constructor should not trap iterator TypeErrors"""
    920 
    921     def test_instanceWithException(self):
    922         self.assertRaises(TypeError, set, baditer())
    923 
    924     def test_instancesWithoutException(self):
    925         # All of these iterables should load without exception.
    926         set([1,2,3])
    927         set((1,2,3))
    928         set({'one':1, 'two':2, 'three':3})
    929         set(xrange(3))
    930         set('abc')
    931         set(gooditer())
    932 
    933     def test_changingSizeWhileIterating(self):
    934         s = set([1,2,3])
    935         try:
    936             for i in s:
    937                 s.update([4])
    938         except RuntimeError:
    939             pass
    940         else:
    941             self.fail("no exception when changing size during iteration")
    942 
    943 #==============================================================================
    944 
    945 class TestSetOfSets(unittest.TestCase):
    946     def test_constructor(self):
    947         inner = frozenset([1])
    948         outer = set([inner])
    949         element = outer.pop()
    950         self.assertEqual(type(element), frozenset)
    951         outer.add(inner)        # Rebuild set of sets with .add method
    952         outer.remove(inner)
    953         self.assertEqual(outer, set())   # Verify that remove worked
    954         outer.discard(inner)    # Absence of KeyError indicates working fine
    955 
    956 #==============================================================================
    957 
    958 class TestBinaryOps(unittest.TestCase):
    959     def setUp(self):
    960         self.set = set((2, 4, 6))
    961 
    962     def test_eq(self):              # SF bug 643115
    963         self.assertEqual(self.set, set({2:1,4:3,6:5}))
    964 
    965     def test_union_subset(self):
    966         result = self.set | set([2])
    967         self.assertEqual(result, set((2, 4, 6)))
    968 
    969     def test_union_superset(self):
    970         result = self.set | set([2, 4, 6, 8])
    971         self.assertEqual(result, set([2, 4, 6, 8]))
    972 
    973     def test_union_overlap(self):
    974         result = self.set | set([3, 4, 5])
    975         self.assertEqual(result, set([2, 3, 4, 5, 6]))
    976 
    977     def test_union_non_overlap(self):
    978         result = self.set | set([8])
    979         self.assertEqual(result, set([2, 4, 6, 8]))
    980 
    981     def test_intersection_subset(self):
    982         result = self.set & set((2, 4))
    983         self.assertEqual(result, set((2, 4)))
    984 
    985     def test_intersection_superset(self):
    986         result = self.set & set([2, 4, 6, 8])
    987         self.assertEqual(result, set([2, 4, 6]))
    988 
    989     def test_intersection_overlap(self):
    990         result = self.set & set([3, 4, 5])
    991         self.assertEqual(result, set([4]))
    992 
    993     def test_intersection_non_overlap(self):
    994         result = self.set & set([8])
    995         self.assertEqual(result, empty_set)
    996 
    997     def test_isdisjoint_subset(self):
    998         result = self.set.isdisjoint(set((2, 4)))
    999         self.assertEqual(result, False)
   1000 
   1001     def test_isdisjoint_superset(self):
   1002         result = self.set.isdisjoint(set([2, 4, 6, 8]))
   1003         self.assertEqual(result, False)
   1004 
   1005     def test_isdisjoint_overlap(self):
   1006         result = self.set.isdisjoint(set([3, 4, 5]))
   1007         self.assertEqual(result, False)
   1008 
   1009     def test_isdisjoint_non_overlap(self):
   1010         result = self.set.isdisjoint(set([8]))
   1011         self.assertEqual(result, True)
   1012 
   1013     def test_sym_difference_subset(self):
   1014         result = self.set ^ set((2, 4))
   1015         self.assertEqual(result, set([6]))
   1016 
   1017     def test_sym_difference_superset(self):
   1018         result = self.set ^ set((2, 4, 6, 8))
   1019         self.assertEqual(result, set([8]))
   1020 
   1021     def test_sym_difference_overlap(self):
   1022         result = self.set ^ set((3, 4, 5))
   1023         self.assertEqual(result, set([2, 3, 5, 6]))
   1024 
   1025     def test_sym_difference_non_overlap(self):
   1026         result = self.set ^ set([8])
   1027         self.assertEqual(result, set([2, 4, 6, 8]))
   1028 
   1029     def test_cmp(self):
   1030         a, b = set('a'), set('b')
   1031         self.assertRaises(TypeError, cmp, a, b)
   1032 
   1033         # You can view this as a buglet:  cmp(a, a) does not raise TypeError,
   1034         # because __eq__ is tried before __cmp__, and a.__eq__(a) returns True,
   1035         # which Python thinks is good enough to synthesize a cmp() result
   1036         # without calling __cmp__.
   1037         self.assertEqual(cmp(a, a), 0)
   1038 
   1039 
   1040 #==============================================================================
   1041 
   1042 class TestUpdateOps(unittest.TestCase):
   1043     def setUp(self):
   1044         self.set = set((2, 4, 6))
   1045 
   1046     def test_union_subset(self):
   1047         self.set |= set([2])
   1048         self.assertEqual(self.set, set((2, 4, 6)))
   1049 
   1050     def test_union_superset(self):
   1051         self.set |= set([2, 4, 6, 8])
   1052         self.assertEqual(self.set, set([2, 4, 6, 8]))
   1053 
   1054     def test_union_overlap(self):
   1055         self.set |= set([3, 4, 5])
   1056         self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
   1057 
   1058     def test_union_non_overlap(self):
   1059         self.set |= set([8])
   1060         self.assertEqual(self.set, set([2, 4, 6, 8]))
   1061 
   1062     def test_union_method_call(self):
   1063         self.set.update(set([3, 4, 5]))
   1064         self.assertEqual(self.set, set([2, 3, 4, 5, 6]))
   1065 
   1066     def test_intersection_subset(self):
   1067         self.set &= set((2, 4))
   1068         self.assertEqual(self.set, set((2, 4)))
   1069 
   1070     def test_intersection_superset(self):
   1071         self.set &= set([2, 4, 6, 8])
   1072         self.assertEqual(self.set, set([2, 4, 6]))
   1073 
   1074     def test_intersection_overlap(self):
   1075         self.set &= set([3, 4, 5])
   1076         self.assertEqual(self.set, set([4]))
   1077 
   1078     def test_intersection_non_overlap(self):
   1079         self.set &= set([8])
   1080         self.assertEqual(self.set, empty_set)
   1081 
   1082     def test_intersection_method_call(self):
   1083         self.set.intersection_update(set([3, 4, 5]))
   1084         self.assertEqual(self.set, set([4]))
   1085 
   1086     def test_sym_difference_subset(self):
   1087         self.set ^= set((2, 4))
   1088         self.assertEqual(self.set, set([6]))
   1089 
   1090     def test_sym_difference_superset(self):
   1091         self.set ^= set((2, 4, 6, 8))
   1092         self.assertEqual(self.set, set([8]))
   1093 
   1094     def test_sym_difference_overlap(self):
   1095         self.set ^= set((3, 4, 5))
   1096         self.assertEqual(self.set, set([2, 3, 5, 6]))
   1097 
   1098     def test_sym_difference_non_overlap(self):
   1099         self.set ^= set([8])
   1100         self.assertEqual(self.set, set([2, 4, 6, 8]))
   1101 
   1102     def test_sym_difference_method_call(self):
   1103         self.set.symmetric_difference_update(set([3, 4, 5]))
   1104         self.assertEqual(self.set, set([2, 3, 5, 6]))
   1105 
   1106     def test_difference_subset(self):
   1107         self.set -= set((2, 4))
   1108         self.assertEqual(self.set, set([6]))
   1109 
   1110     def test_difference_superset(self):
   1111         self.set -= set((2, 4, 6, 8))
   1112         self.assertEqual(self.set, set([]))
   1113 
   1114     def test_difference_overlap(self):
   1115         self.set -= set((3, 4, 5))
   1116         self.assertEqual(self.set, set([2, 6]))
   1117 
   1118     def test_difference_non_overlap(self):
   1119         self.set -= set([8])
   1120         self.assertEqual(self.set, set([2, 4, 6]))
   1121 
   1122     def test_difference_method_call(self):
   1123         self.set.difference_update(set([3, 4, 5]))
   1124         self.assertEqual(self.set, set([2, 6]))
   1125 
   1126 #==============================================================================
   1127 
   1128 class TestMutate(unittest.TestCase):
   1129     def setUp(self):
   1130         self.values = ["a", "b", "c"]
   1131         self.set = set(self.values)
   1132 
   1133     def test_add_present(self):
   1134         self.set.add("c")
   1135         self.assertEqual(self.set, set("abc"))
   1136 
   1137     def test_add_absent(self):
   1138         self.set.add("d")
   1139         self.assertEqual(self.set, set("abcd"))
   1140 
   1141     def test_add_until_full(self):
   1142         tmp = set()
   1143         expected_len = 0
   1144         for v in self.values:
   1145             tmp.add(v)
   1146             expected_len += 1
   1147             self.assertEqual(len(tmp), expected_len)
   1148         self.assertEqual(tmp, self.set)
   1149 
   1150     def test_remove_present(self):
   1151         self.set.remove("b")
   1152         self.assertEqual(self.set, set("ac"))
   1153 
   1154     def test_remove_absent(self):
   1155         try:
   1156             self.set.remove("d")
   1157             self.fail("Removing missing element should have raised LookupError")
   1158         except LookupError:
   1159             pass
   1160 
   1161     def test_remove_until_empty(self):
   1162         expected_len = len(self.set)
   1163         for v in self.values:
   1164             self.set.remove(v)
   1165             expected_len -= 1
   1166             self.assertEqual(len(self.set), expected_len)
   1167 
   1168     def test_discard_present(self):
   1169         self.set.discard("c")
   1170         self.assertEqual(self.set, set("ab"))
   1171 
   1172     def test_discard_absent(self):
   1173         self.set.discard("d")
   1174         self.assertEqual(self.set, set("abc"))
   1175 
   1176     def test_clear(self):
   1177         self.set.clear()
   1178         self.assertEqual(len(self.set), 0)
   1179 
   1180     def test_pop(self):
   1181         popped = {}
   1182         while self.set:
   1183             popped[self.set.pop()] = None
   1184         self.assertEqual(len(popped), len(self.values))
   1185         for v in self.values:
   1186             self.assertIn(v, popped)
   1187 
   1188     def test_update_empty_tuple(self):
   1189         self.set.update(())
   1190         self.assertEqual(self.set, set(self.values))
   1191 
   1192     def test_update_unit_tuple_overlap(self):
   1193         self.set.update(("a",))
   1194         self.assertEqual(self.set, set(self.values))
   1195 
   1196     def test_update_unit_tuple_non_overlap(self):
   1197         self.set.update(("a", "z"))
   1198         self.assertEqual(self.set, set(self.values + ["z"]))
   1199 
   1200 #==============================================================================
   1201 
   1202 class TestSubsets(unittest.TestCase):
   1203 
   1204     case2method = {"<=": "issubset",
   1205                    ">=": "issuperset",
   1206                   }
   1207 
   1208     reverse = {"==": "==",
   1209                "!=": "!=",
   1210                "<":  ">",
   1211                ">":  "<",
   1212                "<=": ">=",
   1213                ">=": "<=",
   1214               }
   1215 
   1216     def test_issubset(self):
   1217         x = self.left
   1218         y = self.right
   1219         for case in "!=", "==", "<", "<=", ">", ">=":
   1220             expected = case in self.cases
   1221             # Test the binary infix spelling.
   1222             result = eval("x" + case + "y", locals())
   1223             self.assertEqual(result, expected)
   1224             # Test the "friendly" method-name spelling, if one exists.
   1225             if case in TestSubsets.case2method:
   1226                 method = getattr(x, TestSubsets.case2method[case])
   1227                 result = method(y)
   1228                 self.assertEqual(result, expected)
   1229 
   1230             # Now do the same for the operands reversed.
   1231             rcase = TestSubsets.reverse[case]
   1232             result = eval("y" + rcase + "x", locals())
   1233             self.assertEqual(result, expected)
   1234             if rcase in TestSubsets.case2method:
   1235                 method = getattr(y, TestSubsets.case2method[rcase])
   1236                 result = method(x)
   1237                 self.assertEqual(result, expected)
   1238 #------------------------------------------------------------------------------
   1239 
   1240 class TestSubsetEqualEmpty(TestSubsets):
   1241     left  = set()
   1242     right = set()
   1243     name  = "both empty"
   1244     cases = "==", "<=", ">="
   1245 
   1246 #------------------------------------------------------------------------------
   1247 
   1248 class TestSubsetEqualNonEmpty(TestSubsets):
   1249     left  = set([1, 2])
   1250     right = set([1, 2])
   1251     name  = "equal pair"
   1252     cases = "==", "<=", ">="
   1253 
   1254 #------------------------------------------------------------------------------
   1255 
   1256 class TestSubsetEmptyNonEmpty(TestSubsets):
   1257     left  = set()
   1258     right = set([1, 2])
   1259     name  = "one empty, one non-empty"
   1260     cases = "!=", "<", "<="
   1261 
   1262 #------------------------------------------------------------------------------
   1263 
   1264 class TestSubsetPartial(TestSubsets):
   1265     left  = set([1])
   1266     right = set([1, 2])
   1267     name  = "one a non-empty proper subset of other"
   1268     cases = "!=", "<", "<="
   1269 
   1270 #------------------------------------------------------------------------------
   1271 
   1272 class TestSubsetNonOverlap(TestSubsets):
   1273     left  = set([1])
   1274     right = set([2])
   1275     name  = "neither empty, neither contains"
   1276     cases = "!="
   1277 
   1278 #==============================================================================
   1279 
   1280 class TestOnlySetsInBinaryOps(unittest.TestCase):
   1281 
   1282     def test_eq_ne(self):
   1283         # Unlike the others, this is testing that == and != *are* allowed.
   1284         self.assertEqual(self.other == self.set, False)
   1285         self.assertEqual(self.set == self.other, False)
   1286         self.assertEqual(self.other != self.set, True)
   1287         self.assertEqual(self.set != self.other, True)
   1288 
   1289     def test_update_operator(self):
   1290         try:
   1291             self.set |= self.other
   1292         except TypeError:
   1293             pass
   1294         else:
   1295             self.fail("expected TypeError")
   1296 
   1297     def test_update(self):
   1298         if self.otherIsIterable:
   1299             self.set.update(self.other)
   1300         else:
   1301             self.assertRaises(TypeError, self.set.update, self.other)
   1302 
   1303     def test_union(self):
   1304         self.assertRaises(TypeError, lambda: self.set | self.other)
   1305         self.assertRaises(TypeError, lambda: self.other | self.set)
   1306         if self.otherIsIterable:
   1307             self.set.union(self.other)
   1308         else:
   1309             self.assertRaises(TypeError, self.set.union, self.other)
   1310 
   1311     def test_intersection_update_operator(self):
   1312         try:
   1313             self.set &= self.other
   1314         except TypeError:
   1315             pass
   1316         else:
   1317             self.fail("expected TypeError")
   1318 
   1319     def test_intersection_update(self):
   1320         if self.otherIsIterable:
   1321             self.set.intersection_update(self.other)
   1322         else:
   1323             self.assertRaises(TypeError,
   1324                               self.set.intersection_update,
   1325                               self.other)
   1326 
   1327     def test_intersection(self):
   1328         self.assertRaises(TypeError, lambda: self.set & self.other)
   1329         self.assertRaises(TypeError, lambda: self.other & self.set)
   1330         if self.otherIsIterable:
   1331             self.set.intersection(self.other)
   1332         else:
   1333             self.assertRaises(TypeError, self.set.intersection, self.other)
   1334 
   1335     def test_sym_difference_update_operator(self):
   1336         try:
   1337             self.set ^= self.other
   1338         except TypeError:
   1339             pass
   1340         else:
   1341             self.fail("expected TypeError")
   1342 
   1343     def test_sym_difference_update(self):
   1344         if self.otherIsIterable:
   1345             self.set.symmetric_difference_update(self.other)
   1346         else:
   1347             self.assertRaises(TypeError,
   1348                               self.set.symmetric_difference_update,
   1349                               self.other)
   1350 
   1351     def test_sym_difference(self):
   1352         self.assertRaises(TypeError, lambda: self.set ^ self.other)
   1353         self.assertRaises(TypeError, lambda: self.other ^ self.set)
   1354         if self.otherIsIterable:
   1355             self.set.symmetric_difference(self.other)
   1356         else:
   1357             self.assertRaises(TypeError, self.set.symmetric_difference, self.other)
   1358 
   1359     def test_difference_update_operator(self):
   1360         try:
   1361             self.set -= self.other
   1362         except TypeError:
   1363             pass
   1364         else:
   1365             self.fail("expected TypeError")
   1366 
   1367     def test_difference_update(self):
   1368         if self.otherIsIterable:
   1369             self.set.difference_update(self.other)
   1370         else:
   1371             self.assertRaises(TypeError,
   1372                               self.set.difference_update,
   1373                               self.other)
   1374 
   1375     def test_difference(self):
   1376         self.assertRaises(TypeError, lambda: self.set - self.other)
   1377         self.assertRaises(TypeError, lambda: self.other - self.set)
   1378         if self.otherIsIterable:
   1379             self.set.difference(self.other)
   1380         else:
   1381             self.assertRaises(TypeError, self.set.difference, self.other)
   1382 
   1383 #------------------------------------------------------------------------------
   1384 
   1385 class TestOnlySetsNumeric(TestOnlySetsInBinaryOps):
   1386     def setUp(self):
   1387         self.set   = set((1, 2, 3))
   1388         self.other = 19
   1389         self.otherIsIterable = False
   1390 
   1391 #------------------------------------------------------------------------------
   1392 
   1393 class TestOnlySetsDict(TestOnlySetsInBinaryOps):
   1394     def setUp(self):
   1395         self.set   = set((1, 2, 3))
   1396         self.other = {1:2, 3:4}
   1397         self.otherIsIterable = True
   1398 
   1399 #------------------------------------------------------------------------------
   1400 
   1401 class TestOnlySetsTuple(TestOnlySetsInBinaryOps):
   1402     def setUp(self):
   1403         self.set   = set((1, 2, 3))
   1404         self.other = (2, 4, 6)
   1405         self.otherIsIterable = True
   1406 
   1407 #------------------------------------------------------------------------------
   1408 
   1409 class TestOnlySetsString(TestOnlySetsInBinaryOps):
   1410     def setUp(self):
   1411         self.set   = set((1, 2, 3))
   1412         self.other = 'abc'
   1413         self.otherIsIterable = True
   1414 
   1415 #------------------------------------------------------------------------------
   1416 
   1417 class TestOnlySetsGenerator(TestOnlySetsInBinaryOps):
   1418     def setUp(self):
   1419         def gen():
   1420             for i in xrange(0, 10, 2):
   1421                 yield i
   1422         self.set   = set((1, 2, 3))
   1423         self.other = gen()
   1424         self.otherIsIterable = True
   1425 
   1426 #==============================================================================
   1427 
   1428 class TestCopying(unittest.TestCase):
   1429 
   1430     def test_copy(self):
   1431         dup = list(self.set.copy())
   1432         self.assertEqual(len(dup), len(self.set))
   1433         for el in self.set:
   1434             self.assertIn(el, dup)
   1435             pos = dup.index(el)
   1436             self.assertIs(el, dup.pop(pos))
   1437         self.assertFalse(dup)
   1438 
   1439     def test_deep_copy(self):
   1440         dup = copy.deepcopy(self.set)
   1441         self.assertSetEqual(dup, self.set)
   1442 
   1443 #------------------------------------------------------------------------------
   1444 
   1445 class TestCopyingEmpty(TestCopying):
   1446     def setUp(self):
   1447         self.set = set()
   1448 
   1449 #------------------------------------------------------------------------------
   1450 
   1451 class TestCopyingSingleton(TestCopying):
   1452     def setUp(self):
   1453         self.set = set(["hello"])
   1454 
   1455 #------------------------------------------------------------------------------
   1456 
   1457 class TestCopyingTriple(TestCopying):
   1458     def setUp(self):
   1459         self.set = set(["zero", 0, None])
   1460 
   1461 #------------------------------------------------------------------------------
   1462 
   1463 class TestCopyingTuple(TestCopying):
   1464     def setUp(self):
   1465         self.set = set([(1, 2)])
   1466 
   1467 #------------------------------------------------------------------------------
   1468 
   1469 class TestCopyingNested(TestCopying):
   1470     def setUp(self):
   1471         self.set = set([((1, 2), (3, 4))])
   1472 
   1473 #==============================================================================
   1474 
   1475 class TestIdentities(unittest.TestCase):
   1476     def setUp(self):
   1477         self.a = set('abracadabra')
   1478         self.b = set('alacazam')
   1479 
   1480     def test_binopsVsSubsets(self):
   1481         a, b = self.a, self.b
   1482         self.assertTrue(a - b < a)
   1483         self.assertTrue(b - a < b)
   1484         self.assertTrue(a & b < a)
   1485         self.assertTrue(a & b < b)
   1486         self.assertTrue(a | b > a)
   1487         self.assertTrue(a | b > b)
   1488         self.assertTrue(a ^ b < a | b)
   1489 
   1490     def test_commutativity(self):
   1491         a, b = self.a, self.b
   1492         self.assertEqual(a&b, b&a)
   1493         self.assertEqual(a|b, b|a)
   1494         self.assertEqual(a^b, b^a)
   1495         if a != b:
   1496             self.assertNotEqual(a-b, b-a)
   1497 
   1498     def test_summations(self):
   1499         # check that sums of parts equal the whole
   1500         a, b = self.a, self.b
   1501         self.assertEqual((a-b)|(a&b)|(b-a), a|b)
   1502         self.assertEqual((a&b)|(a^b), a|b)
   1503         self.assertEqual(a|(b-a), a|b)
   1504         self.assertEqual((a-b)|b, a|b)
   1505         self.assertEqual((a-b)|(a&b), a)
   1506         self.assertEqual((b-a)|(a&b), b)
   1507         self.assertEqual((a-b)|(b-a), a^b)
   1508 
   1509     def test_exclusion(self):
   1510         # check that inverse operations show non-overlap
   1511         a, b, zero = self.a, self.b, set()
   1512         self.assertEqual((a-b)&b, zero)
   1513         self.assertEqual((b-a)&a, zero)
   1514         self.assertEqual((a&b)&(a^b), zero)
   1515 
   1516 # Tests derived from test_itertools.py =======================================
   1517 
   1518 def R(seqn):
   1519     'Regular generator'
   1520     for i in seqn:
   1521         yield i
   1522 
   1523 class G:
   1524     'Sequence using __getitem__'
   1525     def __init__(self, seqn):
   1526         self.seqn = seqn
   1527     def __getitem__(self, i):
   1528         return self.seqn[i]
   1529 
   1530 class I:
   1531     'Sequence using iterator protocol'
   1532     def __init__(self, seqn):
   1533         self.seqn = seqn
   1534         self.i = 0
   1535     def __iter__(self):
   1536         return self
   1537     def next(self):
   1538         if self.i >= len(self.seqn): raise StopIteration
   1539         v = self.seqn[self.i]
   1540         self.i += 1
   1541         return v
   1542 
   1543 class Ig:
   1544     'Sequence using iterator protocol defined with a generator'
   1545     def __init__(self, seqn):
   1546         self.seqn = seqn
   1547         self.i = 0
   1548     def __iter__(self):
   1549         for val in self.seqn:
   1550             yield val
   1551 
   1552 class X:
   1553     'Missing __getitem__ and __iter__'
   1554     def __init__(self, seqn):
   1555         self.seqn = seqn
   1556         self.i = 0
   1557     def next(self):
   1558         if self.i >= len(self.seqn): raise StopIteration
   1559         v = self.seqn[self.i]
   1560         self.i += 1
   1561         return v
   1562 
   1563 class N:
   1564     'Iterator missing next()'
   1565     def __init__(self, seqn):
   1566         self.seqn = seqn
   1567         self.i = 0
   1568     def __iter__(self):
   1569         return self
   1570 
   1571 class E:
   1572     'Test propagation of exceptions'
   1573     def __init__(self, seqn):
   1574         self.seqn = seqn
   1575         self.i = 0
   1576     def __iter__(self):
   1577         return self
   1578     def next(self):
   1579         3 // 0
   1580 
   1581 class S:
   1582     'Test immediate stop'
   1583     def __init__(self, seqn):
   1584         pass
   1585     def __iter__(self):
   1586         return self
   1587     def next(self):
   1588         raise StopIteration
   1589 
   1590 from itertools import chain, imap
   1591 def L(seqn):
   1592     'Test multiple tiers of iterators'
   1593     return chain(imap(lambda x:x, R(Ig(G(seqn)))))
   1594 
   1595 class TestVariousIteratorArgs(unittest.TestCase):
   1596 
   1597     def test_constructor(self):
   1598         for cons in (set, frozenset):
   1599             for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
   1600                 for g in (G, I, Ig, S, L, R):
   1601                     self.assertSetEqual(cons(g(s)), set(g(s)))
   1602                 self.assertRaises(TypeError, cons , X(s))
   1603                 self.assertRaises(TypeError, cons , N(s))
   1604                 self.assertRaises(ZeroDivisionError, cons , E(s))
   1605 
   1606     def test_inline_methods(self):
   1607         s = set('november')
   1608         for data in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5), 'december'):
   1609             for meth in (s.union, s.intersection, s.difference, s.symmetric_difference, s.isdisjoint):
   1610                 for g in (G, I, Ig, L, R):
   1611                     expected = meth(data)
   1612                     actual = meth(g(data))
   1613                     if isinstance(expected, bool):
   1614                         self.assertEqual(actual, expected)
   1615                     else:
   1616                         self.assertSetEqual(actual, expected)
   1617                 self.assertRaises(TypeError, meth, X(s))
   1618                 self.assertRaises(TypeError, meth, N(s))
   1619                 self.assertRaises(ZeroDivisionError, meth, E(s))
   1620 
   1621     def test_inplace_methods(self):
   1622         for data in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5), 'december'):
   1623             for methname in ('update', 'intersection_update',
   1624                              'difference_update', 'symmetric_difference_update'):
   1625                 for g in (G, I, Ig, S, L, R):
   1626                     s = set('january')
   1627                     t = s.copy()
   1628                     getattr(s, methname)(list(g(data)))
   1629                     getattr(t, methname)(g(data))
   1630                     self.assertSetEqual(s, t)
   1631 
   1632                 self.assertRaises(TypeError, getattr(set('january'), methname), X(data))
   1633                 self.assertRaises(TypeError, getattr(set('january'), methname), N(data))
   1634                 self.assertRaises(ZeroDivisionError, getattr(set('january'), methname), E(data))
   1635 
   1636 class bad_eq:
   1637     def __eq__(self, other):
   1638         if be_bad:
   1639             set2.clear()
   1640             raise ZeroDivisionError
   1641         return self is other
   1642     def __hash__(self):
   1643         return 0
   1644 
   1645 class bad_dict_clear:
   1646     def __eq__(self, other):
   1647         if be_bad:
   1648             dict2.clear()
   1649         return self is other
   1650     def __hash__(self):
   1651         return 0
   1652 
   1653 class TestWeirdBugs(unittest.TestCase):
   1654     def test_8420_set_merge(self):
   1655         # This used to segfault
   1656         global be_bad, set2, dict2
   1657         be_bad = False
   1658         set1 = {bad_eq()}
   1659         set2 = {bad_eq() for i in range(75)}
   1660         be_bad = True
   1661         self.assertRaises(ZeroDivisionError, set1.update, set2)
   1662 
   1663         be_bad = False
   1664         set1 = {bad_dict_clear()}
   1665         dict2 = {bad_dict_clear(): None}
   1666         be_bad = True
   1667         set1.symmetric_difference_update(dict2)
   1668 
   1669     def test_iter_and_mutate(self):
   1670         # Issue #24581
   1671         s = set(range(100))
   1672         s.clear()
   1673         s.update(range(100))
   1674         si = iter(s)
   1675         s.clear()
   1676         a = list(range(100))
   1677         s.update(range(100))
   1678         list(si)
   1679 
   1680 # Application tests (based on David Eppstein's graph recipes ====================================
   1681 
   1682 def powerset(U):
   1683     """Generates all subsets of a set or sequence U."""
   1684     U = iter(U)
   1685     try:
   1686         x = frozenset([U.next()])
   1687         for S in powerset(U):
   1688             yield S
   1689             yield S | x
   1690     except StopIteration:
   1691         yield frozenset()
   1692 
   1693 def cube(n):
   1694     """Graph of n-dimensional hypercube."""
   1695     singletons = [frozenset([x]) for x in range(n)]
   1696     return dict([(x, frozenset([x^s for s in singletons]))
   1697                  for x in powerset(range(n))])
   1698 
   1699 def linegraph(G):
   1700     """Graph, the vertices of which are edges of G,
   1701     with two vertices being adjacent iff the corresponding
   1702     edges share a vertex."""
   1703     L = {}
   1704     for x in G:
   1705         for y in G[x]:
   1706             nx = [frozenset([x,z]) for z in G[x] if z != y]
   1707             ny = [frozenset([y,z]) for z in G[y] if z != x]
   1708             L[frozenset([x,y])] = frozenset(nx+ny)
   1709     return L
   1710 
   1711 def faces(G):
   1712     'Return a set of faces in G.  Where a face is a set of vertices on that face'
   1713     # currently limited to triangles,squares, and pentagons
   1714     f = set()
   1715     for v1, edges in G.items():
   1716         for v2 in edges:
   1717             for v3 in G[v2]:
   1718                 if v1 == v3:
   1719                     continue
   1720                 if v1 in G[v3]:
   1721                     f.add(frozenset([v1, v2, v3]))
   1722                 else:
   1723                     for v4 in G[v3]:
   1724                         if v4 == v2:
   1725                             continue
   1726                         if v1 in G[v4]:
   1727                             f.add(frozenset([v1, v2, v3, v4]))
   1728                         else:
   1729                             for v5 in G[v4]:
   1730                                 if v5 == v3 or v5 == v2:
   1731                                     continue
   1732                                 if v1 in G[v5]:
   1733                                     f.add(frozenset([v1, v2, v3, v4, v5]))
   1734     return f
   1735 
   1736 
   1737 class TestGraphs(unittest.TestCase):
   1738 
   1739     def test_cube(self):
   1740 
   1741         g = cube(3)                             # vert --> {v1, v2, v3}
   1742         vertices1 = set(g)
   1743         self.assertEqual(len(vertices1), 8)     # eight vertices
   1744         for edge in g.values():
   1745             self.assertEqual(len(edge), 3)      # each vertex connects to three edges
   1746         vertices2 = set(v for edges in g.values() for v in edges)
   1747         self.assertEqual(vertices1, vertices2)  # edge vertices in original set
   1748 
   1749         cubefaces = faces(g)
   1750         self.assertEqual(len(cubefaces), 6)     # six faces
   1751         for face in cubefaces:
   1752             self.assertEqual(len(face), 4)      # each face is a square
   1753 
   1754     def test_cuboctahedron(self):
   1755 
   1756         # http://en.wikipedia.org/wiki/Cuboctahedron
   1757         # 8 triangular faces and 6 square faces
   1758         # 12 identical vertices each connecting a triangle and square
   1759 
   1760         g = cube(3)
   1761         cuboctahedron = linegraph(g)            # V( --> {V1, V2, V3, V4}
   1762         self.assertEqual(len(cuboctahedron), 12)# twelve vertices
   1763 
   1764         vertices = set(cuboctahedron)
   1765         for edges in cuboctahedron.values():
   1766             self.assertEqual(len(edges), 4)     # each vertex connects to four other vertices
   1767         othervertices = set(edge for edges in cuboctahedron.values() for edge in edges)
   1768         self.assertEqual(vertices, othervertices)   # edge vertices in original set
   1769 
   1770         cubofaces = faces(cuboctahedron)
   1771         facesizes = collections.defaultdict(int)
   1772         for face in cubofaces:
   1773             facesizes[len(face)] += 1
   1774         self.assertEqual(facesizes[3], 8)       # eight triangular faces
   1775         self.assertEqual(facesizes[4], 6)       # six square faces
   1776 
   1777         for vertex in cuboctahedron:
   1778             edge = vertex                       # Cuboctahedron vertices are edges in Cube
   1779             self.assertEqual(len(edge), 2)      # Two cube vertices define an edge
   1780             for cubevert in edge:
   1781                 self.assertIn(cubevert, g)
   1782 
   1783 
   1784 #==============================================================================
   1785 
   1786 def test_main(verbose=None):
   1787     test_classes = (
   1788         TestSet,
   1789         TestSetSubclass,
   1790         TestSetSubclassWithKeywordArgs,
   1791         TestFrozenSet,
   1792         TestFrozenSetSubclass,
   1793         TestSetOfSets,
   1794         TestExceptionPropagation,
   1795         TestBasicOpsEmpty,
   1796         TestBasicOpsSingleton,
   1797         TestBasicOpsTuple,
   1798         TestBasicOpsTriple,
   1799         TestBinaryOps,
   1800         TestUpdateOps,
   1801         TestMutate,
   1802         TestSubsetEqualEmpty,
   1803         TestSubsetEqualNonEmpty,
   1804         TestSubsetEmptyNonEmpty,
   1805         TestSubsetPartial,
   1806         TestSubsetNonOverlap,
   1807         TestOnlySetsNumeric,
   1808         TestOnlySetsDict,
   1809         TestOnlySetsTuple,
   1810         TestOnlySetsString,
   1811         TestOnlySetsGenerator,
   1812         TestCopyingEmpty,
   1813         TestCopyingSingleton,
   1814         TestCopyingTriple,
   1815         TestCopyingTuple,
   1816         TestCopyingNested,
   1817         TestIdentities,
   1818         TestVariousIteratorArgs,
   1819         TestGraphs,
   1820         TestWeirdBugs,
   1821         )
   1822 
   1823     test_support.run_unittest(*test_classes)
   1824 
   1825     # verify reference counting
   1826     if verbose and hasattr(sys, "gettotalrefcount"):
   1827         import gc
   1828         counts = [None] * 5
   1829         for i in xrange(len(counts)):
   1830             test_support.run_unittest(*test_classes)
   1831             gc.collect()
   1832             counts[i] = sys.gettotalrefcount()
   1833         print counts
   1834 
   1835 if __name__ == "__main__":
   1836     test_main(verbose=True)
   1837