1 import functools 2 import sys 3 import unittest 4 from test import test_support 5 from weakref import proxy 6 import pickle 7 8 @staticmethod 9 def PythonPartial(func, *args, **keywords): 10 'Pure Python approximation of partial()' 11 def newfunc(*fargs, **fkeywords): 12 newkeywords = keywords.copy() 13 newkeywords.update(fkeywords) 14 return func(*(args + fargs), **newkeywords) 15 newfunc.func = func 16 newfunc.args = args 17 newfunc.keywords = keywords 18 return newfunc 19 20 def capture(*args, **kw): 21 """capture all positional and keyword arguments""" 22 return args, kw 23 24 def signature(part): 25 """ return the signature of a partial object """ 26 return (part.func, part.args, part.keywords, part.__dict__) 27 28 class TestPartial(unittest.TestCase): 29 30 thetype = functools.partial 31 32 def test_basic_examples(self): 33 p = self.thetype(capture, 1, 2, a=10, b=20) 34 self.assertEqual(p(3, 4, b=30, c=40), 35 ((1, 2, 3, 4), dict(a=10, b=30, c=40))) 36 p = self.thetype(map, lambda x: x*10) 37 self.assertEqual(p([1,2,3,4]), [10, 20, 30, 40]) 38 39 def test_attributes(self): 40 p = self.thetype(capture, 1, 2, a=10, b=20) 41 # attributes should be readable 42 self.assertEqual(p.func, capture) 43 self.assertEqual(p.args, (1, 2)) 44 self.assertEqual(p.keywords, dict(a=10, b=20)) 45 # attributes should not be writable 46 if not isinstance(self.thetype, type): 47 return 48 self.assertRaises(TypeError, setattr, p, 'func', map) 49 self.assertRaises(TypeError, setattr, p, 'args', (1, 2)) 50 self.assertRaises(TypeError, setattr, p, 'keywords', dict(a=1, b=2)) 51 52 p = self.thetype(hex) 53 try: 54 del p.__dict__ 55 except TypeError: 56 pass 57 else: 58 self.fail('partial object allowed __dict__ to be deleted') 59 60 def test_argument_checking(self): 61 self.assertRaises(TypeError, self.thetype) # need at least a func arg 62 try: 63 self.thetype(2)() 64 except TypeError: 65 pass 66 else: 67 self.fail('First arg not checked for callability') 68 69 def test_protection_of_callers_dict_argument(self): 70 # a caller's dictionary should not be altered by partial 71 def func(a=10, b=20): 72 return a 73 d = {'a':3} 74 p = self.thetype(func, a=5) 75 self.assertEqual(p(**d), 3) 76 self.assertEqual(d, {'a':3}) 77 p(b=7) 78 self.assertEqual(d, {'a':3}) 79 80 def test_arg_combinations(self): 81 # exercise special code paths for zero args in either partial 82 # object or the caller 83 p = self.thetype(capture) 84 self.assertEqual(p(), ((), {})) 85 self.assertEqual(p(1,2), ((1,2), {})) 86 p = self.thetype(capture, 1, 2) 87 self.assertEqual(p(), ((1,2), {})) 88 self.assertEqual(p(3,4), ((1,2,3,4), {})) 89 90 def test_kw_combinations(self): 91 # exercise special code paths for no keyword args in 92 # either the partial object or the caller 93 p = self.thetype(capture) 94 self.assertEqual(p(), ((), {})) 95 self.assertEqual(p(a=1), ((), {'a':1})) 96 p = self.thetype(capture, a=1) 97 self.assertEqual(p(), ((), {'a':1})) 98 self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) 99 # keyword args in the call override those in the partial object 100 self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) 101 102 def test_positional(self): 103 # make sure positional arguments are captured correctly 104 for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: 105 p = self.thetype(capture, *args) 106 expected = args + ('x',) 107 got, empty = p('x') 108 self.assertTrue(expected == got and empty == {}) 109 110 def test_keyword(self): 111 # make sure keyword arguments are captured correctly 112 for a in ['a', 0, None, 3.5]: 113 p = self.thetype(capture, a=a) 114 expected = {'a':a,'x':None} 115 empty, got = p(x=None) 116 self.assertTrue(expected == got and empty == ()) 117 118 def test_no_side_effects(self): 119 # make sure there are no side effects that affect subsequent calls 120 p = self.thetype(capture, 0, a=1) 121 args1, kw1 = p(1, b=2) 122 self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) 123 args2, kw2 = p() 124 self.assertTrue(args2 == (0,) and kw2 == {'a':1}) 125 126 def test_error_propagation(self): 127 def f(x, y): 128 x // y 129 self.assertRaises(ZeroDivisionError, self.thetype(f, 1, 0)) 130 self.assertRaises(ZeroDivisionError, self.thetype(f, 1), 0) 131 self.assertRaises(ZeroDivisionError, self.thetype(f), 1, 0) 132 self.assertRaises(ZeroDivisionError, self.thetype(f, y=0), 1) 133 134 def test_weakref(self): 135 f = self.thetype(int, base=16) 136 p = proxy(f) 137 self.assertEqual(f.func, p.func) 138 f = None 139 self.assertRaises(ReferenceError, getattr, p, 'func') 140 141 def test_with_bound_and_unbound_methods(self): 142 data = map(str, range(10)) 143 join = self.thetype(str.join, '') 144 self.assertEqual(join(data), '0123456789') 145 join = self.thetype(''.join) 146 self.assertEqual(join(data), '0123456789') 147 148 def test_pickle(self): 149 f = self.thetype(signature, 'asdf', bar=True) 150 f.add_something_to__dict__ = True 151 f_copy = pickle.loads(pickle.dumps(f)) 152 self.assertEqual(signature(f), signature(f_copy)) 153 154 class PartialSubclass(functools.partial): 155 pass 156 157 class TestPartialSubclass(TestPartial): 158 159 thetype = PartialSubclass 160 161 class TestPythonPartial(TestPartial): 162 163 thetype = PythonPartial 164 165 # the python version isn't picklable 166 def test_pickle(self): pass 167 168 class TestUpdateWrapper(unittest.TestCase): 169 170 def check_wrapper(self, wrapper, wrapped, 171 assigned=functools.WRAPPER_ASSIGNMENTS, 172 updated=functools.WRAPPER_UPDATES): 173 # Check attributes were assigned 174 for name in assigned: 175 self.assertTrue(getattr(wrapper, name) is getattr(wrapped, name)) 176 # Check attributes were updated 177 for name in updated: 178 wrapper_attr = getattr(wrapper, name) 179 wrapped_attr = getattr(wrapped, name) 180 for key in wrapped_attr: 181 self.assertTrue(wrapped_attr[key] is wrapper_attr[key]) 182 183 def _default_update(self): 184 def f(): 185 """This is a test""" 186 pass 187 f.attr = 'This is also a test' 188 def wrapper(): 189 pass 190 functools.update_wrapper(wrapper, f) 191 return wrapper, f 192 193 def test_default_update(self): 194 wrapper, f = self._default_update() 195 self.check_wrapper(wrapper, f) 196 self.assertEqual(wrapper.__name__, 'f') 197 self.assertEqual(wrapper.attr, 'This is also a test') 198 199 @unittest.skipIf(sys.flags.optimize >= 2, 200 "Docstrings are omitted with -O2 and above") 201 def test_default_update_doc(self): 202 wrapper, f = self._default_update() 203 self.assertEqual(wrapper.__doc__, 'This is a test') 204 205 def test_no_update(self): 206 def f(): 207 """This is a test""" 208 pass 209 f.attr = 'This is also a test' 210 def wrapper(): 211 pass 212 functools.update_wrapper(wrapper, f, (), ()) 213 self.check_wrapper(wrapper, f, (), ()) 214 self.assertEqual(wrapper.__name__, 'wrapper') 215 self.assertEqual(wrapper.__doc__, None) 216 self.assertFalse(hasattr(wrapper, 'attr')) 217 218 def test_selective_update(self): 219 def f(): 220 pass 221 f.attr = 'This is a different test' 222 f.dict_attr = dict(a=1, b=2, c=3) 223 def wrapper(): 224 pass 225 wrapper.dict_attr = {} 226 assign = ('attr',) 227 update = ('dict_attr',) 228 functools.update_wrapper(wrapper, f, assign, update) 229 self.check_wrapper(wrapper, f, assign, update) 230 self.assertEqual(wrapper.__name__, 'wrapper') 231 self.assertEqual(wrapper.__doc__, None) 232 self.assertEqual(wrapper.attr, 'This is a different test') 233 self.assertEqual(wrapper.dict_attr, f.dict_attr) 234 235 def test_builtin_update(self): 236 # Test for bug #1576241 237 def wrapper(): 238 pass 239 functools.update_wrapper(wrapper, max) 240 self.assertEqual(wrapper.__name__, 'max') 241 self.assertTrue(wrapper.__doc__.startswith('max(')) 242 243 class TestWraps(TestUpdateWrapper): 244 245 def _default_update(self): 246 def f(): 247 """This is a test""" 248 pass 249 f.attr = 'This is also a test' 250 @functools.wraps(f) 251 def wrapper(): 252 pass 253 self.check_wrapper(wrapper, f) 254 return wrapper 255 256 def test_default_update(self): 257 wrapper = self._default_update() 258 self.assertEqual(wrapper.__name__, 'f') 259 self.assertEqual(wrapper.attr, 'This is also a test') 260 261 @unittest.skipIf(not sys.flags.optimize <= 1, 262 "Docstrings are omitted with -O2 and above") 263 def test_default_update_doc(self): 264 wrapper = self._default_update() 265 self.assertEqual(wrapper.__doc__, 'This is a test') 266 267 def test_no_update(self): 268 def f(): 269 """This is a test""" 270 pass 271 f.attr = 'This is also a test' 272 @functools.wraps(f, (), ()) 273 def wrapper(): 274 pass 275 self.check_wrapper(wrapper, f, (), ()) 276 self.assertEqual(wrapper.__name__, 'wrapper') 277 self.assertEqual(wrapper.__doc__, None) 278 self.assertFalse(hasattr(wrapper, 'attr')) 279 280 def test_selective_update(self): 281 def f(): 282 pass 283 f.attr = 'This is a different test' 284 f.dict_attr = dict(a=1, b=2, c=3) 285 def add_dict_attr(f): 286 f.dict_attr = {} 287 return f 288 assign = ('attr',) 289 update = ('dict_attr',) 290 @functools.wraps(f, assign, update) 291 @add_dict_attr 292 def wrapper(): 293 pass 294 self.check_wrapper(wrapper, f, assign, update) 295 self.assertEqual(wrapper.__name__, 'wrapper') 296 self.assertEqual(wrapper.__doc__, None) 297 self.assertEqual(wrapper.attr, 'This is a different test') 298 self.assertEqual(wrapper.dict_attr, f.dict_attr) 299 300 301 class TestReduce(unittest.TestCase): 302 303 def test_reduce(self): 304 class Squares: 305 306 def __init__(self, max): 307 self.max = max 308 self.sofar = [] 309 310 def __len__(self): return len(self.sofar) 311 312 def __getitem__(self, i): 313 if not 0 <= i < self.max: raise IndexError 314 n = len(self.sofar) 315 while n <= i: 316 self.sofar.append(n*n) 317 n += 1 318 return self.sofar[i] 319 320 reduce = functools.reduce 321 self.assertEqual(reduce(lambda x, y: x+y, ['a', 'b', 'c'], ''), 'abc') 322 self.assertEqual( 323 reduce(lambda x, y: x+y, [['a', 'c'], [], ['d', 'w']], []), 324 ['a','c','d','w'] 325 ) 326 self.assertEqual(reduce(lambda x, y: x*y, range(2,8), 1), 5040) 327 self.assertEqual( 328 reduce(lambda x, y: x*y, range(2,21), 1L), 329 2432902008176640000L 330 ) 331 self.assertEqual(reduce(lambda x, y: x+y, Squares(10)), 285) 332 self.assertEqual(reduce(lambda x, y: x+y, Squares(10), 0), 285) 333 self.assertEqual(reduce(lambda x, y: x+y, Squares(0), 0), 0) 334 self.assertRaises(TypeError, reduce) 335 self.assertRaises(TypeError, reduce, 42, 42) 336 self.assertRaises(TypeError, reduce, 42, 42, 42) 337 self.assertEqual(reduce(42, "1"), "1") # func is never called with one item 338 self.assertEqual(reduce(42, "", "1"), "1") # func is never called with one item 339 self.assertRaises(TypeError, reduce, 42, (42, 42)) 340 341 class TestCmpToKey(unittest.TestCase): 342 def test_cmp_to_key(self): 343 def mycmp(x, y): 344 return y - x 345 self.assertEqual(sorted(range(5), key=functools.cmp_to_key(mycmp)), 346 [4, 3, 2, 1, 0]) 347 348 def test_hash(self): 349 def mycmp(x, y): 350 return y - x 351 key = functools.cmp_to_key(mycmp) 352 k = key(10) 353 self.assertRaises(TypeError, hash(k)) 354 355 class TestTotalOrdering(unittest.TestCase): 356 357 def test_total_ordering_lt(self): 358 @functools.total_ordering 359 class A: 360 def __init__(self, value): 361 self.value = value 362 def __lt__(self, other): 363 return self.value < other.value 364 def __eq__(self, other): 365 return self.value == other.value 366 self.assertTrue(A(1) < A(2)) 367 self.assertTrue(A(2) > A(1)) 368 self.assertTrue(A(1) <= A(2)) 369 self.assertTrue(A(2) >= A(1)) 370 self.assertTrue(A(2) <= A(2)) 371 self.assertTrue(A(2) >= A(2)) 372 373 def test_total_ordering_le(self): 374 @functools.total_ordering 375 class A: 376 def __init__(self, value): 377 self.value = value 378 def __le__(self, other): 379 return self.value <= other.value 380 def __eq__(self, other): 381 return self.value == other.value 382 self.assertTrue(A(1) < A(2)) 383 self.assertTrue(A(2) > A(1)) 384 self.assertTrue(A(1) <= A(2)) 385 self.assertTrue(A(2) >= A(1)) 386 self.assertTrue(A(2) <= A(2)) 387 self.assertTrue(A(2) >= A(2)) 388 389 def test_total_ordering_gt(self): 390 @functools.total_ordering 391 class A: 392 def __init__(self, value): 393 self.value = value 394 def __gt__(self, other): 395 return self.value > other.value 396 def __eq__(self, other): 397 return self.value == other.value 398 self.assertTrue(A(1) < A(2)) 399 self.assertTrue(A(2) > A(1)) 400 self.assertTrue(A(1) <= A(2)) 401 self.assertTrue(A(2) >= A(1)) 402 self.assertTrue(A(2) <= A(2)) 403 self.assertTrue(A(2) >= A(2)) 404 405 def test_total_ordering_ge(self): 406 @functools.total_ordering 407 class A: 408 def __init__(self, value): 409 self.value = value 410 def __ge__(self, other): 411 return self.value >= other.value 412 def __eq__(self, other): 413 return self.value == other.value 414 self.assertTrue(A(1) < A(2)) 415 self.assertTrue(A(2) > A(1)) 416 self.assertTrue(A(1) <= A(2)) 417 self.assertTrue(A(2) >= A(1)) 418 self.assertTrue(A(2) <= A(2)) 419 self.assertTrue(A(2) >= A(2)) 420 421 def test_total_ordering_no_overwrite(self): 422 # new methods should not overwrite existing 423 @functools.total_ordering 424 class A(str): 425 pass 426 self.assertTrue(A("a") < A("b")) 427 self.assertTrue(A("b") > A("a")) 428 self.assertTrue(A("a") <= A("b")) 429 self.assertTrue(A("b") >= A("a")) 430 self.assertTrue(A("b") <= A("b")) 431 self.assertTrue(A("b") >= A("b")) 432 433 def test_no_operations_defined(self): 434 with self.assertRaises(ValueError): 435 @functools.total_ordering 436 class A: 437 pass 438 439 def test_bug_10042(self): 440 @functools.total_ordering 441 class TestTO: 442 def __init__(self, value): 443 self.value = value 444 def __eq__(self, other): 445 if isinstance(other, TestTO): 446 return self.value == other.value 447 return False 448 def __lt__(self, other): 449 if isinstance(other, TestTO): 450 return self.value < other.value 451 raise TypeError 452 with self.assertRaises(TypeError): 453 TestTO(8) <= () 454 455 def test_main(verbose=None): 456 test_classes = ( 457 TestPartial, 458 TestPartialSubclass, 459 TestPythonPartial, 460 TestUpdateWrapper, 461 TestTotalOrdering, 462 TestWraps, 463 TestReduce, 464 ) 465 test_support.run_unittest(*test_classes) 466 467 # verify reference counting 468 if verbose and hasattr(sys, "gettotalrefcount"): 469 import gc 470 counts = [None] * 5 471 for i in xrange(len(counts)): 472 test_support.run_unittest(*test_classes) 473 gc.collect() 474 counts[i] = sys.gettotalrefcount() 475 print counts 476 477 if __name__ == '__main__': 478 test_main(verbose=True) 479