1 import unittest 2 import operator 3 import sys 4 import pickle 5 6 from test import support 7 8 class G: 9 'Sequence using __getitem__' 10 def __init__(self, seqn): 11 self.seqn = seqn 12 def __getitem__(self, i): 13 return self.seqn[i] 14 15 class I: 16 'Sequence using iterator protocol' 17 def __init__(self, seqn): 18 self.seqn = seqn 19 self.i = 0 20 def __iter__(self): 21 return self 22 def __next__(self): 23 if self.i >= len(self.seqn): raise StopIteration 24 v = self.seqn[self.i] 25 self.i += 1 26 return v 27 28 class Ig: 29 'Sequence using iterator protocol defined with a generator' 30 def __init__(self, seqn): 31 self.seqn = seqn 32 self.i = 0 33 def __iter__(self): 34 for val in self.seqn: 35 yield val 36 37 class X: 38 'Missing __getitem__ and __iter__' 39 def __init__(self, seqn): 40 self.seqn = seqn 41 self.i = 0 42 def __next__(self): 43 if self.i >= len(self.seqn): raise StopIteration 44 v = self.seqn[self.i] 45 self.i += 1 46 return v 47 48 class E: 49 'Test propagation of exceptions' 50 def __init__(self, seqn): 51 self.seqn = seqn 52 self.i = 0 53 def __iter__(self): 54 return self 55 def __next__(self): 56 3 // 0 57 58 class N: 59 'Iterator missing __next__()' 60 def __init__(self, seqn): 61 self.seqn = seqn 62 self.i = 0 63 def __iter__(self): 64 return self 65 66 class PickleTest: 67 # Helper to check picklability 68 def check_pickle(self, itorg, seq): 69 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 70 d = pickle.dumps(itorg, proto) 71 it = pickle.loads(d) 72 self.assertEqual(type(itorg), type(it)) 73 self.assertEqual(list(it), seq) 74 75 it = pickle.loads(d) 76 try: 77 next(it) 78 except StopIteration: 79 self.assertFalse(seq[1:]) 80 continue 81 d = pickle.dumps(it, proto) 82 it = pickle.loads(d) 83 self.assertEqual(list(it), seq[1:]) 84 85 class EnumerateTestCase(unittest.TestCase, PickleTest): 86 87 enum = enumerate 88 seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')] 89 90 def test_basicfunction(self): 91 self.assertEqual(type(self.enum(self.seq)), self.enum) 92 e = self.enum(self.seq) 93 self.assertEqual(iter(e), e) 94 self.assertEqual(list(self.enum(self.seq)), self.res) 95 self.enum.__doc__ 96 97 def test_pickle(self): 98 self.check_pickle(self.enum(self.seq), self.res) 99 100 def test_getitemseqn(self): 101 self.assertEqual(list(self.enum(G(self.seq))), self.res) 102 e = self.enum(G('')) 103 self.assertRaises(StopIteration, next, e) 104 105 def test_iteratorseqn(self): 106 self.assertEqual(list(self.enum(I(self.seq))), self.res) 107 e = self.enum(I('')) 108 self.assertRaises(StopIteration, next, e) 109 110 def test_iteratorgenerator(self): 111 self.assertEqual(list(self.enum(Ig(self.seq))), self.res) 112 e = self.enum(Ig('')) 113 self.assertRaises(StopIteration, next, e) 114 115 def test_noniterable(self): 116 self.assertRaises(TypeError, self.enum, X(self.seq)) 117 118 def test_illformediterable(self): 119 self.assertRaises(TypeError, self.enum, N(self.seq)) 120 121 def test_exception_propagation(self): 122 self.assertRaises(ZeroDivisionError, list, self.enum(E(self.seq))) 123 124 def test_argumentcheck(self): 125 self.assertRaises(TypeError, self.enum) # no arguments 126 self.assertRaises(TypeError, self.enum, 1) # wrong type (not iterable) 127 self.assertRaises(TypeError, self.enum, 'abc', 'a') # wrong type 128 self.assertRaises(TypeError, self.enum, 'abc', 2, 3) # too many arguments 129 130 @support.cpython_only 131 def test_tuple_reuse(self): 132 # Tests an implementation detail where tuple is reused 133 # whenever nothing else holds a reference to it 134 self.assertEqual(len(set(map(id, list(enumerate(self.seq))))), len(self.seq)) 135 self.assertEqual(len(set(map(id, enumerate(self.seq)))), min(1,len(self.seq))) 136 137 class MyEnum(enumerate): 138 pass 139 140 class SubclassTestCase(EnumerateTestCase): 141 142 enum = MyEnum 143 144 class TestEmpty(EnumerateTestCase): 145 146 seq, res = '', [] 147 148 class TestBig(EnumerateTestCase): 149 150 seq = range(10,20000,2) 151 res = list(zip(range(20000), seq)) 152 153 class TestReversed(unittest.TestCase, PickleTest): 154 155 def test_simple(self): 156 class A: 157 def __getitem__(self, i): 158 if i < 5: 159 return str(i) 160 raise StopIteration 161 def __len__(self): 162 return 5 163 for data in 'abc', range(5), tuple(enumerate('abc')), A(), range(1,17,5): 164 self.assertEqual(list(data)[::-1], list(reversed(data))) 165 self.assertRaises(TypeError, reversed, {}) 166 # don't allow keyword arguments 167 self.assertRaises(TypeError, reversed, [], a=1) 168 169 def test_range_optimization(self): 170 x = range(1) 171 self.assertEqual(type(reversed(x)), type(iter(x))) 172 173 def test_len(self): 174 for s in ('hello', tuple('hello'), list('hello'), range(5)): 175 self.assertEqual(operator.length_hint(reversed(s)), len(s)) 176 r = reversed(s) 177 list(r) 178 self.assertEqual(operator.length_hint(r), 0) 179 class SeqWithWeirdLen: 180 called = False 181 def __len__(self): 182 if not self.called: 183 self.called = True 184 return 10 185 raise ZeroDivisionError 186 def __getitem__(self, index): 187 return index 188 r = reversed(SeqWithWeirdLen()) 189 self.assertRaises(ZeroDivisionError, operator.length_hint, r) 190 191 192 def test_gc(self): 193 class Seq: 194 def __len__(self): 195 return 10 196 def __getitem__(self, index): 197 return index 198 s = Seq() 199 r = reversed(s) 200 s.r = r 201 202 def test_args(self): 203 self.assertRaises(TypeError, reversed) 204 self.assertRaises(TypeError, reversed, [], 'extra') 205 206 @unittest.skipUnless(hasattr(sys, 'getrefcount'), 'test needs sys.getrefcount()') 207 def test_bug1229429(self): 208 # this bug was never in reversed, it was in 209 # PyObject_CallMethod, and reversed_new calls that sometimes. 210 def f(): 211 pass 212 r = f.__reversed__ = object() 213 rc = sys.getrefcount(r) 214 for i in range(10): 215 try: 216 reversed(f) 217 except TypeError: 218 pass 219 else: 220 self.fail("non-callable __reversed__ didn't raise!") 221 self.assertEqual(rc, sys.getrefcount(r)) 222 223 def test_objmethods(self): 224 # Objects must have __len__() and __getitem__() implemented. 225 class NoLen(object): 226 def __getitem__(self, i): return 1 227 nl = NoLen() 228 self.assertRaises(TypeError, reversed, nl) 229 230 class NoGetItem(object): 231 def __len__(self): return 2 232 ngi = NoGetItem() 233 self.assertRaises(TypeError, reversed, ngi) 234 235 class Blocked(object): 236 def __getitem__(self, i): return 1 237 def __len__(self): return 2 238 __reversed__ = None 239 b = Blocked() 240 self.assertRaises(TypeError, reversed, b) 241 242 def test_pickle(self): 243 for data in 'abc', range(5), tuple(enumerate('abc')), range(1,17,5): 244 self.check_pickle(reversed(data), list(data)[::-1]) 245 246 247 class EnumerateStartTestCase(EnumerateTestCase): 248 249 def test_basicfunction(self): 250 e = self.enum(self.seq) 251 self.assertEqual(iter(e), e) 252 self.assertEqual(list(self.enum(self.seq)), self.res) 253 254 255 class TestStart(EnumerateStartTestCase): 256 257 enum = lambda self, i: enumerate(i, start=11) 258 seq, res = 'abc', [(11, 'a'), (12, 'b'), (13, 'c')] 259 260 261 class TestLongStart(EnumerateStartTestCase): 262 263 enum = lambda self, i: enumerate(i, start=sys.maxsize+1) 264 seq, res = 'abc', [(sys.maxsize+1,'a'), (sys.maxsize+2,'b'), 265 (sys.maxsize+3,'c')] 266 267 268 if __name__ == "__main__": 269 unittest.main() 270