Home | History | Annotate | Download | only in test
      1 import abc
      2 import builtins
      3 import collections
      4 import copy
      5 from itertools import permutations
      6 import pickle
      7 from random import choice
      8 import sys
      9 from test import support
     10 import time
     11 import unittest
     12 from weakref import proxy
     13 import contextlib
     14 try:
     15     import threading
     16 except ImportError:
     17     threading = None
     18 
     19 import functools
     20 
     21 py_functools = support.import_fresh_module('functools', blocked=['_functools'])
     22 c_functools = support.import_fresh_module('functools', fresh=['_functools'])
     23 
     24 decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
     25 
     26 @contextlib.contextmanager
     27 def replaced_module(name, replacement):
     28     original_module = sys.modules[name]
     29     sys.modules[name] = replacement
     30     try:
     31         yield
     32     finally:
     33         sys.modules[name] = original_module
     34 
     35 def capture(*args, **kw):
     36     """capture all positional and keyword arguments"""
     37     return args, kw
     38 
     39 
     40 def signature(part):
     41     """ return the signature of a partial object """
     42     return (part.func, part.args, part.keywords, part.__dict__)
     43 
     44 class MyTuple(tuple):
     45     pass
     46 
     47 class BadTuple(tuple):
     48     def __add__(self, other):
     49         return list(self) + list(other)
     50 
     51 class MyDict(dict):
     52     pass
     53 
     54 
     55 class TestPartial:
     56 
     57     def test_basic_examples(self):
     58         p = self.partial(capture, 1, 2, a=10, b=20)
     59         self.assertTrue(callable(p))
     60         self.assertEqual(p(3, 4, b=30, c=40),
     61                          ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
     62         p = self.partial(map, lambda x: x*10)
     63         self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
     64 
     65     def test_attributes(self):
     66         p = self.partial(capture, 1, 2, a=10, b=20)
     67         # attributes should be readable
     68         self.assertEqual(p.func, capture)
     69         self.assertEqual(p.args, (1, 2))
     70         self.assertEqual(p.keywords, dict(a=10, b=20))
     71 
     72     def test_argument_checking(self):
     73         self.assertRaises(TypeError, self.partial)     # need at least a func arg
     74         try:
     75             self.partial(2)()
     76         except TypeError:
     77             pass
     78         else:
     79             self.fail('First arg not checked for callability')
     80 
     81     def test_protection_of_callers_dict_argument(self):
     82         # a caller's dictionary should not be altered by partial
     83         def func(a=10, b=20):
     84             return a
     85         d = {'a':3}
     86         p = self.partial(func, a=5)
     87         self.assertEqual(p(**d), 3)
     88         self.assertEqual(d, {'a':3})
     89         p(b=7)
     90         self.assertEqual(d, {'a':3})
     91 
     92     def test_kwargs_copy(self):
     93         # Issue #29532: Altering a kwarg dictionary passed to a constructor
     94         # should not affect a partial object after creation
     95         d = {'a': 3}
     96         p = self.partial(capture, **d)
     97         self.assertEqual(p(), ((), {'a': 3}))
     98         d['a'] = 5
     99         self.assertEqual(p(), ((), {'a': 3}))
    100 
    101     def test_arg_combinations(self):
    102         # exercise special code paths for zero args in either partial
    103         # object or the caller
    104         p = self.partial(capture)
    105         self.assertEqual(p(), ((), {}))
    106         self.assertEqual(p(1,2), ((1,2), {}))
    107         p = self.partial(capture, 1, 2)
    108         self.assertEqual(p(), ((1,2), {}))
    109         self.assertEqual(p(3,4), ((1,2,3,4), {}))
    110 
    111     def test_kw_combinations(self):
    112         # exercise special code paths for no keyword args in
    113         # either the partial object or the caller
    114         p = self.partial(capture)
    115         self.assertEqual(p.keywords, {})
    116         self.assertEqual(p(), ((), {}))
    117         self.assertEqual(p(a=1), ((), {'a':1}))
    118         p = self.partial(capture, a=1)
    119         self.assertEqual(p.keywords, {'a':1})
    120         self.assertEqual(p(), ((), {'a':1}))
    121         self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
    122         # keyword args in the call override those in the partial object
    123         self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
    124 
    125     def test_positional(self):
    126         # make sure positional arguments are captured correctly
    127         for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
    128             p = self.partial(capture, *args)
    129             expected = args + ('x',)
    130             got, empty = p('x')
    131             self.assertTrue(expected == got and empty == {})
    132 
    133     def test_keyword(self):
    134         # make sure keyword arguments are captured correctly
    135         for a in ['a', 0, None, 3.5]:
    136             p = self.partial(capture, a=a)
    137             expected = {'a':a,'x':None}
    138             empty, got = p(x=None)
    139             self.assertTrue(expected == got and empty == ())
    140 
    141     def test_no_side_effects(self):
    142         # make sure there are no side effects that affect subsequent calls
    143         p = self.partial(capture, 0, a=1)
    144         args1, kw1 = p(1, b=2)
    145         self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
    146         args2, kw2 = p()
    147         self.assertTrue(args2 == (0,) and kw2 == {'a':1})
    148 
    149     def test_error_propagation(self):
    150         def f(x, y):
    151             x / y
    152         self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
    153         self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
    154         self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
    155         self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
    156 
    157     def test_weakref(self):
    158         f = self.partial(int, base=16)
    159         p = proxy(f)
    160         self.assertEqual(f.func, p.func)
    161         f = None
    162         self.assertRaises(ReferenceError, getattr, p, 'func')
    163 
    164     def test_with_bound_and_unbound_methods(self):
    165         data = list(map(str, range(10)))
    166         join = self.partial(str.join, '')
    167         self.assertEqual(join(data), '0123456789')
    168         join = self.partial(''.join)
    169         self.assertEqual(join(data), '0123456789')
    170 
    171     def test_nested_optimization(self):
    172         partial = self.partial
    173         inner = partial(signature, 'asdf')
    174         nested = partial(inner, bar=True)
    175         flat = partial(signature, 'asdf', bar=True)
    176         self.assertEqual(signature(nested), signature(flat))
    177 
    178     def test_nested_partial_with_attribute(self):
    179         # see issue 25137
    180         partial = self.partial
    181 
    182         def foo(bar):
    183             return bar
    184 
    185         p = partial(foo, 'first')
    186         p2 = partial(p, 'second')
    187         p2.new_attr = 'spam'
    188         self.assertEqual(p2.new_attr, 'spam')
    189 
    190     def test_repr(self):
    191         args = (object(), object())
    192         args_repr = ', '.join(repr(a) for a in args)
    193         kwargs = {'a': object(), 'b': object()}
    194         kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
    195                         'b={b!r}, a={a!r}'.format_map(kwargs)]
    196         if self.partial in (c_functools.partial, py_functools.partial):
    197             name = 'functools.partial'
    198         else:
    199             name = self.partial.__name__
    200 
    201         f = self.partial(capture)
    202         self.assertEqual(f'{name}({capture!r})', repr(f))
    203 
    204         f = self.partial(capture, *args)
    205         self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
    206 
    207         f = self.partial(capture, **kwargs)
    208         self.assertIn(repr(f),
    209                       [f'{name}({capture!r}, {kwargs_repr})'
    210                        for kwargs_repr in kwargs_reprs])
    211 
    212         f = self.partial(capture, *args, **kwargs)
    213         self.assertIn(repr(f),
    214                       [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
    215                        for kwargs_repr in kwargs_reprs])
    216 
    217     def test_recursive_repr(self):
    218         if self.partial in (c_functools.partial, py_functools.partial):
    219             name = 'functools.partial'
    220         else:
    221             name = self.partial.__name__
    222 
    223         f = self.partial(capture)
    224         f.__setstate__((f, (), {}, {}))
    225         try:
    226             self.assertEqual(repr(f), '%s(...)' % (name,))
    227         finally:
    228             f.__setstate__((capture, (), {}, {}))
    229 
    230         f = self.partial(capture)
    231         f.__setstate__((capture, (f,), {}, {}))
    232         try:
    233             self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
    234         finally:
    235             f.__setstate__((capture, (), {}, {}))
    236 
    237         f = self.partial(capture)
    238         f.__setstate__((capture, (), {'a': f}, {}))
    239         try:
    240             self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
    241         finally:
    242             f.__setstate__((capture, (), {}, {}))
    243 
    244     def test_pickle(self):
    245         with self.AllowPickle():
    246             f = self.partial(signature, ['asdf'], bar=[True])
    247             f.attr = []
    248             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    249                 f_copy = pickle.loads(pickle.dumps(f, proto))
    250                 self.assertEqual(signature(f_copy), signature(f))
    251 
    252     def test_copy(self):
    253         f = self.partial(signature, ['asdf'], bar=[True])
    254         f.attr = []
    255         f_copy = copy.copy(f)
    256         self.assertEqual(signature(f_copy), signature(f))
    257         self.assertIs(f_copy.attr, f.attr)
    258         self.assertIs(f_copy.args, f.args)
    259         self.assertIs(f_copy.keywords, f.keywords)
    260 
    261     def test_deepcopy(self):
    262         f = self.partial(signature, ['asdf'], bar=[True])
    263         f.attr = []
    264         f_copy = copy.deepcopy(f)
    265         self.assertEqual(signature(f_copy), signature(f))
    266         self.assertIsNot(f_copy.attr, f.attr)
    267         self.assertIsNot(f_copy.args, f.args)
    268         self.assertIsNot(f_copy.args[0], f.args[0])
    269         self.assertIsNot(f_copy.keywords, f.keywords)
    270         self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
    271 
    272     def test_setstate(self):
    273         f = self.partial(signature)
    274         f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
    275 
    276         self.assertEqual(signature(f),
    277                          (capture, (1,), dict(a=10), dict(attr=[])))
    278         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
    279 
    280         f.__setstate__((capture, (1,), dict(a=10), None))
    281 
    282         self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
    283         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
    284 
    285         f.__setstate__((capture, (1,), None, None))
    286         #self.assertEqual(signature(f), (capture, (1,), {}, {}))
    287         self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
    288         self.assertEqual(f(2), ((1, 2), {}))
    289         self.assertEqual(f(), ((1,), {}))
    290 
    291         f.__setstate__((capture, (), {}, None))
    292         self.assertEqual(signature(f), (capture, (), {}, {}))
    293         self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
    294         self.assertEqual(f(2), ((2,), {}))
    295         self.assertEqual(f(), ((), {}))
    296 
    297     def test_setstate_errors(self):
    298         f = self.partial(signature)
    299         self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
    300         self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
    301         self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
    302         self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
    303         self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
    304         self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
    305         self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
    306 
    307     def test_setstate_subclasses(self):
    308         f = self.partial(signature)
    309         f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
    310         s = signature(f)
    311         self.assertEqual(s, (capture, (1,), dict(a=10), {}))
    312         self.assertIs(type(s[1]), tuple)
    313         self.assertIs(type(s[2]), dict)
    314         r = f()
    315         self.assertEqual(r, ((1,), {'a': 10}))
    316         self.assertIs(type(r[0]), tuple)
    317         self.assertIs(type(r[1]), dict)
    318 
    319         f.__setstate__((capture, BadTuple((1,)), {}, None))
    320         s = signature(f)
    321         self.assertEqual(s, (capture, (1,), {}, {}))
    322         self.assertIs(type(s[1]), tuple)
    323         r = f(2)
    324         self.assertEqual(r, ((1, 2), {}))
    325         self.assertIs(type(r[0]), tuple)
    326 
    327     def test_recursive_pickle(self):
    328         with self.AllowPickle():
    329             f = self.partial(capture)
    330             f.__setstate__((f, (), {}, {}))
    331             try:
    332                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    333                     with self.assertRaises(RecursionError):
    334                         pickle.dumps(f, proto)
    335             finally:
    336                 f.__setstate__((capture, (), {}, {}))
    337 
    338             f = self.partial(capture)
    339             f.__setstate__((capture, (f,), {}, {}))
    340             try:
    341                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    342                     f_copy = pickle.loads(pickle.dumps(f, proto))
    343                     try:
    344                         self.assertIs(f_copy.args[0], f_copy)
    345                     finally:
    346                         f_copy.__setstate__((capture, (), {}, {}))
    347             finally:
    348                 f.__setstate__((capture, (), {}, {}))
    349 
    350             f = self.partial(capture)
    351             f.__setstate__((capture, (), {'a': f}, {}))
    352             try:
    353                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    354                     f_copy = pickle.loads(pickle.dumps(f, proto))
    355                     try:
    356                         self.assertIs(f_copy.keywords['a'], f_copy)
    357                     finally:
    358                         f_copy.__setstate__((capture, (), {}, {}))
    359             finally:
    360                 f.__setstate__((capture, (), {}, {}))
    361 
    362     # Issue 6083: Reference counting bug
    363     def test_setstate_refcount(self):
    364         class BadSequence:
    365             def __len__(self):
    366                 return 4
    367             def __getitem__(self, key):
    368                 if key == 0:
    369                     return max
    370                 elif key == 1:
    371                     return tuple(range(1000000))
    372                 elif key in (2, 3):
    373                     return {}
    374                 raise IndexError
    375 
    376         f = self.partial(object)
    377         self.assertRaises(TypeError, f.__setstate__, BadSequence())
    378 
    379 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    380 class TestPartialC(TestPartial, unittest.TestCase):
    381     if c_functools:
    382         partial = c_functools.partial
    383 
    384     class AllowPickle:
    385         def __enter__(self):
    386             return self
    387         def __exit__(self, type, value, tb):
    388             return False
    389 
    390     def test_attributes_unwritable(self):
    391         # attributes should not be writable
    392         p = self.partial(capture, 1, 2, a=10, b=20)
    393         self.assertRaises(AttributeError, setattr, p, 'func', map)
    394         self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
    395         self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
    396 
    397         p = self.partial(hex)
    398         try:
    399             del p.__dict__
    400         except TypeError:
    401             pass
    402         else:
    403             self.fail('partial object allowed __dict__ to be deleted')
    404 
    405 class TestPartialPy(TestPartial, unittest.TestCase):
    406     partial = py_functools.partial
    407 
    408     class AllowPickle:
    409         def __init__(self):
    410             self._cm = replaced_module("functools", py_functools)
    411         def __enter__(self):
    412             return self._cm.__enter__()
    413         def __exit__(self, type, value, tb):
    414             return self._cm.__exit__(type, value, tb)
    415 
    416 if c_functools:
    417     class CPartialSubclass(c_functools.partial):
    418         pass
    419 
    420 class PyPartialSubclass(py_functools.partial):
    421     pass
    422 
    423 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    424 class TestPartialCSubclass(TestPartialC):
    425     if c_functools:
    426         partial = CPartialSubclass
    427 
    428     # partial subclasses are not optimized for nested calls
    429     test_nested_optimization = None
    430 
    431 class TestPartialPySubclass(TestPartialPy):
    432     partial = PyPartialSubclass
    433 
    434 class TestPartialMethod(unittest.TestCase):
    435 
    436     class A(object):
    437         nothing = functools.partialmethod(capture)
    438         positional = functools.partialmethod(capture, 1)
    439         keywords = functools.partialmethod(capture, a=2)
    440         both = functools.partialmethod(capture, 3, b=4)
    441 
    442         nested = functools.partialmethod(positional, 5)
    443 
    444         over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
    445 
    446         static = functools.partialmethod(staticmethod(capture), 8)
    447         cls = functools.partialmethod(classmethod(capture), d=9)
    448 
    449     a = A()
    450 
    451     def test_arg_combinations(self):
    452         self.assertEqual(self.a.nothing(), ((self.a,), {}))
    453         self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
    454         self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
    455         self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
    456 
    457         self.assertEqual(self.a.positional(), ((self.a, 1), {}))
    458         self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
    459         self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
    460         self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
    461 
    462         self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
    463         self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
    464         self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
    465         self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
    466 
    467         self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
    468         self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
    469         self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
    470         self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
    471 
    472         self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
    473 
    474     def test_nested(self):
    475         self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
    476         self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
    477         self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
    478         self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
    479 
    480         self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
    481 
    482     def test_over_partial(self):
    483         self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
    484         self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
    485         self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
    486         self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
    487 
    488         self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
    489 
    490     def test_bound_method_introspection(self):
    491         obj = self.a
    492         self.assertIs(obj.both.__self__, obj)
    493         self.assertIs(obj.nested.__self__, obj)
    494         self.assertIs(obj.over_partial.__self__, obj)
    495         self.assertIs(obj.cls.__self__, self.A)
    496         self.assertIs(self.A.cls.__self__, self.A)
    497 
    498     def test_unbound_method_retrieval(self):
    499         obj = self.A
    500         self.assertFalse(hasattr(obj.both, "__self__"))
    501         self.assertFalse(hasattr(obj.nested, "__self__"))
    502         self.assertFalse(hasattr(obj.over_partial, "__self__"))
    503         self.assertFalse(hasattr(obj.static, "__self__"))
    504         self.assertFalse(hasattr(self.a.static, "__self__"))
    505 
    506     def test_descriptors(self):
    507         for obj in [self.A, self.a]:
    508             with self.subTest(obj=obj):
    509                 self.assertEqual(obj.static(), ((8,), {}))
    510                 self.assertEqual(obj.static(5), ((8, 5), {}))
    511                 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
    512                 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
    513 
    514                 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
    515                 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
    516                 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
    517                 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
    518 
    519     def test_overriding_keywords(self):
    520         self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
    521         self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
    522 
    523     def test_invalid_args(self):
    524         with self.assertRaises(TypeError):
    525             class B(object):
    526                 method = functools.partialmethod(None, 1)
    527 
    528     def test_repr(self):
    529         self.assertEqual(repr(vars(self.A)['both']),
    530                          'functools.partialmethod({}, 3, b=4)'.format(capture))
    531 
    532     def test_abstract(self):
    533         class Abstract(abc.ABCMeta):
    534 
    535             @abc.abstractmethod
    536             def add(self, x, y):
    537                 pass
    538 
    539             add5 = functools.partialmethod(add, 5)
    540 
    541         self.assertTrue(Abstract.add.__isabstractmethod__)
    542         self.assertTrue(Abstract.add5.__isabstractmethod__)
    543 
    544         for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
    545             self.assertFalse(getattr(func, '__isabstractmethod__', False))
    546 
    547 
    548 class TestUpdateWrapper(unittest.TestCase):
    549 
    550     def check_wrapper(self, wrapper, wrapped,
    551                       assigned=functools.WRAPPER_ASSIGNMENTS,
    552                       updated=functools.WRAPPER_UPDATES):
    553         # Check attributes were assigned
    554         for name in assigned:
    555             self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
    556         # Check attributes were updated
    557         for name in updated:
    558             wrapper_attr = getattr(wrapper, name)
    559             wrapped_attr = getattr(wrapped, name)
    560             for key in wrapped_attr:
    561                 if name == "__dict__" and key == "__wrapped__":
    562                     # __wrapped__ is overwritten by the update code
    563                     continue
    564                 self.assertIs(wrapped_attr[key], wrapper_attr[key])
    565         # Check __wrapped__
    566         self.assertIs(wrapper.__wrapped__, wrapped)
    567 
    568 
    569     def _default_update(self):
    570         def f(a:'This is a new annotation'):
    571             """This is a test"""
    572             pass
    573         f.attr = 'This is also a test'
    574         f.__wrapped__ = "This is a bald faced lie"
    575         def wrapper(b:'This is the prior annotation'):
    576             pass
    577         functools.update_wrapper(wrapper, f)
    578         return wrapper, f
    579 
    580     def test_default_update(self):
    581         wrapper, f = self._default_update()
    582         self.check_wrapper(wrapper, f)
    583         self.assertIs(wrapper.__wrapped__, f)
    584         self.assertEqual(wrapper.__name__, 'f')
    585         self.assertEqual(wrapper.__qualname__, f.__qualname__)
    586         self.assertEqual(wrapper.attr, 'This is also a test')
    587         self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
    588         self.assertNotIn('b', wrapper.__annotations__)
    589 
    590     @unittest.skipIf(sys.flags.optimize >= 2,
    591                      "Docstrings are omitted with -O2 and above")
    592     def test_default_update_doc(self):
    593         wrapper, f = self._default_update()
    594         self.assertEqual(wrapper.__doc__, 'This is a test')
    595 
    596     def test_no_update(self):
    597         def f():
    598             """This is a test"""
    599             pass
    600         f.attr = 'This is also a test'
    601         def wrapper():
    602             pass
    603         functools.update_wrapper(wrapper, f, (), ())
    604         self.check_wrapper(wrapper, f, (), ())
    605         self.assertEqual(wrapper.__name__, 'wrapper')
    606         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    607         self.assertEqual(wrapper.__doc__, None)
    608         self.assertEqual(wrapper.__annotations__, {})
    609         self.assertFalse(hasattr(wrapper, 'attr'))
    610 
    611     def test_selective_update(self):
    612         def f():
    613             pass
    614         f.attr = 'This is a different test'
    615         f.dict_attr = dict(a=1, b=2, c=3)
    616         def wrapper():
    617             pass
    618         wrapper.dict_attr = {}
    619         assign = ('attr',)
    620         update = ('dict_attr',)
    621         functools.update_wrapper(wrapper, f, assign, update)
    622         self.check_wrapper(wrapper, f, assign, update)
    623         self.assertEqual(wrapper.__name__, 'wrapper')
    624         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    625         self.assertEqual(wrapper.__doc__, None)
    626         self.assertEqual(wrapper.attr, 'This is a different test')
    627         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    628 
    629     def test_missing_attributes(self):
    630         def f():
    631             pass
    632         def wrapper():
    633             pass
    634         wrapper.dict_attr = {}
    635         assign = ('attr',)
    636         update = ('dict_attr',)
    637         # Missing attributes on wrapped object are ignored
    638         functools.update_wrapper(wrapper, f, assign, update)
    639         self.assertNotIn('attr', wrapper.__dict__)
    640         self.assertEqual(wrapper.dict_attr, {})
    641         # Wrapper must have expected attributes for updating
    642         del wrapper.dict_attr
    643         with self.assertRaises(AttributeError):
    644             functools.update_wrapper(wrapper, f, assign, update)
    645         wrapper.dict_attr = 1
    646         with self.assertRaises(AttributeError):
    647             functools.update_wrapper(wrapper, f, assign, update)
    648 
    649     @support.requires_docstrings
    650     @unittest.skipIf(sys.flags.optimize >= 2,
    651                      "Docstrings are omitted with -O2 and above")
    652     def test_builtin_update(self):
    653         # Test for bug #1576241
    654         def wrapper():
    655             pass
    656         functools.update_wrapper(wrapper, max)
    657         self.assertEqual(wrapper.__name__, 'max')
    658         self.assertTrue(wrapper.__doc__.startswith('max('))
    659         self.assertEqual(wrapper.__annotations__, {})
    660 
    661 
    662 class TestWraps(TestUpdateWrapper):
    663 
    664     def _default_update(self):
    665         def f():
    666             """This is a test"""
    667             pass
    668         f.attr = 'This is also a test'
    669         f.__wrapped__ = "This is still a bald faced lie"
    670         @functools.wraps(f)
    671         def wrapper():
    672             pass
    673         return wrapper, f
    674 
    675     def test_default_update(self):
    676         wrapper, f = self._default_update()
    677         self.check_wrapper(wrapper, f)
    678         self.assertEqual(wrapper.__name__, 'f')
    679         self.assertEqual(wrapper.__qualname__, f.__qualname__)
    680         self.assertEqual(wrapper.attr, 'This is also a test')
    681 
    682     @unittest.skipIf(sys.flags.optimize >= 2,
    683                      "Docstrings are omitted with -O2 and above")
    684     def test_default_update_doc(self):
    685         wrapper, _ = self._default_update()
    686         self.assertEqual(wrapper.__doc__, 'This is a test')
    687 
    688     def test_no_update(self):
    689         def f():
    690             """This is a test"""
    691             pass
    692         f.attr = 'This is also a test'
    693         @functools.wraps(f, (), ())
    694         def wrapper():
    695             pass
    696         self.check_wrapper(wrapper, f, (), ())
    697         self.assertEqual(wrapper.__name__, 'wrapper')
    698         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    699         self.assertEqual(wrapper.__doc__, None)
    700         self.assertFalse(hasattr(wrapper, 'attr'))
    701 
    702     def test_selective_update(self):
    703         def f():
    704             pass
    705         f.attr = 'This is a different test'
    706         f.dict_attr = dict(a=1, b=2, c=3)
    707         def add_dict_attr(f):
    708             f.dict_attr = {}
    709             return f
    710         assign = ('attr',)
    711         update = ('dict_attr',)
    712         @functools.wraps(f, assign, update)
    713         @add_dict_attr
    714         def wrapper():
    715             pass
    716         self.check_wrapper(wrapper, f, assign, update)
    717         self.assertEqual(wrapper.__name__, 'wrapper')
    718         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    719         self.assertEqual(wrapper.__doc__, None)
    720         self.assertEqual(wrapper.attr, 'This is a different test')
    721         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    722 
    723 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    724 class TestReduce(unittest.TestCase):
    725     if c_functools:
    726         func = c_functools.reduce
    727 
    728     def test_reduce(self):
    729         class Squares:
    730             def __init__(self, max):
    731                 self.max = max
    732                 self.sofar = []
    733 
    734             def __len__(self):
    735                 return len(self.sofar)
    736 
    737             def __getitem__(self, i):
    738                 if not 0 <= i < self.max: raise IndexError
    739                 n = len(self.sofar)
    740                 while n <= i:
    741                     self.sofar.append(n*n)
    742                     n += 1
    743                 return self.sofar[i]
    744         def add(x, y):
    745             return x + y
    746         self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
    747         self.assertEqual(
    748             self.func(add, [['a', 'c'], [], ['d', 'w']], []),
    749             ['a','c','d','w']
    750         )
    751         self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
    752         self.assertEqual(
    753             self.func(lambda x, y: x*y, range(2,21), 1),
    754             2432902008176640000
    755         )
    756         self.assertEqual(self.func(add, Squares(10)), 285)
    757         self.assertEqual(self.func(add, Squares(10), 0), 285)
    758         self.assertEqual(self.func(add, Squares(0), 0), 0)
    759         self.assertRaises(TypeError, self.func)
    760         self.assertRaises(TypeError, self.func, 42, 42)
    761         self.assertRaises(TypeError, self.func, 42, 42, 42)
    762         self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
    763         self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
    764         self.assertRaises(TypeError, self.func, 42, (42, 42))
    765         self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
    766         self.assertRaises(TypeError, self.func, add, "")
    767         self.assertRaises(TypeError, self.func, add, ())
    768         self.assertRaises(TypeError, self.func, add, object())
    769 
    770         class TestFailingIter:
    771             def __iter__(self):
    772                 raise RuntimeError
    773         self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
    774 
    775         self.assertEqual(self.func(add, [], None), None)
    776         self.assertEqual(self.func(add, [], 42), 42)
    777 
    778         class BadSeq:
    779             def __getitem__(self, index):
    780                 raise ValueError
    781         self.assertRaises(ValueError, self.func, 42, BadSeq())
    782 
    783     # Test reduce()'s use of iterators.
    784     def test_iterator_usage(self):
    785         class SequenceClass:
    786             def __init__(self, n):
    787                 self.n = n
    788             def __getitem__(self, i):
    789                 if 0 <= i < self.n:
    790                     return i
    791                 else:
    792                     raise IndexError
    793 
    794         from operator import add
    795         self.assertEqual(self.func(add, SequenceClass(5)), 10)
    796         self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
    797         self.assertRaises(TypeError, self.func, add, SequenceClass(0))
    798         self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
    799         self.assertEqual(self.func(add, SequenceClass(1)), 0)
    800         self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
    801 
    802         d = {"one": 1, "two": 2, "three": 3}
    803         self.assertEqual(self.func(add, d), "".join(d.keys()))
    804 
    805 
    806 class TestCmpToKey:
    807 
    808     def test_cmp_to_key(self):
    809         def cmp1(x, y):
    810             return (x > y) - (x < y)
    811         key = self.cmp_to_key(cmp1)
    812         self.assertEqual(key(3), key(3))
    813         self.assertGreater(key(3), key(1))
    814         self.assertGreaterEqual(key(3), key(3))
    815 
    816         def cmp2(x, y):
    817             return int(x) - int(y)
    818         key = self.cmp_to_key(cmp2)
    819         self.assertEqual(key(4.0), key('4'))
    820         self.assertLess(key(2), key('35'))
    821         self.assertLessEqual(key(2), key('35'))
    822         self.assertNotEqual(key(2), key('35'))
    823 
    824     def test_cmp_to_key_arguments(self):
    825         def cmp1(x, y):
    826             return (x > y) - (x < y)
    827         key = self.cmp_to_key(mycmp=cmp1)
    828         self.assertEqual(key(obj=3), key(obj=3))
    829         self.assertGreater(key(obj=3), key(obj=1))
    830         with self.assertRaises((TypeError, AttributeError)):
    831             key(3) > 1    # rhs is not a K object
    832         with self.assertRaises((TypeError, AttributeError)):
    833             1 < key(3)    # lhs is not a K object
    834         with self.assertRaises(TypeError):
    835             key = self.cmp_to_key()             # too few args
    836         with self.assertRaises(TypeError):
    837             key = self.cmp_to_key(cmp1, None)   # too many args
    838         key = self.cmp_to_key(cmp1)
    839         with self.assertRaises(TypeError):
    840             key()                                    # too few args
    841         with self.assertRaises(TypeError):
    842             key(None, None)                          # too many args
    843 
    844     def test_bad_cmp(self):
    845         def cmp1(x, y):
    846             raise ZeroDivisionError
    847         key = self.cmp_to_key(cmp1)
    848         with self.assertRaises(ZeroDivisionError):
    849             key(3) > key(1)
    850 
    851         class BadCmp:
    852             def __lt__(self, other):
    853                 raise ZeroDivisionError
    854         def cmp1(x, y):
    855             return BadCmp()
    856         with self.assertRaises(ZeroDivisionError):
    857             key(3) > key(1)
    858 
    859     def test_obj_field(self):
    860         def cmp1(x, y):
    861             return (x > y) - (x < y)
    862         key = self.cmp_to_key(mycmp=cmp1)
    863         self.assertEqual(key(50).obj, 50)
    864 
    865     def test_sort_int(self):
    866         def mycmp(x, y):
    867             return y - x
    868         self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
    869                          [4, 3, 2, 1, 0])
    870 
    871     def test_sort_int_str(self):
    872         def mycmp(x, y):
    873             x, y = int(x), int(y)
    874             return (x > y) - (x < y)
    875         values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
    876         values = sorted(values, key=self.cmp_to_key(mycmp))
    877         self.assertEqual([int(value) for value in values],
    878                          [0, 1, 1, 2, 3, 4, 5, 7, 10])
    879 
    880     def test_hash(self):
    881         def mycmp(x, y):
    882             return y - x
    883         key = self.cmp_to_key(mycmp)
    884         k = key(10)
    885         self.assertRaises(TypeError, hash, k)
    886         self.assertNotIsInstance(k, collections.Hashable)
    887 
    888 
    889 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    890 class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
    891     if c_functools:
    892         cmp_to_key = c_functools.cmp_to_key
    893 
    894 
    895 class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
    896     cmp_to_key = staticmethod(py_functools.cmp_to_key)
    897 
    898 
    899 class TestTotalOrdering(unittest.TestCase):
    900 
    901     def test_total_ordering_lt(self):
    902         @functools.total_ordering
    903         class A:
    904             def __init__(self, value):
    905                 self.value = value
    906             def __lt__(self, other):
    907                 return self.value < other.value
    908             def __eq__(self, other):
    909                 return self.value == other.value
    910         self.assertTrue(A(1) < A(2))
    911         self.assertTrue(A(2) > A(1))
    912         self.assertTrue(A(1) <= A(2))
    913         self.assertTrue(A(2) >= A(1))
    914         self.assertTrue(A(2) <= A(2))
    915         self.assertTrue(A(2) >= A(2))
    916         self.assertFalse(A(1) > A(2))
    917 
    918     def test_total_ordering_le(self):
    919         @functools.total_ordering
    920         class A:
    921             def __init__(self, value):
    922                 self.value = value
    923             def __le__(self, other):
    924                 return self.value <= other.value
    925             def __eq__(self, other):
    926                 return self.value == other.value
    927         self.assertTrue(A(1) < A(2))
    928         self.assertTrue(A(2) > A(1))
    929         self.assertTrue(A(1) <= A(2))
    930         self.assertTrue(A(2) >= A(1))
    931         self.assertTrue(A(2) <= A(2))
    932         self.assertTrue(A(2) >= A(2))
    933         self.assertFalse(A(1) >= A(2))
    934 
    935     def test_total_ordering_gt(self):
    936         @functools.total_ordering
    937         class A:
    938             def __init__(self, value):
    939                 self.value = value
    940             def __gt__(self, other):
    941                 return self.value > other.value
    942             def __eq__(self, other):
    943                 return self.value == other.value
    944         self.assertTrue(A(1) < A(2))
    945         self.assertTrue(A(2) > A(1))
    946         self.assertTrue(A(1) <= A(2))
    947         self.assertTrue(A(2) >= A(1))
    948         self.assertTrue(A(2) <= A(2))
    949         self.assertTrue(A(2) >= A(2))
    950         self.assertFalse(A(2) < A(1))
    951 
    952     def test_total_ordering_ge(self):
    953         @functools.total_ordering
    954         class A:
    955             def __init__(self, value):
    956                 self.value = value
    957             def __ge__(self, other):
    958                 return self.value >= other.value
    959             def __eq__(self, other):
    960                 return self.value == other.value
    961         self.assertTrue(A(1) < A(2))
    962         self.assertTrue(A(2) > A(1))
    963         self.assertTrue(A(1) <= A(2))
    964         self.assertTrue(A(2) >= A(1))
    965         self.assertTrue(A(2) <= A(2))
    966         self.assertTrue(A(2) >= A(2))
    967         self.assertFalse(A(2) <= A(1))
    968 
    969     def test_total_ordering_no_overwrite(self):
    970         # new methods should not overwrite existing
    971         @functools.total_ordering
    972         class A(int):
    973             pass
    974         self.assertTrue(A(1) < A(2))
    975         self.assertTrue(A(2) > A(1))
    976         self.assertTrue(A(1) <= A(2))
    977         self.assertTrue(A(2) >= A(1))
    978         self.assertTrue(A(2) <= A(2))
    979         self.assertTrue(A(2) >= A(2))
    980 
    981     def test_no_operations_defined(self):
    982         with self.assertRaises(ValueError):
    983             @functools.total_ordering
    984             class A:
    985                 pass
    986 
    987     def test_type_error_when_not_implemented(self):
    988         # bug 10042; ensure stack overflow does not occur
    989         # when decorated types return NotImplemented
    990         @functools.total_ordering
    991         class ImplementsLessThan:
    992             def __init__(self, value):
    993                 self.value = value
    994             def __eq__(self, other):
    995                 if isinstance(other, ImplementsLessThan):
    996                     return self.value == other.value
    997                 return False
    998             def __lt__(self, other):
    999                 if isinstance(other, ImplementsLessThan):
   1000                     return self.value < other.value
   1001                 return NotImplemented
   1002 
   1003         @functools.total_ordering
   1004         class ImplementsGreaterThan:
   1005             def __init__(self, value):
   1006                 self.value = value
   1007             def __eq__(self, other):
   1008                 if isinstance(other, ImplementsGreaterThan):
   1009                     return self.value == other.value
   1010                 return False
   1011             def __gt__(self, other):
   1012                 if isinstance(other, ImplementsGreaterThan):
   1013                     return self.value > other.value
   1014                 return NotImplemented
   1015 
   1016         @functools.total_ordering
   1017         class ImplementsLessThanEqualTo:
   1018             def __init__(self, value):
   1019                 self.value = value
   1020             def __eq__(self, other):
   1021                 if isinstance(other, ImplementsLessThanEqualTo):
   1022                     return self.value == other.value
   1023                 return False
   1024             def __le__(self, other):
   1025                 if isinstance(other, ImplementsLessThanEqualTo):
   1026                     return self.value <= other.value
   1027                 return NotImplemented
   1028 
   1029         @functools.total_ordering
   1030         class ImplementsGreaterThanEqualTo:
   1031             def __init__(self, value):
   1032                 self.value = value
   1033             def __eq__(self, other):
   1034                 if isinstance(other, ImplementsGreaterThanEqualTo):
   1035                     return self.value == other.value
   1036                 return False
   1037             def __ge__(self, other):
   1038                 if isinstance(other, ImplementsGreaterThanEqualTo):
   1039                     return self.value >= other.value
   1040                 return NotImplemented
   1041 
   1042         @functools.total_ordering
   1043         class ComparatorNotImplemented:
   1044             def __init__(self, value):
   1045                 self.value = value
   1046             def __eq__(self, other):
   1047                 if isinstance(other, ComparatorNotImplemented):
   1048                     return self.value == other.value
   1049                 return False
   1050             def __lt__(self, other):
   1051                 return NotImplemented
   1052 
   1053         with self.subTest("LT < 1"), self.assertRaises(TypeError):
   1054             ImplementsLessThan(-1) < 1
   1055 
   1056         with self.subTest("LT < LE"), self.assertRaises(TypeError):
   1057             ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
   1058 
   1059         with self.subTest("LT < GT"), self.assertRaises(TypeError):
   1060             ImplementsLessThan(1) < ImplementsGreaterThan(1)
   1061 
   1062         with self.subTest("LE <= LT"), self.assertRaises(TypeError):
   1063             ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
   1064 
   1065         with self.subTest("LE <= GE"), self.assertRaises(TypeError):
   1066             ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
   1067 
   1068         with self.subTest("GT > GE"), self.assertRaises(TypeError):
   1069             ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
   1070 
   1071         with self.subTest("GT > LT"), self.assertRaises(TypeError):
   1072             ImplementsGreaterThan(5) > ImplementsLessThan(5)
   1073 
   1074         with self.subTest("GE >= GT"), self.assertRaises(TypeError):
   1075             ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
   1076 
   1077         with self.subTest("GE >= LE"), self.assertRaises(TypeError):
   1078             ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
   1079 
   1080         with self.subTest("GE when equal"):
   1081             a = ComparatorNotImplemented(8)
   1082             b = ComparatorNotImplemented(8)
   1083             self.assertEqual(a, b)
   1084             with self.assertRaises(TypeError):
   1085                 a >= b
   1086 
   1087         with self.subTest("LE when equal"):
   1088             a = ComparatorNotImplemented(9)
   1089             b = ComparatorNotImplemented(9)
   1090             self.assertEqual(a, b)
   1091             with self.assertRaises(TypeError):
   1092                 a <= b
   1093 
   1094     def test_pickle(self):
   1095         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
   1096             for name in '__lt__', '__gt__', '__le__', '__ge__':
   1097                 with self.subTest(method=name, proto=proto):
   1098                     method = getattr(Orderable_LT, name)
   1099                     method_copy = pickle.loads(pickle.dumps(method, proto))
   1100                     self.assertIs(method_copy, method)
   1101 
   1102 @functools.total_ordering
   1103 class Orderable_LT:
   1104     def __init__(self, value):
   1105         self.value = value
   1106     def __lt__(self, other):
   1107         return self.value < other.value
   1108     def __eq__(self, other):
   1109         return self.value == other.value
   1110 
   1111 
   1112 class TestLRU:
   1113 
   1114     def test_lru(self):
   1115         def orig(x, y):
   1116             return 3 * x + y
   1117         f = self.module.lru_cache(maxsize=20)(orig)
   1118         hits, misses, maxsize, currsize = f.cache_info()
   1119         self.assertEqual(maxsize, 20)
   1120         self.assertEqual(currsize, 0)
   1121         self.assertEqual(hits, 0)
   1122         self.assertEqual(misses, 0)
   1123 
   1124         domain = range(5)
   1125         for i in range(1000):
   1126             x, y = choice(domain), choice(domain)
   1127             actual = f(x, y)
   1128             expected = orig(x, y)
   1129             self.assertEqual(actual, expected)
   1130         hits, misses, maxsize, currsize = f.cache_info()
   1131         self.assertTrue(hits > misses)
   1132         self.assertEqual(hits + misses, 1000)
   1133         self.assertEqual(currsize, 20)
   1134 
   1135         f.cache_clear()   # test clearing
   1136         hits, misses, maxsize, currsize = f.cache_info()
   1137         self.assertEqual(hits, 0)
   1138         self.assertEqual(misses, 0)
   1139         self.assertEqual(currsize, 0)
   1140         f(x, y)
   1141         hits, misses, maxsize, currsize = f.cache_info()
   1142         self.assertEqual(hits, 0)
   1143         self.assertEqual(misses, 1)
   1144         self.assertEqual(currsize, 1)
   1145 
   1146         # Test bypassing the cache
   1147         self.assertIs(f.__wrapped__, orig)
   1148         f.__wrapped__(x, y)
   1149         hits, misses, maxsize, currsize = f.cache_info()
   1150         self.assertEqual(hits, 0)
   1151         self.assertEqual(misses, 1)
   1152         self.assertEqual(currsize, 1)
   1153 
   1154         # test size zero (which means "never-cache")
   1155         @self.module.lru_cache(0)
   1156         def f():
   1157             nonlocal f_cnt
   1158             f_cnt += 1
   1159             return 20
   1160         self.assertEqual(f.cache_info().maxsize, 0)
   1161         f_cnt = 0
   1162         for i in range(5):
   1163             self.assertEqual(f(), 20)
   1164         self.assertEqual(f_cnt, 5)
   1165         hits, misses, maxsize, currsize = f.cache_info()
   1166         self.assertEqual(hits, 0)
   1167         self.assertEqual(misses, 5)
   1168         self.assertEqual(currsize, 0)
   1169 
   1170         # test size one
   1171         @self.module.lru_cache(1)
   1172         def f():
   1173             nonlocal f_cnt
   1174             f_cnt += 1
   1175             return 20
   1176         self.assertEqual(f.cache_info().maxsize, 1)
   1177         f_cnt = 0
   1178         for i in range(5):
   1179             self.assertEqual(f(), 20)
   1180         self.assertEqual(f_cnt, 1)
   1181         hits, misses, maxsize, currsize = f.cache_info()
   1182         self.assertEqual(hits, 4)
   1183         self.assertEqual(misses, 1)
   1184         self.assertEqual(currsize, 1)
   1185 
   1186         # test size two
   1187         @self.module.lru_cache(2)
   1188         def f(x):
   1189             nonlocal f_cnt
   1190             f_cnt += 1
   1191             return x*10
   1192         self.assertEqual(f.cache_info().maxsize, 2)
   1193         f_cnt = 0
   1194         for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
   1195             #    *  *              *                          *
   1196             self.assertEqual(f(x), x*10)
   1197         self.assertEqual(f_cnt, 4)
   1198         hits, misses, maxsize, currsize = f.cache_info()
   1199         self.assertEqual(hits, 12)
   1200         self.assertEqual(misses, 4)
   1201         self.assertEqual(currsize, 2)
   1202 
   1203     def test_lru_reentrancy_with_len(self):
   1204         # Test to make sure the LRU cache code isn't thrown-off by
   1205         # caching the built-in len() function.  Since len() can be
   1206         # cached, we shouldn't use it inside the lru code itself.
   1207         old_len = builtins.len
   1208         try:
   1209             builtins.len = self.module.lru_cache(4)(len)
   1210             for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
   1211                 self.assertEqual(len('abcdefghijklmn'[:i]), i)
   1212         finally:
   1213             builtins.len = old_len
   1214 
   1215     def test_lru_type_error(self):
   1216         # Regression test for issue #28653.
   1217         # lru_cache was leaking when one of the arguments
   1218         # wasn't cacheable.
   1219 
   1220         @functools.lru_cache(maxsize=None)
   1221         def infinite_cache(o):
   1222             pass
   1223 
   1224         @functools.lru_cache(maxsize=10)
   1225         def limited_cache(o):
   1226             pass
   1227 
   1228         with self.assertRaises(TypeError):
   1229             infinite_cache([])
   1230 
   1231         with self.assertRaises(TypeError):
   1232             limited_cache([])
   1233 
   1234     def test_lru_with_maxsize_none(self):
   1235         @self.module.lru_cache(maxsize=None)
   1236         def fib(n):
   1237             if n < 2:
   1238                 return n
   1239             return fib(n-1) + fib(n-2)
   1240         self.assertEqual([fib(n) for n in range(16)],
   1241             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
   1242         self.assertEqual(fib.cache_info(),
   1243             self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
   1244         fib.cache_clear()
   1245         self.assertEqual(fib.cache_info(),
   1246             self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
   1247 
   1248     def test_lru_with_maxsize_negative(self):
   1249         @self.module.lru_cache(maxsize=-10)
   1250         def eq(n):
   1251             return n
   1252         for i in (0, 1):
   1253             self.assertEqual([eq(n) for n in range(150)], list(range(150)))
   1254         self.assertEqual(eq.cache_info(),
   1255             self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1))
   1256 
   1257     def test_lru_with_exceptions(self):
   1258         # Verify that user_function exceptions get passed through without
   1259         # creating a hard-to-read chained exception.
   1260         # http://bugs.python.org/issue13177
   1261         for maxsize in (None, 128):
   1262             @self.module.lru_cache(maxsize)
   1263             def func(i):
   1264                 return 'abc'[i]
   1265             self.assertEqual(func(0), 'a')
   1266             with self.assertRaises(IndexError) as cm:
   1267                 func(15)
   1268             self.assertIsNone(cm.exception.__context__)
   1269             # Verify that the previous exception did not result in a cached entry
   1270             with self.assertRaises(IndexError):
   1271                 func(15)
   1272 
   1273     def test_lru_with_types(self):
   1274         for maxsize in (None, 128):
   1275             @self.module.lru_cache(maxsize=maxsize, typed=True)
   1276             def square(x):
   1277                 return x * x
   1278             self.assertEqual(square(3), 9)
   1279             self.assertEqual(type(square(3)), type(9))
   1280             self.assertEqual(square(3.0), 9.0)
   1281             self.assertEqual(type(square(3.0)), type(9.0))
   1282             self.assertEqual(square(x=3), 9)
   1283             self.assertEqual(type(square(x=3)), type(9))
   1284             self.assertEqual(square(x=3.0), 9.0)
   1285             self.assertEqual(type(square(x=3.0)), type(9.0))
   1286             self.assertEqual(square.cache_info().hits, 4)
   1287             self.assertEqual(square.cache_info().misses, 4)
   1288 
   1289     def test_lru_with_keyword_args(self):
   1290         @self.module.lru_cache()
   1291         def fib(n):
   1292             if n < 2:
   1293                 return n
   1294             return fib(n=n-1) + fib(n=n-2)
   1295         self.assertEqual(
   1296             [fib(n=number) for number in range(16)],
   1297             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
   1298         )
   1299         self.assertEqual(fib.cache_info(),
   1300             self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
   1301         fib.cache_clear()
   1302         self.assertEqual(fib.cache_info(),
   1303             self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
   1304 
   1305     def test_lru_with_keyword_args_maxsize_none(self):
   1306         @self.module.lru_cache(maxsize=None)
   1307         def fib(n):
   1308             if n < 2:
   1309                 return n
   1310             return fib(n=n-1) + fib(n=n-2)
   1311         self.assertEqual([fib(n=number) for number in range(16)],
   1312             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
   1313         self.assertEqual(fib.cache_info(),
   1314             self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
   1315         fib.cache_clear()
   1316         self.assertEqual(fib.cache_info(),
   1317             self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
   1318 
   1319     def test_kwargs_order(self):
   1320         # PEP 468: Preserving Keyword Argument Order
   1321         @self.module.lru_cache(maxsize=10)
   1322         def f(**kwargs):
   1323             return list(kwargs.items())
   1324         self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
   1325         self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
   1326         self.assertEqual(f.cache_info(),
   1327             self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
   1328 
   1329     def test_lru_cache_decoration(self):
   1330         def f(zomg: 'zomg_annotation'):
   1331             """f doc string"""
   1332             return 42
   1333         g = self.module.lru_cache()(f)
   1334         for attr in self.module.WRAPPER_ASSIGNMENTS:
   1335             self.assertEqual(getattr(g, attr), getattr(f, attr))
   1336 
   1337     @unittest.skipUnless(threading, 'This test requires threading.')
   1338     def test_lru_cache_threaded(self):
   1339         n, m = 5, 11
   1340         def orig(x, y):
   1341             return 3 * x + y
   1342         f = self.module.lru_cache(maxsize=n*m)(orig)
   1343         hits, misses, maxsize, currsize = f.cache_info()
   1344         self.assertEqual(currsize, 0)
   1345 
   1346         start = threading.Event()
   1347         def full(k):
   1348             start.wait(10)
   1349             for _ in range(m):
   1350                 self.assertEqual(f(k, 0), orig(k, 0))
   1351 
   1352         def clear():
   1353             start.wait(10)
   1354             for _ in range(2*m):
   1355                 f.cache_clear()
   1356 
   1357         orig_si = sys.getswitchinterval()
   1358         support.setswitchinterval(1e-6)
   1359         try:
   1360             # create n threads in order to fill cache
   1361             threads = [threading.Thread(target=full, args=[k])
   1362                        for k in range(n)]
   1363             with support.start_threads(threads):
   1364                 start.set()
   1365 
   1366             hits, misses, maxsize, currsize = f.cache_info()
   1367             if self.module is py_functools:
   1368                 # XXX: Why can be not equal?
   1369                 self.assertLessEqual(misses, n)
   1370                 self.assertLessEqual(hits, m*n - misses)
   1371             else:
   1372                 self.assertEqual(misses, n)
   1373                 self.assertEqual(hits, m*n - misses)
   1374             self.assertEqual(currsize, n)
   1375 
   1376             # create n threads in order to fill cache and 1 to clear it
   1377             threads = [threading.Thread(target=clear)]
   1378             threads += [threading.Thread(target=full, args=[k])
   1379                         for k in range(n)]
   1380             start.clear()
   1381             with support.start_threads(threads):
   1382                 start.set()
   1383         finally:
   1384             sys.setswitchinterval(orig_si)
   1385 
   1386     @unittest.skipUnless(threading, 'This test requires threading.')
   1387     def test_lru_cache_threaded2(self):
   1388         # Simultaneous call with the same arguments
   1389         n, m = 5, 7
   1390         start = threading.Barrier(n+1)
   1391         pause = threading.Barrier(n+1)
   1392         stop = threading.Barrier(n+1)
   1393         @self.module.lru_cache(maxsize=m*n)
   1394         def f(x):
   1395             pause.wait(10)
   1396             return 3 * x
   1397         self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
   1398         def test():
   1399             for i in range(m):
   1400                 start.wait(10)
   1401                 self.assertEqual(f(i), 3 * i)
   1402                 stop.wait(10)
   1403         threads = [threading.Thread(target=test) for k in range(n)]
   1404         with support.start_threads(threads):
   1405             for i in range(m):
   1406                 start.wait(10)
   1407                 stop.reset()
   1408                 pause.wait(10)
   1409                 start.reset()
   1410                 stop.wait(10)
   1411                 pause.reset()
   1412                 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
   1413 
   1414     @unittest.skipUnless(threading, 'This test requires threading.')
   1415     def test_lru_cache_threaded3(self):
   1416         @self.module.lru_cache(maxsize=2)
   1417         def f(x):
   1418             time.sleep(.01)
   1419             return 3 * x
   1420         def test(i, x):
   1421             with self.subTest(thread=i):
   1422                 self.assertEqual(f(x), 3 * x, i)
   1423         threads = [threading.Thread(target=test, args=(i, v))
   1424                    for i, v in enumerate([1, 2, 2, 3, 2])]
   1425         with support.start_threads(threads):
   1426             pass
   1427 
   1428     def test_need_for_rlock(self):
   1429         # This will deadlock on an LRU cache that uses a regular lock
   1430 
   1431         @self.module.lru_cache(maxsize=10)
   1432         def test_func(x):
   1433             'Used to demonstrate a reentrant lru_cache call within a single thread'
   1434             return x
   1435 
   1436         class DoubleEq:
   1437             'Demonstrate a reentrant lru_cache call within a single thread'
   1438             def __init__(self, x):
   1439                 self.x = x
   1440             def __hash__(self):
   1441                 return self.x
   1442             def __eq__(self, other):
   1443                 if self.x == 2:
   1444                     test_func(DoubleEq(1))
   1445                 return self.x == other.x
   1446 
   1447         test_func(DoubleEq(1))                      # Load the cache
   1448         test_func(DoubleEq(2))                      # Load the cache
   1449         self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
   1450                          DoubleEq(2))               # Verify the correct return value
   1451 
   1452     def test_early_detection_of_bad_call(self):
   1453         # Issue #22184
   1454         with self.assertRaises(TypeError):
   1455             @functools.lru_cache
   1456             def f():
   1457                 pass
   1458 
   1459     def test_lru_method(self):
   1460         class X(int):
   1461             f_cnt = 0
   1462             @self.module.lru_cache(2)
   1463             def f(self, x):
   1464                 self.f_cnt += 1
   1465                 return x*10+self
   1466         a = X(5)
   1467         b = X(5)
   1468         c = X(7)
   1469         self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
   1470 
   1471         for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
   1472             self.assertEqual(a.f(x), x*10 + 5)
   1473         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
   1474         self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
   1475 
   1476         for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
   1477             self.assertEqual(b.f(x), x*10 + 5)
   1478         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
   1479         self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
   1480 
   1481         for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
   1482             self.assertEqual(c.f(x), x*10 + 7)
   1483         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
   1484         self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
   1485 
   1486         self.assertEqual(a.f.cache_info(), X.f.cache_info())
   1487         self.assertEqual(b.f.cache_info(), X.f.cache_info())
   1488         self.assertEqual(c.f.cache_info(), X.f.cache_info())
   1489 
   1490     def test_pickle(self):
   1491         cls = self.__class__
   1492         for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
   1493             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
   1494                 with self.subTest(proto=proto, func=f):
   1495                     f_copy = pickle.loads(pickle.dumps(f, proto))
   1496                     self.assertIs(f_copy, f)
   1497 
   1498     def test_copy(self):
   1499         cls = self.__class__
   1500         def orig(x, y):
   1501             return 3 * x + y
   1502         part = self.module.partial(orig, 2)
   1503         funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
   1504                  self.module.lru_cache(2)(part))
   1505         for f in funcs:
   1506             with self.subTest(func=f):
   1507                 f_copy = copy.copy(f)
   1508                 self.assertIs(f_copy, f)
   1509 
   1510     def test_deepcopy(self):
   1511         cls = self.__class__
   1512         def orig(x, y):
   1513             return 3 * x + y
   1514         part = self.module.partial(orig, 2)
   1515         funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
   1516                  self.module.lru_cache(2)(part))
   1517         for f in funcs:
   1518             with self.subTest(func=f):
   1519                 f_copy = copy.deepcopy(f)
   1520                 self.assertIs(f_copy, f)
   1521 
   1522 
   1523 @py_functools.lru_cache()
   1524 def py_cached_func(x, y):
   1525     return 3 * x + y
   1526 
   1527 @c_functools.lru_cache()
   1528 def c_cached_func(x, y):
   1529     return 3 * x + y
   1530 
   1531 
   1532 class TestLRUPy(TestLRU, unittest.TestCase):
   1533     module = py_functools
   1534     cached_func = py_cached_func,
   1535 
   1536     @module.lru_cache()
   1537     def cached_meth(self, x, y):
   1538         return 3 * x + y
   1539 
   1540     @staticmethod
   1541     @module.lru_cache()
   1542     def cached_staticmeth(x, y):
   1543         return 3 * x + y
   1544 
   1545 
   1546 class TestLRUC(TestLRU, unittest.TestCase):
   1547     module = c_functools
   1548     cached_func = c_cached_func,
   1549 
   1550     @module.lru_cache()
   1551     def cached_meth(self, x, y):
   1552         return 3 * x + y
   1553 
   1554     @staticmethod
   1555     @module.lru_cache()
   1556     def cached_staticmeth(x, y):
   1557         return 3 * x + y
   1558 
   1559 
   1560 class TestSingleDispatch(unittest.TestCase):
   1561     def test_simple_overloads(self):
   1562         @functools.singledispatch
   1563         def g(obj):
   1564             return "base"
   1565         def g_int(i):
   1566             return "integer"
   1567         g.register(int, g_int)
   1568         self.assertEqual(g("str"), "base")
   1569         self.assertEqual(g(1), "integer")
   1570         self.assertEqual(g([1,2,3]), "base")
   1571 
   1572     def test_mro(self):
   1573         @functools.singledispatch
   1574         def g(obj):
   1575             return "base"
   1576         class A:
   1577             pass
   1578         class C(A):
   1579             pass
   1580         class B(A):
   1581             pass
   1582         class D(C, B):
   1583             pass
   1584         def g_A(a):
   1585             return "A"
   1586         def g_B(b):
   1587             return "B"
   1588         g.register(A, g_A)
   1589         g.register(B, g_B)
   1590         self.assertEqual(g(A()), "A")
   1591         self.assertEqual(g(B()), "B")
   1592         self.assertEqual(g(C()), "A")
   1593         self.assertEqual(g(D()), "B")
   1594 
   1595     def test_register_decorator(self):
   1596         @functools.singledispatch
   1597         def g(obj):
   1598             return "base"
   1599         @g.register(int)
   1600         def g_int(i):
   1601             return "int %s" % (i,)
   1602         self.assertEqual(g(""), "base")
   1603         self.assertEqual(g(12), "int 12")
   1604         self.assertIs(g.dispatch(int), g_int)
   1605         self.assertIs(g.dispatch(object), g.dispatch(str))
   1606         # Note: in the assert above this is not g.
   1607         # @singledispatch returns the wrapper.
   1608 
   1609     def test_wrapping_attributes(self):
   1610         @functools.singledispatch
   1611         def g(obj):
   1612             "Simple test"
   1613             return "Test"
   1614         self.assertEqual(g.__name__, "g")
   1615         if sys.flags.optimize < 2:
   1616             self.assertEqual(g.__doc__, "Simple test")
   1617 
   1618     @unittest.skipUnless(decimal, 'requires _decimal')
   1619     @support.cpython_only
   1620     def test_c_classes(self):
   1621         @functools.singledispatch
   1622         def g(obj):
   1623             return "base"
   1624         @g.register(decimal.DecimalException)
   1625         def _(obj):
   1626             return obj.args
   1627         subn = decimal.Subnormal("Exponent < Emin")
   1628         rnd = decimal.Rounded("Number got rounded")
   1629         self.assertEqual(g(subn), ("Exponent < Emin",))
   1630         self.assertEqual(g(rnd), ("Number got rounded",))
   1631         @g.register(decimal.Subnormal)
   1632         def _(obj):
   1633             return "Too small to care."
   1634         self.assertEqual(g(subn), "Too small to care.")
   1635         self.assertEqual(g(rnd), ("Number got rounded",))
   1636 
   1637     def test_compose_mro(self):
   1638         # None of the examples in this test depend on haystack ordering.
   1639         c = collections
   1640         mro = functools._compose_mro
   1641         bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
   1642         for haystack in permutations(bases):
   1643             m = mro(dict, haystack)
   1644             self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
   1645                                  c.Collection, c.Sized, c.Iterable,
   1646                                  c.Container, object])
   1647         bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict]
   1648         for haystack in permutations(bases):
   1649             m = mro(c.ChainMap, haystack)
   1650             self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping,
   1651                                  c.Collection, c.Sized, c.Iterable,
   1652                                  c.Container, object])
   1653 
   1654         # If there's a generic function with implementations registered for
   1655         # both Sized and Container, passing a defaultdict to it results in an
   1656         # ambiguous dispatch which will cause a RuntimeError (see
   1657         # test_mro_conflicts).
   1658         bases = [c.Container, c.Sized, str]
   1659         for haystack in permutations(bases):
   1660             m = mro(c.defaultdict, [c.Sized, c.Container, str])
   1661             self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container,
   1662                                  object])
   1663 
   1664         # MutableSequence below is registered directly on D. In other words, it
   1665         # precedes MutableMapping which means single dispatch will always
   1666         # choose MutableSequence here.
   1667         class D(c.defaultdict):
   1668             pass
   1669         c.MutableSequence.register(D)
   1670         bases = [c.MutableSequence, c.MutableMapping]
   1671         for haystack in permutations(bases):
   1672             m = mro(D, bases)
   1673             self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
   1674                                  c.defaultdict, dict, c.MutableMapping, c.Mapping,
   1675                                  c.Collection, c.Sized, c.Iterable, c.Container,
   1676                                  object])
   1677 
   1678         # Container and Callable are registered on different base classes and
   1679         # a generic function supporting both should always pick the Callable
   1680         # implementation if a C instance is passed.
   1681         class C(c.defaultdict):
   1682             def __call__(self):
   1683                 pass
   1684         bases = [c.Sized, c.Callable, c.Container, c.Mapping]
   1685         for haystack in permutations(bases):
   1686             m = mro(C, haystack)
   1687             self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping,
   1688                                  c.Collection, c.Sized, c.Iterable,
   1689                                  c.Container, object])
   1690 
   1691     def test_register_abc(self):
   1692         c = collections
   1693         d = {"a": "b"}
   1694         l = [1, 2, 3]
   1695         s = {object(), None}
   1696         f = frozenset(s)
   1697         t = (1, 2, 3)
   1698         @functools.singledispatch
   1699         def g(obj):
   1700             return "base"
   1701         self.assertEqual(g(d), "base")
   1702         self.assertEqual(g(l), "base")
   1703         self.assertEqual(g(s), "base")
   1704         self.assertEqual(g(f), "base")
   1705         self.assertEqual(g(t), "base")
   1706         g.register(c.Sized, lambda obj: "sized")
   1707         self.assertEqual(g(d), "sized")
   1708         self.assertEqual(g(l), "sized")
   1709         self.assertEqual(g(s), "sized")
   1710         self.assertEqual(g(f), "sized")
   1711         self.assertEqual(g(t), "sized")
   1712         g.register(c.MutableMapping, lambda obj: "mutablemapping")
   1713         self.assertEqual(g(d), "mutablemapping")
   1714         self.assertEqual(g(l), "sized")
   1715         self.assertEqual(g(s), "sized")
   1716         self.assertEqual(g(f), "sized")
   1717         self.assertEqual(g(t), "sized")
   1718         g.register(c.ChainMap, lambda obj: "chainmap")
   1719         self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
   1720         self.assertEqual(g(l), "sized")
   1721         self.assertEqual(g(s), "sized")
   1722         self.assertEqual(g(f), "sized")
   1723         self.assertEqual(g(t), "sized")
   1724         g.register(c.MutableSequence, lambda obj: "mutablesequence")
   1725         self.assertEqual(g(d), "mutablemapping")
   1726         self.assertEqual(g(l), "mutablesequence")
   1727         self.assertEqual(g(s), "sized")
   1728         self.assertEqual(g(f), "sized")
   1729         self.assertEqual(g(t), "sized")
   1730         g.register(c.MutableSet, lambda obj: "mutableset")
   1731         self.assertEqual(g(d), "mutablemapping")
   1732         self.assertEqual(g(l), "mutablesequence")
   1733         self.assertEqual(g(s), "mutableset")
   1734         self.assertEqual(g(f), "sized")
   1735         self.assertEqual(g(t), "sized")
   1736         g.register(c.Mapping, lambda obj: "mapping")
   1737         self.assertEqual(g(d), "mutablemapping")  # not specific enough
   1738         self.assertEqual(g(l), "mutablesequence")
   1739         self.assertEqual(g(s), "mutableset")
   1740         self.assertEqual(g(f), "sized")
   1741         self.assertEqual(g(t), "sized")
   1742         g.register(c.Sequence, lambda obj: "sequence")
   1743         self.assertEqual(g(d), "mutablemapping")
   1744         self.assertEqual(g(l), "mutablesequence")
   1745         self.assertEqual(g(s), "mutableset")
   1746         self.assertEqual(g(f), "sized")
   1747         self.assertEqual(g(t), "sequence")
   1748         g.register(c.Set, lambda obj: "set")
   1749         self.assertEqual(g(d), "mutablemapping")
   1750         self.assertEqual(g(l), "mutablesequence")
   1751         self.assertEqual(g(s), "mutableset")
   1752         self.assertEqual(g(f), "set")
   1753         self.assertEqual(g(t), "sequence")
   1754         g.register(dict, lambda obj: "dict")
   1755         self.assertEqual(g(d), "dict")
   1756         self.assertEqual(g(l), "mutablesequence")
   1757         self.assertEqual(g(s), "mutableset")
   1758         self.assertEqual(g(f), "set")
   1759         self.assertEqual(g(t), "sequence")
   1760         g.register(list, lambda obj: "list")
   1761         self.assertEqual(g(d), "dict")
   1762         self.assertEqual(g(l), "list")
   1763         self.assertEqual(g(s), "mutableset")
   1764         self.assertEqual(g(f), "set")
   1765         self.assertEqual(g(t), "sequence")
   1766         g.register(set, lambda obj: "concrete-set")
   1767         self.assertEqual(g(d), "dict")
   1768         self.assertEqual(g(l), "list")
   1769         self.assertEqual(g(s), "concrete-set")
   1770         self.assertEqual(g(f), "set")
   1771         self.assertEqual(g(t), "sequence")
   1772         g.register(frozenset, lambda obj: "frozen-set")
   1773         self.assertEqual(g(d), "dict")
   1774         self.assertEqual(g(l), "list")
   1775         self.assertEqual(g(s), "concrete-set")
   1776         self.assertEqual(g(f), "frozen-set")
   1777         self.assertEqual(g(t), "sequence")
   1778         g.register(tuple, lambda obj: "tuple")
   1779         self.assertEqual(g(d), "dict")
   1780         self.assertEqual(g(l), "list")
   1781         self.assertEqual(g(s), "concrete-set")
   1782         self.assertEqual(g(f), "frozen-set")
   1783         self.assertEqual(g(t), "tuple")
   1784 
   1785     def test_c3_abc(self):
   1786         c = collections
   1787         mro = functools._c3_mro
   1788         class A(object):
   1789             pass
   1790         class B(A):
   1791             def __len__(self):
   1792                 return 0   # implies Sized
   1793         @c.Container.register
   1794         class C(object):
   1795             pass
   1796         class D(object):
   1797             pass   # unrelated
   1798         class X(D, C, B):
   1799             def __call__(self):
   1800                 pass   # implies Callable
   1801         expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
   1802         for abcs in permutations([c.Sized, c.Callable, c.Container]):
   1803             self.assertEqual(mro(X, abcs=abcs), expected)
   1804         # unrelated ABCs don't appear in the resulting MRO
   1805         many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
   1806         self.assertEqual(mro(X, abcs=many_abcs), expected)
   1807 
   1808     def test_false_meta(self):
   1809         # see issue23572
   1810         class MetaA(type):
   1811             def __len__(self):
   1812                 return 0
   1813         class A(metaclass=MetaA):
   1814             pass
   1815         class AA(A):
   1816             pass
   1817         @functools.singledispatch
   1818         def fun(a):
   1819             return 'base A'
   1820         @fun.register(A)
   1821         def _(a):
   1822             return 'fun A'
   1823         aa = AA()
   1824         self.assertEqual(fun(aa), 'fun A')
   1825 
   1826     def test_mro_conflicts(self):
   1827         c = collections
   1828         @functools.singledispatch
   1829         def g(arg):
   1830             return "base"
   1831         class O(c.Sized):
   1832             def __len__(self):
   1833                 return 0
   1834         o = O()
   1835         self.assertEqual(g(o), "base")
   1836         g.register(c.Iterable, lambda arg: "iterable")
   1837         g.register(c.Container, lambda arg: "container")
   1838         g.register(c.Sized, lambda arg: "sized")
   1839         g.register(c.Set, lambda arg: "set")
   1840         self.assertEqual(g(o), "sized")
   1841         c.Iterable.register(O)
   1842         self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
   1843         c.Container.register(O)
   1844         self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
   1845         c.Set.register(O)
   1846         self.assertEqual(g(o), "set")     # because c.Set is a subclass of
   1847                                           # c.Sized and c.Container
   1848         class P:
   1849             pass
   1850         p = P()
   1851         self.assertEqual(g(p), "base")
   1852         c.Iterable.register(P)
   1853         self.assertEqual(g(p), "iterable")
   1854         c.Container.register(P)
   1855         with self.assertRaises(RuntimeError) as re_one:
   1856             g(p)
   1857         self.assertIn(
   1858             str(re_one.exception),
   1859             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
   1860               "or <class 'collections.abc.Iterable'>"),
   1861              ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
   1862               "or <class 'collections.abc.Container'>")),
   1863         )
   1864         class Q(c.Sized):
   1865             def __len__(self):
   1866                 return 0
   1867         q = Q()
   1868         self.assertEqual(g(q), "sized")
   1869         c.Iterable.register(Q)
   1870         self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
   1871         c.Set.register(Q)
   1872         self.assertEqual(g(q), "set")     # because c.Set is a subclass of
   1873                                           # c.Sized and c.Iterable
   1874         @functools.singledispatch
   1875         def h(arg):
   1876             return "base"
   1877         @h.register(c.Sized)
   1878         def _(arg):
   1879             return "sized"
   1880         @h.register(c.Container)
   1881         def _(arg):
   1882             return "container"
   1883         # Even though Sized and Container are explicit bases of MutableMapping,
   1884         # this ABC is implicitly registered on defaultdict which makes all of
   1885         # MutableMapping's bases implicit as well from defaultdict's
   1886         # perspective.
   1887         with self.assertRaises(RuntimeError) as re_two:
   1888             h(c.defaultdict(lambda: 0))
   1889         self.assertIn(
   1890             str(re_two.exception),
   1891             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
   1892               "or <class 'collections.abc.Sized'>"),
   1893              ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
   1894               "or <class 'collections.abc.Container'>")),
   1895         )
   1896         class R(c.defaultdict):
   1897             pass
   1898         c.MutableSequence.register(R)
   1899         @functools.singledispatch
   1900         def i(arg):
   1901             return "base"
   1902         @i.register(c.MutableMapping)
   1903         def _(arg):
   1904             return "mapping"
   1905         @i.register(c.MutableSequence)
   1906         def _(arg):
   1907             return "sequence"
   1908         r = R()
   1909         self.assertEqual(i(r), "sequence")
   1910         class S:
   1911             pass
   1912         class T(S, c.Sized):
   1913             def __len__(self):
   1914                 return 0
   1915         t = T()
   1916         self.assertEqual(h(t), "sized")
   1917         c.Container.register(T)
   1918         self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
   1919         class U:
   1920             def __len__(self):
   1921                 return 0
   1922         u = U()
   1923         self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
   1924                                           # from the existence of __len__()
   1925         c.Container.register(U)
   1926         # There is no preference for registered versus inferred ABCs.
   1927         with self.assertRaises(RuntimeError) as re_three:
   1928             h(u)
   1929         self.assertIn(
   1930             str(re_three.exception),
   1931             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
   1932               "or <class 'collections.abc.Sized'>"),
   1933              ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
   1934               "or <class 'collections.abc.Container'>")),
   1935         )
   1936         class V(c.Sized, S):
   1937             def __len__(self):
   1938                 return 0
   1939         @functools.singledispatch
   1940         def j(arg):
   1941             return "base"
   1942         @j.register(S)
   1943         def _(arg):
   1944             return "s"
   1945         @j.register(c.Container)
   1946         def _(arg):
   1947             return "container"
   1948         v = V()
   1949         self.assertEqual(j(v), "s")
   1950         c.Container.register(V)
   1951         self.assertEqual(j(v), "container")   # because it ends up right after
   1952                                               # Sized in the MRO
   1953 
   1954     def test_cache_invalidation(self):
   1955         from collections import UserDict
   1956         class TracingDict(UserDict):
   1957             def __init__(self, *args, **kwargs):
   1958                 super(TracingDict, self).__init__(*args, **kwargs)
   1959                 self.set_ops = []
   1960                 self.get_ops = []
   1961             def __getitem__(self, key):
   1962                 result = self.data[key]
   1963                 self.get_ops.append(key)
   1964                 return result
   1965             def __setitem__(self, key, value):
   1966                 self.set_ops.append(key)
   1967                 self.data[key] = value
   1968             def clear(self):
   1969                 self.data.clear()
   1970         _orig_wkd = functools.WeakKeyDictionary
   1971         td = TracingDict()
   1972         functools.WeakKeyDictionary = lambda: td
   1973         c = collections
   1974         @functools.singledispatch
   1975         def g(arg):
   1976             return "base"
   1977         d = {}
   1978         l = []
   1979         self.assertEqual(len(td), 0)
   1980         self.assertEqual(g(d), "base")
   1981         self.assertEqual(len(td), 1)
   1982         self.assertEqual(td.get_ops, [])
   1983         self.assertEqual(td.set_ops, [dict])
   1984         self.assertEqual(td.data[dict], g.registry[object])
   1985         self.assertEqual(g(l), "base")
   1986         self.assertEqual(len(td), 2)
   1987         self.assertEqual(td.get_ops, [])
   1988         self.assertEqual(td.set_ops, [dict, list])
   1989         self.assertEqual(td.data[dict], g.registry[object])
   1990         self.assertEqual(td.data[list], g.registry[object])
   1991         self.assertEqual(td.data[dict], td.data[list])
   1992         self.assertEqual(g(l), "base")
   1993         self.assertEqual(g(d), "base")
   1994         self.assertEqual(td.get_ops, [list, dict])
   1995         self.assertEqual(td.set_ops, [dict, list])
   1996         g.register(list, lambda arg: "list")
   1997         self.assertEqual(td.get_ops, [list, dict])
   1998         self.assertEqual(len(td), 0)
   1999         self.assertEqual(g(d), "base")
   2000         self.assertEqual(len(td), 1)
   2001         self.assertEqual(td.get_ops, [list, dict])
   2002         self.assertEqual(td.set_ops, [dict, list, dict])
   2003         self.assertEqual(td.data[dict],
   2004                          functools._find_impl(dict, g.registry))
   2005         self.assertEqual(g(l), "list")
   2006         self.assertEqual(len(td), 2)
   2007         self.assertEqual(td.get_ops, [list, dict])
   2008         self.assertEqual(td.set_ops, [dict, list, dict, list])
   2009         self.assertEqual(td.data[list],
   2010                          functools._find_impl(list, g.registry))
   2011         class X:
   2012             pass
   2013         c.MutableMapping.register(X)   # Will not invalidate the cache,
   2014                                        # not using ABCs yet.
   2015         self.assertEqual(g(d), "base")
   2016         self.assertEqual(g(l), "list")
   2017         self.assertEqual(td.get_ops, [list, dict, dict, list])
   2018         self.assertEqual(td.set_ops, [dict, list, dict, list])
   2019         g.register(c.Sized, lambda arg: "sized")
   2020         self.assertEqual(len(td), 0)
   2021         self.assertEqual(g(d), "sized")
   2022         self.assertEqual(len(td), 1)
   2023         self.assertEqual(td.get_ops, [list, dict, dict, list])
   2024         self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
   2025         self.assertEqual(g(l), "list")
   2026         self.assertEqual(len(td), 2)
   2027         self.assertEqual(td.get_ops, [list, dict, dict, list])
   2028         self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
   2029         self.assertEqual(g(l), "list")
   2030         self.assertEqual(g(d), "sized")
   2031         self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
   2032         self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
   2033         g.dispatch(list)
   2034         g.dispatch(dict)
   2035         self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
   2036                                       list, dict])
   2037         self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
   2038         c.MutableSet.register(X)       # Will invalidate the cache.
   2039         self.assertEqual(len(td), 2)   # Stale cache.
   2040         self.assertEqual(g(l), "list")
   2041         self.assertEqual(len(td), 1)
   2042         g.register(c.MutableMapping, lambda arg: "mutablemapping")
   2043         self.assertEqual(len(td), 0)
   2044         self.assertEqual(g(d), "mutablemapping")
   2045         self.assertEqual(len(td), 1)
   2046         self.assertEqual(g(l), "list")
   2047         self.assertEqual(len(td), 2)
   2048         g.register(dict, lambda arg: "dict")
   2049         self.assertEqual(g(d), "dict")
   2050         self.assertEqual(g(l), "list")
   2051         g._clear_cache()
   2052         self.assertEqual(len(td), 0)
   2053         functools.WeakKeyDictionary = _orig_wkd
   2054 
   2055 
   2056 if __name__ == '__main__':
   2057     unittest.main()
   2058