1 import abc 2 import builtins 3 import collections 4 import copy 5 from itertools import permutations 6 import pickle 7 from random import choice 8 import sys 9 from test import support 10 import time 11 import unittest 12 from weakref import proxy 13 import contextlib 14 try: 15 import threading 16 except ImportError: 17 threading = None 18 19 import functools 20 21 py_functools = support.import_fresh_module('functools', blocked=['_functools']) 22 c_functools = support.import_fresh_module('functools', fresh=['_functools']) 23 24 decimal = support.import_fresh_module('decimal', fresh=['_decimal']) 25 26 @contextlib.contextmanager 27 def replaced_module(name, replacement): 28 original_module = sys.modules[name] 29 sys.modules[name] = replacement 30 try: 31 yield 32 finally: 33 sys.modules[name] = original_module 34 35 def capture(*args, **kw): 36 """capture all positional and keyword arguments""" 37 return args, kw 38 39 40 def signature(part): 41 """ return the signature of a partial object """ 42 return (part.func, part.args, part.keywords, part.__dict__) 43 44 class MyTuple(tuple): 45 pass 46 47 class BadTuple(tuple): 48 def __add__(self, other): 49 return list(self) + list(other) 50 51 class MyDict(dict): 52 pass 53 54 55 class TestPartial: 56 57 def test_basic_examples(self): 58 p = self.partial(capture, 1, 2, a=10, b=20) 59 self.assertTrue(callable(p)) 60 self.assertEqual(p(3, 4, b=30, c=40), 61 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 62 p = self.partial(map, lambda x: x*10) 63 self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) 64 65 def test_attributes(self): 66 p = self.partial(capture, 1, 2, a=10, b=20) 67 # attributes should be readable 68 self.assertEqual(p.func, capture) 69 self.assertEqual(p.args, (1, 2)) 70 self.assertEqual(p.keywords, dict(a=10, b=20)) 71 72 def test_argument_checking(self): 73 self.assertRaises(TypeError, self.partial) # need at least a func arg 74 try: 75 self.partial(2)() 76 except TypeError: 77 pass 78 else: 79 self.fail('First arg not checked for callability') 80 81 def test_protection_of_callers_dict_argument(self): 82 # a caller's dictionary should not be altered by partial 83 def func(a=10, b=20): 84 return a 85 d = {'a':3} 86 p = self.partial(func, a=5) 87 self.assertEqual(p(**d), 3) 88 self.assertEqual(d, {'a':3}) 89 p(b=7) 90 self.assertEqual(d, {'a':3}) 91 92 def test_kwargs_copy(self): 93 # Issue #29532: Altering a kwarg dictionary passed to a constructor 94 # should not affect a partial object after creation 95 d = {'a': 3} 96 p = self.partial(capture, **d) 97 self.assertEqual(p(), ((), {'a': 3})) 98 d['a'] = 5 99 self.assertEqual(p(), ((), {'a': 3})) 100 101 def test_arg_combinations(self): 102 # exercise special code paths for zero args in either partial 103 # object or the caller 104 p = self.partial(capture) 105 self.assertEqual(p(), ((), {})) 106 self.assertEqual(p(1,2), ((1,2), {})) 107 p = self.partial(capture, 1, 2) 108 self.assertEqual(p(), ((1,2), {})) 109 self.assertEqual(p(3,4), ((1,2,3,4), {})) 110 111 def test_kw_combinations(self): 112 # exercise special code paths for no keyword args in 113 # either the partial object or the caller 114 p = self.partial(capture) 115 self.assertEqual(p.keywords, {}) 116 self.assertEqual(p(), ((), {})) 117 self.assertEqual(p(a=1), ((), {'a':1})) 118 p = self.partial(capture, a=1) 119 self.assertEqual(p.keywords, {'a':1}) 120 self.assertEqual(p(), ((), {'a':1})) 121 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 122 # keyword args in the call override those in the partial object 123 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 124 125 def test_positional(self): 126 # make sure positional arguments are captured correctly 127 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 128 p = self.partial(capture, *args) 129 expected = args + ('x',) 130 got, empty = p('x') 131 self.assertTrue(expected == got and empty == {}) 132 133 def test_keyword(self): 134 # make sure keyword arguments are captured correctly 135 for a in ['a', 0, None, 3.5]: 136 p = self.partial(capture, a=a) 137 expected = {'a':a,'x':None} 138 empty, got = p(x=None) 139 self.assertTrue(expected == got and empty == ()) 140 141 def test_no_side_effects(self): 142 # make sure there are no side effects that affect subsequent calls 143 p = self.partial(capture, 0, a=1) 144 args1, kw1 = p(1, b=2) 145 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 146 args2, kw2 = p() 147 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 148 149 def test_error_propagation(self): 150 def f(x, y): 151 x / y 152 self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) 153 self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) 154 self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) 155 self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) 156 157 def test_weakref(self): 158 f = self.partial(int, base=16) 159 p = proxy(f) 160 self.assertEqual(f.func, p.func) 161 f = None 162 self.assertRaises(ReferenceError, getattr, p, 'func') 163 164 def test_with_bound_and_unbound_methods(self): 165 data = list(map(str, range(10))) 166 join = self.partial(str.join, '') 167 self.assertEqual(join(data), '0123456789') 168 join = self.partial(''.join) 169 self.assertEqual(join(data), '0123456789') 170 171 def test_nested_optimization(self): 172 partial = self.partial 173 inner = partial(signature, 'asdf') 174 nested = partial(inner, bar=True) 175 flat = partial(signature, 'asdf', bar=True) 176 self.assertEqual(signature(nested), signature(flat)) 177 178 def test_nested_partial_with_attribute(self): 179 # see issue 25137 180 partial = self.partial 181 182 def foo(bar): 183 return bar 184 185 p = partial(foo, 'first') 186 p2 = partial(p, 'second') 187 p2.new_attr = 'spam' 188 self.assertEqual(p2.new_attr, 'spam') 189 190 def test_repr(self): 191 args = (object(), object()) 192 args_repr = ', '.join(repr(a) for a in args) 193 kwargs = {'a': object(), 'b': object()} 194 kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), 195 'b={b!r}, a={a!r}'.format_map(kwargs)] 196 if self.partial in (c_functools.partial, py_functools.partial): 197 name = 'functools.partial' 198 else: 199 name = self.partial.__name__ 200 201 f = self.partial(capture) 202 self.assertEqual(f'{name}({capture!r})', repr(f)) 203 204 f = self.partial(capture, *args) 205 self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) 206 207 f = self.partial(capture, **kwargs) 208 self.assertIn(repr(f), 209 [f'{name}({capture!r}, {kwargs_repr})' 210 for kwargs_repr in kwargs_reprs]) 211 212 f = self.partial(capture, *args, **kwargs) 213 self.assertIn(repr(f), 214 [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' 215 for kwargs_repr in kwargs_reprs]) 216 217 def test_recursive_repr(self): 218 if self.partial in (c_functools.partial, py_functools.partial): 219 name = 'functools.partial' 220 else: 221 name = self.partial.__name__ 222 223 f = self.partial(capture) 224 f.__setstate__((f, (), {}, {})) 225 try: 226 self.assertEqual(repr(f), '%s(...)' % (name,)) 227 finally: 228 f.__setstate__((capture, (), {}, {})) 229 230 f = self.partial(capture) 231 f.__setstate__((capture, (f,), {}, {})) 232 try: 233 self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) 234 finally: 235 f.__setstate__((capture, (), {}, {})) 236 237 f = self.partial(capture) 238 f.__setstate__((capture, (), {'a': f}, {})) 239 try: 240 self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) 241 finally: 242 f.__setstate__((capture, (), {}, {})) 243 244 def test_pickle(self): 245 with self.AllowPickle(): 246 f = self.partial(signature, ['asdf'], bar=[True]) 247 f.attr = [] 248 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 249 f_copy = pickle.loads(pickle.dumps(f, proto)) 250 self.assertEqual(signature(f_copy), signature(f)) 251 252 def test_copy(self): 253 f = self.partial(signature, ['asdf'], bar=[True]) 254 f.attr = [] 255 f_copy = copy.copy(f) 256 self.assertEqual(signature(f_copy), signature(f)) 257 self.assertIs(f_copy.attr, f.attr) 258 self.assertIs(f_copy.args, f.args) 259 self.assertIs(f_copy.keywords, f.keywords) 260 261 def test_deepcopy(self): 262 f = self.partial(signature, ['asdf'], bar=[True]) 263 f.attr = [] 264 f_copy = copy.deepcopy(f) 265 self.assertEqual(signature(f_copy), signature(f)) 266 self.assertIsNot(f_copy.attr, f.attr) 267 self.assertIsNot(f_copy.args, f.args) 268 self.assertIsNot(f_copy.args[0], f.args[0]) 269 self.assertIsNot(f_copy.keywords, f.keywords) 270 self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) 271 272 def test_setstate(self): 273 f = self.partial(signature) 274 f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) 275 276 self.assertEqual(signature(f), 277 (capture, (1,), dict(a=10), dict(attr=[]))) 278 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 279 280 f.__setstate__((capture, (1,), dict(a=10), None)) 281 282 self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) 283 self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) 284 285 f.__setstate__((capture, (1,), None, None)) 286 #self.assertEqual(signature(f), (capture, (1,), {}, {})) 287 self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) 288 self.assertEqual(f(2), ((1, 2), {})) 289 self.assertEqual(f(), ((1,), {})) 290 291 f.__setstate__((capture, (), {}, None)) 292 self.assertEqual(signature(f), (capture, (), {}, {})) 293 self.assertEqual(f(2, b=20), ((2,), {'b': 20})) 294 self.assertEqual(f(2), ((2,), {})) 295 self.assertEqual(f(), ((), {})) 296 297 def test_setstate_errors(self): 298 f = self.partial(signature) 299 self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) 300 self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) 301 self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) 302 self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) 303 self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) 304 self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) 305 self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) 306 307 def test_setstate_subclasses(self): 308 f = self.partial(signature) 309 f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) 310 s = signature(f) 311 self.assertEqual(s, (capture, (1,), dict(a=10), {})) 312 self.assertIs(type(s[1]), tuple) 313 self.assertIs(type(s[2]), dict) 314 r = f() 315 self.assertEqual(r, ((1,), {'a': 10})) 316 self.assertIs(type(r[0]), tuple) 317 self.assertIs(type(r[1]), dict) 318 319 f.__setstate__((capture, BadTuple((1,)), {}, None)) 320 s = signature(f) 321 self.assertEqual(s, (capture, (1,), {}, {})) 322 self.assertIs(type(s[1]), tuple) 323 r = f(2) 324 self.assertEqual(r, ((1, 2), {})) 325 self.assertIs(type(r[0]), tuple) 326 327 def test_recursive_pickle(self): 328 with self.AllowPickle(): 329 f = self.partial(capture) 330 f.__setstate__((f, (), {}, {})) 331 try: 332 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 333 with self.assertRaises(RecursionError): 334 pickle.dumps(f, proto) 335 finally: 336 f.__setstate__((capture, (), {}, {})) 337 338 f = self.partial(capture) 339 f.__setstate__((capture, (f,), {}, {})) 340 try: 341 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 342 f_copy = pickle.loads(pickle.dumps(f, proto)) 343 try: 344 self.assertIs(f_copy.args[0], f_copy) 345 finally: 346 f_copy.__setstate__((capture, (), {}, {})) 347 finally: 348 f.__setstate__((capture, (), {}, {})) 349 350 f = self.partial(capture) 351 f.__setstate__((capture, (), {'a': f}, {})) 352 try: 353 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 354 f_copy = pickle.loads(pickle.dumps(f, proto)) 355 try: 356 self.assertIs(f_copy.keywords['a'], f_copy) 357 finally: 358 f_copy.__setstate__((capture, (), {}, {})) 359 finally: 360 f.__setstate__((capture, (), {}, {})) 361 362 # Issue 6083: Reference counting bug 363 def test_setstate_refcount(self): 364 class BadSequence: 365 def __len__(self): 366 return 4 367 def __getitem__(self, key): 368 if key == 0: 369 return max 370 elif key == 1: 371 return tuple(range(1000000)) 372 elif key in (2, 3): 373 return {} 374 raise IndexError 375 376 f = self.partial(object) 377 self.assertRaises(TypeError, f.__setstate__, BadSequence()) 378 379 @unittest.skipUnless(c_functools, 'requires the C _functools module') 380 class TestPartialC(TestPartial, unittest.TestCase): 381 if c_functools: 382 partial = c_functools.partial 383 384 class AllowPickle: 385 def __enter__(self): 386 return self 387 def __exit__(self, type, value, tb): 388 return False 389 390 def test_attributes_unwritable(self): 391 # attributes should not be writable 392 p = self.partial(capture, 1, 2, a=10, b=20) 393 self.assertRaises(AttributeError, setattr, p, 'func', map) 394 self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) 395 self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) 396 397 p = self.partial(hex) 398 try: 399 del p.__dict__ 400 except TypeError: 401 pass 402 else: 403 self.fail('partial object allowed __dict__ to be deleted') 404 405 class TestPartialPy(TestPartial, unittest.TestCase): 406 partial = py_functools.partial 407 408 class AllowPickle: 409 def __init__(self): 410 self._cm = replaced_module("functools", py_functools) 411 def __enter__(self): 412 return self._cm.__enter__() 413 def __exit__(self, type, value, tb): 414 return self._cm.__exit__(type, value, tb) 415 416 if c_functools: 417 class CPartialSubclass(c_functools.partial): 418 pass 419 420 class PyPartialSubclass(py_functools.partial): 421 pass 422 423 @unittest.skipUnless(c_functools, 'requires the C _functools module') 424 class TestPartialCSubclass(TestPartialC): 425 if c_functools: 426 partial = CPartialSubclass 427 428 # partial subclasses are not optimized for nested calls 429 test_nested_optimization = None 430 431 class TestPartialPySubclass(TestPartialPy): 432 partial = PyPartialSubclass 433 434 class TestPartialMethod(unittest.TestCase): 435 436 class A(object): 437 nothing = functools.partialmethod(capture) 438 positional = functools.partialmethod(capture, 1) 439 keywords = functools.partialmethod(capture, a=2) 440 both = functools.partialmethod(capture, 3, b=4) 441 442 nested = functools.partialmethod(positional, 5) 443 444 over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) 445 446 static = functools.partialmethod(staticmethod(capture), 8) 447 cls = functools.partialmethod(classmethod(capture), d=9) 448 449 a = A() 450 451 def test_arg_combinations(self): 452 self.assertEqual(self.a.nothing(), ((self.a,), {})) 453 self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) 454 self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) 455 self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) 456 457 self.assertEqual(self.a.positional(), ((self.a, 1), {})) 458 self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) 459 self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) 460 self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) 461 462 self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) 463 self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) 464 self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) 465 self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) 466 467 self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) 468 self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) 469 self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) 470 self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 471 472 self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) 473 474 def test_nested(self): 475 self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) 476 self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) 477 self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) 478 self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 479 480 self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) 481 482 def test_over_partial(self): 483 self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) 484 self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) 485 self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) 486 self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 487 488 self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) 489 490 def test_bound_method_introspection(self): 491 obj = self.a 492 self.assertIs(obj.both.__self__, obj) 493 self.assertIs(obj.nested.__self__, obj) 494 self.assertIs(obj.over_partial.__self__, obj) 495 self.assertIs(obj.cls.__self__, self.A) 496 self.assertIs(self.A.cls.__self__, self.A) 497 498 def test_unbound_method_retrieval(self): 499 obj = self.A 500 self.assertFalse(hasattr(obj.both, "__self__")) 501 self.assertFalse(hasattr(obj.nested, "__self__")) 502 self.assertFalse(hasattr(obj.over_partial, "__self__")) 503 self.assertFalse(hasattr(obj.static, "__self__")) 504 self.assertFalse(hasattr(self.a.static, "__self__")) 505 506 def test_descriptors(self): 507 for obj in [self.A, self.a]: 508 with self.subTest(obj=obj): 509 self.assertEqual(obj.static(), ((8,), {})) 510 self.assertEqual(obj.static(5), ((8, 5), {})) 511 self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) 512 self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) 513 514 self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) 515 self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) 516 self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) 517 self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) 518 519 def test_overriding_keywords(self): 520 self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) 521 self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) 522 523 def test_invalid_args(self): 524 with self.assertRaises(TypeError): 525 class B(object): 526 method = functools.partialmethod(None, 1) 527 528 def test_repr(self): 529 self.assertEqual(repr(vars(self.A)['both']), 530 'functools.partialmethod({}, 3, b=4)'.format(capture)) 531 532 def test_abstract(self): 533 class Abstract(abc.ABCMeta): 534 535 @abc.abstractmethod 536 def add(self, x, y): 537 pass 538 539 add5 = functools.partialmethod(add, 5) 540 541 self.assertTrue(Abstract.add.__isabstractmethod__) 542 self.assertTrue(Abstract.add5.__isabstractmethod__) 543 544 for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: 545 self.assertFalse(getattr(func, '__isabstractmethod__', False)) 546 547 548 class TestUpdateWrapper(unittest.TestCase): 549 550 def check_wrapper(self, wrapper, wrapped, 551 assigned=functools.WRAPPER_ASSIGNMENTS, 552 updated=functools.WRAPPER_UPDATES): 553 # Check attributes were assigned 554 for name in assigned: 555 self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) 556 # Check attributes were updated 557 for name in updated: 558 wrapper_attr = getattr(wrapper, name) 559 wrapped_attr = getattr(wrapped, name) 560 for key in wrapped_attr: 561 if name == "__dict__" and key == "__wrapped__": 562 # __wrapped__ is overwritten by the update code 563 continue 564 self.assertIs(wrapped_attr[key], wrapper_attr[key]) 565 # Check __wrapped__ 566 self.assertIs(wrapper.__wrapped__, wrapped) 567 568 569 def _default_update(self): 570 def f(a:'This is a new annotation'): 571 """This is a test""" 572 pass 573 f.attr = 'This is also a test' 574 f.__wrapped__ = "This is a bald faced lie" 575 def wrapper(b:'This is the prior annotation'): 576 pass 577 functools.update_wrapper(wrapper, f) 578 return wrapper, f 579 580 def test_default_update(self): 581 wrapper, f = self._default_update() 582 self.check_wrapper(wrapper, f) 583 self.assertIs(wrapper.__wrapped__, f) 584 self.assertEqual(wrapper.__name__, 'f') 585 self.assertEqual(wrapper.__qualname__, f.__qualname__) 586 self.assertEqual(wrapper.attr, 'This is also a test') 587 self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') 588 self.assertNotIn('b', wrapper.__annotations__) 589 590 @unittest.skipIf(sys.flags.optimize >= 2, 591 "Docstrings are omitted with -O2 and above") 592 def test_default_update_doc(self): 593 wrapper, f = self._default_update() 594 self.assertEqual(wrapper.__doc__, 'This is a test') 595 596 def test_no_update(self): 597 def f(): 598 """This is a test""" 599 pass 600 f.attr = 'This is also a test' 601 def wrapper(): 602 pass 603 functools.update_wrapper(wrapper, f, (), ()) 604 self.check_wrapper(wrapper, f, (), ()) 605 self.assertEqual(wrapper.__name__, 'wrapper') 606 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 607 self.assertEqual(wrapper.__doc__, None) 608 self.assertEqual(wrapper.__annotations__, {}) 609 self.assertFalse(hasattr(wrapper, 'attr')) 610 611 def test_selective_update(self): 612 def f(): 613 pass 614 f.attr = 'This is a different test' 615 f.dict_attr = dict(a=1, b=2, c=3) 616 def wrapper(): 617 pass 618 wrapper.dict_attr = {} 619 assign = ('attr',) 620 update = ('dict_attr',) 621 functools.update_wrapper(wrapper, f, assign, update) 622 self.check_wrapper(wrapper, f, assign, update) 623 self.assertEqual(wrapper.__name__, 'wrapper') 624 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 625 self.assertEqual(wrapper.__doc__, None) 626 self.assertEqual(wrapper.attr, 'This is a different test') 627 self.assertEqual(wrapper.dict_attr, f.dict_attr) 628 629 def test_missing_attributes(self): 630 def f(): 631 pass 632 def wrapper(): 633 pass 634 wrapper.dict_attr = {} 635 assign = ('attr',) 636 update = ('dict_attr',) 637 # Missing attributes on wrapped object are ignored 638 functools.update_wrapper(wrapper, f, assign, update) 639 self.assertNotIn('attr', wrapper.__dict__) 640 self.assertEqual(wrapper.dict_attr, {}) 641 # Wrapper must have expected attributes for updating 642 del wrapper.dict_attr 643 with self.assertRaises(AttributeError): 644 functools.update_wrapper(wrapper, f, assign, update) 645 wrapper.dict_attr = 1 646 with self.assertRaises(AttributeError): 647 functools.update_wrapper(wrapper, f, assign, update) 648 649 @support.requires_docstrings 650 @unittest.skipIf(sys.flags.optimize >= 2, 651 "Docstrings are omitted with -O2 and above") 652 def test_builtin_update(self): 653 # Test for bug #1576241 654 def wrapper(): 655 pass 656 functools.update_wrapper(wrapper, max) 657 self.assertEqual(wrapper.__name__, 'max') 658 self.assertTrue(wrapper.__doc__.startswith('max(')) 659 self.assertEqual(wrapper.__annotations__, {}) 660 661 662 class TestWraps(TestUpdateWrapper): 663 664 def _default_update(self): 665 def f(): 666 """This is a test""" 667 pass 668 f.attr = 'This is also a test' 669 f.__wrapped__ = "This is still a bald faced lie" 670 @functools.wraps(f) 671 def wrapper(): 672 pass 673 return wrapper, f 674 675 def test_default_update(self): 676 wrapper, f = self._default_update() 677 self.check_wrapper(wrapper, f) 678 self.assertEqual(wrapper.__name__, 'f') 679 self.assertEqual(wrapper.__qualname__, f.__qualname__) 680 self.assertEqual(wrapper.attr, 'This is also a test') 681 682 @unittest.skipIf(sys.flags.optimize >= 2, 683 "Docstrings are omitted with -O2 and above") 684 def test_default_update_doc(self): 685 wrapper, _ = self._default_update() 686 self.assertEqual(wrapper.__doc__, 'This is a test') 687 688 def test_no_update(self): 689 def f(): 690 """This is a test""" 691 pass 692 f.attr = 'This is also a test' 693 @functools.wraps(f, (), ()) 694 def wrapper(): 695 pass 696 self.check_wrapper(wrapper, f, (), ()) 697 self.assertEqual(wrapper.__name__, 'wrapper') 698 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 699 self.assertEqual(wrapper.__doc__, None) 700 self.assertFalse(hasattr(wrapper, 'attr')) 701 702 def test_selective_update(self): 703 def f(): 704 pass 705 f.attr = 'This is a different test' 706 f.dict_attr = dict(a=1, b=2, c=3) 707 def add_dict_attr(f): 708 f.dict_attr = {} 709 return f 710 assign = ('attr',) 711 update = ('dict_attr',) 712 @functools.wraps(f, assign, update) 713 @add_dict_attr 714 def wrapper(): 715 pass 716 self.check_wrapper(wrapper, f, assign, update) 717 self.assertEqual(wrapper.__name__, 'wrapper') 718 self.assertNotEqual(wrapper.__qualname__, f.__qualname__) 719 self.assertEqual(wrapper.__doc__, None) 720 self.assertEqual(wrapper.attr, 'This is a different test') 721 self.assertEqual(wrapper.dict_attr, f.dict_attr) 722 723 @unittest.skipUnless(c_functools, 'requires the C _functools module') 724 class TestReduce(unittest.TestCase): 725 if c_functools: 726 func = c_functools.reduce 727 728 def test_reduce(self): 729 class Squares: 730 def __init__(self, max): 731 self.max = max 732 self.sofar = [] 733 734 def __len__(self): 735 return len(self.sofar) 736 737 def __getitem__(self, i): 738 if not 0 <= i < self.max: raise IndexError 739 n = len(self.sofar) 740 while n <= i: 741 self.sofar.append(n*n) 742 n += 1 743 return self.sofar[i] 744 def add(x, y): 745 return x + y 746 self.assertEqual(self.func(add, ['a', 'b', 'c'], ''), 'abc') 747 self.assertEqual( 748 self.func(add, [['a', 'c'], [], ['d', 'w']], []), 749 ['a','c','d','w'] 750 ) 751 self.assertEqual(self.func(lambda x, y: x*y, range(2,8), 1), 5040) 752 self.assertEqual( 753 self.func(lambda x, y: x*y, range(2,21), 1), 754 2432902008176640000 755 ) 756 self.assertEqual(self.func(add, Squares(10)), 285) 757 self.assertEqual(self.func(add, Squares(10), 0), 285) 758 self.assertEqual(self.func(add, Squares(0), 0), 0) 759 self.assertRaises(TypeError, self.func) 760 self.assertRaises(TypeError, self.func, 42, 42) 761 self.assertRaises(TypeError, self.func, 42, 42, 42) 762 self.assertEqual(self.func(42, "1"), "1") # func is never called with one item 763 self.assertEqual(self.func(42, "", "1"), "1") # func is never called with one item 764 self.assertRaises(TypeError, self.func, 42, (42, 42)) 765 self.assertRaises(TypeError, self.func, add, []) # arg 2 must not be empty sequence with no initial value 766 self.assertRaises(TypeError, self.func, add, "") 767 self.assertRaises(TypeError, self.func, add, ()) 768 self.assertRaises(TypeError, self.func, add, object()) 769 770 class TestFailingIter: 771 def __iter__(self): 772 raise RuntimeError 773 self.assertRaises(RuntimeError, self.func, add, TestFailingIter()) 774 775 self.assertEqual(self.func(add, [], None), None) 776 self.assertEqual(self.func(add, [], 42), 42) 777 778 class BadSeq: 779 def __getitem__(self, index): 780 raise ValueError 781 self.assertRaises(ValueError, self.func, 42, BadSeq()) 782 783 # Test reduce()'s use of iterators. 784 def test_iterator_usage(self): 785 class SequenceClass: 786 def __init__(self, n): 787 self.n = n 788 def __getitem__(self, i): 789 if 0 <= i < self.n: 790 return i 791 else: 792 raise IndexError 793 794 from operator import add 795 self.assertEqual(self.func(add, SequenceClass(5)), 10) 796 self.assertEqual(self.func(add, SequenceClass(5), 42), 52) 797 self.assertRaises(TypeError, self.func, add, SequenceClass(0)) 798 self.assertEqual(self.func(add, SequenceClass(0), 42), 42) 799 self.assertEqual(self.func(add, SequenceClass(1)), 0) 800 self.assertEqual(self.func(add, SequenceClass(1), 42), 42) 801 802 d = {"one": 1, "two": 2, "three": 3} 803 self.assertEqual(self.func(add, d), "".join(d.keys())) 804 805 806 class TestCmpToKey: 807 808 def test_cmp_to_key(self): 809 def cmp1(x, y): 810 return (x > y) - (x < y) 811 key = self.cmp_to_key(cmp1) 812 self.assertEqual(key(3), key(3)) 813 self.assertGreater(key(3), key(1)) 814 self.assertGreaterEqual(key(3), key(3)) 815 816 def cmp2(x, y): 817 return int(x) - int(y) 818 key = self.cmp_to_key(cmp2) 819 self.assertEqual(key(4.0), key('4')) 820 self.assertLess(key(2), key('35')) 821 self.assertLessEqual(key(2), key('35')) 822 self.assertNotEqual(key(2), key('35')) 823 824 def test_cmp_to_key_arguments(self): 825 def cmp1(x, y): 826 return (x > y) - (x < y) 827 key = self.cmp_to_key(mycmp=cmp1) 828 self.assertEqual(key(obj=3), key(obj=3)) 829 self.assertGreater(key(obj=3), key(obj=1)) 830 with self.assertRaises((TypeError, AttributeError)): 831 key(3) > 1 # rhs is not a K object 832 with self.assertRaises((TypeError, AttributeError)): 833 1 < key(3) # lhs is not a K object 834 with self.assertRaises(TypeError): 835 key = self.cmp_to_key() # too few args 836 with self.assertRaises(TypeError): 837 key = self.cmp_to_key(cmp1, None) # too many args 838 key = self.cmp_to_key(cmp1) 839 with self.assertRaises(TypeError): 840 key() # too few args 841 with self.assertRaises(TypeError): 842 key(None, None) # too many args 843 844 def test_bad_cmp(self): 845 def cmp1(x, y): 846 raise ZeroDivisionError 847 key = self.cmp_to_key(cmp1) 848 with self.assertRaises(ZeroDivisionError): 849 key(3) > key(1) 850 851 class BadCmp: 852 def __lt__(self, other): 853 raise ZeroDivisionError 854 def cmp1(x, y): 855 return BadCmp() 856 with self.assertRaises(ZeroDivisionError): 857 key(3) > key(1) 858 859 def test_obj_field(self): 860 def cmp1(x, y): 861 return (x > y) - (x < y) 862 key = self.cmp_to_key(mycmp=cmp1) 863 self.assertEqual(key(50).obj, 50) 864 865 def test_sort_int(self): 866 def mycmp(x, y): 867 return y - x 868 self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), 869 [4, 3, 2, 1, 0]) 870 871 def test_sort_int_str(self): 872 def mycmp(x, y): 873 x, y = int(x), int(y) 874 return (x > y) - (x < y) 875 values = [5, '3', 7, 2, '0', '1', 4, '10', 1] 876 values = sorted(values, key=self.cmp_to_key(mycmp)) 877 self.assertEqual([int(value) for value in values], 878 [0, 1, 1, 2, 3, 4, 5, 7, 10]) 879 880 def test_hash(self): 881 def mycmp(x, y): 882 return y - x 883 key = self.cmp_to_key(mycmp) 884 k = key(10) 885 self.assertRaises(TypeError, hash, k) 886 self.assertNotIsInstance(k, collections.Hashable) 887 888 889 @unittest.skipUnless(c_functools, 'requires the C _functools module') 890 class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): 891 if c_functools: 892 cmp_to_key = c_functools.cmp_to_key 893 894 895 class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): 896 cmp_to_key = staticmethod(py_functools.cmp_to_key) 897 898 899 class TestTotalOrdering(unittest.TestCase): 900 901 def test_total_ordering_lt(self): 902 @functools.total_ordering 903 class A: 904 def __init__(self, value): 905 self.value = value 906 def __lt__(self, other): 907 return self.value < other.value 908 def __eq__(self, other): 909 return self.value == other.value 910 self.assertTrue(A(1) < A(2)) 911 self.assertTrue(A(2) > A(1)) 912 self.assertTrue(A(1) <= A(2)) 913 self.assertTrue(A(2) >= A(1)) 914 self.assertTrue(A(2) <= A(2)) 915 self.assertTrue(A(2) >= A(2)) 916 self.assertFalse(A(1) > A(2)) 917 918 def test_total_ordering_le(self): 919 @functools.total_ordering 920 class A: 921 def __init__(self, value): 922 self.value = value 923 def __le__(self, other): 924 return self.value <= other.value 925 def __eq__(self, other): 926 return self.value == other.value 927 self.assertTrue(A(1) < A(2)) 928 self.assertTrue(A(2) > A(1)) 929 self.assertTrue(A(1) <= A(2)) 930 self.assertTrue(A(2) >= A(1)) 931 self.assertTrue(A(2) <= A(2)) 932 self.assertTrue(A(2) >= A(2)) 933 self.assertFalse(A(1) >= A(2)) 934 935 def test_total_ordering_gt(self): 936 @functools.total_ordering 937 class A: 938 def __init__(self, value): 939 self.value = value 940 def __gt__(self, other): 941 return self.value > other.value 942 def __eq__(self, other): 943 return self.value == other.value 944 self.assertTrue(A(1) < A(2)) 945 self.assertTrue(A(2) > A(1)) 946 self.assertTrue(A(1) <= A(2)) 947 self.assertTrue(A(2) >= A(1)) 948 self.assertTrue(A(2) <= A(2)) 949 self.assertTrue(A(2) >= A(2)) 950 self.assertFalse(A(2) < A(1)) 951 952 def test_total_ordering_ge(self): 953 @functools.total_ordering 954 class A: 955 def __init__(self, value): 956 self.value = value 957 def __ge__(self, other): 958 return self.value >= other.value 959 def __eq__(self, other): 960 return self.value == other.value 961 self.assertTrue(A(1) < A(2)) 962 self.assertTrue(A(2) > A(1)) 963 self.assertTrue(A(1) <= A(2)) 964 self.assertTrue(A(2) >= A(1)) 965 self.assertTrue(A(2) <= A(2)) 966 self.assertTrue(A(2) >= A(2)) 967 self.assertFalse(A(2) <= A(1)) 968 969 def test_total_ordering_no_overwrite(self): 970 # new methods should not overwrite existing 971 @functools.total_ordering 972 class A(int): 973 pass 974 self.assertTrue(A(1) < A(2)) 975 self.assertTrue(A(2) > A(1)) 976 self.assertTrue(A(1) <= A(2)) 977 self.assertTrue(A(2) >= A(1)) 978 self.assertTrue(A(2) <= A(2)) 979 self.assertTrue(A(2) >= A(2)) 980 981 def test_no_operations_defined(self): 982 with self.assertRaises(ValueError): 983 @functools.total_ordering 984 class A: 985 pass 986 987 def test_type_error_when_not_implemented(self): 988 # bug 10042; ensure stack overflow does not occur 989 # when decorated types return NotImplemented 990 @functools.total_ordering 991 class ImplementsLessThan: 992 def __init__(self, value): 993 self.value = value 994 def __eq__(self, other): 995 if isinstance(other, ImplementsLessThan): 996 return self.value == other.value 997 return False 998 def __lt__(self, other): 999 if isinstance(other, ImplementsLessThan): 1000 return self.value < other.value 1001 return NotImplemented 1002 1003 @functools.total_ordering 1004 class ImplementsGreaterThan: 1005 def __init__(self, value): 1006 self.value = value 1007 def __eq__(self, other): 1008 if isinstance(other, ImplementsGreaterThan): 1009 return self.value == other.value 1010 return False 1011 def __gt__(self, other): 1012 if isinstance(other, ImplementsGreaterThan): 1013 return self.value > other.value 1014 return NotImplemented 1015 1016 @functools.total_ordering 1017 class ImplementsLessThanEqualTo: 1018 def __init__(self, value): 1019 self.value = value 1020 def __eq__(self, other): 1021 if isinstance(other, ImplementsLessThanEqualTo): 1022 return self.value == other.value 1023 return False 1024 def __le__(self, other): 1025 if isinstance(other, ImplementsLessThanEqualTo): 1026 return self.value <= other.value 1027 return NotImplemented 1028 1029 @functools.total_ordering 1030 class ImplementsGreaterThanEqualTo: 1031 def __init__(self, value): 1032 self.value = value 1033 def __eq__(self, other): 1034 if isinstance(other, ImplementsGreaterThanEqualTo): 1035 return self.value == other.value 1036 return False 1037 def __ge__(self, other): 1038 if isinstance(other, ImplementsGreaterThanEqualTo): 1039 return self.value >= other.value 1040 return NotImplemented 1041 1042 @functools.total_ordering 1043 class ComparatorNotImplemented: 1044 def __init__(self, value): 1045 self.value = value 1046 def __eq__(self, other): 1047 if isinstance(other, ComparatorNotImplemented): 1048 return self.value == other.value 1049 return False 1050 def __lt__(self, other): 1051 return NotImplemented 1052 1053 with self.subTest("LT < 1"), self.assertRaises(TypeError): 1054 ImplementsLessThan(-1) < 1 1055 1056 with self.subTest("LT < LE"), self.assertRaises(TypeError): 1057 ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) 1058 1059 with self.subTest("LT < GT"), self.assertRaises(TypeError): 1060 ImplementsLessThan(1) < ImplementsGreaterThan(1) 1061 1062 with self.subTest("LE <= LT"), self.assertRaises(TypeError): 1063 ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) 1064 1065 with self.subTest("LE <= GE"), self.assertRaises(TypeError): 1066 ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) 1067 1068 with self.subTest("GT > GE"), self.assertRaises(TypeError): 1069 ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) 1070 1071 with self.subTest("GT > LT"), self.assertRaises(TypeError): 1072 ImplementsGreaterThan(5) > ImplementsLessThan(5) 1073 1074 with self.subTest("GE >= GT"), self.assertRaises(TypeError): 1075 ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) 1076 1077 with self.subTest("GE >= LE"), self.assertRaises(TypeError): 1078 ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) 1079 1080 with self.subTest("GE when equal"): 1081 a = ComparatorNotImplemented(8) 1082 b = ComparatorNotImplemented(8) 1083 self.assertEqual(a, b) 1084 with self.assertRaises(TypeError): 1085 a >= b 1086 1087 with self.subTest("LE when equal"): 1088 a = ComparatorNotImplemented(9) 1089 b = ComparatorNotImplemented(9) 1090 self.assertEqual(a, b) 1091 with self.assertRaises(TypeError): 1092 a <= b 1093 1094 def test_pickle(self): 1095 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1096 for name in '__lt__', '__gt__', '__le__', '__ge__': 1097 with self.subTest(method=name, proto=proto): 1098 method = getattr(Orderable_LT, name) 1099 method_copy = pickle.loads(pickle.dumps(method, proto)) 1100 self.assertIs(method_copy, method) 1101 1102 @functools.total_ordering 1103 class Orderable_LT: 1104 def __init__(self, value): 1105 self.value = value 1106 def __lt__(self, other): 1107 return self.value < other.value 1108 def __eq__(self, other): 1109 return self.value == other.value 1110 1111 1112 class TestLRU: 1113 1114 def test_lru(self): 1115 def orig(x, y): 1116 return 3 * x + y 1117 f = self.module.lru_cache(maxsize=20)(orig) 1118 hits, misses, maxsize, currsize = f.cache_info() 1119 self.assertEqual(maxsize, 20) 1120 self.assertEqual(currsize, 0) 1121 self.assertEqual(hits, 0) 1122 self.assertEqual(misses, 0) 1123 1124 domain = range(5) 1125 for i in range(1000): 1126 x, y = choice(domain), choice(domain) 1127 actual = f(x, y) 1128 expected = orig(x, y) 1129 self.assertEqual(actual, expected) 1130 hits, misses, maxsize, currsize = f.cache_info() 1131 self.assertTrue(hits > misses) 1132 self.assertEqual(hits + misses, 1000) 1133 self.assertEqual(currsize, 20) 1134 1135 f.cache_clear() # test clearing 1136 hits, misses, maxsize, currsize = f.cache_info() 1137 self.assertEqual(hits, 0) 1138 self.assertEqual(misses, 0) 1139 self.assertEqual(currsize, 0) 1140 f(x, y) 1141 hits, misses, maxsize, currsize = f.cache_info() 1142 self.assertEqual(hits, 0) 1143 self.assertEqual(misses, 1) 1144 self.assertEqual(currsize, 1) 1145 1146 # Test bypassing the cache 1147 self.assertIs(f.__wrapped__, orig) 1148 f.__wrapped__(x, y) 1149 hits, misses, maxsize, currsize = f.cache_info() 1150 self.assertEqual(hits, 0) 1151 self.assertEqual(misses, 1) 1152 self.assertEqual(currsize, 1) 1153 1154 # test size zero (which means "never-cache") 1155 @self.module.lru_cache(0) 1156 def f(): 1157 nonlocal f_cnt 1158 f_cnt += 1 1159 return 20 1160 self.assertEqual(f.cache_info().maxsize, 0) 1161 f_cnt = 0 1162 for i in range(5): 1163 self.assertEqual(f(), 20) 1164 self.assertEqual(f_cnt, 5) 1165 hits, misses, maxsize, currsize = f.cache_info() 1166 self.assertEqual(hits, 0) 1167 self.assertEqual(misses, 5) 1168 self.assertEqual(currsize, 0) 1169 1170 # test size one 1171 @self.module.lru_cache(1) 1172 def f(): 1173 nonlocal f_cnt 1174 f_cnt += 1 1175 return 20 1176 self.assertEqual(f.cache_info().maxsize, 1) 1177 f_cnt = 0 1178 for i in range(5): 1179 self.assertEqual(f(), 20) 1180 self.assertEqual(f_cnt, 1) 1181 hits, misses, maxsize, currsize = f.cache_info() 1182 self.assertEqual(hits, 4) 1183 self.assertEqual(misses, 1) 1184 self.assertEqual(currsize, 1) 1185 1186 # test size two 1187 @self.module.lru_cache(2) 1188 def f(x): 1189 nonlocal f_cnt 1190 f_cnt += 1 1191 return x*10 1192 self.assertEqual(f.cache_info().maxsize, 2) 1193 f_cnt = 0 1194 for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: 1195 # * * * * 1196 self.assertEqual(f(x), x*10) 1197 self.assertEqual(f_cnt, 4) 1198 hits, misses, maxsize, currsize = f.cache_info() 1199 self.assertEqual(hits, 12) 1200 self.assertEqual(misses, 4) 1201 self.assertEqual(currsize, 2) 1202 1203 def test_lru_reentrancy_with_len(self): 1204 # Test to make sure the LRU cache code isn't thrown-off by 1205 # caching the built-in len() function. Since len() can be 1206 # cached, we shouldn't use it inside the lru code itself. 1207 old_len = builtins.len 1208 try: 1209 builtins.len = self.module.lru_cache(4)(len) 1210 for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: 1211 self.assertEqual(len('abcdefghijklmn'[:i]), i) 1212 finally: 1213 builtins.len = old_len 1214 1215 def test_lru_type_error(self): 1216 # Regression test for issue #28653. 1217 # lru_cache was leaking when one of the arguments 1218 # wasn't cacheable. 1219 1220 @functools.lru_cache(maxsize=None) 1221 def infinite_cache(o): 1222 pass 1223 1224 @functools.lru_cache(maxsize=10) 1225 def limited_cache(o): 1226 pass 1227 1228 with self.assertRaises(TypeError): 1229 infinite_cache([]) 1230 1231 with self.assertRaises(TypeError): 1232 limited_cache([]) 1233 1234 def test_lru_with_maxsize_none(self): 1235 @self.module.lru_cache(maxsize=None) 1236 def fib(n): 1237 if n < 2: 1238 return n 1239 return fib(n-1) + fib(n-2) 1240 self.assertEqual([fib(n) for n in range(16)], 1241 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1242 self.assertEqual(fib.cache_info(), 1243 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1244 fib.cache_clear() 1245 self.assertEqual(fib.cache_info(), 1246 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1247 1248 def test_lru_with_maxsize_negative(self): 1249 @self.module.lru_cache(maxsize=-10) 1250 def eq(n): 1251 return n 1252 for i in (0, 1): 1253 self.assertEqual([eq(n) for n in range(150)], list(range(150))) 1254 self.assertEqual(eq.cache_info(), 1255 self.module._CacheInfo(hits=0, misses=300, maxsize=-10, currsize=1)) 1256 1257 def test_lru_with_exceptions(self): 1258 # Verify that user_function exceptions get passed through without 1259 # creating a hard-to-read chained exception. 1260 # http://bugs.python.org/issue13177 1261 for maxsize in (None, 128): 1262 @self.module.lru_cache(maxsize) 1263 def func(i): 1264 return 'abc'[i] 1265 self.assertEqual(func(0), 'a') 1266 with self.assertRaises(IndexError) as cm: 1267 func(15) 1268 self.assertIsNone(cm.exception.__context__) 1269 # Verify that the previous exception did not result in a cached entry 1270 with self.assertRaises(IndexError): 1271 func(15) 1272 1273 def test_lru_with_types(self): 1274 for maxsize in (None, 128): 1275 @self.module.lru_cache(maxsize=maxsize, typed=True) 1276 def square(x): 1277 return x * x 1278 self.assertEqual(square(3), 9) 1279 self.assertEqual(type(square(3)), type(9)) 1280 self.assertEqual(square(3.0), 9.0) 1281 self.assertEqual(type(square(3.0)), type(9.0)) 1282 self.assertEqual(square(x=3), 9) 1283 self.assertEqual(type(square(x=3)), type(9)) 1284 self.assertEqual(square(x=3.0), 9.0) 1285 self.assertEqual(type(square(x=3.0)), type(9.0)) 1286 self.assertEqual(square.cache_info().hits, 4) 1287 self.assertEqual(square.cache_info().misses, 4) 1288 1289 def test_lru_with_keyword_args(self): 1290 @self.module.lru_cache() 1291 def fib(n): 1292 if n < 2: 1293 return n 1294 return fib(n=n-1) + fib(n=n-2) 1295 self.assertEqual( 1296 [fib(n=number) for number in range(16)], 1297 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] 1298 ) 1299 self.assertEqual(fib.cache_info(), 1300 self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) 1301 fib.cache_clear() 1302 self.assertEqual(fib.cache_info(), 1303 self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) 1304 1305 def test_lru_with_keyword_args_maxsize_none(self): 1306 @self.module.lru_cache(maxsize=None) 1307 def fib(n): 1308 if n < 2: 1309 return n 1310 return fib(n=n-1) + fib(n=n-2) 1311 self.assertEqual([fib(n=number) for number in range(16)], 1312 [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) 1313 self.assertEqual(fib.cache_info(), 1314 self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) 1315 fib.cache_clear() 1316 self.assertEqual(fib.cache_info(), 1317 self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) 1318 1319 def test_kwargs_order(self): 1320 # PEP 468: Preserving Keyword Argument Order 1321 @self.module.lru_cache(maxsize=10) 1322 def f(**kwargs): 1323 return list(kwargs.items()) 1324 self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) 1325 self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) 1326 self.assertEqual(f.cache_info(), 1327 self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) 1328 1329 def test_lru_cache_decoration(self): 1330 def f(zomg: 'zomg_annotation'): 1331 """f doc string""" 1332 return 42 1333 g = self.module.lru_cache()(f) 1334 for attr in self.module.WRAPPER_ASSIGNMENTS: 1335 self.assertEqual(getattr(g, attr), getattr(f, attr)) 1336 1337 @unittest.skipUnless(threading, 'This test requires threading.') 1338 def test_lru_cache_threaded(self): 1339 n, m = 5, 11 1340 def orig(x, y): 1341 return 3 * x + y 1342 f = self.module.lru_cache(maxsize=n*m)(orig) 1343 hits, misses, maxsize, currsize = f.cache_info() 1344 self.assertEqual(currsize, 0) 1345 1346 start = threading.Event() 1347 def full(k): 1348 start.wait(10) 1349 for _ in range(m): 1350 self.assertEqual(f(k, 0), orig(k, 0)) 1351 1352 def clear(): 1353 start.wait(10) 1354 for _ in range(2*m): 1355 f.cache_clear() 1356 1357 orig_si = sys.getswitchinterval() 1358 support.setswitchinterval(1e-6) 1359 try: 1360 # create n threads in order to fill cache 1361 threads = [threading.Thread(target=full, args=[k]) 1362 for k in range(n)] 1363 with support.start_threads(threads): 1364 start.set() 1365 1366 hits, misses, maxsize, currsize = f.cache_info() 1367 if self.module is py_functools: 1368 # XXX: Why can be not equal? 1369 self.assertLessEqual(misses, n) 1370 self.assertLessEqual(hits, m*n - misses) 1371 else: 1372 self.assertEqual(misses, n) 1373 self.assertEqual(hits, m*n - misses) 1374 self.assertEqual(currsize, n) 1375 1376 # create n threads in order to fill cache and 1 to clear it 1377 threads = [threading.Thread(target=clear)] 1378 threads += [threading.Thread(target=full, args=[k]) 1379 for k in range(n)] 1380 start.clear() 1381 with support.start_threads(threads): 1382 start.set() 1383 finally: 1384 sys.setswitchinterval(orig_si) 1385 1386 @unittest.skipUnless(threading, 'This test requires threading.') 1387 def test_lru_cache_threaded2(self): 1388 # Simultaneous call with the same arguments 1389 n, m = 5, 7 1390 start = threading.Barrier(n+1) 1391 pause = threading.Barrier(n+1) 1392 stop = threading.Barrier(n+1) 1393 @self.module.lru_cache(maxsize=m*n) 1394 def f(x): 1395 pause.wait(10) 1396 return 3 * x 1397 self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) 1398 def test(): 1399 for i in range(m): 1400 start.wait(10) 1401 self.assertEqual(f(i), 3 * i) 1402 stop.wait(10) 1403 threads = [threading.Thread(target=test) for k in range(n)] 1404 with support.start_threads(threads): 1405 for i in range(m): 1406 start.wait(10) 1407 stop.reset() 1408 pause.wait(10) 1409 start.reset() 1410 stop.wait(10) 1411 pause.reset() 1412 self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) 1413 1414 @unittest.skipUnless(threading, 'This test requires threading.') 1415 def test_lru_cache_threaded3(self): 1416 @self.module.lru_cache(maxsize=2) 1417 def f(x): 1418 time.sleep(.01) 1419 return 3 * x 1420 def test(i, x): 1421 with self.subTest(thread=i): 1422 self.assertEqual(f(x), 3 * x, i) 1423 threads = [threading.Thread(target=test, args=(i, v)) 1424 for i, v in enumerate([1, 2, 2, 3, 2])] 1425 with support.start_threads(threads): 1426 pass 1427 1428 def test_need_for_rlock(self): 1429 # This will deadlock on an LRU cache that uses a regular lock 1430 1431 @self.module.lru_cache(maxsize=10) 1432 def test_func(x): 1433 'Used to demonstrate a reentrant lru_cache call within a single thread' 1434 return x 1435 1436 class DoubleEq: 1437 'Demonstrate a reentrant lru_cache call within a single thread' 1438 def __init__(self, x): 1439 self.x = x 1440 def __hash__(self): 1441 return self.x 1442 def __eq__(self, other): 1443 if self.x == 2: 1444 test_func(DoubleEq(1)) 1445 return self.x == other.x 1446 1447 test_func(DoubleEq(1)) # Load the cache 1448 test_func(DoubleEq(2)) # Load the cache 1449 self.assertEqual(test_func(DoubleEq(2)), # Trigger a re-entrant __eq__ call 1450 DoubleEq(2)) # Verify the correct return value 1451 1452 def test_early_detection_of_bad_call(self): 1453 # Issue #22184 1454 with self.assertRaises(TypeError): 1455 @functools.lru_cache 1456 def f(): 1457 pass 1458 1459 def test_lru_method(self): 1460 class X(int): 1461 f_cnt = 0 1462 @self.module.lru_cache(2) 1463 def f(self, x): 1464 self.f_cnt += 1 1465 return x*10+self 1466 a = X(5) 1467 b = X(5) 1468 c = X(7) 1469 self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) 1470 1471 for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: 1472 self.assertEqual(a.f(x), x*10 + 5) 1473 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) 1474 self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) 1475 1476 for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: 1477 self.assertEqual(b.f(x), x*10 + 5) 1478 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) 1479 self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) 1480 1481 for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: 1482 self.assertEqual(c.f(x), x*10 + 7) 1483 self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) 1484 self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) 1485 1486 self.assertEqual(a.f.cache_info(), X.f.cache_info()) 1487 self.assertEqual(b.f.cache_info(), X.f.cache_info()) 1488 self.assertEqual(c.f.cache_info(), X.f.cache_info()) 1489 1490 def test_pickle(self): 1491 cls = self.__class__ 1492 for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: 1493 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1494 with self.subTest(proto=proto, func=f): 1495 f_copy = pickle.loads(pickle.dumps(f, proto)) 1496 self.assertIs(f_copy, f) 1497 1498 def test_copy(self): 1499 cls = self.__class__ 1500 def orig(x, y): 1501 return 3 * x + y 1502 part = self.module.partial(orig, 2) 1503 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1504 self.module.lru_cache(2)(part)) 1505 for f in funcs: 1506 with self.subTest(func=f): 1507 f_copy = copy.copy(f) 1508 self.assertIs(f_copy, f) 1509 1510 def test_deepcopy(self): 1511 cls = self.__class__ 1512 def orig(x, y): 1513 return 3 * x + y 1514 part = self.module.partial(orig, 2) 1515 funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, 1516 self.module.lru_cache(2)(part)) 1517 for f in funcs: 1518 with self.subTest(func=f): 1519 f_copy = copy.deepcopy(f) 1520 self.assertIs(f_copy, f) 1521 1522 1523 @py_functools.lru_cache() 1524 def py_cached_func(x, y): 1525 return 3 * x + y 1526 1527 @c_functools.lru_cache() 1528 def c_cached_func(x, y): 1529 return 3 * x + y 1530 1531 1532 class TestLRUPy(TestLRU, unittest.TestCase): 1533 module = py_functools 1534 cached_func = py_cached_func, 1535 1536 @module.lru_cache() 1537 def cached_meth(self, x, y): 1538 return 3 * x + y 1539 1540 @staticmethod 1541 @module.lru_cache() 1542 def cached_staticmeth(x, y): 1543 return 3 * x + y 1544 1545 1546 class TestLRUC(TestLRU, unittest.TestCase): 1547 module = c_functools 1548 cached_func = c_cached_func, 1549 1550 @module.lru_cache() 1551 def cached_meth(self, x, y): 1552 return 3 * x + y 1553 1554 @staticmethod 1555 @module.lru_cache() 1556 def cached_staticmeth(x, y): 1557 return 3 * x + y 1558 1559 1560 class TestSingleDispatch(unittest.TestCase): 1561 def test_simple_overloads(self): 1562 @functools.singledispatch 1563 def g(obj): 1564 return "base" 1565 def g_int(i): 1566 return "integer" 1567 g.register(int, g_int) 1568 self.assertEqual(g("str"), "base") 1569 self.assertEqual(g(1), "integer") 1570 self.assertEqual(g([1,2,3]), "base") 1571 1572 def test_mro(self): 1573 @functools.singledispatch 1574 def g(obj): 1575 return "base" 1576 class A: 1577 pass 1578 class C(A): 1579 pass 1580 class B(A): 1581 pass 1582 class D(C, B): 1583 pass 1584 def g_A(a): 1585 return "A" 1586 def g_B(b): 1587 return "B" 1588 g.register(A, g_A) 1589 g.register(B, g_B) 1590 self.assertEqual(g(A()), "A") 1591 self.assertEqual(g(B()), "B") 1592 self.assertEqual(g(C()), "A") 1593 self.assertEqual(g(D()), "B") 1594 1595 def test_register_decorator(self): 1596 @functools.singledispatch 1597 def g(obj): 1598 return "base" 1599 @g.register(int) 1600 def g_int(i): 1601 return "int %s" % (i,) 1602 self.assertEqual(g(""), "base") 1603 self.assertEqual(g(12), "int 12") 1604 self.assertIs(g.dispatch(int), g_int) 1605 self.assertIs(g.dispatch(object), g.dispatch(str)) 1606 # Note: in the assert above this is not g. 1607 # @singledispatch returns the wrapper. 1608 1609 def test_wrapping_attributes(self): 1610 @functools.singledispatch 1611 def g(obj): 1612 "Simple test" 1613 return "Test" 1614 self.assertEqual(g.__name__, "g") 1615 if sys.flags.optimize < 2: 1616 self.assertEqual(g.__doc__, "Simple test") 1617 1618 @unittest.skipUnless(decimal, 'requires _decimal') 1619 @support.cpython_only 1620 def test_c_classes(self): 1621 @functools.singledispatch 1622 def g(obj): 1623 return "base" 1624 @g.register(decimal.DecimalException) 1625 def _(obj): 1626 return obj.args 1627 subn = decimal.Subnormal("Exponent < Emin") 1628 rnd = decimal.Rounded("Number got rounded") 1629 self.assertEqual(g(subn), ("Exponent < Emin",)) 1630 self.assertEqual(g(rnd), ("Number got rounded",)) 1631 @g.register(decimal.Subnormal) 1632 def _(obj): 1633 return "Too small to care." 1634 self.assertEqual(g(subn), "Too small to care.") 1635 self.assertEqual(g(rnd), ("Number got rounded",)) 1636 1637 def test_compose_mro(self): 1638 # None of the examples in this test depend on haystack ordering. 1639 c = collections 1640 mro = functools._compose_mro 1641 bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] 1642 for haystack in permutations(bases): 1643 m = mro(dict, haystack) 1644 self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, 1645 c.Collection, c.Sized, c.Iterable, 1646 c.Container, object]) 1647 bases = [c.Container, c.Mapping, c.MutableMapping, c.OrderedDict] 1648 for haystack in permutations(bases): 1649 m = mro(c.ChainMap, haystack) 1650 self.assertEqual(m, [c.ChainMap, c.MutableMapping, c.Mapping, 1651 c.Collection, c.Sized, c.Iterable, 1652 c.Container, object]) 1653 1654 # If there's a generic function with implementations registered for 1655 # both Sized and Container, passing a defaultdict to it results in an 1656 # ambiguous dispatch which will cause a RuntimeError (see 1657 # test_mro_conflicts). 1658 bases = [c.Container, c.Sized, str] 1659 for haystack in permutations(bases): 1660 m = mro(c.defaultdict, [c.Sized, c.Container, str]) 1661 self.assertEqual(m, [c.defaultdict, dict, c.Sized, c.Container, 1662 object]) 1663 1664 # MutableSequence below is registered directly on D. In other words, it 1665 # precedes MutableMapping which means single dispatch will always 1666 # choose MutableSequence here. 1667 class D(c.defaultdict): 1668 pass 1669 c.MutableSequence.register(D) 1670 bases = [c.MutableSequence, c.MutableMapping] 1671 for haystack in permutations(bases): 1672 m = mro(D, bases) 1673 self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, 1674 c.defaultdict, dict, c.MutableMapping, c.Mapping, 1675 c.Collection, c.Sized, c.Iterable, c.Container, 1676 object]) 1677 1678 # Container and Callable are registered on different base classes and 1679 # a generic function supporting both should always pick the Callable 1680 # implementation if a C instance is passed. 1681 class C(c.defaultdict): 1682 def __call__(self): 1683 pass 1684 bases = [c.Sized, c.Callable, c.Container, c.Mapping] 1685 for haystack in permutations(bases): 1686 m = mro(C, haystack) 1687 self.assertEqual(m, [C, c.Callable, c.defaultdict, dict, c.Mapping, 1688 c.Collection, c.Sized, c.Iterable, 1689 c.Container, object]) 1690 1691 def test_register_abc(self): 1692 c = collections 1693 d = {"a": "b"} 1694 l = [1, 2, 3] 1695 s = {object(), None} 1696 f = frozenset(s) 1697 t = (1, 2, 3) 1698 @functools.singledispatch 1699 def g(obj): 1700 return "base" 1701 self.assertEqual(g(d), "base") 1702 self.assertEqual(g(l), "base") 1703 self.assertEqual(g(s), "base") 1704 self.assertEqual(g(f), "base") 1705 self.assertEqual(g(t), "base") 1706 g.register(c.Sized, lambda obj: "sized") 1707 self.assertEqual(g(d), "sized") 1708 self.assertEqual(g(l), "sized") 1709 self.assertEqual(g(s), "sized") 1710 self.assertEqual(g(f), "sized") 1711 self.assertEqual(g(t), "sized") 1712 g.register(c.MutableMapping, lambda obj: "mutablemapping") 1713 self.assertEqual(g(d), "mutablemapping") 1714 self.assertEqual(g(l), "sized") 1715 self.assertEqual(g(s), "sized") 1716 self.assertEqual(g(f), "sized") 1717 self.assertEqual(g(t), "sized") 1718 g.register(c.ChainMap, lambda obj: "chainmap") 1719 self.assertEqual(g(d), "mutablemapping") # irrelevant ABCs registered 1720 self.assertEqual(g(l), "sized") 1721 self.assertEqual(g(s), "sized") 1722 self.assertEqual(g(f), "sized") 1723 self.assertEqual(g(t), "sized") 1724 g.register(c.MutableSequence, lambda obj: "mutablesequence") 1725 self.assertEqual(g(d), "mutablemapping") 1726 self.assertEqual(g(l), "mutablesequence") 1727 self.assertEqual(g(s), "sized") 1728 self.assertEqual(g(f), "sized") 1729 self.assertEqual(g(t), "sized") 1730 g.register(c.MutableSet, lambda obj: "mutableset") 1731 self.assertEqual(g(d), "mutablemapping") 1732 self.assertEqual(g(l), "mutablesequence") 1733 self.assertEqual(g(s), "mutableset") 1734 self.assertEqual(g(f), "sized") 1735 self.assertEqual(g(t), "sized") 1736 g.register(c.Mapping, lambda obj: "mapping") 1737 self.assertEqual(g(d), "mutablemapping") # not specific enough 1738 self.assertEqual(g(l), "mutablesequence") 1739 self.assertEqual(g(s), "mutableset") 1740 self.assertEqual(g(f), "sized") 1741 self.assertEqual(g(t), "sized") 1742 g.register(c.Sequence, lambda obj: "sequence") 1743 self.assertEqual(g(d), "mutablemapping") 1744 self.assertEqual(g(l), "mutablesequence") 1745 self.assertEqual(g(s), "mutableset") 1746 self.assertEqual(g(f), "sized") 1747 self.assertEqual(g(t), "sequence") 1748 g.register(c.Set, lambda obj: "set") 1749 self.assertEqual(g(d), "mutablemapping") 1750 self.assertEqual(g(l), "mutablesequence") 1751 self.assertEqual(g(s), "mutableset") 1752 self.assertEqual(g(f), "set") 1753 self.assertEqual(g(t), "sequence") 1754 g.register(dict, lambda obj: "dict") 1755 self.assertEqual(g(d), "dict") 1756 self.assertEqual(g(l), "mutablesequence") 1757 self.assertEqual(g(s), "mutableset") 1758 self.assertEqual(g(f), "set") 1759 self.assertEqual(g(t), "sequence") 1760 g.register(list, lambda obj: "list") 1761 self.assertEqual(g(d), "dict") 1762 self.assertEqual(g(l), "list") 1763 self.assertEqual(g(s), "mutableset") 1764 self.assertEqual(g(f), "set") 1765 self.assertEqual(g(t), "sequence") 1766 g.register(set, lambda obj: "concrete-set") 1767 self.assertEqual(g(d), "dict") 1768 self.assertEqual(g(l), "list") 1769 self.assertEqual(g(s), "concrete-set") 1770 self.assertEqual(g(f), "set") 1771 self.assertEqual(g(t), "sequence") 1772 g.register(frozenset, lambda obj: "frozen-set") 1773 self.assertEqual(g(d), "dict") 1774 self.assertEqual(g(l), "list") 1775 self.assertEqual(g(s), "concrete-set") 1776 self.assertEqual(g(f), "frozen-set") 1777 self.assertEqual(g(t), "sequence") 1778 g.register(tuple, lambda obj: "tuple") 1779 self.assertEqual(g(d), "dict") 1780 self.assertEqual(g(l), "list") 1781 self.assertEqual(g(s), "concrete-set") 1782 self.assertEqual(g(f), "frozen-set") 1783 self.assertEqual(g(t), "tuple") 1784 1785 def test_c3_abc(self): 1786 c = collections 1787 mro = functools._c3_mro 1788 class A(object): 1789 pass 1790 class B(A): 1791 def __len__(self): 1792 return 0 # implies Sized 1793 @c.Container.register 1794 class C(object): 1795 pass 1796 class D(object): 1797 pass # unrelated 1798 class X(D, C, B): 1799 def __call__(self): 1800 pass # implies Callable 1801 expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] 1802 for abcs in permutations([c.Sized, c.Callable, c.Container]): 1803 self.assertEqual(mro(X, abcs=abcs), expected) 1804 # unrelated ABCs don't appear in the resulting MRO 1805 many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] 1806 self.assertEqual(mro(X, abcs=many_abcs), expected) 1807 1808 def test_false_meta(self): 1809 # see issue23572 1810 class MetaA(type): 1811 def __len__(self): 1812 return 0 1813 class A(metaclass=MetaA): 1814 pass 1815 class AA(A): 1816 pass 1817 @functools.singledispatch 1818 def fun(a): 1819 return 'base A' 1820 @fun.register(A) 1821 def _(a): 1822 return 'fun A' 1823 aa = AA() 1824 self.assertEqual(fun(aa), 'fun A') 1825 1826 def test_mro_conflicts(self): 1827 c = collections 1828 @functools.singledispatch 1829 def g(arg): 1830 return "base" 1831 class O(c.Sized): 1832 def __len__(self): 1833 return 0 1834 o = O() 1835 self.assertEqual(g(o), "base") 1836 g.register(c.Iterable, lambda arg: "iterable") 1837 g.register(c.Container, lambda arg: "container") 1838 g.register(c.Sized, lambda arg: "sized") 1839 g.register(c.Set, lambda arg: "set") 1840 self.assertEqual(g(o), "sized") 1841 c.Iterable.register(O) 1842 self.assertEqual(g(o), "sized") # because it's explicitly in __mro__ 1843 c.Container.register(O) 1844 self.assertEqual(g(o), "sized") # see above: Sized is in __mro__ 1845 c.Set.register(O) 1846 self.assertEqual(g(o), "set") # because c.Set is a subclass of 1847 # c.Sized and c.Container 1848 class P: 1849 pass 1850 p = P() 1851 self.assertEqual(g(p), "base") 1852 c.Iterable.register(P) 1853 self.assertEqual(g(p), "iterable") 1854 c.Container.register(P) 1855 with self.assertRaises(RuntimeError) as re_one: 1856 g(p) 1857 self.assertIn( 1858 str(re_one.exception), 1859 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 1860 "or <class 'collections.abc.Iterable'>"), 1861 ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " 1862 "or <class 'collections.abc.Container'>")), 1863 ) 1864 class Q(c.Sized): 1865 def __len__(self): 1866 return 0 1867 q = Q() 1868 self.assertEqual(g(q), "sized") 1869 c.Iterable.register(Q) 1870 self.assertEqual(g(q), "sized") # because it's explicitly in __mro__ 1871 c.Set.register(Q) 1872 self.assertEqual(g(q), "set") # because c.Set is a subclass of 1873 # c.Sized and c.Iterable 1874 @functools.singledispatch 1875 def h(arg): 1876 return "base" 1877 @h.register(c.Sized) 1878 def _(arg): 1879 return "sized" 1880 @h.register(c.Container) 1881 def _(arg): 1882 return "container" 1883 # Even though Sized and Container are explicit bases of MutableMapping, 1884 # this ABC is implicitly registered on defaultdict which makes all of 1885 # MutableMapping's bases implicit as well from defaultdict's 1886 # perspective. 1887 with self.assertRaises(RuntimeError) as re_two: 1888 h(c.defaultdict(lambda: 0)) 1889 self.assertIn( 1890 str(re_two.exception), 1891 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 1892 "or <class 'collections.abc.Sized'>"), 1893 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 1894 "or <class 'collections.abc.Container'>")), 1895 ) 1896 class R(c.defaultdict): 1897 pass 1898 c.MutableSequence.register(R) 1899 @functools.singledispatch 1900 def i(arg): 1901 return "base" 1902 @i.register(c.MutableMapping) 1903 def _(arg): 1904 return "mapping" 1905 @i.register(c.MutableSequence) 1906 def _(arg): 1907 return "sequence" 1908 r = R() 1909 self.assertEqual(i(r), "sequence") 1910 class S: 1911 pass 1912 class T(S, c.Sized): 1913 def __len__(self): 1914 return 0 1915 t = T() 1916 self.assertEqual(h(t), "sized") 1917 c.Container.register(T) 1918 self.assertEqual(h(t), "sized") # because it's explicitly in the MRO 1919 class U: 1920 def __len__(self): 1921 return 0 1922 u = U() 1923 self.assertEqual(h(u), "sized") # implicit Sized subclass inferred 1924 # from the existence of __len__() 1925 c.Container.register(U) 1926 # There is no preference for registered versus inferred ABCs. 1927 with self.assertRaises(RuntimeError) as re_three: 1928 h(u) 1929 self.assertIn( 1930 str(re_three.exception), 1931 (("Ambiguous dispatch: <class 'collections.abc.Container'> " 1932 "or <class 'collections.abc.Sized'>"), 1933 ("Ambiguous dispatch: <class 'collections.abc.Sized'> " 1934 "or <class 'collections.abc.Container'>")), 1935 ) 1936 class V(c.Sized, S): 1937 def __len__(self): 1938 return 0 1939 @functools.singledispatch 1940 def j(arg): 1941 return "base" 1942 @j.register(S) 1943 def _(arg): 1944 return "s" 1945 @j.register(c.Container) 1946 def _(arg): 1947 return "container" 1948 v = V() 1949 self.assertEqual(j(v), "s") 1950 c.Container.register(V) 1951 self.assertEqual(j(v), "container") # because it ends up right after 1952 # Sized in the MRO 1953 1954 def test_cache_invalidation(self): 1955 from collections import UserDict 1956 class TracingDict(UserDict): 1957 def __init__(self, *args, **kwargs): 1958 super(TracingDict, self).__init__(*args, **kwargs) 1959 self.set_ops = [] 1960 self.get_ops = [] 1961 def __getitem__(self, key): 1962 result = self.data[key] 1963 self.get_ops.append(key) 1964 return result 1965 def __setitem__(self, key, value): 1966 self.set_ops.append(key) 1967 self.data[key] = value 1968 def clear(self): 1969 self.data.clear() 1970 _orig_wkd = functools.WeakKeyDictionary 1971 td = TracingDict() 1972 functools.WeakKeyDictionary = lambda: td 1973 c = collections 1974 @functools.singledispatch 1975 def g(arg): 1976 return "base" 1977 d = {} 1978 l = [] 1979 self.assertEqual(len(td), 0) 1980 self.assertEqual(g(d), "base") 1981 self.assertEqual(len(td), 1) 1982 self.assertEqual(td.get_ops, []) 1983 self.assertEqual(td.set_ops, [dict]) 1984 self.assertEqual(td.data[dict], g.registry[object]) 1985 self.assertEqual(g(l), "base") 1986 self.assertEqual(len(td), 2) 1987 self.assertEqual(td.get_ops, []) 1988 self.assertEqual(td.set_ops, [dict, list]) 1989 self.assertEqual(td.data[dict], g.registry[object]) 1990 self.assertEqual(td.data[list], g.registry[object]) 1991 self.assertEqual(td.data[dict], td.data[list]) 1992 self.assertEqual(g(l), "base") 1993 self.assertEqual(g(d), "base") 1994 self.assertEqual(td.get_ops, [list, dict]) 1995 self.assertEqual(td.set_ops, [dict, list]) 1996 g.register(list, lambda arg: "list") 1997 self.assertEqual(td.get_ops, [list, dict]) 1998 self.assertEqual(len(td), 0) 1999 self.assertEqual(g(d), "base") 2000 self.assertEqual(len(td), 1) 2001 self.assertEqual(td.get_ops, [list, dict]) 2002 self.assertEqual(td.set_ops, [dict, list, dict]) 2003 self.assertEqual(td.data[dict], 2004 functools._find_impl(dict, g.registry)) 2005 self.assertEqual(g(l), "list") 2006 self.assertEqual(len(td), 2) 2007 self.assertEqual(td.get_ops, [list, dict]) 2008 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2009 self.assertEqual(td.data[list], 2010 functools._find_impl(list, g.registry)) 2011 class X: 2012 pass 2013 c.MutableMapping.register(X) # Will not invalidate the cache, 2014 # not using ABCs yet. 2015 self.assertEqual(g(d), "base") 2016 self.assertEqual(g(l), "list") 2017 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2018 self.assertEqual(td.set_ops, [dict, list, dict, list]) 2019 g.register(c.Sized, lambda arg: "sized") 2020 self.assertEqual(len(td), 0) 2021 self.assertEqual(g(d), "sized") 2022 self.assertEqual(len(td), 1) 2023 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2024 self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) 2025 self.assertEqual(g(l), "list") 2026 self.assertEqual(len(td), 2) 2027 self.assertEqual(td.get_ops, [list, dict, dict, list]) 2028 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2029 self.assertEqual(g(l), "list") 2030 self.assertEqual(g(d), "sized") 2031 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) 2032 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2033 g.dispatch(list) 2034 g.dispatch(dict) 2035 self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, 2036 list, dict]) 2037 self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) 2038 c.MutableSet.register(X) # Will invalidate the cache. 2039 self.assertEqual(len(td), 2) # Stale cache. 2040 self.assertEqual(g(l), "list") 2041 self.assertEqual(len(td), 1) 2042 g.register(c.MutableMapping, lambda arg: "mutablemapping") 2043 self.assertEqual(len(td), 0) 2044 self.assertEqual(g(d), "mutablemapping") 2045 self.assertEqual(len(td), 1) 2046 self.assertEqual(g(l), "list") 2047 self.assertEqual(len(td), 2) 2048 g.register(dict, lambda arg: "dict") 2049 self.assertEqual(g(d), "dict") 2050 self.assertEqual(g(l), "list") 2051 g._clear_cache() 2052 self.assertEqual(len(td), 0) 2053 functools.WeakKeyDictionary = _orig_wkd 2054 2055 2056 if __name__ == '__main__': 2057 unittest.main() 2058