Home | History | Annotate | Download | only in test
      1 """Unit tests for the memoryview
      2 
      3 XXX We need more tests! Some tests are in test_bytes
      4 """
      5 
      6 import unittest
      7 import sys
      8 import gc
      9 import weakref
     10 import array
     11 from test import test_support
     12 import io
     13 
     14 
     15 class AbstractMemoryTests:
     16     source_bytes = b"abcdef"
     17 
     18     @property
     19     def _source(self):
     20         return self.source_bytes
     21 
     22     @property
     23     def _types(self):
     24         return filter(None, [self.ro_type, self.rw_type])
     25 
     26     def check_getitem_with_type(self, tp):
     27         item = self.getitem_type
     28         b = tp(self._source)
     29         oldrefcount = sys.getrefcount(b)
     30         m = self._view(b)
     31         self.assertEqual(m[0], item(b"a"))
     32         self.assertIsInstance(m[0], bytes)
     33         self.assertEqual(m[5], item(b"f"))
     34         self.assertEqual(m[-1], item(b"f"))
     35         self.assertEqual(m[-6], item(b"a"))
     36         # Bounds checking
     37         self.assertRaises(IndexError, lambda: m[6])
     38         self.assertRaises(IndexError, lambda: m[-7])
     39         self.assertRaises(IndexError, lambda: m[sys.maxsize])
     40         self.assertRaises(IndexError, lambda: m[-sys.maxsize])
     41         # Type checking
     42         self.assertRaises(TypeError, lambda: m[None])
     43         self.assertRaises(TypeError, lambda: m[0.0])
     44         self.assertRaises(TypeError, lambda: m["a"])
     45         m = None
     46         self.assertEqual(sys.getrefcount(b), oldrefcount)
     47 
     48     def test_getitem(self):
     49         for tp in self._types:
     50             self.check_getitem_with_type(tp)
     51 
     52     def test_iter(self):
     53         for tp in self._types:
     54             b = tp(self._source)
     55             m = self._view(b)
     56             self.assertEqual(list(m), [m[i] for i in range(len(m))])
     57 
     58     def test_repr(self):
     59         for tp in self._types:
     60             b = tp(self._source)
     61             m = self._view(b)
     62             self.assertIsInstance(m.__repr__(), str)
     63 
     64     def test_setitem_readonly(self):
     65         if not self.ro_type:
     66             return
     67         b = self.ro_type(self._source)
     68         oldrefcount = sys.getrefcount(b)
     69         m = self._view(b)
     70         def setitem(value):
     71             m[0] = value
     72         self.assertRaises(TypeError, setitem, b"a")
     73         self.assertRaises(TypeError, setitem, 65)
     74         self.assertRaises(TypeError, setitem, memoryview(b"a"))
     75         m = None
     76         self.assertEqual(sys.getrefcount(b), oldrefcount)
     77 
     78     def test_setitem_writable(self):
     79         if not self.rw_type:
     80             return
     81         tp = self.rw_type
     82         b = self.rw_type(self._source)
     83         oldrefcount = sys.getrefcount(b)
     84         m = self._view(b)
     85         m[0] = tp(b"0")
     86         self._check_contents(tp, b, b"0bcdef")
     87         m[1:3] = tp(b"12")
     88         self._check_contents(tp, b, b"012def")
     89         m[1:1] = tp(b"")
     90         self._check_contents(tp, b, b"012def")
     91         m[:] = tp(b"abcdef")
     92         self._check_contents(tp, b, b"abcdef")
     93 
     94         # Overlapping copies of a view into itself
     95         m[0:3] = m[2:5]
     96         self._check_contents(tp, b, b"cdedef")
     97         m[:] = tp(b"abcdef")
     98         m[2:5] = m[0:3]
     99         self._check_contents(tp, b, b"ababcf")
    100 
    101         def setitem(key, value):
    102             m[key] = tp(value)
    103         # Bounds checking
    104         self.assertRaises(IndexError, setitem, 6, b"a")
    105         self.assertRaises(IndexError, setitem, -7, b"a")
    106         self.assertRaises(IndexError, setitem, sys.maxsize, b"a")
    107         self.assertRaises(IndexError, setitem, -sys.maxsize, b"a")
    108         # Wrong index/slice types
    109         self.assertRaises(TypeError, setitem, 0.0, b"a")
    110         self.assertRaises(TypeError, setitem, (0,), b"a")
    111         self.assertRaises(TypeError, setitem, "a", b"a")
    112         # Trying to resize the memory object
    113         self.assertRaises(ValueError, setitem, 0, b"")
    114         self.assertRaises(ValueError, setitem, 0, b"ab")
    115         self.assertRaises(ValueError, setitem, slice(1,1), b"a")
    116         self.assertRaises(ValueError, setitem, slice(0,2), b"a")
    117 
    118         m = None
    119         self.assertEqual(sys.getrefcount(b), oldrefcount)
    120 
    121     def test_delitem(self):
    122         for tp in self._types:
    123             b = tp(self._source)
    124             m = self._view(b)
    125             with self.assertRaises(TypeError):
    126                 del m[1]
    127             with self.assertRaises(TypeError):
    128                 del m[1:4]
    129 
    130     def test_tobytes(self):
    131         for tp in self._types:
    132             m = self._view(tp(self._source))
    133             b = m.tobytes()
    134             # This calls self.getitem_type() on each separate byte of b"abcdef"
    135             expected = b"".join(
    136                 self.getitem_type(c) for c in b"abcdef")
    137             self.assertEqual(b, expected)
    138             self.assertIsInstance(b, bytes)
    139 
    140     def test_tolist(self):
    141         for tp in self._types:
    142             m = self._view(tp(self._source))
    143             l = m.tolist()
    144             self.assertEqual(l, map(ord, b"abcdef"))
    145 
    146     def test_compare(self):
    147         # memoryviews can compare for equality with other objects
    148         # having the buffer interface.
    149         for tp in self._types:
    150             m = self._view(tp(self._source))
    151             for tp_comp in self._types:
    152                 self.assertTrue(m == tp_comp(b"abcdef"))
    153                 self.assertFalse(m != tp_comp(b"abcdef"))
    154                 self.assertFalse(m == tp_comp(b"abcde"))
    155                 self.assertTrue(m != tp_comp(b"abcde"))
    156                 self.assertFalse(m == tp_comp(b"abcde1"))
    157                 self.assertTrue(m != tp_comp(b"abcde1"))
    158             self.assertTrue(m == m)
    159             self.assertTrue(m == m[:])
    160             self.assertTrue(m[0:6] == m[:])
    161             self.assertFalse(m[0:5] == m)
    162 
    163             # Comparison with objects which don't support the buffer API
    164             self.assertFalse(m == u"abcdef")
    165             self.assertTrue(m != u"abcdef")
    166             self.assertFalse(u"abcdef" == m)
    167             self.assertTrue(u"abcdef" != m)
    168 
    169             # Unordered comparisons are unimplemented, and therefore give
    170             # arbitrary results (they raise a TypeError in py3k)
    171 
    172     def check_attributes_with_type(self, tp):
    173         m = self._view(tp(self._source))
    174         self.assertEqual(m.format, self.format)
    175         self.assertIsInstance(m.format, str)
    176         self.assertEqual(m.itemsize, self.itemsize)
    177         self.assertEqual(m.ndim, 1)
    178         self.assertEqual(m.shape, (6,))
    179         self.assertEqual(len(m), 6)
    180         self.assertEqual(m.strides, (self.itemsize,))
    181         self.assertEqual(m.suboffsets, None)
    182         return m
    183 
    184     def test_attributes_readonly(self):
    185         if not self.ro_type:
    186             return
    187         m = self.check_attributes_with_type(self.ro_type)
    188         self.assertEqual(m.readonly, True)
    189 
    190     def test_attributes_writable(self):
    191         if not self.rw_type:
    192             return
    193         m = self.check_attributes_with_type(self.rw_type)
    194         self.assertEqual(m.readonly, False)
    195 
    196     # Disabled: unicode uses the old buffer API in 2.x
    197 
    198     #def test_getbuffer(self):
    199         ## Test PyObject_GetBuffer() on a memoryview object.
    200         #for tp in self._types:
    201             #b = tp(self._source)
    202             #oldrefcount = sys.getrefcount(b)
    203             #m = self._view(b)
    204             #oldviewrefcount = sys.getrefcount(m)
    205             #s = unicode(m, "utf-8")
    206             #self._check_contents(tp, b, s.encode("utf-8"))
    207             #self.assertEqual(sys.getrefcount(m), oldviewrefcount)
    208             #m = None
    209             #self.assertEqual(sys.getrefcount(b), oldrefcount)
    210 
    211     def test_gc(self):
    212         for tp in self._types:
    213             if not isinstance(tp, type):
    214                 # If tp is a factory rather than a plain type, skip
    215                 continue
    216 
    217             class MySource(tp):
    218                 pass
    219             class MyObject:
    220                 pass
    221 
    222             # Create a reference cycle through a memoryview object
    223             b = MySource(tp(b'abc'))
    224             m = self._view(b)
    225             o = MyObject()
    226             b.m = m
    227             b.o = o
    228             wr = weakref.ref(o)
    229             b = m = o = None
    230             # The cycle must be broken
    231             gc.collect()
    232             self.assertTrue(wr() is None, wr())
    233 
    234     def test_writable_readonly(self):
    235         # Issue #10451: memoryview incorrectly exposes a readonly
    236         # buffer as writable causing a segfault if using mmap
    237         tp = self.ro_type
    238         if tp is None:
    239             return
    240         b = tp(self._source)
    241         m = self._view(b)
    242         i = io.BytesIO(b'ZZZZ')
    243         self.assertRaises(TypeError, i.readinto, m)
    244 
    245 # Variations on source objects for the buffer: bytes-like objects, then arrays
    246 # with itemsize > 1.
    247 # NOTE: support for multi-dimensional objects is unimplemented.
    248 
    249 class BaseBytesMemoryTests(AbstractMemoryTests):
    250     ro_type = bytes
    251     rw_type = bytearray
    252     getitem_type = bytes
    253     itemsize = 1
    254     format = 'B'
    255 
    256 # Disabled: array.array() does not support the new buffer API in 2.x
    257 
    258 #class BaseArrayMemoryTests(AbstractMemoryTests):
    259     #ro_type = None
    260     #rw_type = lambda self, b: array.array('i', map(ord, b))
    261     #getitem_type = lambda self, b: array.array('i', map(ord, b)).tostring()
    262     #itemsize = array.array('i').itemsize
    263     #format = 'i'
    264 
    265     #def test_getbuffer(self):
    266         ## XXX Test should be adapted for non-byte buffers
    267         #pass
    268 
    269     #def test_tolist(self):
    270         ## XXX NotImplementedError: tolist() only supports byte views
    271         #pass
    272 
    273 
    274 # Variations on indirection levels: memoryview, slice of memoryview,
    275 # slice of slice of memoryview.
    276 # This is important to test allocation subtleties.
    277 
    278 class BaseMemoryviewTests:
    279     def _view(self, obj):
    280         return memoryview(obj)
    281 
    282     def _check_contents(self, tp, obj, contents):
    283         self.assertEqual(obj, tp(contents))
    284 
    285 class BaseMemorySliceTests:
    286     source_bytes = b"XabcdefY"
    287 
    288     def _view(self, obj):
    289         m = memoryview(obj)
    290         return m[1:7]
    291 
    292     def _check_contents(self, tp, obj, contents):
    293         self.assertEqual(obj[1:7], tp(contents))
    294 
    295     def test_refs(self):
    296         for tp in self._types:
    297             m = memoryview(tp(self._source))
    298             oldrefcount = sys.getrefcount(m)
    299             m[1:2]
    300             self.assertEqual(sys.getrefcount(m), oldrefcount)
    301 
    302 class BaseMemorySliceSliceTests:
    303     source_bytes = b"XabcdefY"
    304 
    305     def _view(self, obj):
    306         m = memoryview(obj)
    307         return m[:7][1:]
    308 
    309     def _check_contents(self, tp, obj, contents):
    310         self.assertEqual(obj[1:7], tp(contents))
    311 
    312 
    313 # Concrete test classes
    314 
    315 class BytesMemoryviewTest(unittest.TestCase,
    316     BaseMemoryviewTests, BaseBytesMemoryTests):
    317 
    318     def test_constructor(self):
    319         for tp in self._types:
    320             ob = tp(self._source)
    321             self.assertTrue(memoryview(ob))
    322             self.assertTrue(memoryview(object=ob))
    323             self.assertRaises(TypeError, memoryview)
    324             self.assertRaises(TypeError, memoryview, ob, ob)
    325             self.assertRaises(TypeError, memoryview, argument=ob)
    326             self.assertRaises(TypeError, memoryview, ob, argument=True)
    327 
    328 #class ArrayMemoryviewTest(unittest.TestCase,
    329     #BaseMemoryviewTests, BaseArrayMemoryTests):
    330 
    331     #def test_array_assign(self):
    332         ## Issue #4569: segfault when mutating a memoryview with itemsize != 1
    333         #a = array.array('i', range(10))
    334         #m = memoryview(a)
    335         #new_a = array.array('i', range(9, -1, -1))
    336         #m[:] = new_a
    337         #self.assertEqual(a, new_a)
    338 
    339 
    340 class BytesMemorySliceTest(unittest.TestCase,
    341     BaseMemorySliceTests, BaseBytesMemoryTests):
    342     pass
    343 
    344 #class ArrayMemorySliceTest(unittest.TestCase,
    345     #BaseMemorySliceTests, BaseArrayMemoryTests):
    346     #pass
    347 
    348 class BytesMemorySliceSliceTest(unittest.TestCase,
    349     BaseMemorySliceSliceTests, BaseBytesMemoryTests):
    350     pass
    351 
    352 #class ArrayMemorySliceSliceTest(unittest.TestCase,
    353     #BaseMemorySliceSliceTests, BaseArrayMemoryTests):
    354     #pass
    355 
    356 
    357 def test_main():
    358     test_support.run_unittest(__name__)
    359 
    360 if __name__ == "__main__":
    361     test_main()
    362