Home | History | Annotate | Download | only in test
      1 import functools
      2 import sys
      3 import unittest
      4 from test import test_support
      5 from weakref import proxy
      6 import pickle
      7 
      8 @staticmethod
      9 def PythonPartial(func, *args, **keywords):
     10     'Pure Python approximation of partial()'
     11     def newfunc(*fargs, **fkeywords):
     12         newkeywords = keywords.copy()
     13         newkeywords.update(fkeywords)
     14         return func(*(args + fargs), **newkeywords)
     15     newfunc.func = func
     16     newfunc.args = args
     17     newfunc.keywords = keywords
     18     return newfunc
     19 
     20 def capture(*args, **kw):
     21     """capture all positional and keyword arguments"""
     22     return args, kw
     23 
     24 def signature(part):
     25     """ return the signature of a partial object """
     26     return (part.func, part.args, part.keywords, part.__dict__)
     27 
     28 class TestPartial(unittest.TestCase):
     29 
     30     thetype = functools.partial
     31 
     32     def test_basic_examples(self):
     33         p = self.thetype(capture, 1, 2, a=10, b=20)
     34         self.assertEqual(p(3, 4, b=30, c=40),
     35                          ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
     36         p = self.thetype(map, lambda x: x*10)
     37         self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
     38 
     39     def test_attributes(self):
     40         p = self.thetype(capture, 1, 2, a=10, b=20)
     41         # attributes should be readable
     42         self.assertEqual(p.func, capture)
     43         self.assertEqual(p.args, (1, 2))
     44         self.assertEqual(p.keywords, dict(a=10, b=20))
     45         # attributes should not be writable
     46         if not isinstance(self.thetype, type):
     47             return
     48         self.assertRaises(TypeError, setattr, p, 'func', map)
     49         self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
     50         self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
     51 
     52         p = self.thetype(hex)
     53         try:
     54             del p.__dict__
     55         except TypeError:
     56             pass
     57         else:
     58             self.fail('partial object allowed __dict__ to be deleted')
     59 
     60     def test_argument_checking(self):
     61         self.assertRaises(TypeError, self.thetype)     # need at least a func arg
     62         try:
     63             self.thetype(2)()
     64         except TypeError:
     65             pass
     66         else:
     67             self.fail('First arg not checked for callability')
     68 
     69     def test_protection_of_callers_dict_argument(self):
     70         # a caller's dictionary should not be altered by partial
     71         def func(a=10, b=20):
     72             return a
     73         d = {'a':3}
     74         p = self.thetype(func, a=5)
     75         self.assertEqual(p(**d), 3)
     76         self.assertEqual(d, {'a':3})
     77         p(b=7)
     78         self.assertEqual(d, {'a':3})
     79 
     80     def test_arg_combinations(self):
     81         # exercise special code paths for zero args in either partial
     82         # object or the caller
     83         p = self.thetype(capture)
     84         self.assertEqual(p(), ((), {}))
     85         self.assertEqual(p(1,2), ((1,2), {}))
     86         p = self.thetype(capture, 1, 2)
     87         self.assertEqual(p(), ((1,2), {}))
     88         self.assertEqual(p(3,4), ((1,2,3,4), {}))
     89 
     90     def test_kw_combinations(self):
     91         # exercise special code paths for no keyword args in
     92         # either the partial object or the caller
     93         p = self.thetype(capture)
     94         self.assertEqual(p(), ((), {}))
     95         self.assertEqual(p(a=1), ((), {'a':1}))
     96         p = self.thetype(capture, a=1)
     97         self.assertEqual(p(), ((), {'a':1}))
     98         self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
     99         # keyword args in the call override those in the partial object
    100         self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
    101 
    102     def test_positional(self):
    103         # make sure positional arguments are captured correctly
    104         for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
    105             p = self.thetype(capture, *args)
    106             expected = args + ('x',)
    107             got, empty = p('x')
    108             self.assertTrue(expected == got and empty == {})
    109 
    110     def test_keyword(self):
    111         # make sure keyword arguments are captured correctly
    112         for a in ['a', 0, None, 3.5]:
    113             p = self.thetype(capture, a=a)
    114             expected = {'a':a,'x':None}
    115             empty, got = p(x=None)
    116             self.assertTrue(expected == got and empty == ())
    117 
    118     def test_no_side_effects(self):
    119         # make sure there are no side effects that affect subsequent calls
    120         p = self.thetype(capture, 0, a=1)
    121         args1, kw1 = p(1, b=2)
    122         self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
    123         args2, kw2 = p()
    124         self.assertTrue(args2 == (0,) and kw2 == {'a':1})
    125 
    126     def test_error_propagation(self):
    127         def f(x, y):
    128             x // y
    129         self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
    130         self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
    131         self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
    132         self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
    133 
    134     def test_weakref(self):
    135         f = self.thetype(int, base=16)
    136         p = proxy(f)
    137         self.assertEqual(f.func, p.func)
    138         f = None
    139         self.assertRaises(ReferenceError, getattr, p, 'func')
    140 
    141     def test_with_bound_and_unbound_methods(self):
    142         data = map(str, range(10))
    143         join = self.thetype(str.join, '')
    144         self.assertEqual(join(data), '0123456789')
    145         join = self.thetype(''.join)
    146         self.assertEqual(join(data), '0123456789')
    147 
    148     def test_pickle(self):
    149         f = self.thetype(signature, 'asdf', bar=True)
    150         f.add_something_to__dict__ = True
    151         f_copy = pickle.loads(pickle.dumps(f))
    152         self.assertEqual(signature(f), signature(f_copy))
    153 
    154     # Issue 6083: Reference counting bug
    155     def test_setstate_refcount(self):
    156         class BadSequence:
    157             def __len__(self):
    158                 return 4
    159             def __getitem__(self, key):
    160                 if key == 0:
    161                     return max
    162                 elif key == 1:
    163                     return tuple(range(1000000))
    164                 elif key in (2, 3):
    165                     return {}
    166                 raise IndexError
    167 
    168         f = self.thetype(object)
    169         self.assertRaises(SystemError, f.__setstate__, BadSequence())
    170 
    171 class PartialSubclass(functools.partial):
    172     pass
    173 
    174 class TestPartialSubclass(TestPartial):
    175 
    176     thetype = PartialSubclass
    177 
    178 class TestPythonPartial(TestPartial):
    179 
    180     thetype = PythonPartial
    181 
    182     # the python version isn't picklable
    183     def test_pickle(self): pass
    184     def test_setstate_refcount(self): pass
    185 
    186 class TestUpdateWrapper(unittest.TestCase):
    187 
    188     def check_wrapper(self, wrapper, wrapped,
    189                       assigned=functools.WRAPPER_ASSIGNMENTS,
    190                       updated=functools.WRAPPER_UPDATES):
    191         # Check attributes were assigned
    192         for name in assigned:
    193             self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
    194         # Check attributes were updated
    195         for name in updated:
    196             wrapper_attr = getattr(wrapper, name)
    197             wrapped_attr = getattr(wrapped, name)
    198             for key in wrapped_attr:
    199                 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
    200 
    201     def _default_update(self):
    202         def f():
    203             """This is a test"""
    204             pass
    205         f.attr = 'This is also a test'
    206         def wrapper():
    207             pass
    208         functools.update_wrapper(wrapper, f)
    209         return wrapper, f
    210 
    211     def test_default_update(self):
    212         wrapper, f = self._default_update()
    213         self.check_wrapper(wrapper, f)
    214         self.assertEqual(wrapper.__name__, 'f')
    215         self.assertEqual(wrapper.attr, 'This is also a test')
    216 
    217     @unittest.skipIf(sys.flags.optimize >= 2,
    218                      "Docstrings are omitted with -O2 and above")
    219     def test_default_update_doc(self):
    220         wrapper, f = self._default_update()
    221         self.assertEqual(wrapper.__doc__, 'This is a test')
    222 
    223     def test_no_update(self):
    224         def f():
    225             """This is a test"""
    226             pass
    227         f.attr = 'This is also a test'
    228         def wrapper():
    229             pass
    230         functools.update_wrapper(wrapper, f, (), ())
    231         self.check_wrapper(wrapper, f, (), ())
    232         self.assertEqual(wrapper.__name__, 'wrapper')
    233         self.assertEqual(wrapper.__doc__, None)
    234         self.assertFalse(hasattr(wrapper, 'attr'))
    235 
    236     def test_selective_update(self):
    237         def f():
    238             pass
    239         f.attr = 'This is a different test'
    240         f.dict_attr = dict(a=1, b=2, c=3)
    241         def wrapper():
    242             pass
    243         wrapper.dict_attr = {}
    244         assign = ('attr',)
    245         update = ('dict_attr',)
    246         functools.update_wrapper(wrapper, f, assign, update)
    247         self.check_wrapper(wrapper, f, assign, update)
    248         self.assertEqual(wrapper.__name__, 'wrapper')
    249         self.assertEqual(wrapper.__doc__, None)
    250         self.assertEqual(wrapper.attr, 'This is a different test')
    251         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    252 
    253     @test_support.requires_docstrings
    254     def test_builtin_update(self):
    255         # Test for bug #1576241
    256         def wrapper():
    257             pass
    258         functools.update_wrapper(wrapper, max)
    259         self.assertEqual(wrapper.__name__, 'max')
    260         self.assertTrue(wrapper.__doc__.startswith('max('))
    261 
    262 class TestWraps(TestUpdateWrapper):
    263 
    264     def _default_update(self):
    265         def f():
    266             """This is a test"""
    267             pass
    268         f.attr = 'This is also a test'
    269         @functools.wraps(f)
    270         def wrapper():
    271             pass
    272         self.check_wrapper(wrapper, f)
    273         return wrapper
    274 
    275     def test_default_update(self):
    276         wrapper = self._default_update()
    277         self.assertEqual(wrapper.__name__, 'f')
    278         self.assertEqual(wrapper.attr, 'This is also a test')
    279 
    280     @unittest.skipIf(sys.flags.optimize >= 2,
    281                      "Docstrings are omitted with -O2 and above")
    282     def test_default_update_doc(self):
    283         wrapper = self._default_update()
    284         self.assertEqual(wrapper.__doc__, 'This is a test')
    285 
    286     def test_no_update(self):
    287         def f():
    288             """This is a test"""
    289             pass
    290         f.attr = 'This is also a test'
    291         @functools.wraps(f, (), ())
    292         def wrapper():
    293             pass
    294         self.check_wrapper(wrapper, f, (), ())
    295         self.assertEqual(wrapper.__name__, 'wrapper')
    296         self.assertEqual(wrapper.__doc__, None)
    297         self.assertFalse(hasattr(wrapper, 'attr'))
    298 
    299     def test_selective_update(self):
    300         def f():
    301             pass
    302         f.attr = 'This is a different test'
    303         f.dict_attr = dict(a=1, b=2, c=3)
    304         def add_dict_attr(f):
    305             f.dict_attr = {}
    306             return f
    307         assign = ('attr',)
    308         update = ('dict_attr',)
    309         @functools.wraps(f, assign, update)
    310         @add_dict_attr
    311         def wrapper():
    312             pass
    313         self.check_wrapper(wrapper, f, assign, update)
    314         self.assertEqual(wrapper.__name__, 'wrapper')
    315         self.assertEqual(wrapper.__doc__, None)
    316         self.assertEqual(wrapper.attr, 'This is a different test')
    317         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    318 
    319 
    320 class TestReduce(unittest.TestCase):
    321 
    322     def test_reduce(self):
    323         class Squares:
    324 
    325             def __init__(self, max):
    326                 self.max = max
    327                 self.sofar = []
    328 
    329             def __len__(self): return len(self.sofar)
    330 
    331             def __getitem__(self, i):
    332                 if not 0 <= i < self.max: raise IndexError
    333                 n = len(self.sofar)
    334                 while n <= i:
    335                     self.sofar.append(n*n)
    336                     n += 1
    337                 return self.sofar[i]
    338 
    339         reduce = functools.reduce
    340         self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
    341         self.assertEqual(
    342             reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
    343             ['a','c','d','w']
    344         )
    345         self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
    346         self.assertEqual(
    347             reduce(lambda x, y: x*y, range(2,21), 1L),
    348             2432902008176640000L
    349         )
    350         self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
    351         self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
    352         self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
    353         self.assertRaises(TypeError, reduce)
    354         self.assertRaises(TypeError, reduce, 42, 42)
    355         self.assertRaises(TypeError, reduce, 42, 42, 42)
    356         self.assertEqual(reduce(42, "1"), "1") # func is never called with one item
    357         self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item
    358         self.assertRaises(TypeError, reduce, 42, (42, 42))
    359 
    360 class TestCmpToKey(unittest.TestCase):
    361     def test_cmp_to_key(self):
    362         def mycmp(x, y):
    363             return y - x
    364         self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
    365                          [4, 3, 2, 1, 0])
    366 
    367     def test_hash(self):
    368         def mycmp(x, y):
    369             return y - x
    370         key = functools.cmp_to_key(mycmp)
    371         k = key(10)
    372         self.assertRaises(TypeError, hash(k))
    373 
    374 class TestTotalOrdering(unittest.TestCase):
    375 
    376     def test_total_ordering_lt(self):
    377         @functools.total_ordering
    378         class A:
    379             def __init__(self, value):
    380                 self.value = value
    381             def __lt__(self, other):
    382                 return self.value < other.value
    383             def __eq__(self, other):
    384                 return self.value == other.value
    385         self.assertTrue(A(1) < A(2))
    386         self.assertTrue(A(2) > A(1))
    387         self.assertTrue(A(1) <= A(2))
    388         self.assertTrue(A(2) >= A(1))
    389         self.assertTrue(A(2) <= A(2))
    390         self.assertTrue(A(2) >= A(2))
    391 
    392     def test_total_ordering_le(self):
    393         @functools.total_ordering
    394         class A:
    395             def __init__(self, value):
    396                 self.value = value
    397             def __le__(self, other):
    398                 return self.value <= other.value
    399             def __eq__(self, other):
    400                 return self.value == other.value
    401         self.assertTrue(A(1) < A(2))
    402         self.assertTrue(A(2) > A(1))
    403         self.assertTrue(A(1) <= A(2))
    404         self.assertTrue(A(2) >= A(1))
    405         self.assertTrue(A(2) <= A(2))
    406         self.assertTrue(A(2) >= A(2))
    407 
    408     def test_total_ordering_gt(self):
    409         @functools.total_ordering
    410         class A:
    411             def __init__(self, value):
    412                 self.value = value
    413             def __gt__(self, other):
    414                 return self.value > other.value
    415             def __eq__(self, other):
    416                 return self.value == other.value
    417         self.assertTrue(A(1) < A(2))
    418         self.assertTrue(A(2) > A(1))
    419         self.assertTrue(A(1) <= A(2))
    420         self.assertTrue(A(2) >= A(1))
    421         self.assertTrue(A(2) <= A(2))
    422         self.assertTrue(A(2) >= A(2))
    423 
    424     def test_total_ordering_ge(self):
    425         @functools.total_ordering
    426         class A:
    427             def __init__(self, value):
    428                 self.value = value
    429             def __ge__(self, other):
    430                 return self.value >= other.value
    431             def __eq__(self, other):
    432                 return self.value == other.value
    433         self.assertTrue(A(1) < A(2))
    434         self.assertTrue(A(2) > A(1))
    435         self.assertTrue(A(1) <= A(2))
    436         self.assertTrue(A(2) >= A(1))
    437         self.assertTrue(A(2) <= A(2))
    438         self.assertTrue(A(2) >= A(2))
    439 
    440     def test_total_ordering_no_overwrite(self):
    441         # new methods should not overwrite existing
    442         @functools.total_ordering
    443         class A(str):
    444             pass
    445         self.assertTrue(A("a") < A("b"))
    446         self.assertTrue(A("b") > A("a"))
    447         self.assertTrue(A("a") <= A("b"))
    448         self.assertTrue(A("b") >= A("a"))
    449         self.assertTrue(A("b") <= A("b"))
    450         self.assertTrue(A("b") >= A("b"))
    451 
    452     def test_no_operations_defined(self):
    453         with self.assertRaises(ValueError):
    454             @functools.total_ordering
    455             class A:
    456                 pass
    457 
    458     def test_bug_10042(self):
    459         @functools.total_ordering
    460         class TestTO:
    461             def __init__(self, value):
    462                 self.value = value
    463             def __eq__(self, other):
    464                 if isinstance(other, TestTO):
    465                     return self.value == other.value
    466                 return False
    467             def __lt__(self, other):
    468                 if isinstance(other, TestTO):
    469                     return self.value < other.value
    470                 raise TypeError
    471         with self.assertRaises(TypeError):
    472             TestTO(8) <= ()
    473 
    474 def test_main(verbose=None):
    475     test_classes = (
    476         TestPartial,
    477         TestPartialSubclass,
    478         TestPythonPartial,
    479         TestUpdateWrapper,
    480         TestTotalOrdering,
    481         TestWraps,
    482         TestReduce,
    483     )
    484     test_support.run_unittest(*test_classes)
    485 
    486     # verify reference counting
    487     if verbose and hasattr(sys, "gettotalrefcount"):
    488         import gc
    489         counts = [None] * 5
    490         for i in xrange(len(counts)):
    491             test_support.run_unittest(*test_classes)
    492             gc.collect()
    493             counts[i] = sys.gettotalrefcount()
    494         print counts
    495 
    496 if __name__ == '__main__':
    497     test_main(verbose=True)
    498