Home | History | Annotate | Download | only in test
      1 import unittest
      2 import sys
      3 
      4 from test import test_support
      5 
      6 class G:
      7     'Sequence using __getitem__'
      8     def __init__(self, seqn):
      9         self.seqn = seqn
     10     def __getitem__(self, i):
     11         return self.seqn[i]
     12 
     13 class I:
     14     'Sequence using iterator protocol'
     15     def __init__(self, seqn):
     16         self.seqn = seqn
     17         self.i = 0
     18     def __iter__(self):
     19         return self
     20     def next(self):
     21         if self.i >= len(self.seqn): raise StopIteration
     22         v = self.seqn[self.i]
     23         self.i += 1
     24         return v
     25 
     26 class Ig:
     27     'Sequence using iterator protocol defined with a generator'
     28     def __init__(self, seqn):
     29         self.seqn = seqn
     30         self.i = 0
     31     def __iter__(self):
     32         for val in self.seqn:
     33             yield val
     34 
     35 class X:
     36     'Missing __getitem__ and __iter__'
     37     def __init__(self, seqn):
     38         self.seqn = seqn
     39         self.i = 0
     40     def next(self):
     41         if self.i >= len(self.seqn): raise StopIteration
     42         v = self.seqn[self.i]
     43         self.i += 1
     44         return v
     45 
     46 class E:
     47     'Test propagation of exceptions'
     48     def __init__(self, seqn):
     49         self.seqn = seqn
     50         self.i = 0
     51     def __iter__(self):
     52         return self
     53     def next(self):
     54         3 // 0
     55 
     56 class N:
     57     'Iterator missing next()'
     58     def __init__(self, seqn):
     59         self.seqn = seqn
     60         self.i = 0
     61     def __iter__(self):
     62         return self
     63 
     64 class EnumerateTestCase(unittest.TestCase):
     65 
     66     enum = enumerate
     67     seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
     68 
     69     def test_basicfunction(self):
     70         self.assertEqual(type(self.enum(self.seq)), self.enum)
     71         e = self.enum(self.seq)
     72         self.assertEqual(iter(e), e)
     73         self.assertEqual(list(self.enum(self.seq)), self.res)
     74         self.enum.__doc__
     75 
     76     def test_getitemseqn(self):
     77         self.assertEqual(list(self.enum(G(self.seq))), self.res)
     78         e = self.enum(G(''))
     79         self.assertRaises(StopIteration, e.next)
     80 
     81     def test_iteratorseqn(self):
     82         self.assertEqual(list(self.enum(I(self.seq))), self.res)
     83         e = self.enum(I(''))
     84         self.assertRaises(StopIteration, e.next)
     85 
     86     def test_iteratorgenerator(self):
     87         self.assertEqual(list(self.enum(Ig(self.seq))), self.res)
     88         e = self.enum(Ig(''))
     89         self.assertRaises(StopIteration, e.next)
     90 
     91     def test_noniterable(self):
     92         self.assertRaises(TypeError, self.enum, X(self.seq))
     93 
     94     def test_illformediterable(self):
     95         self.assertRaises(TypeError, list, self.enum(N(self.seq)))
     96 
     97     def test_exception_propagation(self):
     98         self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq)))
     99 
    100     def test_argumentcheck(self):
    101         self.assertRaises(TypeError, self.enum) # no arguments
    102         self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable)
    103         self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type
    104         self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments
    105 
    106     @test_support.cpython_only
    107     def test_tuple_reuse(self):
    108         # Tests an implementation detail where tuple is reused
    109         # whenever nothing else holds a reference to it
    110         self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq))
    111         self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq)))
    112 
    113 class MyEnum(enumerate):
    114     pass
    115 
    116 class SubclassTestCase(EnumerateTestCase):
    117 
    118     enum = MyEnum
    119 
    120 class TestEmpty(EnumerateTestCase):
    121 
    122     seq, res = '', []
    123 
    124 class TestBig(EnumerateTestCase):
    125 
    126     seq = range(10,20000,2)
    127     res = zip(range(20000), seq)
    128 
    129 class TestReversed(unittest.TestCase):
    130 
    131     def test_simple(self):
    132         class A:
    133             def __getitem__(self, i):
    134                 if i < 5:
    135                     return str(i)
    136                 raise StopIteration
    137             def __len__(self):
    138                 return 5
    139         for data in 'abc', range(5), tuple(enumerate('abc')), A(), xrange(1,17,5):
    140             self.assertEqual(list(data)[::-1], list(reversed(data)))
    141         self.assertRaises(TypeError, reversed, {})
    142         # don't allow keyword arguments
    143         self.assertRaises(TypeError, reversed, [], a=1)
    144 
    145     def test_classic_class(self):
    146         class A:
    147             def __reversed__(self):
    148                 return [2, 1]
    149         self.assertEqual(list(reversed(A())), [2, 1])
    150 
    151     def test_xrange_optimization(self):
    152         x = xrange(1)
    153         self.assertEqual(type(reversed(x)), type(iter(x)))
    154 
    155     @test_support.cpython_only
    156     def test_len(self):
    157         # This is an implementation detail, not an interface requirement
    158         from test.test_iterlen import len
    159         for s in ('hello', tuple('hello'), list('hello'), xrange(5)):
    160             self.assertEqual(len(reversed(s)), len(s))
    161             r = reversed(s)
    162             list(r)
    163             self.assertEqual(len(r), 0)
    164         class SeqWithWeirdLen:
    165             called = False
    166             def __len__(self):
    167                 if not self.called:
    168                     self.called = True
    169                     return 10
    170                 raise ZeroDivisionError
    171             def __getitem__(self, index):
    172                 return index
    173         r = reversed(SeqWithWeirdLen())
    174         self.assertRaises(ZeroDivisionError, len, r)
    175 
    176 
    177     def test_gc(self):
    178         class Seq:
    179             def __len__(self):
    180                 return 10
    181             def __getitem__(self, index):
    182                 return index
    183         s = Seq()
    184         r = reversed(s)
    185         s.r = r
    186 
    187     def test_args(self):
    188         self.assertRaises(TypeError, reversed)
    189         self.assertRaises(TypeError, reversed, [], 'extra')
    190 
    191     @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()')
    192     def test_bug1229429(self):
    193         # this bug was never in reversed, it was in
    194         # PyObject_CallMethod, and reversed_new calls that sometimes.
    195         def f():
    196             pass
    197         r = f.__reversed__ = object()
    198         rc = sys.getrefcount(r)
    199         for i in range(10):
    200             try:
    201                 reversed(f)
    202             except TypeError:
    203                 pass
    204             else:
    205                 self.fail("non-callable __reversed__ didn't raise!")
    206         self.assertEqual(rc, sys.getrefcount(r))
    207 
    208     def test_objmethods(self):
    209         # Objects must have __len__() and __getitem__() implemented.
    210         class NoLen(object):
    211             def __getitem__(self): return 1
    212         nl = NoLen()
    213         self.assertRaises(TypeError, reversed, nl)
    214 
    215         class NoGetItem(object):
    216             def __len__(self): return 2
    217         ngi = NoGetItem()
    218         self.assertRaises(TypeError, reversed, ngi)
    219 
    220 
    221 class EnumerateStartTestCase(EnumerateTestCase):
    222 
    223     def test_basicfunction(self):
    224         e = self.enum(self.seq)
    225         self.assertEqual(iter(e), e)
    226         self.assertEqual(list(self.enum(self.seq)), self.res)
    227 
    228 
    229 class TestStart(EnumerateStartTestCase):
    230 
    231     enum = lambda self, i: enumerate(i, start=11)
    232     seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')]
    233 
    234 
    235 class TestLongStart(EnumerateStartTestCase):
    236 
    237     enum = lambda self, i: enumerate(i, start=sys.maxint+1)
    238     seq, res = 'abc', [(sys.maxint+1,'a'), (sys.maxint+2,'b'),
    239                        (sys.maxint+3,'c')]
    240 
    241 
    242 def test_main(verbose=None):
    243     test_support.run_unittest(__name__)
    244 
    245     # verify reference counting
    246     if verbose and hasattr(sys, "gettotalrefcount"):
    247         counts = [None] * 5
    248         for i in xrange(len(counts)):
    249             test_support.run_unittest(__name__)
    250             counts[i] = sys.gettotalrefcount()
    251         print counts
    252 
    253 if __name__ == "__main__":
    254     test_main(verbose=True)
    255