Home | History | Annotate | Download | only in test
      1 # Test iterators.
      2 
      3 import unittest
      4 from test.test_support import run_unittest, TESTFN, unlink, have_unicode, \
      5                               check_py3k_warnings, cpython_only
      6 
      7 # Test result of triple loop (too big to inline)
      8 TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
      9             (0, 1, 0), (0, 1, 1), (0, 1, 2),
     10             (0, 2, 0), (0, 2, 1), (0, 2, 2),
     11 
     12             (1, 0, 0), (1, 0, 1), (1, 0, 2),
     13             (1, 1, 0), (1, 1, 1), (1, 1, 2),
     14             (1, 2, 0), (1, 2, 1), (1, 2, 2),
     15 
     16             (2, 0, 0), (2, 0, 1), (2, 0, 2),
     17             (2, 1, 0), (2, 1, 1), (2, 1, 2),
     18             (2, 2, 0), (2, 2, 1), (2, 2, 2)]
     19 
     20 # Helper classes
     21 
     22 class BasicIterClass:
     23     def __init__(self, n):
     24         self.n = n
     25         self.i = 0
     26     def next(self):
     27         res = self.i
     28         if res >= self.n:
     29             raise StopIteration
     30         self.i = res + 1
     31         return res
     32 
     33 class IteratingSequenceClass:
     34     def __init__(self, n):
     35         self.n = n
     36     def __iter__(self):
     37         return BasicIterClass(self.n)
     38 
     39 class SequenceClass:
     40     def __init__(self, n):
     41         self.n = n
     42     def __getitem__(self, i):
     43         if 0 <= i < self.n:
     44             return i
     45         else:
     46             raise IndexError
     47 
     48 # Main test suite
     49 
     50 class TestCase(unittest.TestCase):
     51 
     52     # Helper to check that an iterator returns a given sequence
     53     def check_iterator(self, it, seq):
     54         res = []
     55         while 1:
     56             try:
     57                 val = it.next()
     58             except StopIteration:
     59                 break
     60             res.append(val)
     61         self.assertEqual(res, seq)
     62 
     63     # Helper to check that a for loop generates a given sequence
     64     def check_for_loop(self, expr, seq):
     65         res = []
     66         for val in expr:
     67             res.append(val)
     68         self.assertEqual(res, seq)
     69 
     70     # Test basic use of iter() function
     71     def test_iter_basic(self):
     72         self.check_iterator(iter(range(10)), range(10))
     73 
     74     # Test that iter(iter(x)) is the same as iter(x)
     75     def test_iter_idempotency(self):
     76         seq = range(10)
     77         it = iter(seq)
     78         it2 = iter(it)
     79         self.assertTrue(it is it2)
     80 
     81     # Test that for loops over iterators work
     82     def test_iter_for_loop(self):
     83         self.check_for_loop(iter(range(10)), range(10))
     84 
     85     # Test several independent iterators over the same list
     86     def test_iter_independence(self):
     87         seq = range(3)
     88         res = []
     89         for i in iter(seq):
     90             for j in iter(seq):
     91                 for k in iter(seq):
     92                     res.append((i, j, k))
     93         self.assertEqual(res, TRIPLETS)
     94 
     95     # Test triple list comprehension using iterators
     96     def test_nested_comprehensions_iter(self):
     97         seq = range(3)
     98         res = [(i, j, k)
     99                for i in iter(seq) for j in iter(seq) for k in iter(seq)]
    100         self.assertEqual(res, TRIPLETS)
    101 
    102     # Test triple list comprehension without iterators
    103     def test_nested_comprehensions_for(self):
    104         seq = range(3)
    105         res = [(i, j, k) for i in seq for j in seq for k in seq]
    106         self.assertEqual(res, TRIPLETS)
    107 
    108     # Test a class with __iter__ in a for loop
    109     def test_iter_class_for(self):
    110         self.check_for_loop(IteratingSequenceClass(10), range(10))
    111 
    112     # Test a class with __iter__ with explicit iter()
    113     def test_iter_class_iter(self):
    114         self.check_iterator(iter(IteratingSequenceClass(10)), range(10))
    115 
    116     # Test for loop on a sequence class without __iter__
    117     def test_seq_class_for(self):
    118         self.check_for_loop(SequenceClass(10), range(10))
    119 
    120     # Test iter() on a sequence class without __iter__
    121     def test_seq_class_iter(self):
    122         self.check_iterator(iter(SequenceClass(10)), range(10))
    123 
    124     # Test a new_style class with __iter__ but no next() method
    125     def test_new_style_iter_class(self):
    126         class IterClass(object):
    127             def __iter__(self):
    128                 return self
    129         self.assertRaises(TypeError, iter, IterClass())
    130 
    131     # Test two-argument iter() with callable instance
    132     def test_iter_callable(self):
    133         class C:
    134             def __init__(self):
    135                 self.i = 0
    136             def __call__(self):
    137                 i = self.i
    138                 self.i = i + 1
    139                 if i > 100:
    140                     raise IndexError # Emergency stop
    141                 return i
    142         self.check_iterator(iter(C(), 10), range(10))
    143 
    144     # Test two-argument iter() with function
    145     def test_iter_function(self):
    146         def spam(state=[0]):
    147             i = state[0]
    148             state[0] = i+1
    149             return i
    150         self.check_iterator(iter(spam, 10), range(10))
    151 
    152     # Test two-argument iter() with function that raises StopIteration
    153     def test_iter_function_stop(self):
    154         def spam(state=[0]):
    155             i = state[0]
    156             if i == 10:
    157                 raise StopIteration
    158             state[0] = i+1
    159             return i
    160         self.check_iterator(iter(spam, 20), range(10))
    161 
    162     # Test exception propagation through function iterator
    163     def test_exception_function(self):
    164         def spam(state=[0]):
    165             i = state[0]
    166             state[0] = i+1
    167             if i == 10:
    168                 raise RuntimeError
    169             return i
    170         res = []
    171         try:
    172             for x in iter(spam, 20):
    173                 res.append(x)
    174         except RuntimeError:
    175             self.assertEqual(res, range(10))
    176         else:
    177             self.fail("should have raised RuntimeError")
    178 
    179     # Test exception propagation through sequence iterator
    180     def test_exception_sequence(self):
    181         class MySequenceClass(SequenceClass):
    182             def __getitem__(self, i):
    183                 if i == 10:
    184                     raise RuntimeError
    185                 return SequenceClass.__getitem__(self, i)
    186         res = []
    187         try:
    188             for x in MySequenceClass(20):
    189                 res.append(x)
    190         except RuntimeError:
    191             self.assertEqual(res, range(10))
    192         else:
    193             self.fail("should have raised RuntimeError")
    194 
    195     # Test for StopIteration from __getitem__
    196     def test_stop_sequence(self):
    197         class MySequenceClass(SequenceClass):
    198             def __getitem__(self, i):
    199                 if i == 10:
    200                     raise StopIteration
    201                 return SequenceClass.__getitem__(self, i)
    202         self.check_for_loop(MySequenceClass(20), range(10))
    203 
    204     # Test a big range
    205     def test_iter_big_range(self):
    206         self.check_for_loop(iter(range(10000)), range(10000))
    207 
    208     # Test an empty list
    209     def test_iter_empty(self):
    210         self.check_for_loop(iter([]), [])
    211 
    212     # Test a tuple
    213     def test_iter_tuple(self):
    214         self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), range(10))
    215 
    216     # Test an xrange
    217     def test_iter_xrange(self):
    218         self.check_for_loop(iter(xrange(10)), range(10))
    219 
    220     # Test a string
    221     def test_iter_string(self):
    222         self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
    223 
    224     # Test a Unicode string
    225     if have_unicode:
    226         def test_iter_unicode(self):
    227             self.check_for_loop(iter(unicode("abcde")),
    228                                 [unicode("a"), unicode("b"), unicode("c"),
    229                                  unicode("d"), unicode("e")])
    230 
    231     # Test a directory
    232     def test_iter_dict(self):
    233         dict = {}
    234         for i in range(10):
    235             dict[i] = None
    236         self.check_for_loop(dict, dict.keys())
    237 
    238     # Test a file
    239     def test_iter_file(self):
    240         f = open(TESTFN, "w")
    241         try:
    242             for i in range(5):
    243                 f.write("%d\n" % i)
    244         finally:
    245             f.close()
    246         f = open(TESTFN, "r")
    247         try:
    248             self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
    249             self.check_for_loop(f, [])
    250         finally:
    251             f.close()
    252             try:
    253                 unlink(TESTFN)
    254             except OSError:
    255                 pass
    256 
    257     # Test list()'s use of iterators.
    258     def test_builtin_list(self):
    259         self.assertEqual(list(SequenceClass(5)), range(5))
    260         self.assertEqual(list(SequenceClass(0)), [])
    261         self.assertEqual(list(()), [])
    262         self.assertEqual(list(range(10, -1, -1)), range(10, -1, -1))
    263 
    264         d = {"one": 1, "two": 2, "three": 3}
    265         self.assertEqual(list(d), d.keys())
    266 
    267         self.assertRaises(TypeError, list, list)
    268         self.assertRaises(TypeError, list, 42)
    269 
    270         f = open(TESTFN, "w")
    271         try:
    272             for i in range(5):
    273                 f.write("%d\n" % i)
    274         finally:
    275             f.close()
    276         f = open(TESTFN, "r")
    277         try:
    278             self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
    279             f.seek(0, 0)
    280             self.assertEqual(list(f),
    281                              ["0\n", "1\n", "2\n", "3\n", "4\n"])
    282         finally:
    283             f.close()
    284             try:
    285                 unlink(TESTFN)
    286             except OSError:
    287                 pass
    288 
    289     # Test tuples()'s use of iterators.
    290     def test_builtin_tuple(self):
    291         self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
    292         self.assertEqual(tuple(SequenceClass(0)), ())
    293         self.assertEqual(tuple([]), ())
    294         self.assertEqual(tuple(()), ())
    295         self.assertEqual(tuple("abc"), ("a", "b", "c"))
    296 
    297         d = {"one": 1, "two": 2, "three": 3}
    298         self.assertEqual(tuple(d), tuple(d.keys()))
    299 
    300         self.assertRaises(TypeError, tuple, list)
    301         self.assertRaises(TypeError, tuple, 42)
    302 
    303         f = open(TESTFN, "w")
    304         try:
    305             for i in range(5):
    306                 f.write("%d\n" % i)
    307         finally:
    308             f.close()
    309         f = open(TESTFN, "r")
    310         try:
    311             self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
    312             f.seek(0, 0)
    313             self.assertEqual(tuple(f),
    314                              ("0\n", "1\n", "2\n", "3\n", "4\n"))
    315         finally:
    316             f.close()
    317             try:
    318                 unlink(TESTFN)
    319             except OSError:
    320                 pass
    321 
    322     # Test filter()'s use of iterators.
    323     def test_builtin_filter(self):
    324         self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
    325         self.assertEqual(filter(None, SequenceClass(0)), [])
    326         self.assertEqual(filter(None, ()), ())
    327         self.assertEqual(filter(None, "abc"), "abc")
    328 
    329         d = {"one": 1, "two": 2, "three": 3}
    330         self.assertEqual(filter(None, d), d.keys())
    331 
    332         self.assertRaises(TypeError, filter, None, list)
    333         self.assertRaises(TypeError, filter, None, 42)
    334 
    335         class Boolean:
    336             def __init__(self, truth):
    337                 self.truth = truth
    338             def __nonzero__(self):
    339                 return self.truth
    340         bTrue = Boolean(1)
    341         bFalse = Boolean(0)
    342 
    343         class Seq:
    344             def __init__(self, *args):
    345                 self.vals = args
    346             def __iter__(self):
    347                 class SeqIter:
    348                     def __init__(self, vals):
    349                         self.vals = vals
    350                         self.i = 0
    351                     def __iter__(self):
    352                         return self
    353                     def next(self):
    354                         i = self.i
    355                         self.i = i + 1
    356                         if i < len(self.vals):
    357                             return self.vals[i]
    358                         else:
    359                             raise StopIteration
    360                 return SeqIter(self.vals)
    361 
    362         seq = Seq(*([bTrue, bFalse] * 25))
    363         self.assertEqual(filter(lambda x: not x, seq), [bFalse]*25)
    364         self.assertEqual(filter(lambda x: not x, iter(seq)), [bFalse]*25)
    365 
    366     # Test max() and min()'s use of iterators.
    367     def test_builtin_max_min(self):
    368         self.assertEqual(max(SequenceClass(5)), 4)
    369         self.assertEqual(min(SequenceClass(5)), 0)
    370         self.assertEqual(max(8, -1), 8)
    371         self.assertEqual(min(8, -1), -1)
    372 
    373         d = {"one": 1, "two": 2, "three": 3}
    374         self.assertEqual(max(d), "two")
    375         self.assertEqual(min(d), "one")
    376         self.assertEqual(max(d.itervalues()), 3)
    377         self.assertEqual(min(iter(d.itervalues())), 1)
    378 
    379         f = open(TESTFN, "w")
    380         try:
    381             f.write("medium line\n")
    382             f.write("xtra large line\n")
    383             f.write("itty-bitty line\n")
    384         finally:
    385             f.close()
    386         f = open(TESTFN, "r")
    387         try:
    388             self.assertEqual(min(f), "itty-bitty line\n")
    389             f.seek(0, 0)
    390             self.assertEqual(max(f), "xtra large line\n")
    391         finally:
    392             f.close()
    393             try:
    394                 unlink(TESTFN)
    395             except OSError:
    396                 pass
    397 
    398     # Test map()'s use of iterators.
    399     def test_builtin_map(self):
    400         self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
    401 
    402         d = {"one": 1, "two": 2, "three": 3}
    403         self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
    404         dkeys = d.keys()
    405         expected = [(i < len(d) and dkeys[i] or None,
    406                      i,
    407                      i < len(d) and dkeys[i] or None)
    408                     for i in range(5)]
    409 
    410         # Deprecated map(None, ...)
    411         with check_py3k_warnings():
    412             self.assertEqual(map(None, SequenceClass(5)), range(5))
    413             self.assertEqual(map(None, d), d.keys())
    414             self.assertEqual(map(None, d,
    415                                        SequenceClass(5),
    416                                        iter(d.iterkeys())),
    417                              expected)
    418 
    419         f = open(TESTFN, "w")
    420         try:
    421             for i in range(10):
    422                 f.write("xy" * i + "\n") # line i has len 2*i+1
    423         finally:
    424             f.close()
    425         f = open(TESTFN, "r")
    426         try:
    427             self.assertEqual(map(len, f), range(1, 21, 2))
    428         finally:
    429             f.close()
    430             try:
    431                 unlink(TESTFN)
    432             except OSError:
    433                 pass
    434 
    435     # Test zip()'s use of iterators.
    436     def test_builtin_zip(self):
    437         self.assertEqual(zip(), [])
    438         self.assertEqual(zip(*[]), [])
    439         self.assertEqual(zip(*[(1, 2), 'ab']), [(1, 'a'), (2, 'b')])
    440 
    441         self.assertRaises(TypeError, zip, None)
    442         self.assertRaises(TypeError, zip, range(10), 42)
    443         self.assertRaises(TypeError, zip, range(10), zip)
    444 
    445         self.assertEqual(zip(IteratingSequenceClass(3)),
    446                          [(0,), (1,), (2,)])
    447         self.assertEqual(zip(SequenceClass(3)),
    448                          [(0,), (1,), (2,)])
    449 
    450         d = {"one": 1, "two": 2, "three": 3}
    451         self.assertEqual(d.items(), zip(d, d.itervalues()))
    452 
    453         # Generate all ints starting at constructor arg.
    454         class IntsFrom:
    455             def __init__(self, start):
    456                 self.i = start
    457 
    458             def __iter__(self):
    459                 return self
    460 
    461             def next(self):
    462                 i = self.i
    463                 self.i = i+1
    464                 return i
    465 
    466         f = open(TESTFN, "w")
    467         try:
    468             f.write("a\n" "bbb\n" "cc\n")
    469         finally:
    470             f.close()
    471         f = open(TESTFN, "r")
    472         try:
    473             self.assertEqual(zip(IntsFrom(0), f, IntsFrom(-100)),
    474                              [(0, "a\n", -100),
    475                               (1, "bbb\n", -99),
    476                               (2, "cc\n", -98)])
    477         finally:
    478             f.close()
    479             try:
    480                 unlink(TESTFN)
    481             except OSError:
    482                 pass
    483 
    484         self.assertEqual(zip(xrange(5)), [(i,) for i in range(5)])
    485 
    486         # Classes that lie about their lengths.
    487         class NoGuessLen5:
    488             def __getitem__(self, i):
    489                 if i >= 5:
    490                     raise IndexError
    491                 return i
    492 
    493         class Guess3Len5(NoGuessLen5):
    494             def __len__(self):
    495                 return 3
    496 
    497         class Guess30Len5(NoGuessLen5):
    498             def __len__(self):
    499                 return 30
    500 
    501         self.assertEqual(len(Guess3Len5()), 3)
    502         self.assertEqual(len(Guess30Len5()), 30)
    503         self.assertEqual(zip(NoGuessLen5()), zip(range(5)))
    504         self.assertEqual(zip(Guess3Len5()), zip(range(5)))
    505         self.assertEqual(zip(Guess30Len5()), zip(range(5)))
    506 
    507         expected = [(i, i) for i in range(5)]
    508         for x in NoGuessLen5(), Guess3Len5(), Guess30Len5():
    509             for y in NoGuessLen5(), Guess3Len5(), Guess30Len5():
    510                 self.assertEqual(zip(x, y), expected)
    511 
    512     # Test reduces()'s use of iterators.
    513     def test_deprecated_builtin_reduce(self):
    514         with check_py3k_warnings():
    515             self._test_builtin_reduce()
    516 
    517     def _test_builtin_reduce(self):
    518         from operator import add
    519         self.assertEqual(reduce(add, SequenceClass(5)), 10)
    520         self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
    521         self.assertRaises(TypeError, reduce, add, SequenceClass(0))
    522         self.assertEqual(reduce(add, SequenceClass(0), 42), 42)
    523         self.assertEqual(reduce(add, SequenceClass(1)), 0)
    524         self.assertEqual(reduce(add, SequenceClass(1), 42), 42)
    525 
    526         d = {"one": 1, "two": 2, "three": 3}
    527         self.assertEqual(reduce(add, d), "".join(d.keys()))
    528 
    529     # This test case will be removed if we don't have Unicode
    530     def test_unicode_join_endcase(self):
    531 
    532         # This class inserts a Unicode object into its argument's natural
    533         # iteration, in the 3rd position.
    534         class OhPhooey:
    535             def __init__(self, seq):
    536                 self.it = iter(seq)
    537                 self.i = 0
    538 
    539             def __iter__(self):
    540                 return self
    541 
    542             def next(self):
    543                 i = self.i
    544                 self.i = i+1
    545                 if i == 2:
    546                     return unicode("fooled you!")
    547                 return self.it.next()
    548 
    549         f = open(TESTFN, "w")
    550         try:
    551             f.write("a\n" + "b\n" + "c\n")
    552         finally:
    553             f.close()
    554 
    555         f = open(TESTFN, "r")
    556         # Nasty:  string.join(s) can't know whether unicode.join() is needed
    557         # until it's seen all of s's elements.  But in this case, f's
    558         # iterator cannot be restarted.  So what we're testing here is
    559         # whether string.join() can manage to remember everything it's seen
    560         # and pass that on to unicode.join().
    561         try:
    562             got = " - ".join(OhPhooey(f))
    563             self.assertEqual(got, unicode("a\n - b\n - fooled you! - c\n"))
    564         finally:
    565             f.close()
    566             try:
    567                 unlink(TESTFN)
    568             except OSError:
    569                 pass
    570     if not have_unicode:
    571         def test_unicode_join_endcase(self): pass
    572 
    573     # Test iterators with 'x in y' and 'x not in y'.
    574     def test_in_and_not_in(self):
    575         for sc5 in IteratingSequenceClass(5), SequenceClass(5):
    576             for i in range(5):
    577                 self.assertIn(i, sc5)
    578             for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
    579                 self.assertNotIn(i, sc5)
    580 
    581         self.assertRaises(TypeError, lambda: 3 in 12)
    582         self.assertRaises(TypeError, lambda: 3 not in map)
    583 
    584         d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
    585         for k in d:
    586             self.assertIn(k, d)
    587             self.assertNotIn(k, d.itervalues())
    588         for v in d.values():
    589             self.assertIn(v, d.itervalues())
    590             self.assertNotIn(v, d)
    591         for k, v in d.iteritems():
    592             self.assertIn((k, v), d.iteritems())
    593             self.assertNotIn((v, k), d.iteritems())
    594 
    595         f = open(TESTFN, "w")
    596         try:
    597             f.write("a\n" "b\n" "c\n")
    598         finally:
    599             f.close()
    600         f = open(TESTFN, "r")
    601         try:
    602             for chunk in "abc":
    603                 f.seek(0, 0)
    604                 self.assertNotIn(chunk, f)
    605                 f.seek(0, 0)
    606                 self.assertIn((chunk + "\n"), f)
    607         finally:
    608             f.close()
    609             try:
    610                 unlink(TESTFN)
    611             except OSError:
    612                 pass
    613 
    614     # Test iterators with operator.countOf (PySequence_Count).
    615     def test_countOf(self):
    616         from operator import countOf
    617         self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
    618         self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
    619         self.assertEqual(countOf("122325", "2"), 3)
    620         self.assertEqual(countOf("122325", "6"), 0)
    621 
    622         self.assertRaises(TypeError, countOf, 42, 1)
    623         self.assertRaises(TypeError, countOf, countOf, countOf)
    624 
    625         d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
    626         for k in d:
    627             self.assertEqual(countOf(d, k), 1)
    628         self.assertEqual(countOf(d.itervalues(), 3), 3)
    629         self.assertEqual(countOf(d.itervalues(), 2j), 1)
    630         self.assertEqual(countOf(d.itervalues(), 1j), 0)
    631 
    632         f = open(TESTFN, "w")
    633         try:
    634             f.write("a\n" "b\n" "c\n" "b\n")
    635         finally:
    636             f.close()
    637         f = open(TESTFN, "r")
    638         try:
    639             for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
    640                 f.seek(0, 0)
    641                 self.assertEqual(countOf(f, letter + "\n"), count)
    642         finally:
    643             f.close()
    644             try:
    645                 unlink(TESTFN)
    646             except OSError:
    647                 pass
    648 
    649     # Test iterators with operator.indexOf (PySequence_Index).
    650     def test_indexOf(self):
    651         from operator import indexOf
    652         self.assertEqual(indexOf([1,2,2,3,2,5], 1), 0)
    653         self.assertEqual(indexOf((1,2,2,3,2,5), 2), 1)
    654         self.assertEqual(indexOf((1,2,2,3,2,5), 3), 3)
    655         self.assertEqual(indexOf((1,2,2,3,2,5), 5), 5)
    656         self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 0)
    657         self.assertRaises(ValueError, indexOf, (1,2,2,3,2,5), 6)
    658 
    659         self.assertEqual(indexOf("122325", "2"), 1)
    660         self.assertEqual(indexOf("122325", "5"), 5)
    661         self.assertRaises(ValueError, indexOf, "122325", "6")
    662 
    663         self.assertRaises(TypeError, indexOf, 42, 1)
    664         self.assertRaises(TypeError, indexOf, indexOf, indexOf)
    665 
    666         f = open(TESTFN, "w")
    667         try:
    668             f.write("a\n" "b\n" "c\n" "d\n" "e\n")
    669         finally:
    670             f.close()
    671         f = open(TESTFN, "r")
    672         try:
    673             fiter = iter(f)
    674             self.assertEqual(indexOf(fiter, "b\n"), 1)
    675             self.assertEqual(indexOf(fiter, "d\n"), 1)
    676             self.assertEqual(indexOf(fiter, "e\n"), 0)
    677             self.assertRaises(ValueError, indexOf, fiter, "a\n")
    678         finally:
    679             f.close()
    680             try:
    681                 unlink(TESTFN)
    682             except OSError:
    683                 pass
    684 
    685         iclass = IteratingSequenceClass(3)
    686         for i in range(3):
    687             self.assertEqual(indexOf(iclass, i), i)
    688         self.assertRaises(ValueError, indexOf, iclass, -1)
    689 
    690     # Test iterators with file.writelines().
    691     def test_writelines(self):
    692         f = file(TESTFN, "w")
    693 
    694         try:
    695             self.assertRaises(TypeError, f.writelines, None)
    696             self.assertRaises(TypeError, f.writelines, 42)
    697 
    698             f.writelines(["1\n", "2\n"])
    699             f.writelines(("3\n", "4\n"))
    700             f.writelines({'5\n': None})
    701             f.writelines({})
    702 
    703             # Try a big chunk too.
    704             class Iterator:
    705                 def __init__(self, start, finish):
    706                     self.start = start
    707                     self.finish = finish
    708                     self.i = self.start
    709 
    710                 def next(self):
    711                     if self.i >= self.finish:
    712                         raise StopIteration
    713                     result = str(self.i) + '\n'
    714                     self.i += 1
    715                     return result
    716 
    717                 def __iter__(self):
    718                     return self
    719 
    720             class Whatever:
    721                 def __init__(self, start, finish):
    722                     self.start = start
    723                     self.finish = finish
    724 
    725                 def __iter__(self):
    726                     return Iterator(self.start, self.finish)
    727 
    728             f.writelines(Whatever(6, 6+2000))
    729             f.close()
    730 
    731             f = file(TESTFN)
    732             expected = [str(i) + "\n" for i in range(1, 2006)]
    733             self.assertEqual(list(f), expected)
    734 
    735         finally:
    736             f.close()
    737             try:
    738                 unlink(TESTFN)
    739             except OSError:
    740                 pass
    741 
    742 
    743     # Test iterators on RHS of unpacking assignments.
    744     def test_unpack_iter(self):
    745         a, b = 1, 2
    746         self.assertEqual((a, b), (1, 2))
    747 
    748         a, b, c = IteratingSequenceClass(3)
    749         self.assertEqual((a, b, c), (0, 1, 2))
    750 
    751         try:    # too many values
    752             a, b = IteratingSequenceClass(3)
    753         except ValueError:
    754             pass
    755         else:
    756             self.fail("should have raised ValueError")
    757 
    758         try:    # not enough values
    759             a, b, c = IteratingSequenceClass(2)
    760         except ValueError:
    761             pass
    762         else:
    763             self.fail("should have raised ValueError")
    764 
    765         try:    # not iterable
    766             a, b, c = len
    767         except TypeError:
    768             pass
    769         else:
    770             self.fail("should have raised TypeError")
    771 
    772         a, b, c = {1: 42, 2: 42, 3: 42}.itervalues()
    773         self.assertEqual((a, b, c), (42, 42, 42))
    774 
    775         f = open(TESTFN, "w")
    776         lines = ("a\n", "bb\n", "ccc\n")
    777         try:
    778             for line in lines:
    779                 f.write(line)
    780         finally:
    781             f.close()
    782         f = open(TESTFN, "r")
    783         try:
    784             a, b, c = f
    785             self.assertEqual((a, b, c), lines)
    786         finally:
    787             f.close()
    788             try:
    789                 unlink(TESTFN)
    790             except OSError:
    791                 pass
    792 
    793         (a, b), (c,) = IteratingSequenceClass(2), {42: 24}
    794         self.assertEqual((a, b, c), (0, 1, 42))
    795 
    796 
    797     @cpython_only
    798     def test_ref_counting_behavior(self):
    799         class C(object):
    800             count = 0
    801             def __new__(cls):
    802                 cls.count += 1
    803                 return object.__new__(cls)
    804             def __del__(self):
    805                 cls = self.__class__
    806                 assert cls.count > 0
    807                 cls.count -= 1
    808         x = C()
    809         self.assertEqual(C.count, 1)
    810         del x
    811         self.assertEqual(C.count, 0)
    812         l = [C(), C(), C()]
    813         self.assertEqual(C.count, 3)
    814         try:
    815             a, b = iter(l)
    816         except ValueError:
    817             pass
    818         del l
    819         self.assertEqual(C.count, 0)
    820 
    821 
    822     # Make sure StopIteration is a "sink state".
    823     # This tests various things that weren't sink states in Python 2.2.1,
    824     # plus various things that always were fine.
    825 
    826     def test_sinkstate_list(self):
    827         # This used to fail
    828         a = range(5)
    829         b = iter(a)
    830         self.assertEqual(list(b), range(5))
    831         a.extend(range(5, 10))
    832         self.assertEqual(list(b), [])
    833 
    834     def test_sinkstate_tuple(self):
    835         a = (0, 1, 2, 3, 4)
    836         b = iter(a)
    837         self.assertEqual(list(b), range(5))
    838         self.assertEqual(list(b), [])
    839 
    840     def test_sinkstate_string(self):
    841         a = "abcde"
    842         b = iter(a)
    843         self.assertEqual(list(b), ['a', 'b', 'c', 'd', 'e'])
    844         self.assertEqual(list(b), [])
    845 
    846     def test_sinkstate_sequence(self):
    847         # This used to fail
    848         a = SequenceClass(5)
    849         b = iter(a)
    850         self.assertEqual(list(b), range(5))
    851         a.n = 10
    852         self.assertEqual(list(b), [])
    853 
    854     def test_sinkstate_callable(self):
    855         # This used to fail
    856         def spam(state=[0]):
    857             i = state[0]
    858             state[0] = i+1
    859             if i == 10:
    860                 raise AssertionError, "shouldn't have gotten this far"
    861             return i
    862         b = iter(spam, 5)
    863         self.assertEqual(list(b), range(5))
    864         self.assertEqual(list(b), [])
    865 
    866     def test_sinkstate_dict(self):
    867         # XXX For a more thorough test, see towards the end of:
    868         # http://mail.python.org/pipermail/python-dev/2002-July/026512.html
    869         a = {1:1, 2:2, 0:0, 4:4, 3:3}
    870         for b in iter(a), a.iterkeys(), a.iteritems(), a.itervalues():
    871             b = iter(a)
    872             self.assertEqual(len(list(b)), 5)
    873             self.assertEqual(list(b), [])
    874 
    875     def test_sinkstate_yield(self):
    876         def gen():
    877             for i in range(5):
    878                 yield i
    879         b = gen()
    880         self.assertEqual(list(b), range(5))
    881         self.assertEqual(list(b), [])
    882 
    883     def test_sinkstate_range(self):
    884         a = xrange(5)
    885         b = iter(a)
    886         self.assertEqual(list(b), range(5))
    887         self.assertEqual(list(b), [])
    888 
    889     def test_sinkstate_enumerate(self):
    890         a = range(5)
    891         e = enumerate(a)
    892         b = iter(e)
    893         self.assertEqual(list(b), zip(range(5), range(5)))
    894         self.assertEqual(list(b), [])
    895 
    896     def test_3720(self):
    897         # Avoid a crash, when an iterator deletes its next() method.
    898         class BadIterator(object):
    899             def __iter__(self):
    900                 return self
    901             def next(self):
    902                 del BadIterator.next
    903                 return 1
    904 
    905         try:
    906             for i in BadIterator() :
    907                 pass
    908         except TypeError:
    909             pass
    910 
    911     def test_extending_list_with_iterator_does_not_segfault(self):
    912         # The code to extend a list with an iterator has a fair
    913         # amount of nontrivial logic in terms of guessing how
    914         # much memory to allocate in advance, "stealing" refs,
    915         # and then shrinking at the end.  This is a basic smoke
    916         # test for that scenario.
    917         def gen():
    918             for i in range(500):
    919                 yield i
    920         lst = [0] * 500
    921         for i in range(240):
    922             lst.pop(0)
    923         lst.extend(gen())
    924         self.assertEqual(len(lst), 760)
    925 
    926 
    927 def test_main():
    928     run_unittest(TestCase)
    929 
    930 
    931 if __name__ == "__main__":
    932     test_main()
    933