1 import functools 2 import sys 3 import unittest 4 from test import test_support 5 from weakref import proxy 6 import pickle 7 8 @staticmethod 9 def PythonPartial(func, *args, **keywords): 10 'Pure Python approximation of partial()' 11 def newfunc(*fargs, **fkeywords): 12 newkeywords = keywords.copy() 13 newkeywords.update(fkeywords) 14 return func(*(args + fargs), **newkeywords) 15 newfunc.func = func 16 newfunc.args = args 17 newfunc.keywords = keywords 18 return newfunc 19 20 def capture(*args, **kw): 21 """capture all positional and keyword arguments""" 22 return args, kw 23 24 def signature(part): 25 """ return the signature of a partial object """ 26 return (part.func, part.args, part.keywords, part.__dict__) 27 28 class TestPartial(unittest.TestCase): 29 30 thetype = functools.partial 31 32 def test_basic_examples(self): 33 p = self.thetype(capture, 1, 2, a=10, b=20) 34 self.assertEqual(p(3, 4, b=30, c=40), 35 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 36 p = self.thetype(map, lambda x: x*10) 37 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40]) 38 39 def test_attributes(self): 40 p = self.thetype(capture, 1, 2, a=10, b=20) 41 # attributes should be readable 42 self.assertEqual(p.func, capture) 43 self.assertEqual(p.args, (1, 2)) 44 self.assertEqual(p.keywords, dict(a=10, b=20)) 45 # attributes should not be writable 46 if not isinstance(self.thetype, type): 47 return 48 self.assertRaises(TypeError, setattr, p, 'func', map) 49 self.assertRaises(TypeError, setattr, p, 'args', (1, 2)) 50 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2)) 51 52 p = self.thetype(hex) 53 try: 54 del p.__dict__ 55 except TypeError: 56 pass 57 else: 58 self.fail('partial object allowed __dict__ to be deleted') 59 60 def test_argument_checking(self): 61 self.assertRaises(TypeError, self.thetype) # need at least a func arg 62 try: 63 self.thetype(2)() 64 except TypeError: 65 pass 66 else: 67 self.fail('First arg not checked for callability') 68 69 def test_protection_of_callers_dict_argument(self): 70 # a caller's dictionary should not be altered by partial 71 def func(a=10, b=20): 72 return a 73 d = {'a':3} 74 p = self.thetype(func, a=5) 75 self.assertEqual(p(**d), 3) 76 self.assertEqual(d, {'a':3}) 77 p(b=7) 78 self.assertEqual(d, {'a':3}) 79 80 def test_arg_combinations(self): 81 # exercise special code paths for zero args in either partial 82 # object or the caller 83 p = self.thetype(capture) 84 self.assertEqual(p(), ((), {})) 85 self.assertEqual(p(1,2), ((1,2), {})) 86 p = self.thetype(capture, 1, 2) 87 self.assertEqual(p(), ((1,2), {})) 88 self.assertEqual(p(3,4), ((1,2,3,4), {})) 89 90 def test_kw_combinations(self): 91 # exercise special code paths for no keyword args in 92 # either the partial object or the caller 93 p = self.thetype(capture) 94 self.assertEqual(p(), ((), {})) 95 self.assertEqual(p(a=1), ((), {'a':1})) 96 p = self.thetype(capture, a=1) 97 self.assertEqual(p(), ((), {'a':1})) 98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 99 # keyword args in the call override those in the partial object 100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 101 102 def test_positional(self): 103 # make sure positional arguments are captured correctly 104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 105 p = self.thetype(capture, *args) 106 expected = args + ('x',) 107 got, empty = p('x') 108 self.assertTrue(expected == got and empty == {}) 109 110 def test_keyword(self): 111 # make sure keyword arguments are captured correctly 112 for a in ['a', 0, None, 3.5]: 113 p = self.thetype(capture, a=a) 114 expected = {'a':a,'x':None} 115 empty, got = p(x=None) 116 self.assertTrue(expected == got and empty == ()) 117 118 def test_no_side_effects(self): 119 # make sure there are no side effects that affect subsequent calls 120 p = self.thetype(capture, 0, a=1) 121 args1, kw1 = p(1, b=2) 122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 123 args2, kw2 = p() 124 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 125 126 def test_error_propagation(self): 127 def f(x, y): 128 x // y 129 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0)) 130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0) 131 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0) 132 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1) 133 134 def test_weakref(self): 135 f = self.thetype(int, base=16) 136 p = proxy(f) 137 self.assertEqual(f.func, p.func) 138 f = None 139 self.assertRaises(ReferenceError, getattr, p, 'func') 140 141 def test_with_bound_and_unbound_methods(self): 142 data = map(str, range(10)) 143 join = self.thetype(str.join, '') 144 self.assertEqual(join(data), '0123456789') 145 join = self.thetype(''.join) 146 self.assertEqual(join(data), '0123456789') 147 148 def test_pickle(self): 149 f = self.thetype(signature, 'asdf', bar=True) 150 f.add_something_to__dict__ = True 151 f_copy = pickle.loads(pickle.dumps(f)) 152 self.assertEqual(signature(f), signature(f_copy)) 153 154 # Issue 6083: Reference counting bug 155 def test_setstate_refcount(self): 156 class BadSequence: 157 def __len__(self): 158 return 4 159 def __getitem__(self, key): 160 if key == 0: 161 return max 162 elif key == 1: 163 return tuple(range(1000000)) 164 elif key in (2, 3): 165 return {} 166 raise IndexError 167 168 f = self.thetype(object) 169 self.assertRaises(SystemError, f.__setstate__, BadSequence()) 170 171 class PartialSubclass(functools.partial): 172 pass 173 174 class TestPartialSubclass(TestPartial): 175 176 thetype = PartialSubclass 177 178 class TestPythonPartial(TestPartial): 179 180 thetype = PythonPartial 181 182 # the python version isn't picklable 183 def test_pickle(self): pass 184 def test_setstate_refcount(self): pass 185 186 class TestUpdateWrapper(unittest.TestCase): 187 188 def check_wrapper(self, wrapper, wrapped, 189 assigned=functools.WRAPPER_ASSIGNMENTS, 190 updated=functools.WRAPPER_UPDATES): 191 # Check attributes were assigned 192 for name in assigned: 193 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name)) 194 # Check attributes were updated 195 for name in updated: 196 wrapper_attr = getattr(wrapper, name) 197 wrapped_attr = getattr(wrapped, name) 198 for key in wrapped_attr: 199 self.assertTrue(wrapped_attr[key] is wrapper_attr[key]) 200 201 def _default_update(self): 202 def f(): 203 """This is a test""" 204 pass 205 f.attr = 'This is also a test' 206 def wrapper(): 207 pass 208 functools.update_wrapper(wrapper, f) 209 return wrapper, f 210 211 def test_default_update(self): 212 wrapper, f = self._default_update() 213 self.check_wrapper(wrapper, f) 214 self.assertEqual(wrapper.__name__, 'f') 215 self.assertEqual(wrapper.attr, 'This is also a test') 216 217 @unittest.skipIf(sys.flags.optimize >= 2, 218 "Docstrings are omitted with -O2 and above") 219 def test_default_update_doc(self): 220 wrapper, f = self._default_update() 221 self.assertEqual(wrapper.__doc__, 'This is a test') 222 223 def test_no_update(self): 224 def f(): 225 """This is a test""" 226 pass 227 f.attr = 'This is also a test' 228 def wrapper(): 229 pass 230 functools.update_wrapper(wrapper, f, (), ()) 231 self.check_wrapper(wrapper, f, (), ()) 232 self.assertEqual(wrapper.__name__, 'wrapper') 233 self.assertEqual(wrapper.__doc__, None) 234 self.assertFalse(hasattr(wrapper, 'attr')) 235 236 def test_selective_update(self): 237 def f(): 238 pass 239 f.attr = 'This is a different test' 240 f.dict_attr = dict(a=1, b=2, c=3) 241 def wrapper(): 242 pass 243 wrapper.dict_attr = {} 244 assign = ('attr',) 245 update = ('dict_attr',) 246 functools.update_wrapper(wrapper, f, assign, update) 247 self.check_wrapper(wrapper, f, assign, update) 248 self.assertEqual(wrapper.__name__, 'wrapper') 249 self.assertEqual(wrapper.__doc__, None) 250 self.assertEqual(wrapper.attr, 'This is a different test') 251 self.assertEqual(wrapper.dict_attr, f.dict_attr) 252 253 @test_support.requires_docstrings 254 def test_builtin_update(self): 255 # Test for bug #1576241 256 def wrapper(): 257 pass 258 functools.update_wrapper(wrapper, max) 259 self.assertEqual(wrapper.__name__, 'max') 260 self.assertTrue(wrapper.__doc__.startswith('max(')) 261 262 class TestWraps(TestUpdateWrapper): 263 264 def _default_update(self): 265 def f(): 266 """This is a test""" 267 pass 268 f.attr = 'This is also a test' 269 @functools.wraps(f) 270 def wrapper(): 271 pass 272 self.check_wrapper(wrapper, f) 273 return wrapper 274 275 def test_default_update(self): 276 wrapper = self._default_update() 277 self.assertEqual(wrapper.__name__, 'f') 278 self.assertEqual(wrapper.attr, 'This is also a test') 279 280 @unittest.skipIf(sys.flags.optimize >= 2, 281 "Docstrings are omitted with -O2 and above") 282 def test_default_update_doc(self): 283 wrapper = self._default_update() 284 self.assertEqual(wrapper.__doc__, 'This is a test') 285 286 def test_no_update(self): 287 def f(): 288 """This is a test""" 289 pass 290 f.attr = 'This is also a test' 291 @functools.wraps(f, (), ()) 292 def wrapper(): 293 pass 294 self.check_wrapper(wrapper, f, (), ()) 295 self.assertEqual(wrapper.__name__, 'wrapper') 296 self.assertEqual(wrapper.__doc__, None) 297 self.assertFalse(hasattr(wrapper, 'attr')) 298 299 def test_selective_update(self): 300 def f(): 301 pass 302 f.attr = 'This is a different test' 303 f.dict_attr = dict(a=1, b=2, c=3) 304 def add_dict_attr(f): 305 f.dict_attr = {} 306 return f 307 assign = ('attr',) 308 update = ('dict_attr',) 309 @functools.wraps(f, assign, update) 310 @add_dict_attr 311 def wrapper(): 312 pass 313 self.check_wrapper(wrapper, f, assign, update) 314 self.assertEqual(wrapper.__name__, 'wrapper') 315 self.assertEqual(wrapper.__doc__, None) 316 self.assertEqual(wrapper.attr, 'This is a different test') 317 self.assertEqual(wrapper.dict_attr, f.dict_attr) 318 319 320 class TestReduce(unittest.TestCase): 321 322 def test_reduce(self): 323 class Squares: 324 325 def __init__(self, max): 326 self.max = max 327 self.sofar = [] 328 329 def __len__(self): return len(self.sofar) 330 331 def __getitem__(self, i): 332 if not 0 <= i < self.max: raise IndexError 333 n = len(self.sofar) 334 while n <= i: 335 self.sofar.append(n*n) 336 n += 1 337 return self.sofar[i] 338 339 reduce = functools.reduce 340 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc') 341 self.assertEqual( 342 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []), 343 ['a','c','d','w'] 344 ) 345 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040) 346 self.assertEqual( 347 reduce(lambda x, y: x*y, range(2,21), 1L), 348 2432902008176640000L 349 ) 350 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285) 351 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285) 352 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0) 353 self.assertRaises(TypeError, reduce) 354 self.assertRaises(TypeError, reduce, 42, 42) 355 self.assertRaises(TypeError, reduce, 42, 42, 42) 356 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item 357 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item 358 self.assertRaises(TypeError, reduce, 42, (42, 42)) 359 360 class TestCmpToKey(unittest.TestCase): 361 def test_cmp_to_key(self): 362 def mycmp(x, y): 363 return y - x 364 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), 365 [4, 3, 2, 1, 0]) 366 367 def test_hash(self): 368 def mycmp(x, y): 369 return y - x 370 key = functools.cmp_to_key(mycmp) 371 k = key(10) 372 self.assertRaises(TypeError, hash(k)) 373 374 class TestTotalOrdering(unittest.TestCase): 375 376 def test_total_ordering_lt(self): 377 @functools.total_ordering 378 class A: 379 def __init__(self, value): 380 self.value = value 381 def __lt__(self, other): 382 return self.value < other.value 383 def __eq__(self, other): 384 return self.value == other.value 385 self.assertTrue(A(1) < A(2)) 386 self.assertTrue(A(2) > A(1)) 387 self.assertTrue(A(1) <= A(2)) 388 self.assertTrue(A(2) >= A(1)) 389 self.assertTrue(A(2) <= A(2)) 390 self.assertTrue(A(2) >= A(2)) 391 392 def test_total_ordering_le(self): 393 @functools.total_ordering 394 class A: 395 def __init__(self, value): 396 self.value = value 397 def __le__(self, other): 398 return self.value <= other.value 399 def __eq__(self, other): 400 return self.value == other.value 401 self.assertTrue(A(1) < A(2)) 402 self.assertTrue(A(2) > A(1)) 403 self.assertTrue(A(1) <= A(2)) 404 self.assertTrue(A(2) >= A(1)) 405 self.assertTrue(A(2) <= A(2)) 406 self.assertTrue(A(2) >= A(2)) 407 408 def test_total_ordering_gt(self): 409 @functools.total_ordering 410 class A: 411 def __init__(self, value): 412 self.value = value 413 def __gt__(self, other): 414 return self.value > other.value 415 def __eq__(self, other): 416 return self.value == other.value 417 self.assertTrue(A(1) < A(2)) 418 self.assertTrue(A(2) > A(1)) 419 self.assertTrue(A(1) <= A(2)) 420 self.assertTrue(A(2) >= A(1)) 421 self.assertTrue(A(2) <= A(2)) 422 self.assertTrue(A(2) >= A(2)) 423 424 def test_total_ordering_ge(self): 425 @functools.total_ordering 426 class A: 427 def __init__(self, value): 428 self.value = value 429 def __ge__(self, other): 430 return self.value >= other.value 431 def __eq__(self, other): 432 return self.value == other.value 433 self.assertTrue(A(1) < A(2)) 434 self.assertTrue(A(2) > A(1)) 435 self.assertTrue(A(1) <= A(2)) 436 self.assertTrue(A(2) >= A(1)) 437 self.assertTrue(A(2) <= A(2)) 438 self.assertTrue(A(2) >= A(2)) 439 440 def test_total_ordering_no_overwrite(self): 441 # new methods should not overwrite existing 442 @functools.total_ordering 443 class A(str): 444 pass 445 self.assertTrue(A("a") < A("b")) 446 self.assertTrue(A("b") > A("a")) 447 self.assertTrue(A("a") <= A("b")) 448 self.assertTrue(A("b") >= A("a")) 449 self.assertTrue(A("b") <= A("b")) 450 self.assertTrue(A("b") >= A("b")) 451 452 def test_no_operations_defined(self): 453 with self.assertRaises(ValueError): 454 @functools.total_ordering 455 class A: 456 pass 457 458 def test_bug_10042(self): 459 @functools.total_ordering 460 class TestTO: 461 def __init__(self, value): 462 self.value = value 463 def __eq__(self, other): 464 if isinstance(other, TestTO): 465 return self.value == other.value 466 return False 467 def __lt__(self, other): 468 if isinstance(other, TestTO): 469 return self.value < other.value 470 raise TypeError 471 with self.assertRaises(TypeError): 472 TestTO(8) <= () 473 474 def test_main(verbose=None): 475 test_classes = ( 476 TestPartial, 477 TestPartialSubclass, 478 TestPythonPartial, 479 TestUpdateWrapper, 480 TestTotalOrdering, 481 TestWraps, 482 TestReduce, 483 ) 484 test_support.run_unittest(*test_classes) 485 486 # verify reference counting 487 if verbose and hasattr(sys, "gettotalrefcount"): 488 import gc 489 counts = [None] * 5 490 for i in xrange(len(counts)): 491 test_support.run_unittest(*test_classes) 492 gc.collect() 493 counts[i] = sys.gettotalrefcount() 494 print counts 495 496 if __name__ == '__main__': 497 test_main(verbose=True) 498