Home | History | Annotate | Download | only in test
      1 import unittest
      2 
      3 def funcattrs(**kwds):
      4     def decorate(func):
      5         func.__dict__.update(kwds)
      6         return func
      7     return decorate
      8 
      9 class MiscDecorators (object):
     10     @staticmethod
     11     def author(name):
     12         def decorate(func):
     13             func.__dict__['author'] = name
     14             return func
     15         return decorate
     16 
     17 # -----------------------------------------------
     18 
     19 class DbcheckError (Exception):
     20     def __init__(self, exprstr, func, args, kwds):
     21         # A real version of this would set attributes here
     22         Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" %
     23                            (exprstr, func, args, kwds))
     24 
     25 
     26 def dbcheck(exprstr, globals=None, locals=None):
     27     "Decorator to implement debugging assertions"
     28     def decorate(func):
     29         expr = compile(exprstr, "dbcheck-%s" % func.__name__, "eval")
     30         def check(*args, **kwds):
     31             if not eval(expr, globals, locals):
     32                 raise DbcheckError(exprstr, func, args, kwds)
     33             return func(*args, **kwds)
     34         return check
     35     return decorate
     36 
     37 # -----------------------------------------------
     38 
     39 def countcalls(counts):
     40     "Decorator to count calls to a function"
     41     def decorate(func):
     42         func_name = func.__name__
     43         counts[func_name] = 0
     44         def call(*args, **kwds):
     45             counts[func_name] += 1
     46             return func(*args, **kwds)
     47         call.__name__ = func_name
     48         return call
     49     return decorate
     50 
     51 # -----------------------------------------------
     52 
     53 def memoize(func):
     54     saved = {}
     55     def call(*args):
     56         try:
     57             return saved[args]
     58         except KeyError:
     59             res = func(*args)
     60             saved[args] = res
     61             return res
     62         except TypeError:
     63             # Unhashable argument
     64             return func(*args)
     65     call.__name__ = func.__name__
     66     return call
     67 
     68 # -----------------------------------------------
     69 
     70 class TestDecorators(unittest.TestCase):
     71 
     72     def test_single(self):
     73         class C(object):
     74             @staticmethod
     75             def foo(): return 42
     76         self.assertEqual(C.foo(), 42)
     77         self.assertEqual(C().foo(), 42)
     78 
     79     def test_staticmethod_function(self):
     80         @staticmethod
     81         def notamethod(x):
     82             return x
     83         self.assertRaises(TypeError, notamethod, 1)
     84 
     85     def test_dotted(self):
     86         decorators = MiscDecorators()
     87         @decorators.author('Cleese')
     88         def foo(): return 42
     89         self.assertEqual(foo(), 42)
     90         self.assertEqual(foo.author, 'Cleese')
     91 
     92     def test_argforms(self):
     93         # A few tests of argument passing, as we use restricted form
     94         # of expressions for decorators.
     95 
     96         def noteargs(*args, **kwds):
     97             def decorate(func):
     98                 setattr(func, 'dbval', (args, kwds))
     99                 return func
    100             return decorate
    101 
    102         args = ( 'Now', 'is', 'the', 'time' )
    103         kwds = dict(one=1, two=2)
    104         @noteargs(*args, **kwds)
    105         def f1(): return 42
    106         self.assertEqual(f1(), 42)
    107         self.assertEqual(f1.dbval, (args, kwds))
    108 
    109         @noteargs('terry', 'gilliam', eric='idle', john='cleese')
    110         def f2(): return 84
    111         self.assertEqual(f2(), 84)
    112         self.assertEqual(f2.dbval, (('terry', 'gilliam'),
    113                                      dict(eric='idle', john='cleese')))
    114 
    115         @noteargs(1, 2,)
    116         def f3(): pass
    117         self.assertEqual(f3.dbval, ((1, 2), {}))
    118 
    119     def test_dbcheck(self):
    120         @dbcheck('args[1] is not None')
    121         def f(a, b):
    122             return a + b
    123         self.assertEqual(f(1, 2), 3)
    124         self.assertRaises(DbcheckError, f, 1, None)
    125 
    126     def test_memoize(self):
    127         counts = {}
    128 
    129         @memoize
    130         @countcalls(counts)
    131         def double(x):
    132             return x * 2
    133         self.assertEqual(double.__name__, 'double')
    134 
    135         self.assertEqual(counts, dict(double=0))
    136 
    137         # Only the first call with a given argument bumps the call count:
    138         #
    139         self.assertEqual(double(2), 4)
    140         self.assertEqual(counts['double'], 1)
    141         self.assertEqual(double(2), 4)
    142         self.assertEqual(counts['double'], 1)
    143         self.assertEqual(double(3), 6)
    144         self.assertEqual(counts['double'], 2)
    145 
    146         # Unhashable arguments do not get memoized:
    147         #
    148         self.assertEqual(double([10]), [10, 10])
    149         self.assertEqual(counts['double'], 3)
    150         self.assertEqual(double([10]), [10, 10])
    151         self.assertEqual(counts['double'], 4)
    152 
    153     def test_errors(self):
    154         # Test syntax restrictions - these are all compile-time errors:
    155         #
    156         for expr in [ "1+2", "x[3]", "(1, 2)" ]:
    157             # Sanity check: is expr is a valid expression by itself?
    158             compile(expr, "testexpr", "exec")
    159 
    160             codestr = "@%s\ndef f(): pass" % expr
    161             self.assertRaises(SyntaxError, compile, codestr, "test", "exec")
    162 
    163         # You can't put multiple decorators on a single line:
    164         #
    165         self.assertRaises(SyntaxError, compile,
    166                           "@f1 @f2\ndef f(): pass", "test", "exec")
    167 
    168         # Test runtime errors
    169 
    170         def unimp(func):
    171             raise NotImplementedError
    172         context = dict(nullval=None, unimp=unimp)
    173 
    174         for expr, exc in [ ("undef", NameError),
    175                            ("nullval", TypeError),
    176                            ("nullval.attr", AttributeError),
    177                            ("unimp", NotImplementedError)]:
    178             codestr = "@%s\ndef f(): pass\nassert f() is None" % expr
    179             code = compile(codestr, "test", "exec")
    180             self.assertRaises(exc, eval, code, context)
    181 
    182     def test_double(self):
    183         class C(object):
    184             @funcattrs(abc=1, xyz="haha")
    185             @funcattrs(booh=42)
    186             def foo(self): return 42
    187         self.assertEqual(C().foo(), 42)
    188         self.assertEqual(C.foo.abc, 1)
    189         self.assertEqual(C.foo.xyz, "haha")
    190         self.assertEqual(C.foo.booh, 42)
    191 
    192     def test_order(self):
    193         # Test that decorators are applied in the proper order to the function
    194         # they are decorating.
    195         def callnum(num):
    196             """Decorator factory that returns a decorator that replaces the
    197             passed-in function with one that returns the value of 'num'"""
    198             def deco(func):
    199                 return lambda: num
    200             return deco
    201         @callnum(2)
    202         @callnum(1)
    203         def foo(): return 42
    204         self.assertEqual(foo(), 2,
    205                             "Application order of decorators is incorrect")
    206 
    207     def test_eval_order(self):
    208         # Evaluating a decorated function involves four steps for each
    209         # decorator-maker (the function that returns a decorator):
    210         #
    211         #    1: Evaluate the decorator-maker name
    212         #    2: Evaluate the decorator-maker arguments (if any)
    213         #    3: Call the decorator-maker to make a decorator
    214         #    4: Call the decorator
    215         #
    216         # When there are multiple decorators, these steps should be
    217         # performed in the above order for each decorator, but we should
    218         # iterate through the decorators in the reverse of the order they
    219         # appear in the source.
    220 
    221         actions = []
    222 
    223         def make_decorator(tag):
    224             actions.append('makedec' + tag)
    225             def decorate(func):
    226                 actions.append('calldec' + tag)
    227                 return func
    228             return decorate
    229 
    230         class NameLookupTracer (object):
    231             def __init__(self, index):
    232                 self.index = index
    233 
    234             def __getattr__(self, fname):
    235                 if fname == 'make_decorator':
    236                     opname, res = ('evalname', make_decorator)
    237                 elif fname == 'arg':
    238                     opname, res = ('evalargs', str(self.index))
    239                 else:
    240                     assert False, "Unknown attrname %s" % fname
    241                 actions.append('%s%d' % (opname, self.index))
    242                 return res
    243 
    244         c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ])
    245 
    246         expected_actions = [ 'evalname1', 'evalargs1', 'makedec1',
    247                              'evalname2', 'evalargs2', 'makedec2',
    248                              'evalname3', 'evalargs3', 'makedec3',
    249                              'calldec3', 'calldec2', 'calldec1' ]
    250 
    251         actions = []
    252         @c1.make_decorator(c1.arg)
    253         @c2.make_decorator(c2.arg)
    254         @c3.make_decorator(c3.arg)
    255         def foo(): return 42
    256         self.assertEqual(foo(), 42)
    257 
    258         self.assertEqual(actions, expected_actions)
    259 
    260         # Test the equivalence claim in chapter 7 of the reference manual.
    261         #
    262         actions = []
    263         def bar(): return 42
    264         bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar)))
    265         self.assertEqual(bar(), 42)
    266         self.assertEqual(actions, expected_actions)
    267 
    268 class TestClassDecorators(unittest.TestCase):
    269 
    270     def test_simple(self):
    271         def plain(x):
    272             x.extra = 'Hello'
    273             return x
    274         @plain
    275         class C(object): pass
    276         self.assertEqual(C.extra, 'Hello')
    277 
    278     def test_double(self):
    279         def ten(x):
    280             x.extra = 10
    281             return x
    282         def add_five(x):
    283             x.extra += 5
    284             return x
    285 
    286         @add_five
    287         @ten
    288         class C(object): pass
    289         self.assertEqual(C.extra, 15)
    290 
    291     def test_order(self):
    292         def applied_first(x):
    293             x.extra = 'first'
    294             return x
    295         def applied_second(x):
    296             x.extra = 'second'
    297             return x
    298         @applied_second
    299         @applied_first
    300         class C(object): pass
    301         self.assertEqual(C.extra, 'second')
    302 
    303 if __name__ == "__main__":
    304     unittest.main()
    305