Home | History | Annotate | Download | only in test
      1 import unittest
      2 from test.test_support import check_syntax_error, check_py3k_warnings, \
      3                               check_warnings, run_unittest
      4 
      5 
      6 class ScopeTests(unittest.TestCase):
      7 
      8     def testSimpleNesting(self):
      9 
     10         def make_adder(x):
     11             def adder(y):
     12                 return x + y
     13             return adder
     14 
     15         inc = make_adder(1)
     16         plus10 = make_adder(10)
     17 
     18         self.assertEqual(inc(1), 2)
     19         self.assertEqual(plus10(-2), 8)
     20 
     21     def testExtraNesting(self):
     22 
     23         def make_adder2(x):
     24             def extra(): # check freevars passing through non-use scopes
     25                 def adder(y):
     26                     return x + y
     27                 return adder
     28             return extra()
     29 
     30         inc = make_adder2(1)
     31         plus10 = make_adder2(10)
     32 
     33         self.assertEqual(inc(1), 2)
     34         self.assertEqual(plus10(-2), 8)
     35 
     36     def testSimpleAndRebinding(self):
     37 
     38         def make_adder3(x):
     39             def adder(y):
     40                 return x + y
     41             x = x + 1 # check tracking of assignment to x in defining scope
     42             return adder
     43 
     44         inc = make_adder3(0)
     45         plus10 = make_adder3(9)
     46 
     47         self.assertEqual(inc(1), 2)
     48         self.assertEqual(plus10(-2), 8)
     49 
     50     def testNestingGlobalNoFree(self):
     51 
     52         def make_adder4(): # XXX add exta level of indirection
     53             def nest():
     54                 def nest():
     55                     def adder(y):
     56                         return global_x + y # check that plain old globals work
     57                     return adder
     58                 return nest()
     59             return nest()
     60 
     61         global_x = 1
     62         adder = make_adder4()
     63         self.assertEqual(adder(1), 2)
     64 
     65         global_x = 10
     66         self.assertEqual(adder(-2), 8)
     67 
     68     def testNestingThroughClass(self):
     69 
     70         def make_adder5(x):
     71             class Adder:
     72                 def __call__(self, y):
     73                     return x + y
     74             return Adder()
     75 
     76         inc = make_adder5(1)
     77         plus10 = make_adder5(10)
     78 
     79         self.assertEqual(inc(1), 2)
     80         self.assertEqual(plus10(-2), 8)
     81 
     82     def testNestingPlusFreeRefToGlobal(self):
     83 
     84         def make_adder6(x):
     85             global global_nest_x
     86             def adder(y):
     87                 return global_nest_x + y
     88             global_nest_x = x
     89             return adder
     90 
     91         inc = make_adder6(1)
     92         plus10 = make_adder6(10)
     93 
     94         self.assertEqual(inc(1), 11) # there's only one global
     95         self.assertEqual(plus10(-2), 8)
     96 
     97     def testNearestEnclosingScope(self):
     98 
     99         def f(x):
    100             def g(y):
    101                 x = 42 # check that this masks binding in f()
    102                 def h(z):
    103                     return x + z
    104                 return h
    105             return g(2)
    106 
    107         test_func = f(10)
    108         self.assertEqual(test_func(5), 47)
    109 
    110     def testMixedFreevarsAndCellvars(self):
    111 
    112         def identity(x):
    113             return x
    114 
    115         def f(x, y, z):
    116             def g(a, b, c):
    117                 a = a + x # 3
    118                 def h():
    119                     # z * (4 + 9)
    120                     # 3 * 13
    121                     return identity(z * (b + y))
    122                 y = c + z # 9
    123                 return h
    124             return g
    125 
    126         g = f(1, 2, 3)
    127         h = g(2, 4, 6)
    128         self.assertEqual(h(), 39)
    129 
    130     def testFreeVarInMethod(self):
    131 
    132         def test():
    133             method_and_var = "var"
    134             class Test:
    135                 def method_and_var(self):
    136                     return "method"
    137                 def test(self):
    138                     return method_and_var
    139                 def actual_global(self):
    140                     return str("global")
    141                 def str(self):
    142                     return str(self)
    143             return Test()
    144 
    145         t = test()
    146         self.assertEqual(t.test(), "var")
    147         self.assertEqual(t.method_and_var(), "method")
    148         self.assertEqual(t.actual_global(), "global")
    149 
    150         method_and_var = "var"
    151         class Test:
    152             # this class is not nested, so the rules are different
    153             def method_and_var(self):
    154                 return "method"
    155             def test(self):
    156                 return method_and_var
    157             def actual_global(self):
    158                 return str("global")
    159             def str(self):
    160                 return str(self)
    161 
    162         t = Test()
    163         self.assertEqual(t.test(), "var")
    164         self.assertEqual(t.method_and_var(), "method")
    165         self.assertEqual(t.actual_global(), "global")
    166 
    167     def testRecursion(self):
    168 
    169         def f(x):
    170             def fact(n):
    171                 if n == 0:
    172                     return 1
    173                 else:
    174                     return n * fact(n - 1)
    175             if x >= 0:
    176                 return fact(x)
    177             else:
    178                 raise ValueError, "x must be >= 0"
    179 
    180         self.assertEqual(f(6), 720)
    181 
    182 
    183     def testUnoptimizedNamespaces(self):
    184 
    185         check_syntax_error(self, """\
    186 def unoptimized_clash1(strip):
    187     def f(s):
    188         from string import *
    189         return strip(s) # ambiguity: free or local
    190     return f
    191 """)
    192 
    193         check_syntax_error(self, """\
    194 def unoptimized_clash2():
    195     from string import *
    196     def f(s):
    197         return strip(s) # ambiguity: global or local
    198     return f
    199 """)
    200 
    201         check_syntax_error(self, """\
    202 def unoptimized_clash2():
    203     from string import *
    204     def g():
    205         def f(s):
    206             return strip(s) # ambiguity: global or local
    207         return f
    208 """)
    209 
    210         # XXX could allow this for exec with const argument, but what's the point
    211         check_syntax_error(self, """\
    212 def error(y):
    213     exec "a = 1"
    214     def f(x):
    215         return x + y
    216     return f
    217 """)
    218 
    219         check_syntax_error(self, """\
    220 def f(x):
    221     def g():
    222         return x
    223     del x # can't del name
    224 """)
    225 
    226         check_syntax_error(self, """\
    227 def f():
    228     def g():
    229         from string import *
    230         return strip # global or local?
    231 """)
    232 
    233         # and verify a few cases that should work
    234 
    235         exec """
    236 def noproblem1():
    237     from string import *
    238     f = lambda x:x
    239 
    240 def noproblem2():
    241     from string import *
    242     def f(x):
    243         return x + 1
    244 
    245 def noproblem3():
    246     from string import *
    247     def f(x):
    248         global y
    249         y = x
    250 """
    251 
    252     def testLambdas(self):
    253 
    254         f1 = lambda x: lambda y: x + y
    255         inc = f1(1)
    256         plus10 = f1(10)
    257         self.assertEqual(inc(1), 2)
    258         self.assertEqual(plus10(5), 15)
    259 
    260         f2 = lambda x: (lambda : lambda y: x + y)()
    261         inc = f2(1)
    262         plus10 = f2(10)
    263         self.assertEqual(inc(1), 2)
    264         self.assertEqual(plus10(5), 15)
    265 
    266         f3 = lambda x: lambda y: global_x + y
    267         global_x = 1
    268         inc = f3(None)
    269         self.assertEqual(inc(2), 3)
    270 
    271         f8 = lambda x, y, z: lambda a, b, c: lambda : z * (b + y)
    272         g = f8(1, 2, 3)
    273         h = g(2, 4, 6)
    274         self.assertEqual(h(), 18)
    275 
    276     def testUnboundLocal(self):
    277 
    278         def errorInOuter():
    279             print y
    280             def inner():
    281                 return y
    282             y = 1
    283 
    284         def errorInInner():
    285             def inner():
    286                 return y
    287             inner()
    288             y = 1
    289 
    290         self.assertRaises(UnboundLocalError, errorInOuter)
    291         self.assertRaises(NameError, errorInInner)
    292 
    293         # test for bug #1501934: incorrect LOAD/STORE_GLOBAL generation
    294         exec """
    295 global_x = 1
    296 def f():
    297     global_x += 1
    298 try:
    299     f()
    300 except UnboundLocalError:
    301     pass
    302 else:
    303     fail('scope of global_x not correctly determined')
    304 """ in {'fail': self.fail}
    305 
    306     def testComplexDefinitions(self):
    307 
    308         def makeReturner(*lst):
    309             def returner():
    310                 return lst
    311             return returner
    312 
    313         self.assertEqual(makeReturner(1,2,3)(), (1,2,3))
    314 
    315         def makeReturner2(**kwargs):
    316             def returner():
    317                 return kwargs
    318             return returner
    319 
    320         self.assertEqual(makeReturner2(a=11)()['a'], 11)
    321 
    322         with check_py3k_warnings(("tuple parameter unpacking has been removed",
    323                                   SyntaxWarning)):
    324             exec """\
    325 def makeAddPair((a, b)):
    326     def addPair((c, d)):
    327         return (a + c, b + d)
    328     return addPair
    329 """ in locals()
    330         self.assertEqual(makeAddPair((1, 2))((100, 200)), (101,202))
    331 
    332     def testScopeOfGlobalStmt(self):
    333 # Examples posted by Samuele Pedroni to python-dev on 3/1/2001
    334 
    335         exec """\
    336 # I
    337 x = 7
    338 def f():
    339     x = 1
    340     def g():
    341         global x
    342         def i():
    343             def h():
    344                 return x
    345             return h()
    346         return i()
    347     return g()
    348 self.assertEqual(f(), 7)
    349 self.assertEqual(x, 7)
    350 
    351 # II
    352 x = 7
    353 def f():
    354     x = 1
    355     def g():
    356         x = 2
    357         def i():
    358             def h():
    359                 return x
    360             return h()
    361         return i()
    362     return g()
    363 self.assertEqual(f(), 2)
    364 self.assertEqual(x, 7)
    365 
    366 # III
    367 x = 7
    368 def f():
    369     x = 1
    370     def g():
    371         global x
    372         x = 2
    373         def i():
    374             def h():
    375                 return x
    376             return h()
    377         return i()
    378     return g()
    379 self.assertEqual(f(), 2)
    380 self.assertEqual(x, 2)
    381 
    382 # IV
    383 x = 7
    384 def f():
    385     x = 3
    386     def g():
    387         global x
    388         x = 2
    389         def i():
    390             def h():
    391                 return x
    392             return h()
    393         return i()
    394     return g()
    395 self.assertEqual(f(), 2)
    396 self.assertEqual(x, 2)
    397 
    398 # XXX what about global statements in class blocks?
    399 # do they affect methods?
    400 
    401 x = 12
    402 class Global:
    403     global x
    404     x = 13
    405     def set(self, val):
    406         x = val
    407     def get(self):
    408         return x
    409 
    410 g = Global()
    411 self.assertEqual(g.get(), 13)
    412 g.set(15)
    413 self.assertEqual(g.get(), 13)
    414 """
    415 
    416     def testLeaks(self):
    417 
    418         class Foo:
    419             count = 0
    420 
    421             def __init__(self):
    422                 Foo.count += 1
    423 
    424             def __del__(self):
    425                 Foo.count -= 1
    426 
    427         def f1():
    428             x = Foo()
    429             def f2():
    430                 return x
    431             f2()
    432 
    433         for i in range(100):
    434             f1()
    435 
    436         self.assertEqual(Foo.count, 0)
    437 
    438     def testClassAndGlobal(self):
    439 
    440         exec """\
    441 def test(x):
    442     class Foo:
    443         global x
    444         def __call__(self, y):
    445             return x + y
    446     return Foo()
    447 
    448 x = 0
    449 self.assertEqual(test(6)(2), 8)
    450 x = -1
    451 self.assertEqual(test(3)(2), 5)
    452 
    453 looked_up_by_load_name = False
    454 class X:
    455     # Implicit globals inside classes are be looked up by LOAD_NAME, not
    456     # LOAD_GLOBAL.
    457     locals()['looked_up_by_load_name'] = True
    458     passed = looked_up_by_load_name
    459 
    460 self.assertTrue(X.passed)
    461 """
    462 
    463     def testLocalsFunction(self):
    464 
    465         def f(x):
    466             def g(y):
    467                 def h(z):
    468                     return y + z
    469                 w = x + y
    470                 y += 3
    471                 return locals()
    472             return g
    473 
    474         d = f(2)(4)
    475         self.assertIn('h', d)
    476         del d['h']
    477         self.assertEqual(d, {'x': 2, 'y': 7, 'w': 6})
    478 
    479     def testLocalsClass(self):
    480         # This test verifies that calling locals() does not pollute
    481         # the local namespace of the class with free variables.  Old
    482         # versions of Python had a bug, where a free variable being
    483         # passed through a class namespace would be inserted into
    484         # locals() by locals() or exec or a trace function.
    485         #
    486         # The real bug lies in frame code that copies variables
    487         # between fast locals and the locals dict, e.g. when executing
    488         # a trace function.
    489 
    490         def f(x):
    491             class C:
    492                 x = 12
    493                 def m(self):
    494                     return x
    495                 locals()
    496             return C
    497 
    498         self.assertEqual(f(1).x, 12)
    499 
    500         def f(x):
    501             class C:
    502                 y = x
    503                 def m(self):
    504                     return x
    505                 z = list(locals())
    506             return C
    507 
    508         varnames = f(1).z
    509         self.assertNotIn("x", varnames)
    510         self.assertIn("y", varnames)
    511 
    512     def testLocalsClass_WithTrace(self):
    513         # Issue23728: after the trace function returns, the locals()
    514         # dictionary is used to update all variables, this used to
    515         # include free variables. But in class statements, free
    516         # variables are not inserted...
    517         import sys
    518         sys.settrace(lambda a,b,c:None)
    519         try:
    520             x = 12
    521 
    522             class C:
    523                 def f(self):
    524                     return x
    525 
    526             self.assertEqual(x, 12) # Used to raise UnboundLocalError
    527         finally:
    528             sys.settrace(None)
    529 
    530     def testBoundAndFree(self):
    531         # var is bound and free in class
    532 
    533         def f(x):
    534             class C:
    535                 def m(self):
    536                     return x
    537                 a = x
    538             return C
    539 
    540         inst = f(3)()
    541         self.assertEqual(inst.a, inst.m())
    542 
    543     def testInteractionWithTraceFunc(self):
    544 
    545         import sys
    546         def tracer(a,b,c):
    547             return tracer
    548 
    549         def adaptgetter(name, klass, getter):
    550             kind, des = getter
    551             if kind == 1:       # AV happens when stepping from this line to next
    552                 if des == "":
    553                     des = "_%s__%s" % (klass.__name__, name)
    554                 return lambda obj: getattr(obj, des)
    555 
    556         class TestClass:
    557             pass
    558 
    559         sys.settrace(tracer)
    560         adaptgetter("foo", TestClass, (1, ""))
    561         sys.settrace(None)
    562 
    563         self.assertRaises(TypeError, sys.settrace)
    564 
    565     def testEvalExecFreeVars(self):
    566 
    567         def f(x):
    568             return lambda: x + 1
    569 
    570         g = f(3)
    571         self.assertRaises(TypeError, eval, g.func_code)
    572 
    573         try:
    574             exec g.func_code in {}
    575         except TypeError:
    576             pass
    577         else:
    578             self.fail("exec should have failed, because code contained free vars")
    579 
    580     def testListCompLocalVars(self):
    581 
    582         try:
    583             print bad
    584         except NameError:
    585             pass
    586         else:
    587             print "bad should not be defined"
    588 
    589         def x():
    590             [bad for s in 'a b' for bad in s.split()]
    591 
    592         x()
    593         try:
    594             print bad
    595         except NameError:
    596             pass
    597 
    598     def testEvalFreeVars(self):
    599 
    600         def f(x):
    601             def g():
    602                 x
    603                 eval("x + 1")
    604             return g
    605 
    606         f(4)()
    607 
    608     def testFreeingCell(self):
    609         # Test what happens when a finalizer accesses
    610         # the cell where the object was stored.
    611         class Special:
    612             def __del__(self):
    613                 nestedcell_get()
    614 
    615         def f():
    616             global nestedcell_get
    617             def nestedcell_get():
    618                 return c
    619 
    620             c = (Special(),)
    621             c = 2
    622 
    623         f() # used to crash the interpreter...
    624 
    625     def testGlobalInParallelNestedFunctions(self):
    626         # A symbol table bug leaked the global statement from one
    627         # function to other nested functions in the same block.
    628         # This test verifies that a global statement in the first
    629         # function does not affect the second function.
    630         CODE = """def f():
    631     y = 1
    632     def g():
    633         global y
    634         return y
    635     def h():
    636         return y + 1
    637     return g, h
    638 
    639 y = 9
    640 g, h = f()
    641 result9 = g()
    642 result2 = h()
    643 """
    644         local_ns = {}
    645         global_ns = {}
    646         exec CODE in local_ns, global_ns
    647         self.assertEqual(2, global_ns["result2"])
    648         self.assertEqual(9, global_ns["result9"])
    649 
    650     def testTopIsNotSignificant(self):
    651         # See #9997.
    652         def top(a):
    653             pass
    654         def b():
    655             global a
    656 
    657 
    658 def test_main():
    659     with check_warnings(("import \* only allowed at module level",
    660                          SyntaxWarning)):
    661         run_unittest(ScopeTests)
    662 
    663 if __name__ == '__main__':
    664     test_main()
    665