Home | History | Annotate | Download | only in test
      1 import abc
      2 import builtins
      3 import collections
      4 import collections.abc
      5 import copy
      6 from itertools import permutations
      7 import pickle
      8 from random import choice
      9 import sys
     10 from test import support
     11 import threading
     12 import time
     13 import typing
     14 import unittest
     15 import unittest.mock
     16 from weakref import proxy
     17 import contextlib
     18 
     19 import functools
     20 
     21 py_functools = support.import_fresh_module('functools', blocked=['_functools'])
     22 c_functools = support.import_fresh_module('functools', fresh=['_functools'])
     23 
     24 decimal = support.import_fresh_module('decimal', fresh=['_decimal'])
     25 
     26 @contextlib.contextmanager
     27 def replaced_module(name, replacement):
     28     original_module = sys.modules[name]
     29     sys.modules[name] = replacement
     30     try:
     31         yield
     32     finally:
     33         sys.modules[name] = original_module
     34 
     35 def capture(*args, **kw):
     36     """capture all positional and keyword arguments"""
     37     return args, kw
     38 
     39 
     40 def signature(part):
     41     """ return the signature of a partial object """
     42     return (part.func, part.args, part.keywords, part.__dict__)
     43 
     44 class MyTuple(tuple):
     45     pass
     46 
     47 class BadTuple(tuple):
     48     def __add__(self, other):
     49         return list(self) + list(other)
     50 
     51 class MyDict(dict):
     52     pass
     53 
     54 
     55 class TestPartial:
     56 
     57     def test_basic_examples(self):
     58         p = self.partial(capture, 1, 2, a=10, b=20)
     59         self.assertTrue(callable(p))
     60         self.assertEqual(p(3, 4, b=30, c=40),
     61                          ((1, 2, 3, 4), dict(a=10, b=30, c=40)))
     62         p = self.partial(map, lambda x: x*10)
     63         self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40])
     64 
     65     def test_attributes(self):
     66         p = self.partial(capture, 1, 2, a=10, b=20)
     67         # attributes should be readable
     68         self.assertEqual(p.func, capture)
     69         self.assertEqual(p.args, (1, 2))
     70         self.assertEqual(p.keywords, dict(a=10, b=20))
     71 
     72     def test_argument_checking(self):
     73         self.assertRaises(TypeError, self.partial)     # need at least a func arg
     74         try:
     75             self.partial(2)()
     76         except TypeError:
     77             pass
     78         else:
     79             self.fail('First arg not checked for callability')
     80 
     81     def test_protection_of_callers_dict_argument(self):
     82         # a caller's dictionary should not be altered by partial
     83         def func(a=10, b=20):
     84             return a
     85         d = {'a':3}
     86         p = self.partial(func, a=5)
     87         self.assertEqual(p(**d), 3)
     88         self.assertEqual(d, {'a':3})
     89         p(b=7)
     90         self.assertEqual(d, {'a':3})
     91 
     92     def test_kwargs_copy(self):
     93         # Issue #29532: Altering a kwarg dictionary passed to a constructor
     94         # should not affect a partial object after creation
     95         d = {'a': 3}
     96         p = self.partial(capture, **d)
     97         self.assertEqual(p(), ((), {'a': 3}))
     98         d['a'] = 5
     99         self.assertEqual(p(), ((), {'a': 3}))
    100 
    101     def test_arg_combinations(self):
    102         # exercise special code paths for zero args in either partial
    103         # object or the caller
    104         p = self.partial(capture)
    105         self.assertEqual(p(), ((), {}))
    106         self.assertEqual(p(1,2), ((1,2), {}))
    107         p = self.partial(capture, 1, 2)
    108         self.assertEqual(p(), ((1,2), {}))
    109         self.assertEqual(p(3,4), ((1,2,3,4), {}))
    110 
    111     def test_kw_combinations(self):
    112         # exercise special code paths for no keyword args in
    113         # either the partial object or the caller
    114         p = self.partial(capture)
    115         self.assertEqual(p.keywords, {})
    116         self.assertEqual(p(), ((), {}))
    117         self.assertEqual(p(a=1), ((), {'a':1}))
    118         p = self.partial(capture, a=1)
    119         self.assertEqual(p.keywords, {'a':1})
    120         self.assertEqual(p(), ((), {'a':1}))
    121         self.assertEqual(p(b=2), ((), {'a':1, 'b':2}))
    122         # keyword args in the call override those in the partial object
    123         self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2}))
    124 
    125     def test_positional(self):
    126         # make sure positional arguments are captured correctly
    127         for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]:
    128             p = self.partial(capture, *args)
    129             expected = args + ('x',)
    130             got, empty = p('x')
    131             self.assertTrue(expected == got and empty == {})
    132 
    133     def test_keyword(self):
    134         # make sure keyword arguments are captured correctly
    135         for a in ['a', 0, None, 3.5]:
    136             p = self.partial(capture, a=a)
    137             expected = {'a':a,'x':None}
    138             empty, got = p(x=None)
    139             self.assertTrue(expected == got and empty == ())
    140 
    141     def test_no_side_effects(self):
    142         # make sure there are no side effects that affect subsequent calls
    143         p = self.partial(capture, 0, a=1)
    144         args1, kw1 = p(1, b=2)
    145         self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2})
    146         args2, kw2 = p()
    147         self.assertTrue(args2 == (0,) and kw2 == {'a':1})
    148 
    149     def test_error_propagation(self):
    150         def f(x, y):
    151             x / y
    152         self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0))
    153         self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0)
    154         self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0)
    155         self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1)
    156 
    157     def test_weakref(self):
    158         f = self.partial(int, base=16)
    159         p = proxy(f)
    160         self.assertEqual(f.func, p.func)
    161         f = None
    162         self.assertRaises(ReferenceError, getattr, p, 'func')
    163 
    164     def test_with_bound_and_unbound_methods(self):
    165         data = list(map(str, range(10)))
    166         join = self.partial(str.join, '')
    167         self.assertEqual(join(data), '0123456789')
    168         join = self.partial(''.join)
    169         self.assertEqual(join(data), '0123456789')
    170 
    171     def test_nested_optimization(self):
    172         partial = self.partial
    173         inner = partial(signature, 'asdf')
    174         nested = partial(inner, bar=True)
    175         flat = partial(signature, 'asdf', bar=True)
    176         self.assertEqual(signature(nested), signature(flat))
    177 
    178     def test_nested_partial_with_attribute(self):
    179         # see issue 25137
    180         partial = self.partial
    181 
    182         def foo(bar):
    183             return bar
    184 
    185         p = partial(foo, 'first')
    186         p2 = partial(p, 'second')
    187         p2.new_attr = 'spam'
    188         self.assertEqual(p2.new_attr, 'spam')
    189 
    190     def test_repr(self):
    191         args = (object(), object())
    192         args_repr = ', '.join(repr(a) for a in args)
    193         kwargs = {'a': object(), 'b': object()}
    194         kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs),
    195                         'b={b!r}, a={a!r}'.format_map(kwargs)]
    196         if self.partial in (c_functools.partial, py_functools.partial):
    197             name = 'functools.partial'
    198         else:
    199             name = self.partial.__name__
    200 
    201         f = self.partial(capture)
    202         self.assertEqual(f'{name}({capture!r})', repr(f))
    203 
    204         f = self.partial(capture, *args)
    205         self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f))
    206 
    207         f = self.partial(capture, **kwargs)
    208         self.assertIn(repr(f),
    209                       [f'{name}({capture!r}, {kwargs_repr})'
    210                        for kwargs_repr in kwargs_reprs])
    211 
    212         f = self.partial(capture, *args, **kwargs)
    213         self.assertIn(repr(f),
    214                       [f'{name}({capture!r}, {args_repr}, {kwargs_repr})'
    215                        for kwargs_repr in kwargs_reprs])
    216 
    217     def test_recursive_repr(self):
    218         if self.partial in (c_functools.partial, py_functools.partial):
    219             name = 'functools.partial'
    220         else:
    221             name = self.partial.__name__
    222 
    223         f = self.partial(capture)
    224         f.__setstate__((f, (), {}, {}))
    225         try:
    226             self.assertEqual(repr(f), '%s(...)' % (name,))
    227         finally:
    228             f.__setstate__((capture, (), {}, {}))
    229 
    230         f = self.partial(capture)
    231         f.__setstate__((capture, (f,), {}, {}))
    232         try:
    233             self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,))
    234         finally:
    235             f.__setstate__((capture, (), {}, {}))
    236 
    237         f = self.partial(capture)
    238         f.__setstate__((capture, (), {'a': f}, {}))
    239         try:
    240             self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,))
    241         finally:
    242             f.__setstate__((capture, (), {}, {}))
    243 
    244     def test_pickle(self):
    245         with self.AllowPickle():
    246             f = self.partial(signature, ['asdf'], bar=[True])
    247             f.attr = []
    248             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    249                 f_copy = pickle.loads(pickle.dumps(f, proto))
    250                 self.assertEqual(signature(f_copy), signature(f))
    251 
    252     def test_copy(self):
    253         f = self.partial(signature, ['asdf'], bar=[True])
    254         f.attr = []
    255         f_copy = copy.copy(f)
    256         self.assertEqual(signature(f_copy), signature(f))
    257         self.assertIs(f_copy.attr, f.attr)
    258         self.assertIs(f_copy.args, f.args)
    259         self.assertIs(f_copy.keywords, f.keywords)
    260 
    261     def test_deepcopy(self):
    262         f = self.partial(signature, ['asdf'], bar=[True])
    263         f.attr = []
    264         f_copy = copy.deepcopy(f)
    265         self.assertEqual(signature(f_copy), signature(f))
    266         self.assertIsNot(f_copy.attr, f.attr)
    267         self.assertIsNot(f_copy.args, f.args)
    268         self.assertIsNot(f_copy.args[0], f.args[0])
    269         self.assertIsNot(f_copy.keywords, f.keywords)
    270         self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar'])
    271 
    272     def test_setstate(self):
    273         f = self.partial(signature)
    274         f.__setstate__((capture, (1,), dict(a=10), dict(attr=[])))
    275 
    276         self.assertEqual(signature(f),
    277                          (capture, (1,), dict(a=10), dict(attr=[])))
    278         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
    279 
    280         f.__setstate__((capture, (1,), dict(a=10), None))
    281 
    282         self.assertEqual(signature(f), (capture, (1,), dict(a=10), {}))
    283         self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20}))
    284 
    285         f.__setstate__((capture, (1,), None, None))
    286         #self.assertEqual(signature(f), (capture, (1,), {}, {}))
    287         self.assertEqual(f(2, b=20), ((1, 2), {'b': 20}))
    288         self.assertEqual(f(2), ((1, 2), {}))
    289         self.assertEqual(f(), ((1,), {}))
    290 
    291         f.__setstate__((capture, (), {}, None))
    292         self.assertEqual(signature(f), (capture, (), {}, {}))
    293         self.assertEqual(f(2, b=20), ((2,), {'b': 20}))
    294         self.assertEqual(f(2), ((2,), {}))
    295         self.assertEqual(f(), ((), {}))
    296 
    297     def test_setstate_errors(self):
    298         f = self.partial(signature)
    299         self.assertRaises(TypeError, f.__setstate__, (capture, (), {}))
    300         self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None))
    301         self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None])
    302         self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None))
    303         self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None))
    304         self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None))
    305         self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None))
    306 
    307     def test_setstate_subclasses(self):
    308         f = self.partial(signature)
    309         f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None))
    310         s = signature(f)
    311         self.assertEqual(s, (capture, (1,), dict(a=10), {}))
    312         self.assertIs(type(s[1]), tuple)
    313         self.assertIs(type(s[2]), dict)
    314         r = f()
    315         self.assertEqual(r, ((1,), {'a': 10}))
    316         self.assertIs(type(r[0]), tuple)
    317         self.assertIs(type(r[1]), dict)
    318 
    319         f.__setstate__((capture, BadTuple((1,)), {}, None))
    320         s = signature(f)
    321         self.assertEqual(s, (capture, (1,), {}, {}))
    322         self.assertIs(type(s[1]), tuple)
    323         r = f(2)
    324         self.assertEqual(r, ((1, 2), {}))
    325         self.assertIs(type(r[0]), tuple)
    326 
    327     def test_recursive_pickle(self):
    328         with self.AllowPickle():
    329             f = self.partial(capture)
    330             f.__setstate__((f, (), {}, {}))
    331             try:
    332                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    333                     with self.assertRaises(RecursionError):
    334                         pickle.dumps(f, proto)
    335             finally:
    336                 f.__setstate__((capture, (), {}, {}))
    337 
    338             f = self.partial(capture)
    339             f.__setstate__((capture, (f,), {}, {}))
    340             try:
    341                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    342                     f_copy = pickle.loads(pickle.dumps(f, proto))
    343                     try:
    344                         self.assertIs(f_copy.args[0], f_copy)
    345                     finally:
    346                         f_copy.__setstate__((capture, (), {}, {}))
    347             finally:
    348                 f.__setstate__((capture, (), {}, {}))
    349 
    350             f = self.partial(capture)
    351             f.__setstate__((capture, (), {'a': f}, {}))
    352             try:
    353                 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
    354                     f_copy = pickle.loads(pickle.dumps(f, proto))
    355                     try:
    356                         self.assertIs(f_copy.keywords['a'], f_copy)
    357                     finally:
    358                         f_copy.__setstate__((capture, (), {}, {}))
    359             finally:
    360                 f.__setstate__((capture, (), {}, {}))
    361 
    362     # Issue 6083: Reference counting bug
    363     def test_setstate_refcount(self):
    364         class BadSequence:
    365             def __len__(self):
    366                 return 4
    367             def __getitem__(self, key):
    368                 if key == 0:
    369                     return max
    370                 elif key == 1:
    371                     return tuple(range(1000000))
    372                 elif key in (2, 3):
    373                     return {}
    374                 raise IndexError
    375 
    376         f = self.partial(object)
    377         self.assertRaises(TypeError, f.__setstate__, BadSequence())
    378 
    379 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    380 class TestPartialC(TestPartial, unittest.TestCase):
    381     if c_functools:
    382         partial = c_functools.partial
    383 
    384     class AllowPickle:
    385         def __enter__(self):
    386             return self
    387         def __exit__(self, type, value, tb):
    388             return False
    389 
    390     def test_attributes_unwritable(self):
    391         # attributes should not be writable
    392         p = self.partial(capture, 1, 2, a=10, b=20)
    393         self.assertRaises(AttributeError, setattr, p, 'func', map)
    394         self.assertRaises(AttributeError, setattr, p, 'args', (1, 2))
    395         self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2))
    396 
    397         p = self.partial(hex)
    398         try:
    399             del p.__dict__
    400         except TypeError:
    401             pass
    402         else:
    403             self.fail('partial object allowed __dict__ to be deleted')
    404 
    405     def test_manually_adding_non_string_keyword(self):
    406         p = self.partial(capture)
    407         # Adding a non-string/unicode keyword to partial kwargs
    408         p.keywords[1234] = 'value'
    409         r = repr(p)
    410         self.assertIn('1234', r)
    411         self.assertIn("'value'", r)
    412         with self.assertRaises(TypeError):
    413             p()
    414 
    415     def test_keystr_replaces_value(self):
    416         p = self.partial(capture)
    417 
    418         class MutatesYourDict(object):
    419             def __str__(self):
    420                 p.keywords[self] = ['sth2']
    421                 return 'astr'
    422 
    423         # Replacing the value during key formatting should keep the original
    424         # value alive (at least long enough).
    425         p.keywords[MutatesYourDict()] = ['sth']
    426         r = repr(p)
    427         self.assertIn('astr', r)
    428         self.assertIn("['sth']", r)
    429 
    430 
    431 class TestPartialPy(TestPartial, unittest.TestCase):
    432     partial = py_functools.partial
    433 
    434     class AllowPickle:
    435         def __init__(self):
    436             self._cm = replaced_module("functools", py_functools)
    437         def __enter__(self):
    438             return self._cm.__enter__()
    439         def __exit__(self, type, value, tb):
    440             return self._cm.__exit__(type, value, tb)
    441 
    442 if c_functools:
    443     class CPartialSubclass(c_functools.partial):
    444         pass
    445 
    446 class PyPartialSubclass(py_functools.partial):
    447     pass
    448 
    449 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    450 class TestPartialCSubclass(TestPartialC):
    451     if c_functools:
    452         partial = CPartialSubclass
    453 
    454     # partial subclasses are not optimized for nested calls
    455     test_nested_optimization = None
    456 
    457 class TestPartialPySubclass(TestPartialPy):
    458     partial = PyPartialSubclass
    459 
    460 class TestPartialMethod(unittest.TestCase):
    461 
    462     class A(object):
    463         nothing = functools.partialmethod(capture)
    464         positional = functools.partialmethod(capture, 1)
    465         keywords = functools.partialmethod(capture, a=2)
    466         both = functools.partialmethod(capture, 3, b=4)
    467 
    468         nested = functools.partialmethod(positional, 5)
    469 
    470         over_partial = functools.partialmethod(functools.partial(capture, c=6), 7)
    471 
    472         static = functools.partialmethod(staticmethod(capture), 8)
    473         cls = functools.partialmethod(classmethod(capture), d=9)
    474 
    475     a = A()
    476 
    477     def test_arg_combinations(self):
    478         self.assertEqual(self.a.nothing(), ((self.a,), {}))
    479         self.assertEqual(self.a.nothing(5), ((self.a, 5), {}))
    480         self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6}))
    481         self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6}))
    482 
    483         self.assertEqual(self.a.positional(), ((self.a, 1), {}))
    484         self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {}))
    485         self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6}))
    486         self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6}))
    487 
    488         self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2}))
    489         self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2}))
    490         self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6}))
    491         self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6}))
    492 
    493         self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4}))
    494         self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4}))
    495         self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6}))
    496         self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
    497 
    498         self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6}))
    499 
    500     def test_nested(self):
    501         self.assertEqual(self.a.nested(), ((self.a, 1, 5), {}))
    502         self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {}))
    503         self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7}))
    504         self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
    505 
    506         self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7}))
    507 
    508     def test_over_partial(self):
    509         self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6}))
    510         self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6}))
    511         self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8}))
    512         self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
    513 
    514         self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8}))
    515 
    516     def test_bound_method_introspection(self):
    517         obj = self.a
    518         self.assertIs(obj.both.__self__, obj)
    519         self.assertIs(obj.nested.__self__, obj)
    520         self.assertIs(obj.over_partial.__self__, obj)
    521         self.assertIs(obj.cls.__self__, self.A)
    522         self.assertIs(self.A.cls.__self__, self.A)
    523 
    524     def test_unbound_method_retrieval(self):
    525         obj = self.A
    526         self.assertFalse(hasattr(obj.both, "__self__"))
    527         self.assertFalse(hasattr(obj.nested, "__self__"))
    528         self.assertFalse(hasattr(obj.over_partial, "__self__"))
    529         self.assertFalse(hasattr(obj.static, "__self__"))
    530         self.assertFalse(hasattr(self.a.static, "__self__"))
    531 
    532     def test_descriptors(self):
    533         for obj in [self.A, self.a]:
    534             with self.subTest(obj=obj):
    535                 self.assertEqual(obj.static(), ((8,), {}))
    536                 self.assertEqual(obj.static(5), ((8, 5), {}))
    537                 self.assertEqual(obj.static(d=8), ((8,), {'d': 8}))
    538                 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8}))
    539 
    540                 self.assertEqual(obj.cls(), ((self.A,), {'d': 9}))
    541                 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9}))
    542                 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9}))
    543                 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9}))
    544 
    545     def test_overriding_keywords(self):
    546         self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3}))
    547         self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3}))
    548 
    549     def test_invalid_args(self):
    550         with self.assertRaises(TypeError):
    551             class B(object):
    552                 method = functools.partialmethod(None, 1)
    553 
    554     def test_repr(self):
    555         self.assertEqual(repr(vars(self.A)['both']),
    556                          'functools.partialmethod({}, 3, b=4)'.format(capture))
    557 
    558     def test_abstract(self):
    559         class Abstract(abc.ABCMeta):
    560 
    561             @abc.abstractmethod
    562             def add(self, x, y):
    563                 pass
    564 
    565             add5 = functools.partialmethod(add, 5)
    566 
    567         self.assertTrue(Abstract.add.__isabstractmethod__)
    568         self.assertTrue(Abstract.add5.__isabstractmethod__)
    569 
    570         for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]:
    571             self.assertFalse(getattr(func, '__isabstractmethod__', False))
    572 
    573 
    574 class TestUpdateWrapper(unittest.TestCase):
    575 
    576     def check_wrapper(self, wrapper, wrapped,
    577                       assigned=functools.WRAPPER_ASSIGNMENTS,
    578                       updated=functools.WRAPPER_UPDATES):
    579         # Check attributes were assigned
    580         for name in assigned:
    581             self.assertIs(getattr(wrapper, name), getattr(wrapped, name))
    582         # Check attributes were updated
    583         for name in updated:
    584             wrapper_attr = getattr(wrapper, name)
    585             wrapped_attr = getattr(wrapped, name)
    586             for key in wrapped_attr:
    587                 if name == "__dict__" and key == "__wrapped__":
    588                     # __wrapped__ is overwritten by the update code
    589                     continue
    590                 self.assertIs(wrapped_attr[key], wrapper_attr[key])
    591         # Check __wrapped__
    592         self.assertIs(wrapper.__wrapped__, wrapped)
    593 
    594 
    595     def _default_update(self):
    596         def f(a:'This is a new annotation'):
    597             """This is a test"""
    598             pass
    599         f.attr = 'This is also a test'
    600         f.__wrapped__ = "This is a bald faced lie"
    601         def wrapper(b:'This is the prior annotation'):
    602             pass
    603         functools.update_wrapper(wrapper, f)
    604         return wrapper, f
    605 
    606     def test_default_update(self):
    607         wrapper, f = self._default_update()
    608         self.check_wrapper(wrapper, f)
    609         self.assertIs(wrapper.__wrapped__, f)
    610         self.assertEqual(wrapper.__name__, 'f')
    611         self.assertEqual(wrapper.__qualname__, f.__qualname__)
    612         self.assertEqual(wrapper.attr, 'This is also a test')
    613         self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation')
    614         self.assertNotIn('b', wrapper.__annotations__)
    615 
    616     @unittest.skipIf(sys.flags.optimize >= 2,
    617                      "Docstrings are omitted with -O2 and above")
    618     def test_default_update_doc(self):
    619         wrapper, f = self._default_update()
    620         self.assertEqual(wrapper.__doc__, 'This is a test')
    621 
    622     def test_no_update(self):
    623         def f():
    624             """This is a test"""
    625             pass
    626         f.attr = 'This is also a test'
    627         def wrapper():
    628             pass
    629         functools.update_wrapper(wrapper, f, (), ())
    630         self.check_wrapper(wrapper, f, (), ())
    631         self.assertEqual(wrapper.__name__, 'wrapper')
    632         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    633         self.assertEqual(wrapper.__doc__, None)
    634         self.assertEqual(wrapper.__annotations__, {})
    635         self.assertFalse(hasattr(wrapper, 'attr'))
    636 
    637     def test_selective_update(self):
    638         def f():
    639             pass
    640         f.attr = 'This is a different test'
    641         f.dict_attr = dict(a=1, b=2, c=3)
    642         def wrapper():
    643             pass
    644         wrapper.dict_attr = {}
    645         assign = ('attr',)
    646         update = ('dict_attr',)
    647         functools.update_wrapper(wrapper, f, assign, update)
    648         self.check_wrapper(wrapper, f, assign, update)
    649         self.assertEqual(wrapper.__name__, 'wrapper')
    650         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    651         self.assertEqual(wrapper.__doc__, None)
    652         self.assertEqual(wrapper.attr, 'This is a different test')
    653         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    654 
    655     def test_missing_attributes(self):
    656         def f():
    657             pass
    658         def wrapper():
    659             pass
    660         wrapper.dict_attr = {}
    661         assign = ('attr',)
    662         update = ('dict_attr',)
    663         # Missing attributes on wrapped object are ignored
    664         functools.update_wrapper(wrapper, f, assign, update)
    665         self.assertNotIn('attr', wrapper.__dict__)
    666         self.assertEqual(wrapper.dict_attr, {})
    667         # Wrapper must have expected attributes for updating
    668         del wrapper.dict_attr
    669         with self.assertRaises(AttributeError):
    670             functools.update_wrapper(wrapper, f, assign, update)
    671         wrapper.dict_attr = 1
    672         with self.assertRaises(AttributeError):
    673             functools.update_wrapper(wrapper, f, assign, update)
    674 
    675     @support.requires_docstrings
    676     @unittest.skipIf(sys.flags.optimize >= 2,
    677                      "Docstrings are omitted with -O2 and above")
    678     def test_builtin_update(self):
    679         # Test for bug #1576241
    680         def wrapper():
    681             pass
    682         functools.update_wrapper(wrapper, max)
    683         self.assertEqual(wrapper.__name__, 'max')
    684         self.assertTrue(wrapper.__doc__.startswith('max('))
    685         self.assertEqual(wrapper.__annotations__, {})
    686 
    687 
    688 class TestWraps(TestUpdateWrapper):
    689 
    690     def _default_update(self):
    691         def f():
    692             """This is a test"""
    693             pass
    694         f.attr = 'This is also a test'
    695         f.__wrapped__ = "This is still a bald faced lie"
    696         @functools.wraps(f)
    697         def wrapper():
    698             pass
    699         return wrapper, f
    700 
    701     def test_default_update(self):
    702         wrapper, f = self._default_update()
    703         self.check_wrapper(wrapper, f)
    704         self.assertEqual(wrapper.__name__, 'f')
    705         self.assertEqual(wrapper.__qualname__, f.__qualname__)
    706         self.assertEqual(wrapper.attr, 'This is also a test')
    707 
    708     @unittest.skipIf(sys.flags.optimize >= 2,
    709                      "Docstrings are omitted with -O2 and above")
    710     def test_default_update_doc(self):
    711         wrapper, _ = self._default_update()
    712         self.assertEqual(wrapper.__doc__, 'This is a test')
    713 
    714     def test_no_update(self):
    715         def f():
    716             """This is a test"""
    717             pass
    718         f.attr = 'This is also a test'
    719         @functools.wraps(f, (), ())
    720         def wrapper():
    721             pass
    722         self.check_wrapper(wrapper, f, (), ())
    723         self.assertEqual(wrapper.__name__, 'wrapper')
    724         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    725         self.assertEqual(wrapper.__doc__, None)
    726         self.assertFalse(hasattr(wrapper, 'attr'))
    727 
    728     def test_selective_update(self):
    729         def f():
    730             pass
    731         f.attr = 'This is a different test'
    732         f.dict_attr = dict(a=1, b=2, c=3)
    733         def add_dict_attr(f):
    734             f.dict_attr = {}
    735             return f
    736         assign = ('attr',)
    737         update = ('dict_attr',)
    738         @functools.wraps(f, assign, update)
    739         @add_dict_attr
    740         def wrapper():
    741             pass
    742         self.check_wrapper(wrapper, f, assign, update)
    743         self.assertEqual(wrapper.__name__, 'wrapper')
    744         self.assertNotEqual(wrapper.__qualname__, f.__qualname__)
    745         self.assertEqual(wrapper.__doc__, None)
    746         self.assertEqual(wrapper.attr, 'This is a different test')
    747         self.assertEqual(wrapper.dict_attr, f.dict_attr)
    748 
    749 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    750 class TestReduce(unittest.TestCase):
    751     if c_functools:
    752         func = c_functools.reduce
    753 
    754     def test_reduce(self):
    755         class Squares:
    756             def __init__(self, max):
    757                 self.max = max
    758                 self.sofar = []
    759 
    760             def __len__(self):
    761                 return len(self.sofar)
    762 
    763             def __getitem__(self, i):
    764                 if not 0 <= i < self.max: raise IndexError
    765                 n = len(self.sofar)
    766                 while n <= i:
    767                     self.sofar.append(n*n)
    768                     n += 1
    769                 return self.sofar[i]
    770         def add(x, y):
    771             return x + y
    772         self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc')
    773         self.assertEqual(
    774             self.func(add, [['a', 'c'], [], ['d', 'w']], []),
    775             ['a','c','d','w']
    776         )
    777         self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040)
    778         self.assertEqual(
    779             self.func(lambda x, y: x*y, range(2,21), 1),
    780             2432902008176640000
    781         )
    782         self.assertEqual(self.func(add, Squares(10)), 285)
    783         self.assertEqual(self.func(add, Squares(10), 0), 285)
    784         self.assertEqual(self.func(add, Squares(0), 0), 0)
    785         self.assertRaises(TypeError, self.func)
    786         self.assertRaises(TypeError, self.func, 42, 42)
    787         self.assertRaises(TypeError, self.func, 42, 42, 42)
    788         self.assertEqual(self.func(42, "1"), "1") # func is never called with one item
    789         self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item
    790         self.assertRaises(TypeError, self.func, 42, (42, 42))
    791         self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value
    792         self.assertRaises(TypeError, self.func, add, "")
    793         self.assertRaises(TypeError, self.func, add, ())
    794         self.assertRaises(TypeError, self.func, add, object())
    795 
    796         class TestFailingIter:
    797             def __iter__(self):
    798                 raise RuntimeError
    799         self.assertRaises(RuntimeError, self.func, add, TestFailingIter())
    800 
    801         self.assertEqual(self.func(add, [], None), None)
    802         self.assertEqual(self.func(add, [], 42), 42)
    803 
    804         class BadSeq:
    805             def __getitem__(self, index):
    806                 raise ValueError
    807         self.assertRaises(ValueError, self.func, 42, BadSeq())
    808 
    809     # Test reduce()'s use of iterators.
    810     def test_iterator_usage(self):
    811         class SequenceClass:
    812             def __init__(self, n):
    813                 self.n = n
    814             def __getitem__(self, i):
    815                 if 0 <= i < self.n:
    816                     return i
    817                 else:
    818                     raise IndexError
    819 
    820         from operator import add
    821         self.assertEqual(self.func(add, SequenceClass(5)), 10)
    822         self.assertEqual(self.func(add, SequenceClass(5), 42), 52)
    823         self.assertRaises(TypeError, self.func, add, SequenceClass(0))
    824         self.assertEqual(self.func(add, SequenceClass(0), 42), 42)
    825         self.assertEqual(self.func(add, SequenceClass(1)), 0)
    826         self.assertEqual(self.func(add, SequenceClass(1), 42), 42)
    827 
    828         d = {"one": 1, "two": 2, "three": 3}
    829         self.assertEqual(self.func(add, d), "".join(d.keys()))
    830 
    831 
    832 class TestCmpToKey:
    833 
    834     def test_cmp_to_key(self):
    835         def cmp1(x, y):
    836             return (x > y) - (x < y)
    837         key = self.cmp_to_key(cmp1)
    838         self.assertEqual(key(3), key(3))
    839         self.assertGreater(key(3), key(1))
    840         self.assertGreaterEqual(key(3), key(3))
    841 
    842         def cmp2(x, y):
    843             return int(x) - int(y)
    844         key = self.cmp_to_key(cmp2)
    845         self.assertEqual(key(4.0), key('4'))
    846         self.assertLess(key(2), key('35'))
    847         self.assertLessEqual(key(2), key('35'))
    848         self.assertNotEqual(key(2), key('35'))
    849 
    850     def test_cmp_to_key_arguments(self):
    851         def cmp1(x, y):
    852             return (x > y) - (x < y)
    853         key = self.cmp_to_key(mycmp=cmp1)
    854         self.assertEqual(key(obj=3), key(obj=3))
    855         self.assertGreater(key(obj=3), key(obj=1))
    856         with self.assertRaises((TypeError, AttributeError)):
    857             key(3) > 1    # rhs is not a K object
    858         with self.assertRaises((TypeError, AttributeError)):
    859             1 < key(3)    # lhs is not a K object
    860         with self.assertRaises(TypeError):
    861             key = self.cmp_to_key()             # too few args
    862         with self.assertRaises(TypeError):
    863             key = self.cmp_to_key(cmp1, None)   # too many args
    864         key = self.cmp_to_key(cmp1)
    865         with self.assertRaises(TypeError):
    866             key()                                    # too few args
    867         with self.assertRaises(TypeError):
    868             key(None, None)                          # too many args
    869 
    870     def test_bad_cmp(self):
    871         def cmp1(x, y):
    872             raise ZeroDivisionError
    873         key = self.cmp_to_key(cmp1)
    874         with self.assertRaises(ZeroDivisionError):
    875             key(3) > key(1)
    876 
    877         class BadCmp:
    878             def __lt__(self, other):
    879                 raise ZeroDivisionError
    880         def cmp1(x, y):
    881             return BadCmp()
    882         with self.assertRaises(ZeroDivisionError):
    883             key(3) > key(1)
    884 
    885     def test_obj_field(self):
    886         def cmp1(x, y):
    887             return (x > y) - (x < y)
    888         key = self.cmp_to_key(mycmp=cmp1)
    889         self.assertEqual(key(50).obj, 50)
    890 
    891     def test_sort_int(self):
    892         def mycmp(x, y):
    893             return y - x
    894         self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)),
    895                          [4, 3, 2, 1, 0])
    896 
    897     def test_sort_int_str(self):
    898         def mycmp(x, y):
    899             x, y = int(x), int(y)
    900             return (x > y) - (x < y)
    901         values = [5, '3', 7, 2, '0', '1', 4, '10', 1]
    902         values = sorted(values, key=self.cmp_to_key(mycmp))
    903         self.assertEqual([int(value) for value in values],
    904                          [0, 1, 1, 2, 3, 4, 5, 7, 10])
    905 
    906     def test_hash(self):
    907         def mycmp(x, y):
    908             return y - x
    909         key = self.cmp_to_key(mycmp)
    910         k = key(10)
    911         self.assertRaises(TypeError, hash, k)
    912         self.assertNotIsInstance(k, collections.abc.Hashable)
    913 
    914 
    915 @unittest.skipUnless(c_functools, 'requires the C _functools module')
    916 class TestCmpToKeyC(TestCmpToKey, unittest.TestCase):
    917     if c_functools:
    918         cmp_to_key = c_functools.cmp_to_key
    919 
    920 
    921 class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase):
    922     cmp_to_key = staticmethod(py_functools.cmp_to_key)
    923 
    924 
    925 class TestTotalOrdering(unittest.TestCase):
    926 
    927     def test_total_ordering_lt(self):
    928         @functools.total_ordering
    929         class A:
    930             def __init__(self, value):
    931                 self.value = value
    932             def __lt__(self, other):
    933                 return self.value < other.value
    934             def __eq__(self, other):
    935                 return self.value == other.value
    936         self.assertTrue(A(1) < A(2))
    937         self.assertTrue(A(2) > A(1))
    938         self.assertTrue(A(1) <= A(2))
    939         self.assertTrue(A(2) >= A(1))
    940         self.assertTrue(A(2) <= A(2))
    941         self.assertTrue(A(2) >= A(2))
    942         self.assertFalse(A(1) > A(2))
    943 
    944     def test_total_ordering_le(self):
    945         @functools.total_ordering
    946         class A:
    947             def __init__(self, value):
    948                 self.value = value
    949             def __le__(self, other):
    950                 return self.value <= other.value
    951             def __eq__(self, other):
    952                 return self.value == other.value
    953         self.assertTrue(A(1) < A(2))
    954         self.assertTrue(A(2) > A(1))
    955         self.assertTrue(A(1) <= A(2))
    956         self.assertTrue(A(2) >= A(1))
    957         self.assertTrue(A(2) <= A(2))
    958         self.assertTrue(A(2) >= A(2))
    959         self.assertFalse(A(1) >= A(2))
    960 
    961     def test_total_ordering_gt(self):
    962         @functools.total_ordering
    963         class A:
    964             def __init__(self, value):
    965                 self.value = value
    966             def __gt__(self, other):
    967                 return self.value > other.value
    968             def __eq__(self, other):
    969                 return self.value == other.value
    970         self.assertTrue(A(1) < A(2))
    971         self.assertTrue(A(2) > A(1))
    972         self.assertTrue(A(1) <= A(2))
    973         self.assertTrue(A(2) >= A(1))
    974         self.assertTrue(A(2) <= A(2))
    975         self.assertTrue(A(2) >= A(2))
    976         self.assertFalse(A(2) < A(1))
    977 
    978     def test_total_ordering_ge(self):
    979         @functools.total_ordering
    980         class A:
    981             def __init__(self, value):
    982                 self.value = value
    983             def __ge__(self, other):
    984                 return self.value >= other.value
    985             def __eq__(self, other):
    986                 return self.value == other.value
    987         self.assertTrue(A(1) < A(2))
    988         self.assertTrue(A(2) > A(1))
    989         self.assertTrue(A(1) <= A(2))
    990         self.assertTrue(A(2) >= A(1))
    991         self.assertTrue(A(2) <= A(2))
    992         self.assertTrue(A(2) >= A(2))
    993         self.assertFalse(A(2) <= A(1))
    994 
    995     def test_total_ordering_no_overwrite(self):
    996         # new methods should not overwrite existing
    997         @functools.total_ordering
    998         class A(int):
    999             pass
   1000         self.assertTrue(A(1) < A(2))
   1001         self.assertTrue(A(2) > A(1))
   1002         self.assertTrue(A(1) <= A(2))
   1003         self.assertTrue(A(2) >= A(1))
   1004         self.assertTrue(A(2) <= A(2))
   1005         self.assertTrue(A(2) >= A(2))
   1006 
   1007     def test_no_operations_defined(self):
   1008         with self.assertRaises(ValueError):
   1009             @functools.total_ordering
   1010             class A:
   1011                 pass
   1012 
   1013     def test_type_error_when_not_implemented(self):
   1014         # bug 10042; ensure stack overflow does not occur
   1015         # when decorated types return NotImplemented
   1016         @functools.total_ordering
   1017         class ImplementsLessThan:
   1018             def __init__(self, value):
   1019                 self.value = value
   1020             def __eq__(self, other):
   1021                 if isinstance(other, ImplementsLessThan):
   1022                     return self.value == other.value
   1023                 return False
   1024             def __lt__(self, other):
   1025                 if isinstance(other, ImplementsLessThan):
   1026                     return self.value < other.value
   1027                 return NotImplemented
   1028 
   1029         @functools.total_ordering
   1030         class ImplementsGreaterThan:
   1031             def __init__(self, value):
   1032                 self.value = value
   1033             def __eq__(self, other):
   1034                 if isinstance(other, ImplementsGreaterThan):
   1035                     return self.value == other.value
   1036                 return False
   1037             def __gt__(self, other):
   1038                 if isinstance(other, ImplementsGreaterThan):
   1039                     return self.value > other.value
   1040                 return NotImplemented
   1041 
   1042         @functools.total_ordering
   1043         class ImplementsLessThanEqualTo:
   1044             def __init__(self, value):
   1045                 self.value = value
   1046             def __eq__(self, other):
   1047                 if isinstance(other, ImplementsLessThanEqualTo):
   1048                     return self.value == other.value
   1049                 return False
   1050             def __le__(self, other):
   1051                 if isinstance(other, ImplementsLessThanEqualTo):
   1052                     return self.value <= other.value
   1053                 return NotImplemented
   1054 
   1055         @functools.total_ordering
   1056         class ImplementsGreaterThanEqualTo:
   1057             def __init__(self, value):
   1058                 self.value = value
   1059             def __eq__(self, other):
   1060                 if isinstance(other, ImplementsGreaterThanEqualTo):
   1061                     return self.value == other.value
   1062                 return False
   1063             def __ge__(self, other):
   1064                 if isinstance(other, ImplementsGreaterThanEqualTo):
   1065                     return self.value >= other.value
   1066                 return NotImplemented
   1067 
   1068         @functools.total_ordering
   1069         class ComparatorNotImplemented:
   1070             def __init__(self, value):
   1071                 self.value = value
   1072             def __eq__(self, other):
   1073                 if isinstance(other, ComparatorNotImplemented):
   1074                     return self.value == other.value
   1075                 return False
   1076             def __lt__(self, other):
   1077                 return NotImplemented
   1078 
   1079         with self.subTest("LT < 1"), self.assertRaises(TypeError):
   1080             ImplementsLessThan(-1) < 1
   1081 
   1082         with self.subTest("LT < LE"), self.assertRaises(TypeError):
   1083             ImplementsLessThan(0) < ImplementsLessThanEqualTo(0)
   1084 
   1085         with self.subTest("LT < GT"), self.assertRaises(TypeError):
   1086             ImplementsLessThan(1) < ImplementsGreaterThan(1)
   1087 
   1088         with self.subTest("LE <= LT"), self.assertRaises(TypeError):
   1089             ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2)
   1090 
   1091         with self.subTest("LE <= GE"), self.assertRaises(TypeError):
   1092             ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3)
   1093 
   1094         with self.subTest("GT > GE"), self.assertRaises(TypeError):
   1095             ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4)
   1096 
   1097         with self.subTest("GT > LT"), self.assertRaises(TypeError):
   1098             ImplementsGreaterThan(5) > ImplementsLessThan(5)
   1099 
   1100         with self.subTest("GE >= GT"), self.assertRaises(TypeError):
   1101             ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6)
   1102 
   1103         with self.subTest("GE >= LE"), self.assertRaises(TypeError):
   1104             ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7)
   1105 
   1106         with self.subTest("GE when equal"):
   1107             a = ComparatorNotImplemented(8)
   1108             b = ComparatorNotImplemented(8)
   1109             self.assertEqual(a, b)
   1110             with self.assertRaises(TypeError):
   1111                 a >= b
   1112 
   1113         with self.subTest("LE when equal"):
   1114             a = ComparatorNotImplemented(9)
   1115             b = ComparatorNotImplemented(9)
   1116             self.assertEqual(a, b)
   1117             with self.assertRaises(TypeError):
   1118                 a <= b
   1119 
   1120     def test_pickle(self):
   1121         for proto in range(pickle.HIGHEST_PROTOCOL + 1):
   1122             for name in '__lt__', '__gt__', '__le__', '__ge__':
   1123                 with self.subTest(method=name, proto=proto):
   1124                     method = getattr(Orderable_LT, name)
   1125                     method_copy = pickle.loads(pickle.dumps(method, proto))
   1126                     self.assertIs(method_copy, method)
   1127 
   1128 @functools.total_ordering
   1129 class Orderable_LT:
   1130     def __init__(self, value):
   1131         self.value = value
   1132     def __lt__(self, other):
   1133         return self.value < other.value
   1134     def __eq__(self, other):
   1135         return self.value == other.value
   1136 
   1137 
   1138 class TestLRU:
   1139 
   1140     def test_lru(self):
   1141         def orig(x, y):
   1142             return 3 * x + y
   1143         f = self.module.lru_cache(maxsize=20)(orig)
   1144         hits, misses, maxsize, currsize = f.cache_info()
   1145         self.assertEqual(maxsize, 20)
   1146         self.assertEqual(currsize, 0)
   1147         self.assertEqual(hits, 0)
   1148         self.assertEqual(misses, 0)
   1149 
   1150         domain = range(5)
   1151         for i in range(1000):
   1152             x, y = choice(domain), choice(domain)
   1153             actual = f(x, y)
   1154             expected = orig(x, y)
   1155             self.assertEqual(actual, expected)
   1156         hits, misses, maxsize, currsize = f.cache_info()
   1157         self.assertTrue(hits > misses)
   1158         self.assertEqual(hits + misses, 1000)
   1159         self.assertEqual(currsize, 20)
   1160 
   1161         f.cache_clear()   # test clearing
   1162         hits, misses, maxsize, currsize = f.cache_info()
   1163         self.assertEqual(hits, 0)
   1164         self.assertEqual(misses, 0)
   1165         self.assertEqual(currsize, 0)
   1166         f(x, y)
   1167         hits, misses, maxsize, currsize = f.cache_info()
   1168         self.assertEqual(hits, 0)
   1169         self.assertEqual(misses, 1)
   1170         self.assertEqual(currsize, 1)
   1171 
   1172         # Test bypassing the cache
   1173         self.assertIs(f.__wrapped__, orig)
   1174         f.__wrapped__(x, y)
   1175         hits, misses, maxsize, currsize = f.cache_info()
   1176         self.assertEqual(hits, 0)
   1177         self.assertEqual(misses, 1)
   1178         self.assertEqual(currsize, 1)
   1179 
   1180         # test size zero (which means "never-cache")
   1181         @self.module.lru_cache(0)
   1182         def f():
   1183             nonlocal f_cnt
   1184             f_cnt += 1
   1185             return 20
   1186         self.assertEqual(f.cache_info().maxsize, 0)
   1187         f_cnt = 0
   1188         for i in range(5):
   1189             self.assertEqual(f(), 20)
   1190         self.assertEqual(f_cnt, 5)
   1191         hits, misses, maxsize, currsize = f.cache_info()
   1192         self.assertEqual(hits, 0)
   1193         self.assertEqual(misses, 5)
   1194         self.assertEqual(currsize, 0)
   1195 
   1196         # test size one
   1197         @self.module.lru_cache(1)
   1198         def f():
   1199             nonlocal f_cnt
   1200             f_cnt += 1
   1201             return 20
   1202         self.assertEqual(f.cache_info().maxsize, 1)
   1203         f_cnt = 0
   1204         for i in range(5):
   1205             self.assertEqual(f(), 20)
   1206         self.assertEqual(f_cnt, 1)
   1207         hits, misses, maxsize, currsize = f.cache_info()
   1208         self.assertEqual(hits, 4)
   1209         self.assertEqual(misses, 1)
   1210         self.assertEqual(currsize, 1)
   1211 
   1212         # test size two
   1213         @self.module.lru_cache(2)
   1214         def f(x):
   1215             nonlocal f_cnt
   1216             f_cnt += 1
   1217             return x*10
   1218         self.assertEqual(f.cache_info().maxsize, 2)
   1219         f_cnt = 0
   1220         for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7:
   1221             #    *  *              *                          *
   1222             self.assertEqual(f(x), x*10)
   1223         self.assertEqual(f_cnt, 4)
   1224         hits, misses, maxsize, currsize = f.cache_info()
   1225         self.assertEqual(hits, 12)
   1226         self.assertEqual(misses, 4)
   1227         self.assertEqual(currsize, 2)
   1228 
   1229     def test_lru_bug_35780(self):
   1230         # C version of the lru_cache was not checking to see if
   1231         # the user function call has already modified the cache
   1232         # (this arises in recursive calls and in multi-threading).
   1233         # This cause the cache to have orphan links not referenced
   1234         # by the cache dictionary.
   1235 
   1236         once = True                 # Modified by f(x) below
   1237 
   1238         @self.module.lru_cache(maxsize=10)
   1239         def f(x):
   1240             nonlocal once
   1241             rv = f'.{x}.'
   1242             if x == 20 and once:
   1243                 once = False
   1244                 rv = f(x)
   1245             return rv
   1246 
   1247         # Fill the cache
   1248         for x in range(15):
   1249             self.assertEqual(f(x), f'.{x}.')
   1250         self.assertEqual(f.cache_info().currsize, 10)
   1251 
   1252         # Make a recursive call and make sure the cache remains full
   1253         self.assertEqual(f(20), '.20.')
   1254         self.assertEqual(f.cache_info().currsize, 10)
   1255 
   1256     def test_lru_hash_only_once(self):
   1257         # To protect against weird reentrancy bugs and to improve
   1258         # efficiency when faced with slow __hash__ methods, the
   1259         # LRU cache guarantees that it will only call __hash__
   1260         # only once per use as an argument to the cached function.
   1261 
   1262         @self.module.lru_cache(maxsize=1)
   1263         def f(x, y):
   1264             return x * 3 + y
   1265 
   1266         # Simulate the integer 5
   1267         mock_int = unittest.mock.Mock()
   1268         mock_int.__mul__ = unittest.mock.Mock(return_value=15)
   1269         mock_int.__hash__ = unittest.mock.Mock(return_value=999)
   1270 
   1271         # Add to cache:  One use as an argument gives one call
   1272         self.assertEqual(f(mock_int, 1), 16)
   1273         self.assertEqual(mock_int.__hash__.call_count, 1)
   1274         self.assertEqual(f.cache_info(), (0, 1, 1, 1))
   1275 
   1276         # Cache hit: One use as an argument gives one additional call
   1277         self.assertEqual(f(mock_int, 1), 16)
   1278         self.assertEqual(mock_int.__hash__.call_count, 2)
   1279         self.assertEqual(f.cache_info(), (1, 1, 1, 1))
   1280 
   1281         # Cache eviction: No use as an argument gives no additional call
   1282         self.assertEqual(f(6, 2), 20)
   1283         self.assertEqual(mock_int.__hash__.call_count, 2)
   1284         self.assertEqual(f.cache_info(), (1, 2, 1, 1))
   1285 
   1286         # Cache miss: One use as an argument gives one additional call
   1287         self.assertEqual(f(mock_int, 1), 16)
   1288         self.assertEqual(mock_int.__hash__.call_count, 3)
   1289         self.assertEqual(f.cache_info(), (1, 3, 1, 1))
   1290 
   1291     def test_lru_reentrancy_with_len(self):
   1292         # Test to make sure the LRU cache code isn't thrown-off by
   1293         # caching the built-in len() function.  Since len() can be
   1294         # cached, we shouldn't use it inside the lru code itself.
   1295         old_len = builtins.len
   1296         try:
   1297             builtins.len = self.module.lru_cache(4)(len)
   1298             for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]:
   1299                 self.assertEqual(len('abcdefghijklmn'[:i]), i)
   1300         finally:
   1301             builtins.len = old_len
   1302 
   1303     def test_lru_star_arg_handling(self):
   1304         # Test regression that arose in ea064ff3c10f
   1305         @functools.lru_cache()
   1306         def f(*args):
   1307             return args
   1308 
   1309         self.assertEqual(f(1, 2), (1, 2))
   1310         self.assertEqual(f((1, 2)), ((1, 2),))
   1311 
   1312     def test_lru_type_error(self):
   1313         # Regression test for issue #28653.
   1314         # lru_cache was leaking when one of the arguments
   1315         # wasn't cacheable.
   1316 
   1317         @functools.lru_cache(maxsize=None)
   1318         def infinite_cache(o):
   1319             pass
   1320 
   1321         @functools.lru_cache(maxsize=10)
   1322         def limited_cache(o):
   1323             pass
   1324 
   1325         with self.assertRaises(TypeError):
   1326             infinite_cache([])
   1327 
   1328         with self.assertRaises(TypeError):
   1329             limited_cache([])
   1330 
   1331     def test_lru_with_maxsize_none(self):
   1332         @self.module.lru_cache(maxsize=None)
   1333         def fib(n):
   1334             if n < 2:
   1335                 return n
   1336             return fib(n-1) + fib(n-2)
   1337         self.assertEqual([fib(n) for n in range(16)],
   1338             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
   1339         self.assertEqual(fib.cache_info(),
   1340             self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
   1341         fib.cache_clear()
   1342         self.assertEqual(fib.cache_info(),
   1343             self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
   1344 
   1345     def test_lru_with_maxsize_negative(self):
   1346         @self.module.lru_cache(maxsize=-10)
   1347         def eq(n):
   1348             return n
   1349         for i in (0, 1):
   1350             self.assertEqual([eq(n) for n in range(150)], list(range(150)))
   1351         self.assertEqual(eq.cache_info(),
   1352             self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0))
   1353 
   1354     def test_lru_with_exceptions(self):
   1355         # Verify that user_function exceptions get passed through without
   1356         # creating a hard-to-read chained exception.
   1357         # http://bugs.python.org/issue13177
   1358         for maxsize in (None, 128):
   1359             @self.module.lru_cache(maxsize)
   1360             def func(i):
   1361                 return 'abc'[i]
   1362             self.assertEqual(func(0), 'a')
   1363             with self.assertRaises(IndexError) as cm:
   1364                 func(15)
   1365             self.assertIsNone(cm.exception.__context__)
   1366             # Verify that the previous exception did not result in a cached entry
   1367             with self.assertRaises(IndexError):
   1368                 func(15)
   1369 
   1370     def test_lru_with_types(self):
   1371         for maxsize in (None, 128):
   1372             @self.module.lru_cache(maxsize=maxsize, typed=True)
   1373             def square(x):
   1374                 return x * x
   1375             self.assertEqual(square(3), 9)
   1376             self.assertEqual(type(square(3)), type(9))
   1377             self.assertEqual(square(3.0), 9.0)
   1378             self.assertEqual(type(square(3.0)), type(9.0))
   1379             self.assertEqual(square(x=3), 9)
   1380             self.assertEqual(type(square(x=3)), type(9))
   1381             self.assertEqual(square(x=3.0), 9.0)
   1382             self.assertEqual(type(square(x=3.0)), type(9.0))
   1383             self.assertEqual(square.cache_info().hits, 4)
   1384             self.assertEqual(square.cache_info().misses, 4)
   1385 
   1386     def test_lru_with_keyword_args(self):
   1387         @self.module.lru_cache()
   1388         def fib(n):
   1389             if n < 2:
   1390                 return n
   1391             return fib(n=n-1) + fib(n=n-2)
   1392         self.assertEqual(
   1393             [fib(n=number) for number in range(16)],
   1394             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]
   1395         )
   1396         self.assertEqual(fib.cache_info(),
   1397             self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16))
   1398         fib.cache_clear()
   1399         self.assertEqual(fib.cache_info(),
   1400             self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0))
   1401 
   1402     def test_lru_with_keyword_args_maxsize_none(self):
   1403         @self.module.lru_cache(maxsize=None)
   1404         def fib(n):
   1405             if n < 2:
   1406                 return n
   1407             return fib(n=n-1) + fib(n=n-2)
   1408         self.assertEqual([fib(n=number) for number in range(16)],
   1409             [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610])
   1410         self.assertEqual(fib.cache_info(),
   1411             self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16))
   1412         fib.cache_clear()
   1413         self.assertEqual(fib.cache_info(),
   1414             self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0))
   1415 
   1416     def test_kwargs_order(self):
   1417         # PEP 468: Preserving Keyword Argument Order
   1418         @self.module.lru_cache(maxsize=10)
   1419         def f(**kwargs):
   1420             return list(kwargs.items())
   1421         self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)])
   1422         self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)])
   1423         self.assertEqual(f.cache_info(),
   1424             self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2))
   1425 
   1426     def test_lru_cache_decoration(self):
   1427         def f(zomg: 'zomg_annotation'):
   1428             """f doc string"""
   1429             return 42
   1430         g = self.module.lru_cache()(f)
   1431         for attr in self.module.WRAPPER_ASSIGNMENTS:
   1432             self.assertEqual(getattr(g, attr), getattr(f, attr))
   1433 
   1434     def test_lru_cache_threaded(self):
   1435         n, m = 5, 11
   1436         def orig(x, y):
   1437             return 3 * x + y
   1438         f = self.module.lru_cache(maxsize=n*m)(orig)
   1439         hits, misses, maxsize, currsize = f.cache_info()
   1440         self.assertEqual(currsize, 0)
   1441 
   1442         start = threading.Event()
   1443         def full(k):
   1444             start.wait(10)
   1445             for _ in range(m):
   1446                 self.assertEqual(f(k, 0), orig(k, 0))
   1447 
   1448         def clear():
   1449             start.wait(10)
   1450             for _ in range(2*m):
   1451                 f.cache_clear()
   1452 
   1453         orig_si = sys.getswitchinterval()
   1454         support.setswitchinterval(1e-6)
   1455         try:
   1456             # create n threads in order to fill cache
   1457             threads = [threading.Thread(target=full, args=[k])
   1458                        for k in range(n)]
   1459             with support.start_threads(threads):
   1460                 start.set()
   1461 
   1462             hits, misses, maxsize, currsize = f.cache_info()
   1463             if self.module is py_functools:
   1464                 # XXX: Why can be not equal?
   1465                 self.assertLessEqual(misses, n)
   1466                 self.assertLessEqual(hits, m*n - misses)
   1467             else:
   1468                 self.assertEqual(misses, n)
   1469                 self.assertEqual(hits, m*n - misses)
   1470             self.assertEqual(currsize, n)
   1471 
   1472             # create n threads in order to fill cache and 1 to clear it
   1473             threads = [threading.Thread(target=clear)]
   1474             threads += [threading.Thread(target=full, args=[k])
   1475                         for k in range(n)]
   1476             start.clear()
   1477             with support.start_threads(threads):
   1478                 start.set()
   1479         finally:
   1480             sys.setswitchinterval(orig_si)
   1481 
   1482     def test_lru_cache_threaded2(self):
   1483         # Simultaneous call with the same arguments
   1484         n, m = 5, 7
   1485         start = threading.Barrier(n+1)
   1486         pause = threading.Barrier(n+1)
   1487         stop = threading.Barrier(n+1)
   1488         @self.module.lru_cache(maxsize=m*n)
   1489         def f(x):
   1490             pause.wait(10)
   1491             return 3 * x
   1492         self.assertEqual(f.cache_info(), (0, 0, m*n, 0))
   1493         def test():
   1494             for i in range(m):
   1495                 start.wait(10)
   1496                 self.assertEqual(f(i), 3 * i)
   1497                 stop.wait(10)
   1498         threads = [threading.Thread(target=test) for k in range(n)]
   1499         with support.start_threads(threads):
   1500             for i in range(m):
   1501                 start.wait(10)
   1502                 stop.reset()
   1503                 pause.wait(10)
   1504                 start.reset()
   1505                 stop.wait(10)
   1506                 pause.reset()
   1507                 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1))
   1508 
   1509     def test_lru_cache_threaded3(self):
   1510         @self.module.lru_cache(maxsize=2)
   1511         def f(x):
   1512             time.sleep(.01)
   1513             return 3 * x
   1514         def test(i, x):
   1515             with self.subTest(thread=i):
   1516                 self.assertEqual(f(x), 3 * x, i)
   1517         threads = [threading.Thread(target=test, args=(i, v))
   1518                    for i, v in enumerate([1, 2, 2, 3, 2])]
   1519         with support.start_threads(threads):
   1520             pass
   1521 
   1522     def test_need_for_rlock(self):
   1523         # This will deadlock on an LRU cache that uses a regular lock
   1524 
   1525         @self.module.lru_cache(maxsize=10)
   1526         def test_func(x):
   1527             'Used to demonstrate a reentrant lru_cache call within a single thread'
   1528             return x
   1529 
   1530         class DoubleEq:
   1531             'Demonstrate a reentrant lru_cache call within a single thread'
   1532             def __init__(self, x):
   1533                 self.x = x
   1534             def __hash__(self):
   1535                 return self.x
   1536             def __eq__(self, other):
   1537                 if self.x == 2:
   1538                     test_func(DoubleEq(1))
   1539                 return self.x == other.x
   1540 
   1541         test_func(DoubleEq(1))                      # Load the cache
   1542         test_func(DoubleEq(2))                      # Load the cache
   1543         self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call
   1544                          DoubleEq(2))               # Verify the correct return value
   1545 
   1546     def test_early_detection_of_bad_call(self):
   1547         # Issue #22184
   1548         with self.assertRaises(TypeError):
   1549             @functools.lru_cache
   1550             def f():
   1551                 pass
   1552 
   1553     def test_lru_method(self):
   1554         class X(int):
   1555             f_cnt = 0
   1556             @self.module.lru_cache(2)
   1557             def f(self, x):
   1558                 self.f_cnt += 1
   1559                 return x*10+self
   1560         a = X(5)
   1561         b = X(5)
   1562         c = X(7)
   1563         self.assertEqual(X.f.cache_info(), (0, 0, 2, 0))
   1564 
   1565         for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3:
   1566             self.assertEqual(a.f(x), x*10 + 5)
   1567         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0))
   1568         self.assertEqual(X.f.cache_info(), (4, 6, 2, 2))
   1569 
   1570         for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2:
   1571             self.assertEqual(b.f(x), x*10 + 5)
   1572         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0))
   1573         self.assertEqual(X.f.cache_info(), (10, 10, 2, 2))
   1574 
   1575         for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1:
   1576             self.assertEqual(c.f(x), x*10 + 7)
   1577         self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5))
   1578         self.assertEqual(X.f.cache_info(), (15, 15, 2, 2))
   1579 
   1580         self.assertEqual(a.f.cache_info(), X.f.cache_info())
   1581         self.assertEqual(b.f.cache_info(), X.f.cache_info())
   1582         self.assertEqual(c.f.cache_info(), X.f.cache_info())
   1583 
   1584     def test_pickle(self):
   1585         cls = self.__class__
   1586         for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth:
   1587             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
   1588                 with self.subTest(proto=proto, func=f):
   1589                     f_copy = pickle.loads(pickle.dumps(f, proto))
   1590                     self.assertIs(f_copy, f)
   1591 
   1592     def test_copy(self):
   1593         cls = self.__class__
   1594         def orig(x, y):
   1595             return 3 * x + y
   1596         part = self.module.partial(orig, 2)
   1597         funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
   1598                  self.module.lru_cache(2)(part))
   1599         for f in funcs:
   1600             with self.subTest(func=f):
   1601                 f_copy = copy.copy(f)
   1602                 self.assertIs(f_copy, f)
   1603 
   1604     def test_deepcopy(self):
   1605         cls = self.__class__
   1606         def orig(x, y):
   1607             return 3 * x + y
   1608         part = self.module.partial(orig, 2)
   1609         funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth,
   1610                  self.module.lru_cache(2)(part))
   1611         for f in funcs:
   1612             with self.subTest(func=f):
   1613                 f_copy = copy.deepcopy(f)
   1614                 self.assertIs(f_copy, f)
   1615 
   1616 
   1617 @py_functools.lru_cache()
   1618 def py_cached_func(x, y):
   1619     return 3 * x + y
   1620 
   1621 @c_functools.lru_cache()
   1622 def c_cached_func(x, y):
   1623     return 3 * x + y
   1624 
   1625 
   1626 class TestLRUPy(TestLRU, unittest.TestCase):
   1627     module = py_functools
   1628     cached_func = py_cached_func,
   1629 
   1630     @module.lru_cache()
   1631     def cached_meth(self, x, y):
   1632         return 3 * x + y
   1633 
   1634     @staticmethod
   1635     @module.lru_cache()
   1636     def cached_staticmeth(x, y):
   1637         return 3 * x + y
   1638 
   1639 
   1640 class TestLRUC(TestLRU, unittest.TestCase):
   1641     module = c_functools
   1642     cached_func = c_cached_func,
   1643 
   1644     @module.lru_cache()
   1645     def cached_meth(self, x, y):
   1646         return 3 * x + y
   1647 
   1648     @staticmethod
   1649     @module.lru_cache()
   1650     def cached_staticmeth(x, y):
   1651         return 3 * x + y
   1652 
   1653 
   1654 class TestSingleDispatch(unittest.TestCase):
   1655     def test_simple_overloads(self):
   1656         @functools.singledispatch
   1657         def g(obj):
   1658             return "base"
   1659         def g_int(i):
   1660             return "integer"
   1661         g.register(int, g_int)
   1662         self.assertEqual(g("str"), "base")
   1663         self.assertEqual(g(1), "integer")
   1664         self.assertEqual(g([1,2,3]), "base")
   1665 
   1666     def test_mro(self):
   1667         @functools.singledispatch
   1668         def g(obj):
   1669             return "base"
   1670         class A:
   1671             pass
   1672         class C(A):
   1673             pass
   1674         class B(A):
   1675             pass
   1676         class D(C, B):
   1677             pass
   1678         def g_A(a):
   1679             return "A"
   1680         def g_B(b):
   1681             return "B"
   1682         g.register(A, g_A)
   1683         g.register(B, g_B)
   1684         self.assertEqual(g(A()), "A")
   1685         self.assertEqual(g(B()), "B")
   1686         self.assertEqual(g(C()), "A")
   1687         self.assertEqual(g(D()), "B")
   1688 
   1689     def test_register_decorator(self):
   1690         @functools.singledispatch
   1691         def g(obj):
   1692             return "base"
   1693         @g.register(int)
   1694         def g_int(i):
   1695             return "int %s" % (i,)
   1696         self.assertEqual(g(""), "base")
   1697         self.assertEqual(g(12), "int 12")
   1698         self.assertIs(g.dispatch(int), g_int)
   1699         self.assertIs(g.dispatch(object), g.dispatch(str))
   1700         # Note: in the assert above this is not g.
   1701         # @singledispatch returns the wrapper.
   1702 
   1703     def test_wrapping_attributes(self):
   1704         @functools.singledispatch
   1705         def g(obj):
   1706             "Simple test"
   1707             return "Test"
   1708         self.assertEqual(g.__name__, "g")
   1709         if sys.flags.optimize < 2:
   1710             self.assertEqual(g.__doc__, "Simple test")
   1711 
   1712     @unittest.skipUnless(decimal, 'requires _decimal')
   1713     @support.cpython_only
   1714     def test_c_classes(self):
   1715         @functools.singledispatch
   1716         def g(obj):
   1717             return "base"
   1718         @g.register(decimal.DecimalException)
   1719         def _(obj):
   1720             return obj.args
   1721         subn = decimal.Subnormal("Exponent < Emin")
   1722         rnd = decimal.Rounded("Number got rounded")
   1723         self.assertEqual(g(subn), ("Exponent < Emin",))
   1724         self.assertEqual(g(rnd), ("Number got rounded",))
   1725         @g.register(decimal.Subnormal)
   1726         def _(obj):
   1727             return "Too small to care."
   1728         self.assertEqual(g(subn), "Too small to care.")
   1729         self.assertEqual(g(rnd), ("Number got rounded",))
   1730 
   1731     def test_compose_mro(self):
   1732         # None of the examples in this test depend on haystack ordering.
   1733         c = collections.abc
   1734         mro = functools._compose_mro
   1735         bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set]
   1736         for haystack in permutations(bases):
   1737             m = mro(dict, haystack)
   1738             self.assertEqual(m, [dict, c.MutableMapping, c.Mapping,
   1739                                  c.Collection, c.Sized, c.Iterable,
   1740                                  c.Container, object])
   1741         bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict]
   1742         for haystack in permutations(bases):
   1743             m = mro(collections.ChainMap, haystack)
   1744             self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping,
   1745                                  c.Collection, c.Sized, c.Iterable,
   1746                                  c.Container, object])
   1747 
   1748         # If there's a generic function with implementations registered for
   1749         # both Sized and Container, passing a defaultdict to it results in an
   1750         # ambiguous dispatch which will cause a RuntimeError (see
   1751         # test_mro_conflicts).
   1752         bases = [c.Container, c.Sized, str]
   1753         for haystack in permutations(bases):
   1754             m = mro(collections.defaultdict, [c.Sized, c.Container, str])
   1755             self.assertEqual(m, [collections.defaultdict, dict, c.Sized,
   1756                                  c.Container, object])
   1757 
   1758         # MutableSequence below is registered directly on D. In other words, it
   1759         # precedes MutableMapping which means single dispatch will always
   1760         # choose MutableSequence here.
   1761         class D(collections.defaultdict):
   1762             pass
   1763         c.MutableSequence.register(D)
   1764         bases = [c.MutableSequence, c.MutableMapping]
   1765         for haystack in permutations(bases):
   1766             m = mro(D, bases)
   1767             self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible,
   1768                                  collections.defaultdict, dict, c.MutableMapping, c.Mapping,
   1769                                  c.Collection, c.Sized, c.Iterable, c.Container,
   1770                                  object])
   1771 
   1772         # Container and Callable are registered on different base classes and
   1773         # a generic function supporting both should always pick the Callable
   1774         # implementation if a C instance is passed.
   1775         class C(collections.defaultdict):
   1776             def __call__(self):
   1777                 pass
   1778         bases = [c.Sized, c.Callable, c.Container, c.Mapping]
   1779         for haystack in permutations(bases):
   1780             m = mro(C, haystack)
   1781             self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping,
   1782                                  c.Collection, c.Sized, c.Iterable,
   1783                                  c.Container, object])
   1784 
   1785     def test_register_abc(self):
   1786         c = collections.abc
   1787         d = {"a": "b"}
   1788         l = [1, 2, 3]
   1789         s = {object(), None}
   1790         f = frozenset(s)
   1791         t = (1, 2, 3)
   1792         @functools.singledispatch
   1793         def g(obj):
   1794             return "base"
   1795         self.assertEqual(g(d), "base")
   1796         self.assertEqual(g(l), "base")
   1797         self.assertEqual(g(s), "base")
   1798         self.assertEqual(g(f), "base")
   1799         self.assertEqual(g(t), "base")
   1800         g.register(c.Sized, lambda obj: "sized")
   1801         self.assertEqual(g(d), "sized")
   1802         self.assertEqual(g(l), "sized")
   1803         self.assertEqual(g(s), "sized")
   1804         self.assertEqual(g(f), "sized")
   1805         self.assertEqual(g(t), "sized")
   1806         g.register(c.MutableMapping, lambda obj: "mutablemapping")
   1807         self.assertEqual(g(d), "mutablemapping")
   1808         self.assertEqual(g(l), "sized")
   1809         self.assertEqual(g(s), "sized")
   1810         self.assertEqual(g(f), "sized")
   1811         self.assertEqual(g(t), "sized")
   1812         g.register(collections.ChainMap, lambda obj: "chainmap")
   1813         self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered
   1814         self.assertEqual(g(l), "sized")
   1815         self.assertEqual(g(s), "sized")
   1816         self.assertEqual(g(f), "sized")
   1817         self.assertEqual(g(t), "sized")
   1818         g.register(c.MutableSequence, lambda obj: "mutablesequence")
   1819         self.assertEqual(g(d), "mutablemapping")
   1820         self.assertEqual(g(l), "mutablesequence")
   1821         self.assertEqual(g(s), "sized")
   1822         self.assertEqual(g(f), "sized")
   1823         self.assertEqual(g(t), "sized")
   1824         g.register(c.MutableSet, lambda obj: "mutableset")
   1825         self.assertEqual(g(d), "mutablemapping")
   1826         self.assertEqual(g(l), "mutablesequence")
   1827         self.assertEqual(g(s), "mutableset")
   1828         self.assertEqual(g(f), "sized")
   1829         self.assertEqual(g(t), "sized")
   1830         g.register(c.Mapping, lambda obj: "mapping")
   1831         self.assertEqual(g(d), "mutablemapping")  # not specific enough
   1832         self.assertEqual(g(l), "mutablesequence")
   1833         self.assertEqual(g(s), "mutableset")
   1834         self.assertEqual(g(f), "sized")
   1835         self.assertEqual(g(t), "sized")
   1836         g.register(c.Sequence, lambda obj: "sequence")
   1837         self.assertEqual(g(d), "mutablemapping")
   1838         self.assertEqual(g(l), "mutablesequence")
   1839         self.assertEqual(g(s), "mutableset")
   1840         self.assertEqual(g(f), "sized")
   1841         self.assertEqual(g(t), "sequence")
   1842         g.register(c.Set, lambda obj: "set")
   1843         self.assertEqual(g(d), "mutablemapping")
   1844         self.assertEqual(g(l), "mutablesequence")
   1845         self.assertEqual(g(s), "mutableset")
   1846         self.assertEqual(g(f), "set")
   1847         self.assertEqual(g(t), "sequence")
   1848         g.register(dict, lambda obj: "dict")
   1849         self.assertEqual(g(d), "dict")
   1850         self.assertEqual(g(l), "mutablesequence")
   1851         self.assertEqual(g(s), "mutableset")
   1852         self.assertEqual(g(f), "set")
   1853         self.assertEqual(g(t), "sequence")
   1854         g.register(list, lambda obj: "list")
   1855         self.assertEqual(g(d), "dict")
   1856         self.assertEqual(g(l), "list")
   1857         self.assertEqual(g(s), "mutableset")
   1858         self.assertEqual(g(f), "set")
   1859         self.assertEqual(g(t), "sequence")
   1860         g.register(set, lambda obj: "concrete-set")
   1861         self.assertEqual(g(d), "dict")
   1862         self.assertEqual(g(l), "list")
   1863         self.assertEqual(g(s), "concrete-set")
   1864         self.assertEqual(g(f), "set")
   1865         self.assertEqual(g(t), "sequence")
   1866         g.register(frozenset, lambda obj: "frozen-set")
   1867         self.assertEqual(g(d), "dict")
   1868         self.assertEqual(g(l), "list")
   1869         self.assertEqual(g(s), "concrete-set")
   1870         self.assertEqual(g(f), "frozen-set")
   1871         self.assertEqual(g(t), "sequence")
   1872         g.register(tuple, lambda obj: "tuple")
   1873         self.assertEqual(g(d), "dict")
   1874         self.assertEqual(g(l), "list")
   1875         self.assertEqual(g(s), "concrete-set")
   1876         self.assertEqual(g(f), "frozen-set")
   1877         self.assertEqual(g(t), "tuple")
   1878 
   1879     def test_c3_abc(self):
   1880         c = collections.abc
   1881         mro = functools._c3_mro
   1882         class A(object):
   1883             pass
   1884         class B(A):
   1885             def __len__(self):
   1886                 return 0   # implies Sized
   1887         @c.Container.register
   1888         class C(object):
   1889             pass
   1890         class D(object):
   1891             pass   # unrelated
   1892         class X(D, C, B):
   1893             def __call__(self):
   1894                 pass   # implies Callable
   1895         expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object]
   1896         for abcs in permutations([c.Sized, c.Callable, c.Container]):
   1897             self.assertEqual(mro(X, abcs=abcs), expected)
   1898         # unrelated ABCs don't appear in the resulting MRO
   1899         many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable]
   1900         self.assertEqual(mro(X, abcs=many_abcs), expected)
   1901 
   1902     def test_false_meta(self):
   1903         # see issue23572
   1904         class MetaA(type):
   1905             def __len__(self):
   1906                 return 0
   1907         class A(metaclass=MetaA):
   1908             pass
   1909         class AA(A):
   1910             pass
   1911         @functools.singledispatch
   1912         def fun(a):
   1913             return 'base A'
   1914         @fun.register(A)
   1915         def _(a):
   1916             return 'fun A'
   1917         aa = AA()
   1918         self.assertEqual(fun(aa), 'fun A')
   1919 
   1920     def test_mro_conflicts(self):
   1921         c = collections.abc
   1922         @functools.singledispatch
   1923         def g(arg):
   1924             return "base"
   1925         class O(c.Sized):
   1926             def __len__(self):
   1927                 return 0
   1928         o = O()
   1929         self.assertEqual(g(o), "base")
   1930         g.register(c.Iterable, lambda arg: "iterable")
   1931         g.register(c.Container, lambda arg: "container")
   1932         g.register(c.Sized, lambda arg: "sized")
   1933         g.register(c.Set, lambda arg: "set")
   1934         self.assertEqual(g(o), "sized")
   1935         c.Iterable.register(O)
   1936         self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__
   1937         c.Container.register(O)
   1938         self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__
   1939         c.Set.register(O)
   1940         self.assertEqual(g(o), "set")     # because c.Set is a subclass of
   1941                                           # c.Sized and c.Container
   1942         class P:
   1943             pass
   1944         p = P()
   1945         self.assertEqual(g(p), "base")
   1946         c.Iterable.register(P)
   1947         self.assertEqual(g(p), "iterable")
   1948         c.Container.register(P)
   1949         with self.assertRaises(RuntimeError) as re_one:
   1950             g(p)
   1951         self.assertIn(
   1952             str(re_one.exception),
   1953             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
   1954               "or <class 'collections.abc.Iterable'>"),
   1955              ("Ambiguous dispatch: <class 'collections.abc.Iterable'> "
   1956               "or <class 'collections.abc.Container'>")),
   1957         )
   1958         class Q(c.Sized):
   1959             def __len__(self):
   1960                 return 0
   1961         q = Q()
   1962         self.assertEqual(g(q), "sized")
   1963         c.Iterable.register(Q)
   1964         self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__
   1965         c.Set.register(Q)
   1966         self.assertEqual(g(q), "set")     # because c.Set is a subclass of
   1967                                           # c.Sized and c.Iterable
   1968         @functools.singledispatch
   1969         def h(arg):
   1970             return "base"
   1971         @h.register(c.Sized)
   1972         def _(arg):
   1973             return "sized"
   1974         @h.register(c.Container)
   1975         def _(arg):
   1976             return "container"
   1977         # Even though Sized and Container are explicit bases of MutableMapping,
   1978         # this ABC is implicitly registered on defaultdict which makes all of
   1979         # MutableMapping's bases implicit as well from defaultdict's
   1980         # perspective.
   1981         with self.assertRaises(RuntimeError) as re_two:
   1982             h(collections.defaultdict(lambda: 0))
   1983         self.assertIn(
   1984             str(re_two.exception),
   1985             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
   1986               "or <class 'collections.abc.Sized'>"),
   1987              ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
   1988               "or <class 'collections.abc.Container'>")),
   1989         )
   1990         class R(collections.defaultdict):
   1991             pass
   1992         c.MutableSequence.register(R)
   1993         @functools.singledispatch
   1994         def i(arg):
   1995             return "base"
   1996         @i.register(c.MutableMapping)
   1997         def _(arg):
   1998             return "mapping"
   1999         @i.register(c.MutableSequence)
   2000         def _(arg):
   2001             return "sequence"
   2002         r = R()
   2003         self.assertEqual(i(r), "sequence")
   2004         class S:
   2005             pass
   2006         class T(S, c.Sized):
   2007             def __len__(self):
   2008                 return 0
   2009         t = T()
   2010         self.assertEqual(h(t), "sized")
   2011         c.Container.register(T)
   2012         self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO
   2013         class U:
   2014             def __len__(self):
   2015                 return 0
   2016         u = U()
   2017         self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred
   2018                                           # from the existence of __len__()
   2019         c.Container.register(U)
   2020         # There is no preference for registered versus inferred ABCs.
   2021         with self.assertRaises(RuntimeError) as re_three:
   2022             h(u)
   2023         self.assertIn(
   2024             str(re_three.exception),
   2025             (("Ambiguous dispatch: <class 'collections.abc.Container'> "
   2026               "or <class 'collections.abc.Sized'>"),
   2027              ("Ambiguous dispatch: <class 'collections.abc.Sized'> "
   2028               "or <class 'collections.abc.Container'>")),
   2029         )
   2030         class V(c.Sized, S):
   2031             def __len__(self):
   2032                 return 0
   2033         @functools.singledispatch
   2034         def j(arg):
   2035             return "base"
   2036         @j.register(S)
   2037         def _(arg):
   2038             return "s"
   2039         @j.register(c.Container)
   2040         def _(arg):
   2041             return "container"
   2042         v = V()
   2043         self.assertEqual(j(v), "s")
   2044         c.Container.register(V)
   2045         self.assertEqual(j(v), "container")   # because it ends up right after
   2046                                               # Sized in the MRO
   2047 
   2048     def test_cache_invalidation(self):
   2049         from collections import UserDict
   2050         import weakref
   2051 
   2052         class TracingDict(UserDict):
   2053             def __init__(self, *args, **kwargs):
   2054                 super(TracingDict, self).__init__(*args, **kwargs)
   2055                 self.set_ops = []
   2056                 self.get_ops = []
   2057             def __getitem__(self, key):
   2058                 result = self.data[key]
   2059                 self.get_ops.append(key)
   2060                 return result
   2061             def __setitem__(self, key, value):
   2062                 self.set_ops.append(key)
   2063                 self.data[key] = value
   2064             def clear(self):
   2065                 self.data.clear()
   2066 
   2067         td = TracingDict()
   2068         with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td):
   2069             c = collections.abc
   2070             @functools.singledispatch
   2071             def g(arg):
   2072                 return "base"
   2073             d = {}
   2074             l = []
   2075             self.assertEqual(len(td), 0)
   2076             self.assertEqual(g(d), "base")
   2077             self.assertEqual(len(td), 1)
   2078             self.assertEqual(td.get_ops, [])
   2079             self.assertEqual(td.set_ops, [dict])
   2080             self.assertEqual(td.data[dict], g.registry[object])
   2081             self.assertEqual(g(l), "base")
   2082             self.assertEqual(len(td), 2)
   2083             self.assertEqual(td.get_ops, [])
   2084             self.assertEqual(td.set_ops, [dict, list])
   2085             self.assertEqual(td.data[dict], g.registry[object])
   2086             self.assertEqual(td.data[list], g.registry[object])
   2087             self.assertEqual(td.data[dict], td.data[list])
   2088             self.assertEqual(g(l), "base")
   2089             self.assertEqual(g(d), "base")
   2090             self.assertEqual(td.get_ops, [list, dict])
   2091             self.assertEqual(td.set_ops, [dict, list])
   2092             g.register(list, lambda arg: "list")
   2093             self.assertEqual(td.get_ops, [list, dict])
   2094             self.assertEqual(len(td), 0)
   2095             self.assertEqual(g(d), "base")
   2096             self.assertEqual(len(td), 1)
   2097             self.assertEqual(td.get_ops, [list, dict])
   2098             self.assertEqual(td.set_ops, [dict, list, dict])
   2099             self.assertEqual(td.data[dict],
   2100                              functools._find_impl(dict, g.registry))
   2101             self.assertEqual(g(l), "list")
   2102             self.assertEqual(len(td), 2)
   2103             self.assertEqual(td.get_ops, [list, dict])
   2104             self.assertEqual(td.set_ops, [dict, list, dict, list])
   2105             self.assertEqual(td.data[list],
   2106                              functools._find_impl(list, g.registry))
   2107             class X:
   2108                 pass
   2109             c.MutableMapping.register(X)   # Will not invalidate the cache,
   2110                                            # not using ABCs yet.
   2111             self.assertEqual(g(d), "base")
   2112             self.assertEqual(g(l), "list")
   2113             self.assertEqual(td.get_ops, [list, dict, dict, list])
   2114             self.assertEqual(td.set_ops, [dict, list, dict, list])
   2115             g.register(c.Sized, lambda arg: "sized")
   2116             self.assertEqual(len(td), 0)
   2117             self.assertEqual(g(d), "sized")
   2118             self.assertEqual(len(td), 1)
   2119             self.assertEqual(td.get_ops, [list, dict, dict, list])
   2120             self.assertEqual(td.set_ops, [dict, list, dict, list, dict])
   2121             self.assertEqual(g(l), "list")
   2122             self.assertEqual(len(td), 2)
   2123             self.assertEqual(td.get_ops, [list, dict, dict, list])
   2124             self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
   2125             self.assertEqual(g(l), "list")
   2126             self.assertEqual(g(d), "sized")
   2127             self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict])
   2128             self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
   2129             g.dispatch(list)
   2130             g.dispatch(dict)
   2131             self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict,
   2132                                           list, dict])
   2133             self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list])
   2134             c.MutableSet.register(X)       # Will invalidate the cache.
   2135             self.assertEqual(len(td), 2)   # Stale cache.
   2136             self.assertEqual(g(l), "list")
   2137             self.assertEqual(len(td), 1)
   2138             g.register(c.MutableMapping, lambda arg: "mutablemapping")
   2139             self.assertEqual(len(td), 0)
   2140             self.assertEqual(g(d), "mutablemapping")
   2141             self.assertEqual(len(td), 1)
   2142             self.assertEqual(g(l), "list")
   2143             self.assertEqual(len(td), 2)
   2144             g.register(dict, lambda arg: "dict")
   2145             self.assertEqual(g(d), "dict")
   2146             self.assertEqual(g(l), "list")
   2147             g._clear_cache()
   2148             self.assertEqual(len(td), 0)
   2149 
   2150     def test_annotations(self):
   2151         @functools.singledispatch
   2152         def i(arg):
   2153             return "base"
   2154         @i.register
   2155         def _(arg: collections.abc.Mapping):
   2156             return "mapping"
   2157         @i.register
   2158         def _(arg: "collections.abc.Sequence"):
   2159             return "sequence"
   2160         self.assertEqual(i(None), "base")
   2161         self.assertEqual(i({"a": 1}), "mapping")
   2162         self.assertEqual(i([1, 2, 3]), "sequence")
   2163         self.assertEqual(i((1, 2, 3)), "sequence")
   2164         self.assertEqual(i("str"), "sequence")
   2165 
   2166         # Registering classes as callables doesn't work with annotations,
   2167         # you need to pass the type explicitly.
   2168         @i.register(str)
   2169         class _:
   2170             def __init__(self, arg):
   2171                 self.arg = arg
   2172 
   2173             def __eq__(self, other):
   2174                 return self.arg == other
   2175         self.assertEqual(i("str"), "str")
   2176 
   2177     def test_invalid_registrations(self):
   2178         msg_prefix = "Invalid first argument to `register()`: "
   2179         msg_suffix = (
   2180             ". Use either `@register(some_class)` or plain `@register` on an "
   2181             "annotated function."
   2182         )
   2183         @functools.singledispatch
   2184         def i(arg):
   2185             return "base"
   2186         with self.assertRaises(TypeError) as exc:
   2187             @i.register(42)
   2188             def _(arg):
   2189                 return "I annotated with a non-type"
   2190         self.assertTrue(str(exc.exception).startswith(msg_prefix + "42"))
   2191         self.assertTrue(str(exc.exception).endswith(msg_suffix))
   2192         with self.assertRaises(TypeError) as exc:
   2193             @i.register
   2194             def _(arg):
   2195                 return "I forgot to annotate"
   2196         self.assertTrue(str(exc.exception).startswith(msg_prefix +
   2197             "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
   2198         ))
   2199         self.assertTrue(str(exc.exception).endswith(msg_suffix))
   2200 
   2201         # FIXME: The following will only work after PEP 560 is implemented.
   2202         return
   2203 
   2204         with self.assertRaises(TypeError) as exc:
   2205             @i.register
   2206             def _(arg: typing.Iterable[str]):
   2207                 # At runtime, dispatching on generics is impossible.
   2208                 # When registering implementations with singledispatch, avoid
   2209                 # types from `typing`. Instead, annotate with regular types
   2210                 # or ABCs.
   2211                 return "I annotated with a generic collection"
   2212         self.assertTrue(str(exc.exception).startswith(msg_prefix +
   2213             "<function TestSingleDispatch.test_invalid_registrations.<locals>._"
   2214         ))
   2215         self.assertTrue(str(exc.exception).endswith(msg_suffix))
   2216 
   2217     def test_invalid_positional_argument(self):
   2218         @functools.singledispatch
   2219         def f(*args):
   2220             pass
   2221         msg = 'f requires at least 1 positional argument'
   2222         with self.assertRaisesRegex(TypeError, msg):
   2223             f()
   2224 
   2225 if __name__ == '__main__':
   2226     unittest.main()
   2227