1 import unittest 2 3 def funcattrs(**kwds): 4 def decorate(func): 5 func.__dict__.update(kwds) 6 return func 7 return decorate 8 9 class MiscDecorators (object): 10 @staticmethod 11 def author(name): 12 def decorate(func): 13 func.__dict__['author'] = name 14 return func 15 return decorate 16 17 # ----------------------------------------------- 18 19 class DbcheckError (Exception): 20 def __init__(self, exprstr, func, args, kwds): 21 # A real version of this would set attributes here 22 Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" % 23 (exprstr, func, args, kwds)) 24 25 26 def dbcheck(exprstr, globals=None, locals=None): 27 "Decorator to implement debugging assertions" 28 def decorate(func): 29 expr = compile(exprstr, "dbcheck-%s" % func.__name__, "eval") 30 def check(*args, **kwds): 31 if not eval(expr, globals, locals): 32 raise DbcheckError(exprstr, func, args, kwds) 33 return func(*args, **kwds) 34 return check 35 return decorate 36 37 # ----------------------------------------------- 38 39 def countcalls(counts): 40 "Decorator to count calls to a function" 41 def decorate(func): 42 func_name = func.__name__ 43 counts[func_name] = 0 44 def call(*args, **kwds): 45 counts[func_name] += 1 46 return func(*args, **kwds) 47 call.__name__ = func_name 48 return call 49 return decorate 50 51 # ----------------------------------------------- 52 53 def memoize(func): 54 saved = {} 55 def call(*args): 56 try: 57 return saved[args] 58 except KeyError: 59 res = func(*args) 60 saved[args] = res 61 return res 62 except TypeError: 63 # Unhashable argument 64 return func(*args) 65 call.__name__ = func.__name__ 66 return call 67 68 # ----------------------------------------------- 69 70 class TestDecorators(unittest.TestCase): 71 72 def test_single(self): 73 class C(object): 74 @staticmethod 75 def foo(): return 42 76 self.assertEqual(C.foo(), 42) 77 self.assertEqual(C().foo(), 42) 78 79 def test_staticmethod_function(self): 80 @staticmethod 81 def notamethod(x): 82 return x 83 self.assertRaises(TypeError, notamethod, 1) 84 85 def test_dotted(self): 86 decorators = MiscDecorators() 87 @decorators.author('Cleese') 88 def foo(): return 42 89 self.assertEqual(foo(), 42) 90 self.assertEqual(foo.author, 'Cleese') 91 92 def test_argforms(self): 93 # A few tests of argument passing, as we use restricted form 94 # of expressions for decorators. 95 96 def noteargs(*args, **kwds): 97 def decorate(func): 98 setattr(func, 'dbval', (args, kwds)) 99 return func 100 return decorate 101 102 args = ( 'Now', 'is', 'the', 'time' ) 103 kwds = dict(one=1, two=2) 104 @noteargs(*args, **kwds) 105 def f1(): return 42 106 self.assertEqual(f1(), 42) 107 self.assertEqual(f1.dbval, (args, kwds)) 108 109 @noteargs('terry', 'gilliam', eric='idle', john='cleese') 110 def f2(): return 84 111 self.assertEqual(f2(), 84) 112 self.assertEqual(f2.dbval, (('terry', 'gilliam'), 113 dict(eric='idle', john='cleese'))) 114 115 @noteargs(1, 2,) 116 def f3(): pass 117 self.assertEqual(f3.dbval, ((1, 2), {})) 118 119 def test_dbcheck(self): 120 @dbcheck('args[1] is not None') 121 def f(a, b): 122 return a + b 123 self.assertEqual(f(1, 2), 3) 124 self.assertRaises(DbcheckError, f, 1, None) 125 126 def test_memoize(self): 127 counts = {} 128 129 @memoize 130 @countcalls(counts) 131 def double(x): 132 return x * 2 133 self.assertEqual(double.__name__, 'double') 134 135 self.assertEqual(counts, dict(double=0)) 136 137 # Only the first call with a given argument bumps the call count: 138 # 139 self.assertEqual(double(2), 4) 140 self.assertEqual(counts['double'], 1) 141 self.assertEqual(double(2), 4) 142 self.assertEqual(counts['double'], 1) 143 self.assertEqual(double(3), 6) 144 self.assertEqual(counts['double'], 2) 145 146 # Unhashable arguments do not get memoized: 147 # 148 self.assertEqual(double([10]), [10, 10]) 149 self.assertEqual(counts['double'], 3) 150 self.assertEqual(double([10]), [10, 10]) 151 self.assertEqual(counts['double'], 4) 152 153 def test_errors(self): 154 # Test syntax restrictions - these are all compile-time errors: 155 # 156 for expr in [ "1+2", "x[3]", "(1, 2)" ]: 157 # Sanity check: is expr is a valid expression by itself? 158 compile(expr, "testexpr", "exec") 159 160 codestr = "@%s\ndef f(): pass" % expr 161 self.assertRaises(SyntaxError, compile, codestr, "test", "exec") 162 163 # You can't put multiple decorators on a single line: 164 # 165 self.assertRaises(SyntaxError, compile, 166 "@f1 @f2\ndef f(): pass", "test", "exec") 167 168 # Test runtime errors 169 170 def unimp(func): 171 raise NotImplementedError 172 context = dict(nullval=None, unimp=unimp) 173 174 for expr, exc in [ ("undef", NameError), 175 ("nullval", TypeError), 176 ("nullval.attr", AttributeError), 177 ("unimp", NotImplementedError)]: 178 codestr = "@%s\ndef f(): pass\nassert f() is None" % expr 179 code = compile(codestr, "test", "exec") 180 self.assertRaises(exc, eval, code, context) 181 182 def test_double(self): 183 class C(object): 184 @funcattrs(abc=1, xyz="haha") 185 @funcattrs(booh=42) 186 def foo(self): return 42 187 self.assertEqual(C().foo(), 42) 188 self.assertEqual(C.foo.abc, 1) 189 self.assertEqual(C.foo.xyz, "haha") 190 self.assertEqual(C.foo.booh, 42) 191 192 def test_order(self): 193 # Test that decorators are applied in the proper order to the function 194 # they are decorating. 195 def callnum(num): 196 """Decorator factory that returns a decorator that replaces the 197 passed-in function with one that returns the value of 'num'""" 198 def deco(func): 199 return lambda: num 200 return deco 201 @callnum(2) 202 @callnum(1) 203 def foo(): return 42 204 self.assertEqual(foo(), 2, 205 "Application order of decorators is incorrect") 206 207 def test_eval_order(self): 208 # Evaluating a decorated function involves four steps for each 209 # decorator-maker (the function that returns a decorator): 210 # 211 # 1: Evaluate the decorator-maker name 212 # 2: Evaluate the decorator-maker arguments (if any) 213 # 3: Call the decorator-maker to make a decorator 214 # 4: Call the decorator 215 # 216 # When there are multiple decorators, these steps should be 217 # performed in the above order for each decorator, but we should 218 # iterate through the decorators in the reverse of the order they 219 # appear in the source. 220 221 actions = [] 222 223 def make_decorator(tag): 224 actions.append('makedec' + tag) 225 def decorate(func): 226 actions.append('calldec' + tag) 227 return func 228 return decorate 229 230 class NameLookupTracer (object): 231 def __init__(self, index): 232 self.index = index 233 234 def __getattr__(self, fname): 235 if fname == 'make_decorator': 236 opname, res = ('evalname', make_decorator) 237 elif fname == 'arg': 238 opname, res = ('evalargs', str(self.index)) 239 else: 240 assert False, "Unknown attrname %s" % fname 241 actions.append('%s%d' % (opname, self.index)) 242 return res 243 244 c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ]) 245 246 expected_actions = [ 'evalname1', 'evalargs1', 'makedec1', 247 'evalname2', 'evalargs2', 'makedec2', 248 'evalname3', 'evalargs3', 'makedec3', 249 'calldec3', 'calldec2', 'calldec1' ] 250 251 actions = [] 252 @c1.make_decorator(c1.arg) 253 @c2.make_decorator(c2.arg) 254 @c3.make_decorator(c3.arg) 255 def foo(): return 42 256 self.assertEqual(foo(), 42) 257 258 self.assertEqual(actions, expected_actions) 259 260 # Test the equivalence claim in chapter 7 of the reference manual. 261 # 262 actions = [] 263 def bar(): return 42 264 bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar))) 265 self.assertEqual(bar(), 42) 266 self.assertEqual(actions, expected_actions) 267 268 class TestClassDecorators(unittest.TestCase): 269 270 def test_simple(self): 271 def plain(x): 272 x.extra = 'Hello' 273 return x 274 @plain 275 class C(object): pass 276 self.assertEqual(C.extra, 'Hello') 277 278 def test_double(self): 279 def ten(x): 280 x.extra = 10 281 return x 282 def add_five(x): 283 x.extra += 5 284 return x 285 286 @add_five 287 @ten 288 class C(object): pass 289 self.assertEqual(C.extra, 15) 290 291 def test_order(self): 292 def applied_first(x): 293 x.extra = 'first' 294 return x 295 def applied_second(x): 296 x.extra = 'second' 297 return x 298 @applied_second 299 @applied_first 300 class C(object): pass 301 self.assertEqual(C.extra, 'second') 302 303 if __name__ == "__main__": 304 unittest.main() 305