Home | History | Annotate | Download | only in test
      1 import unittest
      2 from test import test_support
      3 
      4 def funcattrs(**kwds):
      5     def decorate(func):
      6         func.__dict__.update(kwds)
      7         return func
      8     return decorate
      9 
     10 class MiscDecorators (object):
     11     @staticmethod
     12     def author(name):
     13         def decorate(func):
     14             func.__dict__['author'] = name
     15             return func
     16         return decorate
     17 
     18 # -----------------------------------------------
     19 
     20 class DbcheckError (Exception):
     21     def __init__(self, exprstr, func, args, kwds):
     22         # A real version of this would set attributes here
     23         Exception.__init__(self, "dbcheck %r failed (func=%s args=%s kwds=%s)" %
     24                            (exprstr, func, args, kwds))
     25 
     26 
     27 def dbcheck(exprstr, globals=None, locals=None):
     28     "Decorator to implement debugging assertions"
     29     def decorate(func):
     30         expr = compile(exprstr, "dbcheck-%s" % func.func_name, "eval")
     31         def check(*args, **kwds):
     32             if not eval(expr, globals, locals):
     33                 raise DbcheckError(exprstr, func, args, kwds)
     34             return func(*args, **kwds)
     35         return check
     36     return decorate
     37 
     38 # -----------------------------------------------
     39 
     40 def countcalls(counts):
     41     "Decorator to count calls to a function"
     42     def decorate(func):
     43         func_name = func.func_name
     44         counts[func_name] = 0
     45         def call(*args, **kwds):
     46             counts[func_name] += 1
     47             return func(*args, **kwds)
     48         call.func_name = func_name
     49         return call
     50     return decorate
     51 
     52 # -----------------------------------------------
     53 
     54 def memoize(func):
     55     saved = {}
     56     def call(*args):
     57         try:
     58             return saved[args]
     59         except KeyError:
     60             res = func(*args)
     61             saved[args] = res
     62             return res
     63         except TypeError:
     64             # Unhashable argument
     65             return func(*args)
     66     call.func_name = func.func_name
     67     return call
     68 
     69 # -----------------------------------------------
     70 
     71 class TestDecorators(unittest.TestCase):
     72 
     73     def test_single(self):
     74         class C(object):
     75             @staticmethod
     76             def foo(): return 42
     77         self.assertEqual(C.foo(), 42)
     78         self.assertEqual(C().foo(), 42)
     79 
     80     def test_staticmethod_function(self):
     81         @staticmethod
     82         def notamethod(x):
     83             return x
     84         self.assertRaises(TypeError, notamethod, 1)
     85 
     86     def test_dotted(self):
     87         decorators = MiscDecorators()
     88         @decorators.author('Cleese')
     89         def foo(): return 42
     90         self.assertEqual(foo(), 42)
     91         self.assertEqual(foo.author, 'Cleese')
     92 
     93     def test_argforms(self):
     94         # A few tests of argument passing, as we use restricted form
     95         # of expressions for decorators.
     96 
     97         def noteargs(*args, **kwds):
     98             def decorate(func):
     99                 setattr(func, 'dbval', (args, kwds))
    100                 return func
    101             return decorate
    102 
    103         args = ( 'Now', 'is', 'the', 'time' )
    104         kwds = dict(one=1, two=2)
    105         @noteargs(*args, **kwds)
    106         def f1(): return 42
    107         self.assertEqual(f1(), 42)
    108         self.assertEqual(f1.dbval, (args, kwds))
    109 
    110         @noteargs('terry', 'gilliam', eric='idle', john='cleese')
    111         def f2(): return 84
    112         self.assertEqual(f2(), 84)
    113         self.assertEqual(f2.dbval, (('terry', 'gilliam'),
    114                                      dict(eric='idle', john='cleese')))
    115 
    116         @noteargs(1, 2,)
    117         def f3(): pass
    118         self.assertEqual(f3.dbval, ((1, 2), {}))
    119 
    120     def test_dbcheck(self):
    121         @dbcheck('args[1] is not None')
    122         def f(a, b):
    123             return a + b
    124         self.assertEqual(f(1, 2), 3)
    125         self.assertRaises(DbcheckError, f, 1, None)
    126 
    127     def test_memoize(self):
    128         counts = {}
    129 
    130         @memoize
    131         @countcalls(counts)
    132         def double(x):
    133             return x * 2
    134         self.assertEqual(double.func_name, 'double')
    135 
    136         self.assertEqual(counts, dict(double=0))
    137 
    138         # Only the first call with a given argument bumps the call count:
    139         #
    140         self.assertEqual(double(2), 4)
    141         self.assertEqual(counts['double'], 1)
    142         self.assertEqual(double(2), 4)
    143         self.assertEqual(counts['double'], 1)
    144         self.assertEqual(double(3), 6)
    145         self.assertEqual(counts['double'], 2)
    146 
    147         # Unhashable arguments do not get memoized:
    148         #
    149         self.assertEqual(double([10]), [10, 10])
    150         self.assertEqual(counts['double'], 3)
    151         self.assertEqual(double([10]), [10, 10])
    152         self.assertEqual(counts['double'], 4)
    153 
    154     def test_errors(self):
    155         # Test syntax restrictions - these are all compile-time errors:
    156         #
    157         for expr in [ "1+2", "x[3]", "(1, 2)" ]:
    158             # Sanity check: is expr is a valid expression by itself?
    159             compile(expr, "testexpr", "exec")
    160 
    161             codestr = "@%s\ndef f(): pass" % expr
    162             self.assertRaises(SyntaxError, compile, codestr, "test", "exec")
    163 
    164         # You can't put multiple decorators on a single line:
    165         #
    166         self.assertRaises(SyntaxError, compile,
    167                           "@f1 @f2\ndef f(): pass", "test", "exec")
    168 
    169         # Test runtime errors
    170 
    171         def unimp(func):
    172             raise NotImplementedError
    173         context = dict(nullval=None, unimp=unimp)
    174 
    175         for expr, exc in [ ("undef", NameError),
    176                            ("nullval", TypeError),
    177                            ("nullval.attr", AttributeError),
    178                            ("unimp", NotImplementedError)]:
    179             codestr = "@%s\ndef f(): pass\nassert f() is None" % expr
    180             code = compile(codestr, "test", "exec")
    181             self.assertRaises(exc, eval, code, context)
    182 
    183     def test_double(self):
    184         class C(object):
    185             @funcattrs(abc=1, xyz="haha")
    186             @funcattrs(booh=42)
    187             def foo(self): return 42
    188         self.assertEqual(C().foo(), 42)
    189         self.assertEqual(C.foo.abc, 1)
    190         self.assertEqual(C.foo.xyz, "haha")
    191         self.assertEqual(C.foo.booh, 42)
    192 
    193     def test_order(self):
    194         # Test that decorators are applied in the proper order to the function
    195         # they are decorating.
    196         def callnum(num):
    197             """Decorator factory that returns a decorator that replaces the
    198             passed-in function with one that returns the value of 'num'"""
    199             def deco(func):
    200                 return lambda: num
    201             return deco
    202         @callnum(2)
    203         @callnum(1)
    204         def foo(): return 42
    205         self.assertEqual(foo(), 2,
    206                             "Application order of decorators is incorrect")
    207 
    208     def test_eval_order(self):
    209         # Evaluating a decorated function involves four steps for each
    210         # decorator-maker (the function that returns a decorator):
    211         #
    212         #    1: Evaluate the decorator-maker name
    213         #    2: Evaluate the decorator-maker arguments (if any)
    214         #    3: Call the decorator-maker to make a decorator
    215         #    4: Call the decorator
    216         #
    217         # When there are multiple decorators, these steps should be
    218         # performed in the above order for each decorator, but we should
    219         # iterate through the decorators in the reverse of the order they
    220         # appear in the source.
    221 
    222         actions = []
    223 
    224         def make_decorator(tag):
    225             actions.append('makedec' + tag)
    226             def decorate(func):
    227                 actions.append('calldec' + tag)
    228                 return func
    229             return decorate
    230 
    231         class NameLookupTracer (object):
    232             def __init__(self, index):
    233                 self.index = index
    234 
    235             def __getattr__(self, fname):
    236                 if fname == 'make_decorator':
    237                     opname, res = ('evalname', make_decorator)
    238                 elif fname == 'arg':
    239                     opname, res = ('evalargs', str(self.index))
    240                 else:
    241                     assert False, "Unknown attrname %s" % fname
    242                 actions.append('%s%d' % (opname, self.index))
    243                 return res
    244 
    245         c1, c2, c3 = map(NameLookupTracer, [ 1, 2, 3 ])
    246 
    247         expected_actions = [ 'evalname1', 'evalargs1', 'makedec1',
    248                              'evalname2', 'evalargs2', 'makedec2',
    249                              'evalname3', 'evalargs3', 'makedec3',
    250                              'calldec3', 'calldec2', 'calldec1' ]
    251 
    252         actions = []
    253         @c1.make_decorator(c1.arg)
    254         @c2.make_decorator(c2.arg)
    255         @c3.make_decorator(c3.arg)
    256         def foo(): return 42
    257         self.assertEqual(foo(), 42)
    258 
    259         self.assertEqual(actions, expected_actions)
    260 
    261         # Test the equivalence claim in chapter 7 of the reference manual.
    262         #
    263         actions = []
    264         def bar(): return 42
    265         bar = c1.make_decorator(c1.arg)(c2.make_decorator(c2.arg)(c3.make_decorator(c3.arg)(bar)))
    266         self.assertEqual(bar(), 42)
    267         self.assertEqual(actions, expected_actions)
    268 
    269 class TestClassDecorators(unittest.TestCase):
    270 
    271     def test_simple(self):
    272         def plain(x):
    273             x.extra = 'Hello'
    274             return x
    275         @plain
    276         class C(object): pass
    277         self.assertEqual(C.extra, 'Hello')
    278 
    279     def test_double(self):
    280         def ten(x):
    281             x.extra = 10
    282             return x
    283         def add_five(x):
    284             x.extra += 5
    285             return x
    286 
    287         @add_five
    288         @ten
    289         class C(object): pass
    290         self.assertEqual(C.extra, 15)
    291 
    292     def test_order(self):
    293         def applied_first(x):
    294             x.extra = 'first'
    295             return x
    296         def applied_second(x):
    297             x.extra = 'second'
    298             return x
    299         @applied_second
    300         @applied_first
    301         class C(object): pass
    302         self.assertEqual(C.extra, 'second')
    303 
    304 def test_main():
    305     test_support.run_unittest(TestDecorators)
    306     test_support.run_unittest(TestClassDecorators)
    307 
    308 if __name__=="__main__":
    309     test_main()
    310