Home | History | Annotate | Download | only in test
      1 import copy_reg
      2 import unittest
      3 
      4 from test import test_support
      5 from test.pickletester import ExtensionSaver
      6 
      7 class C:
      8     pass
      9 
     10 
     11 class WithoutSlots(object):
     12     pass
     13 
     14 class WithWeakref(object):
     15     __slots__ = ('__weakref__',)
     16 
     17 class WithPrivate(object):
     18     __slots__ = ('__spam',)
     19 
     20 class WithSingleString(object):
     21     __slots__ = 'spam'
     22 
     23 class WithInherited(WithSingleString):
     24     __slots__ = ('eggs',)
     25 
     26 
     27 class CopyRegTestCase(unittest.TestCase):
     28 
     29     def test_class(self):
     30         self.assertRaises(TypeError, copy_reg.pickle,
     31                           C, None, None)
     32 
     33     def test_noncallable_reduce(self):
     34         self.assertRaises(TypeError, copy_reg.pickle,
     35                           type(1), "not a callable")
     36 
     37     def test_noncallable_constructor(self):
     38         self.assertRaises(TypeError, copy_reg.pickle,
     39                           type(1), int, "not a callable")
     40 
     41     def test_bool(self):
     42         import copy
     43         self.assertEqual(True, copy.copy(True))
     44 
     45     def test_extension_registry(self):
     46         mod, func, code = 'junk1 ', ' junk2', 0xabcd
     47         e = ExtensionSaver(code)
     48         try:
     49             # Shouldn't be in registry now.
     50             self.assertRaises(ValueError, copy_reg.remove_extension,
     51                               mod, func, code)
     52             copy_reg.add_extension(mod, func, code)
     53             # Should be in the registry.
     54             self.assertTrue(copy_reg._extension_registry[mod, func] == code)
     55             self.assertTrue(copy_reg._inverted_registry[code] == (mod, func))
     56             # Shouldn't be in the cache.
     57             self.assertNotIn(code, copy_reg._extension_cache)
     58             # Redundant registration should be OK.
     59             copy_reg.add_extension(mod, func, code)  # shouldn't blow up
     60             # Conflicting code.
     61             self.assertRaises(ValueError, copy_reg.add_extension,
     62                               mod, func, code + 1)
     63             self.assertRaises(ValueError, copy_reg.remove_extension,
     64                               mod, func, code + 1)
     65             # Conflicting module name.
     66             self.assertRaises(ValueError, copy_reg.add_extension,
     67                               mod[1:], func, code )
     68             self.assertRaises(ValueError, copy_reg.remove_extension,
     69                               mod[1:], func, code )
     70             # Conflicting function name.
     71             self.assertRaises(ValueError, copy_reg.add_extension,
     72                               mod, func[1:], code)
     73             self.assertRaises(ValueError, copy_reg.remove_extension,
     74                               mod, func[1:], code)
     75             # Can't remove one that isn't registered at all.
     76             if code + 1 not in copy_reg._inverted_registry:
     77                 self.assertRaises(ValueError, copy_reg.remove_extension,
     78                                   mod[1:], func[1:], code + 1)
     79 
     80         finally:
     81             e.restore()
     82 
     83         # Shouldn't be there anymore.
     84         self.assertNotIn((mod, func), copy_reg._extension_registry)
     85         # The code *may* be in copy_reg._extension_registry, though, if
     86         # we happened to pick on a registered code.  So don't check for
     87         # that.
     88 
     89         # Check valid codes at the limits.
     90         for code in 1, 0x7fffffff:
     91             e = ExtensionSaver(code)
     92             try:
     93                 copy_reg.add_extension(mod, func, code)
     94                 copy_reg.remove_extension(mod, func, code)
     95             finally:
     96                 e.restore()
     97 
     98         # Ensure invalid codes blow up.
     99         for code in -1, 0, 0x80000000L:
    100             self.assertRaises(ValueError, copy_reg.add_extension,
    101                               mod, func, code)
    102 
    103     def test_slotnames(self):
    104         self.assertEqual(copy_reg._slotnames(WithoutSlots), [])
    105         self.assertEqual(copy_reg._slotnames(WithWeakref), [])
    106         expected = ['_WithPrivate__spam']
    107         self.assertEqual(copy_reg._slotnames(WithPrivate), expected)
    108         self.assertEqual(copy_reg._slotnames(WithSingleString), ['spam'])
    109         expected = ['eggs', 'spam']
    110         expected.sort()
    111         result = copy_reg._slotnames(WithInherited)
    112         result.sort()
    113         self.assertEqual(result, expected)
    114 
    115 
    116 def test_main():
    117     test_support.run_unittest(CopyRegTestCase)
    118 
    119 
    120 if __name__ == "__main__":
    121     test_main()
    122