Home | History | Annotate | Download | only in test
      1 import unittest
      2 from test import test_support
      3 import operator
      4 from sys import maxint
      5 maxsize = test_support.MAX_Py_ssize_t
      6 minsize = -maxsize-1
      7 
      8 class oldstyle:
      9     def __index__(self):
     10         return self.ind
     11 
     12 class newstyle(object):
     13     def __index__(self):
     14         return self.ind
     15 
     16 class TrapInt(int):
     17     def __index__(self):
     18         return self
     19 
     20 class TrapLong(long):
     21     def __index__(self):
     22         return self
     23 
     24 class BaseTestCase(unittest.TestCase):
     25     def setUp(self):
     26         self.o = oldstyle()
     27         self.n = newstyle()
     28 
     29     def test_basic(self):
     30         self.o.ind = -2
     31         self.n.ind = 2
     32         self.assertEqual(operator.index(self.o), -2)
     33         self.assertEqual(operator.index(self.n), 2)
     34 
     35     def test_slice(self):
     36         self.o.ind = 1
     37         self.n.ind = 2
     38         slc = slice(self.o, self.o, self.o)
     39         check_slc = slice(1, 1, 1)
     40         self.assertEqual(slc.indices(self.o), check_slc.indices(1))
     41         slc = slice(self.n, self.n, self.n)
     42         check_slc = slice(2, 2, 2)
     43         self.assertEqual(slc.indices(self.n), check_slc.indices(2))
     44 
     45     def test_wrappers(self):
     46         self.o.ind = 4
     47         self.n.ind = 5
     48         self.assertEqual(6 .__index__(), 6)
     49         self.assertEqual(-7L.__index__(), -7)
     50         self.assertEqual(self.o.__index__(), 4)
     51         self.assertEqual(self.n.__index__(), 5)
     52         self.assertEqual(True.__index__(), 1)
     53         self.assertEqual(False.__index__(), 0)
     54 
     55     def test_subclasses(self):
     56         r = range(10)
     57         self.assertEqual(r[TrapInt(5):TrapInt(10)], r[5:10])
     58         self.assertEqual(r[TrapLong(5):TrapLong(10)], r[5:10])
     59         self.assertEqual(slice(TrapInt()).indices(0), (0,0,1))
     60         self.assertEqual(slice(TrapLong(0)).indices(0), (0,0,1))
     61 
     62     def test_error(self):
     63         self.o.ind = 'dumb'
     64         self.n.ind = 'bad'
     65         self.assertRaises(TypeError, operator.index, self.o)
     66         self.assertRaises(TypeError, operator.index, self.n)
     67         self.assertRaises(TypeError, slice(self.o).indices, 0)
     68         self.assertRaises(TypeError, slice(self.n).indices, 0)
     69 
     70 
     71 class SeqTestCase(unittest.TestCase):
     72     # This test case isn't run directly. It just defines common tests
     73     # to the different sequence types below
     74     def setUp(self):
     75         self.o = oldstyle()
     76         self.n = newstyle()
     77         self.o2 = oldstyle()
     78         self.n2 = newstyle()
     79 
     80     def test_index(self):
     81         self.o.ind = -2
     82         self.n.ind = 2
     83         self.assertEqual(self.seq[self.n], self.seq[2])
     84         self.assertEqual(self.seq[self.o], self.seq[-2])
     85 
     86     def test_slice(self):
     87         self.o.ind = 1
     88         self.o2.ind = 3
     89         self.n.ind = 2
     90         self.n2.ind = 4
     91         self.assertEqual(self.seq[self.o:self.o2], self.seq[1:3])
     92         self.assertEqual(self.seq[self.n:self.n2], self.seq[2:4])
     93 
     94     def test_slice_bug7532(self):
     95         seqlen = len(self.seq)
     96         self.o.ind = int(seqlen * 1.5)
     97         self.n.ind = seqlen + 2
     98         self.assertEqual(self.seq[self.o:], self.seq[0:0])
     99         self.assertEqual(self.seq[:self.o], self.seq)
    100         self.assertEqual(self.seq[self.n:], self.seq[0:0])
    101         self.assertEqual(self.seq[:self.n], self.seq)
    102         if isinstance(self.seq, ClassicSeq):
    103             return
    104         # These tests fail for ClassicSeq (see bug #7532)
    105         self.o2.ind = -seqlen - 2
    106         self.n2.ind = -int(seqlen * 1.5)
    107         self.assertEqual(self.seq[self.o2:], self.seq)
    108         self.assertEqual(self.seq[:self.o2], self.seq[0:0])
    109         self.assertEqual(self.seq[self.n2:], self.seq)
    110         self.assertEqual(self.seq[:self.n2], self.seq[0:0])
    111 
    112     def test_repeat(self):
    113         self.o.ind = 3
    114         self.n.ind = 2
    115         self.assertEqual(self.seq * self.o, self.seq * 3)
    116         self.assertEqual(self.seq * self.n, self.seq * 2)
    117         self.assertEqual(self.o * self.seq, self.seq * 3)
    118         self.assertEqual(self.n * self.seq, self.seq * 2)
    119 
    120     def test_wrappers(self):
    121         self.o.ind = 4
    122         self.n.ind = 5
    123         self.assertEqual(self.seq.__getitem__(self.o), self.seq[4])
    124         self.assertEqual(self.seq.__mul__(self.o), self.seq * 4)
    125         self.assertEqual(self.seq.__rmul__(self.o), self.seq * 4)
    126         self.assertEqual(self.seq.__getitem__(self.n), self.seq[5])
    127         self.assertEqual(self.seq.__mul__(self.n), self.seq * 5)
    128         self.assertEqual(self.seq.__rmul__(self.n), self.seq * 5)
    129 
    130     def test_subclasses(self):
    131         self.assertEqual(self.seq[TrapInt()], self.seq[0])
    132         self.assertEqual(self.seq[TrapLong()], self.seq[0])
    133 
    134     def test_error(self):
    135         self.o.ind = 'dumb'
    136         self.n.ind = 'bad'
    137         indexobj = lambda x, obj: obj.seq[x]
    138         self.assertRaises(TypeError, indexobj, self.o, self)
    139         self.assertRaises(TypeError, indexobj, self.n, self)
    140         sliceobj = lambda x, obj: obj.seq[x:]
    141         self.assertRaises(TypeError, sliceobj, self.o, self)
    142         self.assertRaises(TypeError, sliceobj, self.n, self)
    143 
    144 
    145 class ListTestCase(SeqTestCase):
    146     seq = [0,10,20,30,40,50]
    147 
    148     def test_setdelitem(self):
    149         self.o.ind = -2
    150         self.n.ind = 2
    151         lst = list('ab!cdefghi!j')
    152         del lst[self.o]
    153         del lst[self.n]
    154         lst[self.o] = 'X'
    155         lst[self.n] = 'Y'
    156         self.assertEqual(lst, list('abYdefghXj'))
    157 
    158         lst = [5, 6, 7, 8, 9, 10, 11]
    159         lst.__setitem__(self.n, "here")
    160         self.assertEqual(lst, [5, 6, "here", 8, 9, 10, 11])
    161         lst.__delitem__(self.n)
    162         self.assertEqual(lst, [5, 6, 8, 9, 10, 11])
    163 
    164     def test_inplace_repeat(self):
    165         self.o.ind = 2
    166         self.n.ind = 3
    167         lst = [6, 4]
    168         lst *= self.o
    169         self.assertEqual(lst, [6, 4, 6, 4])
    170         lst *= self.n
    171         self.assertEqual(lst, [6, 4, 6, 4] * 3)
    172 
    173         lst = [5, 6, 7, 8, 9, 11]
    174         l2 = lst.__imul__(self.n)
    175         self.assertIs(l2, lst)
    176         self.assertEqual(lst, [5, 6, 7, 8, 9, 11] * 3)
    177 
    178 
    179 class _BaseSeq:
    180 
    181     def __init__(self, iterable):
    182         self._list = list(iterable)
    183 
    184     def __repr__(self):
    185         return repr(self._list)
    186 
    187     def __eq__(self, other):
    188         return self._list == other
    189 
    190     def __len__(self):
    191         return len(self._list)
    192 
    193     def __mul__(self, n):
    194         return self.__class__(self._list*n)
    195     __rmul__ = __mul__
    196 
    197     def __getitem__(self, index):
    198         return self._list[index]
    199 
    200 
    201 class _GetSliceMixin:
    202 
    203     def __getslice__(self, i, j):
    204         return self._list.__getslice__(i, j)
    205 
    206 
    207 class ClassicSeq(_BaseSeq): pass
    208 class NewSeq(_BaseSeq, object): pass
    209 class ClassicSeqDeprecated(_GetSliceMixin, ClassicSeq): pass
    210 class NewSeqDeprecated(_GetSliceMixin, NewSeq): pass
    211 
    212 
    213 class TupleTestCase(SeqTestCase):
    214     seq = (0,10,20,30,40,50)
    215 
    216 class StringTestCase(SeqTestCase):
    217     seq = "this is a test"
    218 
    219 class ByteArrayTestCase(SeqTestCase):
    220     seq = bytearray("this is a test")
    221 
    222 class UnicodeTestCase(SeqTestCase):
    223     seq = u"this is a test"
    224 
    225 class ClassicSeqTestCase(SeqTestCase):
    226     seq = ClassicSeq((0,10,20,30,40,50))
    227 
    228 class NewSeqTestCase(SeqTestCase):
    229     seq = NewSeq((0,10,20,30,40,50))
    230 
    231 class ClassicSeqDeprecatedTestCase(SeqTestCase):
    232     seq = ClassicSeqDeprecated((0,10,20,30,40,50))
    233 
    234 class NewSeqDeprecatedTestCase(SeqTestCase):
    235     seq = NewSeqDeprecated((0,10,20,30,40,50))
    236 
    237 
    238 class XRangeTestCase(unittest.TestCase):
    239 
    240     def test_xrange(self):
    241         n = newstyle()
    242         n.ind = 5
    243         self.assertEqual(xrange(1, 20)[n], 6)
    244         self.assertEqual(xrange(1, 20).__getitem__(n), 6)
    245 
    246 class OverflowTestCase(unittest.TestCase):
    247 
    248     def setUp(self):
    249         self.pos = 2**100
    250         self.neg = -self.pos
    251 
    252     def test_large_longs(self):
    253         self.assertEqual(self.pos.__index__(), self.pos)
    254         self.assertEqual(self.neg.__index__(), self.neg)
    255 
    256     def _getitem_helper(self, base):
    257         class GetItem(base):
    258             def __len__(self):
    259                 return maxint # cannot return long here
    260             def __getitem__(self, key):
    261                 return key
    262         x = GetItem()
    263         self.assertEqual(x[self.pos], self.pos)
    264         self.assertEqual(x[self.neg], self.neg)
    265         self.assertEqual(x[self.neg:self.pos].indices(maxsize),
    266                          (0, maxsize, 1))
    267         self.assertEqual(x[self.neg:self.pos:1].indices(maxsize),
    268                          (0, maxsize, 1))
    269 
    270     def _getslice_helper_deprecated(self, base):
    271         class GetItem(base):
    272             def __len__(self):
    273                 return maxint # cannot return long here
    274             def __getitem__(self, key):
    275                 return key
    276             def __getslice__(self, i, j):
    277                 return i, j
    278         x = GetItem()
    279         self.assertEqual(x[self.pos], self.pos)
    280         self.assertEqual(x[self.neg], self.neg)
    281         self.assertEqual(x[self.neg:self.pos], (maxint+minsize, maxsize))
    282         self.assertEqual(x[self.neg:self.pos:1].indices(maxsize),
    283                          (0, maxsize, 1))
    284 
    285     def test_getitem(self):
    286         self._getitem_helper(object)
    287         with test_support.check_py3k_warnings():
    288             self._getslice_helper_deprecated(object)
    289 
    290     def test_getitem_classic(self):
    291         class Empty: pass
    292         # XXX This test fails (see bug #7532)
    293         #self._getitem_helper(Empty)
    294         with test_support.check_py3k_warnings():
    295             self._getslice_helper_deprecated(Empty)
    296 
    297     def test_sequence_repeat(self):
    298         self.assertRaises(OverflowError, lambda: "a" * self.pos)
    299         self.assertRaises(OverflowError, lambda: "a" * self.neg)
    300 
    301 
    302 def test_main():
    303     test_support.run_unittest(
    304         BaseTestCase,
    305         ListTestCase,
    306         TupleTestCase,
    307         ByteArrayTestCase,
    308         StringTestCase,
    309         UnicodeTestCase,
    310         ClassicSeqTestCase,
    311         NewSeqTestCase,
    312         XRangeTestCase,
    313         OverflowTestCase,
    314     )
    315     with test_support.check_py3k_warnings():
    316         test_support.run_unittest(
    317             ClassicSeqDeprecatedTestCase,
    318             NewSeqDeprecatedTestCase,
    319         )
    320 
    321 
    322 if __name__ == "__main__":
    323     test_main()
    324