Home | History | Annotate | Download | only in test
      1 import functools
      2 import sys
      3 import unittest
      4 from test import test_support
      5 from weakref import proxy
      6 import pickle
      7 
      8 @staticmethod
      9 def PythonPartial(func, *args, **keywords):
     10     'Pure Python approximation of partial()'
     11     def newfunc(*fargs, **fkeywords):
     12         newkeywords = keywords.copy()
     13         newkeywords.update(fkeywords)
     14         return func(*(args + fargs), **newkeywords)
     15     newfunc.func = func
     16     newfunc.args = args
     17     newfunc.keywords = keywords
     18     return newfunc
     19 
     20 def capture(*args, **kw):
     21     """capture all positional and keyword arguments"""
     22     return args, kw
     23 
     24 def signature(part):
     25     """ return the signature of a partial object """
     26     return (part.func, part.args, part.keywords, part.__dict__)
     27 
     28 class TestPartial(unittest.TestCase):
     29 
     30     thetype = functools.partial
     31 
     32     def test_basic_examples(self):
     33         p = self.thetype(capture, 1, 2, a=10, b=20)
     34         self.assertEqual(p(3, 4, b=30, c=40),
     35                          ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
     36         p = self.thetype(map, lambda x: x*10)
     37         self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40])
     38 
     39     def test_attributes(self):
     40         p = self.thetype(capture, 1, 2, a=10, b=20)
     41         # attributes should be readable

     42         self.assertEqual(p.func, capture)
     43         self.assertEqual(p.args, (1, 2))
     44         self.assertEqual(p.keywords, dict(a=10, b=20))
     45         # attributes should not be writable

     46         if not isinstance(self.thetype, type):
     47             return
     48         self.assertRaises(TypeError, setattr, p, 'func', map)
     49         self.assertRaises(TypeError, setattr, p, 'args', (1, 2))
     50         self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2))
     51 
     52         p = self.thetype(hex)
     53         try:
     54             del p.__dict__
     55         except TypeError:
     56             pass
     57         else:
     58             self.fail('partial object allowed __dict__ to be deleted')
     59 
     60     def test_argument_checking(self):
     61         self.assertRaises(TypeError, self.thetype)     # need at least a func arg

     62         try:
     63             self.thetype(2)()
     64         except TypeError:
     65             pass
     66         else:
     67             self.fail('First arg not checked for callability')
     68 
     69     def test_protection_of_callers_dict_argument(self):
     70         # a caller's dictionary should not be altered by partial

     71         def func(a=10, b=20):
     72             return a
     73         d = {'a':3}
     74         p = self.thetype(func, a=5)
     75         self.assertEqual(p(**d), 3)
     76         self.assertEqual(d, {'a':3})
     77         p(b=7)
     78         self.assertEqual(d, {'a':3})
     79 
     80     def test_arg_combinations(self):
     81         # exercise special code paths for zero args in either partial

     82         # object or the caller

     83         p = self.thetype(capture)
     84         self.assertEqual(p(), ((), {}))
     85         self.assertEqual(p(1,2), ((1,2), {}))
     86         p = self.thetype(capture, 1, 2)
     87         self.assertEqual(p(), ((1,2), {}))
     88         self.assertEqual(p(3,4), ((1,2,3,4), {}))
     89 
     90     def test_kw_combinations(self):
     91         # exercise special code paths for no keyword args in

     92         # either the partial object or the caller

     93         p = self.thetype(capture)
     94         self.assertEqual(p(), ((), {}))
     95         self.assertEqual(p(a=1), ((), {'a':1}))
     96         p = self.thetype(capture, a=1)
     97         self.assertEqual(p(), ((), {'a':1}))
     98         self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
     99         # keyword args in the call override those in the partial object

    100         self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
    101 
    102     def test_positional(self):
    103         # make sure positional arguments are captured correctly

    104         for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
    105             p = self.thetype(capture, *args)
    106             expected = args + ('x',)
    107             got, empty = p('x')
    108             self.assertTrue(expected == got and empty == {})
    109 
    110     def test_keyword(self):
    111         # make sure keyword arguments are captured correctly

    112         for a in ['a', 0, None, 3.5]:
    113             p = self.thetype(capture, a=a)
    114             expected = {'a':a,'x':None}
    115             empty, got = p(x=None)
    116             self.assertTrue(expected == got and empty == ())
    117 
    118     def test_no_side_effects(self):
    119         # make sure there are no side effects that affect subsequent calls

    120         p = self.thetype(capture, 0, a=1)
    121         args1, kw1 = p(1, b=2)
    122         self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
    123         args2, kw2 = p()
    124         self.assertTrue(args2 == (0,) and kw2 == {'a':1})
    125 
    126     def test_error_propagation(self):
    127         def f(x, y):
    128             x // y
    129         self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0))
    130         self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0)
    131         self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0)
    132         self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1)
    133 
    134     def test_weakref(self):
    135         f = self.thetype(int, base=16)
    136         p = proxy(f)
    137         self.assertEqual(f.func, p.func)
    138         f = None
    139         self.assertRaises(ReferenceError, getattr, p, 'func')
    140 
    141     def test_with_bound_and_unbound_methods(self):
    142         data = map(str, range(10))
    143         join = self.thetype(str.join, '')
    144         self.assertEqual(join(data), '0123456789')
    145         join = self.thetype(''.join)
    146         self.assertEqual(join(data), '0123456789')
    147 
    148     def test_pickle(self):
    149         f = self.thetype(signature, 'asdf', bar=True)
    150         f.add_something_to__dict__ = True
    151         f_copy = pickle.loads(pickle.dumps(f))
    152         self.assertEqual(signature(f), signature(f_copy))
    153 
    154 class PartialSubclass(functools.partial):
    155     pass
    156 
    157 class TestPartialSubclass(TestPartial):
    158 
    159     thetype = PartialSubclass
    160 
    161 class TestPythonPartial(TestPartial):
    162 
    163     thetype = PythonPartial
    164 
    165     # the python version isn't picklable

    166     def test_pickle(self): pass
    167 
    168 class TestUpdateWrapper(unittest.TestCase):
    169 
    170     def check_wrapper(self, wrapper, wrapped,
    171                       assigned=functools.WRAPPER_ASSIGNMENTS,
    172                       updated=functools.WRAPPER_UPDATES):
    173         # Check attributes were assigned

    174         for name in assigned:
    175             self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name))
    176         # Check attributes were updated

    177         for name in updated:
    178             wrapper_attr = getattr(wrapper, name)
    179             wrapped_attr = getattr(wrapped, name)
    180             for key in wrapped_attr:
    181                 self.assertTrue(wrapped_attr[key] is wrapper_attr[key])
    182 
    183     def _default_update(self):
    184         def f():
    185             """This is a test"""
    186             pass
    187         f.attr = 'This is also a test'
    188         def wrapper():
    189             pass
    190         functools.update_wrapper(wrapper, f)
    191         return wrapper, f
    192 
    193     def test_default_update(self):
    194         wrapper, f = self._default_update()
    195         self.check_wrapper(wrapper, f)
    196         self.assertEqual(wrapper.__name__, 'f')
    197         self.assertEqual(wrapper.attr, 'This is also a test')
    198 
    199     @unittest.skipIf(sys.flags.optimize >= 2,
    200                      "Docstrings are omitted with -O2 and above")
    201     def test_default_update_doc(self):
    202         wrapper, f = self._default_update()
    203         self.assertEqual(wrapper.__doc__, 'This is a test')
    204 
    205     def test_no_update(self):
    206         def f():
    207             """This is a test"""
    208             pass
    209         f.attr = 'This is also a test'
    210         def wrapper():
    211             pass
    212         functools.update_wrapper(wrapper, f, (), ())
    213         self.check_wrapper(wrapper, f, (), ())
    214         self.assertEqual(wrapper.__name__, 'wrapper')
    215         self.assertEqual(wrapper.__doc__, None)
    216         self.assertFalse(hasattr(wrapper, 'attr'))
    217 
    218     def test_selective_update(self):
    219         def f():
    220             pass
    221         f.attr = 'This is a different test'
    222         f.dict_attr = dict(a=1, b=2, c=3)
    223         def wrapper():
    224             pass
    225         wrapper.dict_attr = {}
    226         assign = ('attr',)
    227         update = ('dict_attr',)
    228         functools.update_wrapper(wrapper, f, assign, update)
    229         self.check_wrapper(wrapper, f, assign, update)
    230         self.assertEqual(wrapper.__name__, 'wrapper')
    231         self.assertEqual(wrapper.__doc__, None)
    232         self.assertEqual(wrapper.attr, 'This is a different test')
    233         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    234 
    235     def test_builtin_update(self):
    236         # Test for bug #1576241

    237         def wrapper():
    238             pass
    239         functools.update_wrapper(wrapper, max)
    240         self.assertEqual(wrapper.__name__, 'max')
    241         self.assertTrue(wrapper.__doc__.startswith('max('))
    242 
    243 class TestWraps(TestUpdateWrapper):
    244 
    245     def _default_update(self):
    246         def f():
    247             """This is a test"""
    248             pass
    249         f.attr = 'This is also a test'
    250         @functools.wraps(f)
    251         def wrapper():
    252             pass
    253         self.check_wrapper(wrapper, f)
    254         return wrapper
    255 
    256     def test_default_update(self):
    257         wrapper = self._default_update()
    258         self.assertEqual(wrapper.__name__, 'f')
    259         self.assertEqual(wrapper.attr, 'This is also a test')
    260 
    261     @unittest.skipIf(not sys.flags.optimize <= 1,
    262                      "Docstrings are omitted with -O2 and above")
    263     def test_default_update_doc(self):
    264         wrapper = self._default_update()
    265         self.assertEqual(wrapper.__doc__, 'This is a test')
    266 
    267     def test_no_update(self):
    268         def f():
    269             """This is a test"""
    270             pass
    271         f.attr = 'This is also a test'
    272         @functools.wraps(f, (), ())
    273         def wrapper():
    274             pass
    275         self.check_wrapper(wrapper, f, (), ())
    276         self.assertEqual(wrapper.__name__, 'wrapper')
    277         self.assertEqual(wrapper.__doc__, None)
    278         self.assertFalse(hasattr(wrapper, 'attr'))
    279 
    280     def test_selective_update(self):
    281         def f():
    282             pass
    283         f.attr = 'This is a different test'
    284         f.dict_attr = dict(a=1, b=2, c=3)
    285         def add_dict_attr(f):
    286             f.dict_attr = {}
    287             return f
    288         assign = ('attr',)
    289         update = ('dict_attr',)
    290         @functools.wraps(f, assign, update)
    291         @add_dict_attr
    292         def wrapper():
    293             pass
    294         self.check_wrapper(wrapper, f, assign, update)
    295         self.assertEqual(wrapper.__name__, 'wrapper')
    296         self.assertEqual(wrapper.__doc__, None)
    297         self.assertEqual(wrapper.attr, 'This is a different test')
    298         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    299 
    300 
    301 class TestReduce(unittest.TestCase):
    302 
    303     def test_reduce(self):
    304         class Squares:
    305 
    306             def __init__(self, max):
    307                 self.max = max
    308                 self.sofar = []
    309 
    310             def __len__(self): return len(self.sofar)
    311 
    312             def __getitem__(self, i):
    313                 if not 0 <= i < self.max: raise IndexError
    314                 n = len(self.sofar)
    315                 while n <= i:
    316                     self.sofar.append(n*n)
    317                     n += 1
    318                 return self.sofar[i]
    319 
    320         reduce = functools.reduce
    321         self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc')
    322         self.assertEqual(
    323             reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []),
    324             ['a','c','d','w']
    325         )
    326         self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040)
    327         self.assertEqual(
    328             reduce(lambda x, y: x*y, range(2,21), 1L),
    329             2432902008176640000L
    330         )
    331         self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285)
    332         self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285)
    333         self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0)
    334         self.assertRaises(TypeError, reduce)
    335         self.assertRaises(TypeError, reduce, 42, 42)
    336         self.assertRaises(TypeError, reduce, 42, 42, 42)
    337         self.assertEqual(reduce(42, "1"), "1") # func is never called with one item

    338         self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item

    339         self.assertRaises(TypeError, reduce, 42, (42, 42))
    340 
    341 class TestCmpToKey(unittest.TestCase):
    342     def test_cmp_to_key(self):
    343         def mycmp(x, y):
    344             return y - x
    345         self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)),
    346                          [4, 3, 2, 1, 0])
    347 
    348     def test_hash(self):
    349         def mycmp(x, y):
    350             return y - x
    351         key = functools.cmp_to_key(mycmp)
    352         k = key(10)
    353         self.assertRaises(TypeError, hash(k))
    354 
    355 class TestTotalOrdering(unittest.TestCase):
    356 
    357     def test_total_ordering_lt(self):
    358         @functools.total_ordering
    359         class A:
    360             def __init__(self, value):
    361                 self.value = value
    362             def __lt__(self, other):
    363                 return self.value < other.value
    364             def __eq__(self, other):
    365                 return self.value == other.value
    366         self.assertTrue(A(1) < A(2))
    367         self.assertTrue(A(2) > A(1))
    368         self.assertTrue(A(1) <= A(2))
    369         self.assertTrue(A(2) >= A(1))
    370         self.assertTrue(A(2) <= A(2))
    371         self.assertTrue(A(2) >= A(2))
    372 
    373     def test_total_ordering_le(self):
    374         @functools.total_ordering
    375         class A:
    376             def __init__(self, value):
    377                 self.value = value
    378             def __le__(self, other):
    379                 return self.value <= other.value
    380             def __eq__(self, other):
    381                 return self.value == other.value
    382         self.assertTrue(A(1) < A(2))
    383         self.assertTrue(A(2) > A(1))
    384         self.assertTrue(A(1) <= A(2))
    385         self.assertTrue(A(2) >= A(1))
    386         self.assertTrue(A(2) <= A(2))
    387         self.assertTrue(A(2) >= A(2))
    388 
    389     def test_total_ordering_gt(self):
    390         @functools.total_ordering
    391         class A:
    392             def __init__(self, value):
    393                 self.value = value
    394             def __gt__(self, other):
    395                 return self.value > other.value
    396             def __eq__(self, other):
    397                 return self.value == other.value
    398         self.assertTrue(A(1) < A(2))
    399         self.assertTrue(A(2) > A(1))
    400         self.assertTrue(A(1) <= A(2))
    401         self.assertTrue(A(2) >= A(1))
    402         self.assertTrue(A(2) <= A(2))
    403         self.assertTrue(A(2) >= A(2))
    404 
    405     def test_total_ordering_ge(self):
    406         @functools.total_ordering
    407         class A:
    408             def __init__(self, value):
    409                 self.value = value
    410             def __ge__(self, other):
    411                 return self.value >= other.value
    412             def __eq__(self, other):
    413                 return self.value == other.value
    414         self.assertTrue(A(1) < A(2))
    415         self.assertTrue(A(2) > A(1))
    416         self.assertTrue(A(1) <= A(2))
    417         self.assertTrue(A(2) >= A(1))
    418         self.assertTrue(A(2) <= A(2))
    419         self.assertTrue(A(2) >= A(2))
    420 
    421     def test_total_ordering_no_overwrite(self):
    422         # new methods should not overwrite existing

    423         @functools.total_ordering
    424         class A(str):
    425             pass
    426         self.assertTrue(A("a") < A("b"))
    427         self.assertTrue(A("b") > A("a"))
    428         self.assertTrue(A("a") <= A("b"))
    429         self.assertTrue(A("b") >= A("a"))
    430         self.assertTrue(A("b") <= A("b"))
    431         self.assertTrue(A("b") >= A("b"))
    432 
    433     def test_no_operations_defined(self):
    434         with self.assertRaises(ValueError):
    435             @functools.total_ordering
    436             class A:
    437                 pass
    438 
    439     def test_bug_10042(self):
    440         @functools.total_ordering
    441         class TestTO:
    442             def __init__(self, value):
    443                 self.value = value
    444             def __eq__(self, other):
    445                 if isinstance(other, TestTO):
    446                     return self.value == other.value
    447                 return False
    448             def __lt__(self, other):
    449                 if isinstance(other, TestTO):
    450                     return self.value < other.value
    451                 raise TypeError
    452         with self.assertRaises(TypeError):
    453             TestTO(8) <= ()
    454 
    455 def test_main(verbose=None):
    456     test_classes = (
    457         TestPartial,
    458         TestPartialSubclass,
    459         TestPythonPartial,
    460         TestUpdateWrapper,
    461         TestTotalOrdering,
    462         TestWraps,
    463         TestReduce,
    464     )
    465     test_support.run_unittest(*test_classes)
    466 
    467     # verify reference counting

    468     if verbose and hasattr(sys, "gettotalrefcount"):
    469         import gc
    470         counts = [None] * 5
    471         for i in xrange(len(counts)):
    472             test_support.run_unittest(*test_classes)
    473             gc.collect()
    474             counts[i] = sys.gettotalrefcount()
    475         print counts
    476 
    477 if __name__ == '__main__':
    478     test_main(verbose=True)
    479