Home | History | Annotate | Download | only in test
      1 import unittest
      2 import operator
      3 import sys
      4 import pickle
      5 
      6 from test import support
      7 
      8 class G:
      9     'Sequence using __getitem__'
     10     def __init__(self, seqn):
     11         self.seqn = seqn
     12     def __getitem__(self, i):
     13         return self.seqn[i]
     14 
     15 class I:
     16     'Sequence using iterator protocol'
     17     def __init__(self, seqn):
     18         self.seqn = seqn
     19         self.i = 0
     20     def __iter__(self):
     21         return self
     22     def __next__(self):
     23         if self.i >= len(self.seqn): raise StopIteration
     24         v = self.seqn[self.i]
     25         self.i += 1
     26         return v
     27 
     28 class Ig:
     29     'Sequence using iterator protocol defined with a generator'
     30     def __init__(self, seqn):
     31         self.seqn = seqn
     32         self.i = 0
     33     def __iter__(self):
     34         for val in self.seqn:
     35             yield val
     36 
     37 class X:
     38     'Missing __getitem__ and __iter__'
     39     def __init__(self, seqn):
     40         self.seqn = seqn
     41         self.i = 0
     42     def __next__(self):
     43         if self.i >= len(self.seqn): raise StopIteration
     44         v = self.seqn[self.i]
     45         self.i += 1
     46         return v
     47 
     48 class E:
     49     'Test propagation of exceptions'
     50     def __init__(self, seqn):
     51         self.seqn = seqn
     52         self.i = 0
     53     def __iter__(self):
     54         return self
     55     def __next__(self):
     56         3 // 0
     57 
     58 class N:
     59     'Iterator missing __next__()'
     60     def __init__(self, seqn):
     61         self.seqn = seqn
     62         self.i = 0
     63     def __iter__(self):
     64         return self
     65 
     66 class PickleTest:
     67     # Helper to check picklability
     68     def check_pickle(self, itorg, seq):
     69         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
     70             d = pickle.dumps(itorg, proto)
     71             it = pickle.loads(d)
     72             self.assertEqual(type(itorg), type(it))
     73             self.assertEqual(list(it), seq)
     74 
     75             it = pickle.loads(d)
     76             try:
     77                 next(it)
     78             except StopIteration:
     79                 self.assertFalse(seq[1:])
     80                 continue
     81             d = pickle.dumps(it, proto)
     82             it = pickle.loads(d)
     83             self.assertEqual(list(it), seq[1:])
     84 
     85 class EnumerateTestCase(unittest.TestCase, PickleTest):
     86 
     87     enum = enumerate
     88     seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
     89 
     90     def test_basicfunction(self):
     91         self.assertEqual(type(self.enum(self.seq)), self.enum)
     92         e = self.enum(self.seq)
     93         self.assertEqual(iter(e), e)
     94         self.assertEqual(list(self.enum(self.seq)), self.res)
     95         self.enum.__doc__
     96 
     97     def test_pickle(self):
     98         self.check_pickle(self.enum(self.seq), self.res)
     99 
    100     def test_getitemseqn(self):
    101         self.assertEqual(list(self.enum(G(self.seq))), self.res)
    102         e = self.enum(G(''))
    103         self.assertRaises(StopIteration, next, e)
    104 
    105     def test_iteratorseqn(self):
    106         self.assertEqual(list(self.enum(I(self.seq))), self.res)
    107         e = self.enum(I(''))
    108         self.assertRaises(StopIteration, next, e)
    109 
    110     def test_iteratorgenerator(self):
    111         self.assertEqual(list(self.enum(Ig(self.seq))), self.res)
    112         e = self.enum(Ig(''))
    113         self.assertRaises(StopIteration, next, e)
    114 
    115     def test_noniterable(self):
    116         self.assertRaises(TypeError, self.enum, X(self.seq))
    117 
    118     def test_illformediterable(self):
    119         self.assertRaises(TypeError, self.enum, N(self.seq))
    120 
    121     def test_exception_propagation(self):
    122         self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq)))
    123 
    124     def test_argumentcheck(self):
    125         self.assertRaises(TypeError, self.enum) # no arguments
    126         self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable)
    127         self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type
    128         self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments
    129 
    130     @support.cpython_only
    131     def test_tuple_reuse(self):
    132         # Tests an implementation detail where tuple is reused
    133         # whenever nothing else holds a reference to it
    134         self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq))
    135         self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq)))
    136 
    137 class MyEnum(enumerate):
    138     pass
    139 
    140 class SubclassTestCase(EnumerateTestCase):
    141 
    142     enum = MyEnum
    143 
    144 class TestEmpty(EnumerateTestCase):
    145 
    146     seq, res = '', []
    147 
    148 class TestBig(EnumerateTestCase):
    149 
    150     seq = range(10,20000,2)
    151     res = list(zip(range(20000), seq))
    152 
    153 class TestReversed(unittest.TestCase, PickleTest):
    154 
    155     def test_simple(self):
    156         class A:
    157             def __getitem__(self, i):
    158                 if i < 5:
    159                     return str(i)
    160                 raise StopIteration
    161             def __len__(self):
    162                 return 5
    163         for data in 'abc', range(5), tuple(enumerate('abc')), A(), range(1,17,5):
    164             self.assertEqual(list(data)[::-1], list(reversed(data)))
    165         self.assertRaises(TypeError, reversed, {})
    166         # don't allow keyword arguments
    167         self.assertRaises(TypeError, reversed, [], a=1)
    168 
    169     def test_range_optimization(self):
    170         x = range(1)
    171         self.assertEqual(type(reversed(x)), type(iter(x)))
    172 
    173     def test_len(self):
    174         for s in ('hello', tuple('hello'), list('hello'), range(5)):
    175             self.assertEqual(operator.length_hint(reversed(s)), len(s))
    176             r = reversed(s)
    177             list(r)
    178             self.assertEqual(operator.length_hint(r), 0)
    179         class SeqWithWeirdLen:
    180             called = False
    181             def __len__(self):
    182                 if not self.called:
    183                     self.called = True
    184                     return 10
    185                 raise ZeroDivisionError
    186             def __getitem__(self, index):
    187                 return index
    188         r = reversed(SeqWithWeirdLen())
    189         self.assertRaises(ZeroDivisionError, operator.length_hint, r)
    190 
    191 
    192     def test_gc(self):
    193         class Seq:
    194             def __len__(self):
    195                 return 10
    196             def __getitem__(self, index):
    197                 return index
    198         s = Seq()
    199         r = reversed(s)
    200         s.r = r
    201 
    202     def test_args(self):
    203         self.assertRaises(TypeError, reversed)
    204         self.assertRaises(TypeError, reversed, [], 'extra')
    205 
    206     @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()')
    207     def test_bug1229429(self):
    208         # this bug was never in reversed, it was in
    209         # PyObject_CallMethod, and reversed_new calls that sometimes.
    210         def f():
    211             pass
    212         r = f.__reversed__ = object()
    213         rc = sys.getrefcount(r)
    214         for i in range(10):
    215             try:
    216                 reversed(f)
    217             except TypeError:
    218                 pass
    219             else:
    220                 self.fail("non-callable __reversed__ didn't raise!")
    221         self.assertEqual(rc, sys.getrefcount(r))
    222 
    223     def test_objmethods(self):
    224         # Objects must have __len__() and __getitem__() implemented.
    225         class NoLen(object):
    226             def __getitem__(self, i): return 1
    227         nl = NoLen()
    228         self.assertRaises(TypeError, reversed, nl)
    229 
    230         class NoGetItem(object):
    231             def __len__(self): return 2
    232         ngi = NoGetItem()
    233         self.assertRaises(TypeError, reversed, ngi)
    234 
    235         class Blocked(object):
    236             def __getitem__(self, i): return 1
    237             def __len__(self): return 2
    238             __reversed__ = None
    239         b = Blocked()
    240         self.assertRaises(TypeError, reversed, b)
    241 
    242     def test_pickle(self):
    243         for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5):
    244             self.check_pickle(reversed(data), list(data)[::-1])
    245 
    246 
    247 class EnumerateStartTestCase(EnumerateTestCase):
    248 
    249     def test_basicfunction(self):
    250         e = self.enum(self.seq)
    251         self.assertEqual(iter(e), e)
    252         self.assertEqual(list(self.enum(self.seq)), self.res)
    253 
    254 
    255 class TestStart(EnumerateStartTestCase):
    256 
    257     enum = lambda self, i: enumerate(i, start=11)
    258     seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')]
    259 
    260 
    261 class TestLongStart(EnumerateStartTestCase):
    262 
    263     enum = lambda self, i: enumerate(i, start=sys.maxsize+1)
    264     seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'),
    265                        (sys.maxsize+3,'c')]
    266 
    267 
    268 if __name__ == "__main__":
    269     unittest.main()
    270