Home | History | Annotate | Download | only in tests
      1 """ Test suite for the code in fixer_util """
      2 
      3 # Testing imports
      4 from . import support
      5 
      6 # Python imports
      7 import os.path
      8 
      9 # Local imports
     10 from lib2to3.pytree import Node, Leaf
     11 from lib2to3 import fixer_util
     12 from lib2to3.fixer_util import Attr, Name, Call, Comma
     13 from lib2to3.pgen2 import token
     14 
     15 def parse(code, strip_levels=0):
     16     # The topmost node is file_input, which we don't care about.
     17     # The next-topmost node is a *_stmt node, which we also don't care about
     18     tree = support.parse_string(code)
     19     for i in range(strip_levels):
     20         tree = tree.children[0]
     21     tree.parent = None
     22     return tree
     23 
     24 class MacroTestCase(support.TestCase):
     25     def assertStr(self, node, string):
     26         if isinstance(node, (tuple, list)):
     27             node = Node(fixer_util.syms.simple_stmt, node)
     28         self.assertEqual(str(node), string)
     29 
     30 
     31 class Test_is_tuple(support.TestCase):
     32     def is_tuple(self, string):
     33         return fixer_util.is_tuple(parse(string, strip_levels=2))
     34 
     35     def test_valid(self):
     36         self.assertTrue(self.is_tuple("(a, b)"))
     37         self.assertTrue(self.is_tuple("(a, (b, c))"))
     38         self.assertTrue(self.is_tuple("((a, (b, c)),)"))
     39         self.assertTrue(self.is_tuple("(a,)"))
     40         self.assertTrue(self.is_tuple("()"))
     41 
     42     def test_invalid(self):
     43         self.assertFalse(self.is_tuple("(a)"))
     44         self.assertFalse(self.is_tuple("('foo') % (b, c)"))
     45 
     46 
     47 class Test_is_list(support.TestCase):
     48     def is_list(self, string):
     49         return fixer_util.is_list(parse(string, strip_levels=2))
     50 
     51     def test_valid(self):
     52         self.assertTrue(self.is_list("[]"))
     53         self.assertTrue(self.is_list("[a]"))
     54         self.assertTrue(self.is_list("[a, b]"))
     55         self.assertTrue(self.is_list("[a, [b, c]]"))
     56         self.assertTrue(self.is_list("[[a, [b, c]],]"))
     57 
     58     def test_invalid(self):
     59         self.assertFalse(self.is_list("[]+[]"))
     60 
     61 
     62 class Test_Attr(MacroTestCase):
     63     def test(self):
     64         call = parse("foo()", strip_levels=2)
     65 
     66         self.assertStr(Attr(Name("a"), Name("b")), "a.b")
     67         self.assertStr(Attr(call, Name("b")), "foo().b")
     68 
     69     def test_returns(self):
     70         attr = Attr(Name("a"), Name("b"))
     71         self.assertEqual(type(attr), list)
     72 
     73 
     74 class Test_Name(MacroTestCase):
     75     def test(self):
     76         self.assertStr(Name("a"), "a")
     77         self.assertStr(Name("foo.foo().bar"), "foo.foo().bar")
     78         self.assertStr(Name("a", prefix="b"), "ba")
     79 
     80 
     81 class Test_Call(MacroTestCase):
     82     def _Call(self, name, args=None, prefix=None):
     83         """Help the next test"""
     84         children = []
     85         if isinstance(args, list):
     86             for arg in args:
     87                 children.append(arg)
     88                 children.append(Comma())
     89             children.pop()
     90         return Call(Name(name), children, prefix)
     91 
     92     def test(self):
     93         kids = [None,
     94                 [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2),
     95                  Leaf(token.NUMBER, 3)],
     96                 [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3),
     97                  Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)],
     98                 [Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")]
     99                 ]
    100         self.assertStr(self._Call("A"), "A()")
    101         self.assertStr(self._Call("b", kids[1]), "b(1,2,3)")
    102         self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)")
    103         self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)")
    104 
    105 
    106 class Test_does_tree_import(support.TestCase):
    107     def _find_bind_rec(self, name, node):
    108         # Search a tree for a binding -- used to find the starting
    109         # point for these tests.
    110         c = fixer_util.find_binding(name, node)
    111         if c: return c
    112         for child in node.children:
    113             c = self._find_bind_rec(name, child)
    114             if c: return c
    115 
    116     def does_tree_import(self, package, name, string):
    117         node = parse(string)
    118         # Find the binding of start -- that's what we'll go from
    119         node = self._find_bind_rec('start', node)
    120         return fixer_util.does_tree_import(package, name, node)
    121 
    122     def try_with(self, string):
    123         failing_tests = (("a", "a", "from a import b"),
    124                          ("a.d", "a", "from a.d import b"),
    125                          ("d.a", "a", "from d.a import b"),
    126                          (None, "a", "import b"),
    127                          (None, "a", "import b, c, d"))
    128         for package, name, import_ in failing_tests:
    129             n = self.does_tree_import(package, name, import_ + "\n" + string)
    130             self.assertFalse(n)
    131             n = self.does_tree_import(package, name, string + "\n" + import_)
    132             self.assertFalse(n)
    133 
    134         passing_tests = (("a", "a", "from a import a"),
    135                          ("x", "a", "from x import a"),
    136                          ("x", "a", "from x import b, c, a, d"),
    137                          ("x.b", "a", "from x.b import a"),
    138                          ("x.b", "a", "from x.b import b, c, a, d"),
    139                          (None, "a", "import a"),
    140                          (None, "a", "import b, c, a, d"))
    141         for package, name, import_ in passing_tests:
    142             n = self.does_tree_import(package, name, import_ + "\n" + string)
    143             self.assertTrue(n)
    144             n = self.does_tree_import(package, name, string + "\n" + import_)
    145             self.assertTrue(n)
    146 
    147     def test_in_function(self):
    148         self.try_with("def foo():\n\tbar.baz()\n\tstart=3")
    149 
    150 class Test_find_binding(support.TestCase):
    151     def find_binding(self, name, string, package=None):
    152         return fixer_util.find_binding(name, parse(string), package)
    153 
    154     def test_simple_assignment(self):
    155         self.assertTrue(self.find_binding("a", "a = b"))
    156         self.assertTrue(self.find_binding("a", "a = [b, c, d]"))
    157         self.assertTrue(self.find_binding("a", "a = foo()"))
    158         self.assertTrue(self.find_binding("a", "a = foo().foo.foo[6][foo]"))
    159         self.assertFalse(self.find_binding("a", "foo = a"))
    160         self.assertFalse(self.find_binding("a", "foo = (a, b, c)"))
    161 
    162     def test_tuple_assignment(self):
    163         self.assertTrue(self.find_binding("a", "(a,) = b"))
    164         self.assertTrue(self.find_binding("a", "(a, b, c) = [b, c, d]"))
    165         self.assertTrue(self.find_binding("a", "(c, (d, a), b) = foo()"))
    166         self.assertTrue(self.find_binding("a", "(a, b) = foo().foo[6][foo]"))
    167         self.assertFalse(self.find_binding("a", "(foo, b) = (b, a)"))
    168         self.assertFalse(self.find_binding("a", "(foo, (b, c)) = (a, b, c)"))
    169 
    170     def test_list_assignment(self):
    171         self.assertTrue(self.find_binding("a", "[a] = b"))
    172         self.assertTrue(self.find_binding("a", "[a, b, c] = [b, c, d]"))
    173         self.assertTrue(self.find_binding("a", "[c, [d, a], b] = foo()"))
    174         self.assertTrue(self.find_binding("a", "[a, b] = foo().foo[a][foo]"))
    175         self.assertFalse(self.find_binding("a", "[foo, b] = (b, a)"))
    176         self.assertFalse(self.find_binding("a", "[foo, [b, c]] = (a, b, c)"))
    177 
    178     def test_invalid_assignments(self):
    179         self.assertFalse(self.find_binding("a", "foo.a = 5"))
    180         self.assertFalse(self.find_binding("a", "foo[a] = 5"))
    181         self.assertFalse(self.find_binding("a", "foo(a) = 5"))
    182         self.assertFalse(self.find_binding("a", "foo(a, b) = 5"))
    183 
    184     def test_simple_import(self):
    185         self.assertTrue(self.find_binding("a", "import a"))
    186         self.assertTrue(self.find_binding("a", "import b, c, a, d"))
    187         self.assertFalse(self.find_binding("a", "import b"))
    188         self.assertFalse(self.find_binding("a", "import b, c, d"))
    189 
    190     def test_from_import(self):
    191         self.assertTrue(self.find_binding("a", "from x import a"))
    192         self.assertTrue(self.find_binding("a", "from a import a"))
    193         self.assertTrue(self.find_binding("a", "from x import b, c, a, d"))
    194         self.assertTrue(self.find_binding("a", "from x.b import a"))
    195         self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d"))
    196         self.assertFalse(self.find_binding("a", "from a import b"))
    197         self.assertFalse(self.find_binding("a", "from a.d import b"))
    198         self.assertFalse(self.find_binding("a", "from d.a import b"))
    199 
    200     def test_import_as(self):
    201         self.assertTrue(self.find_binding("a", "import b as a"))
    202         self.assertTrue(self.find_binding("a", "import b as a, c, a as f, d"))
    203         self.assertFalse(self.find_binding("a", "import a as f"))
    204         self.assertFalse(self.find_binding("a", "import b, c as f, d as e"))
    205 
    206     def test_from_import_as(self):
    207         self.assertTrue(self.find_binding("a", "from x import b as a"))
    208         self.assertTrue(self.find_binding("a", "from x import g as a, d as b"))
    209         self.assertTrue(self.find_binding("a", "from x.b import t as a"))
    210         self.assertTrue(self.find_binding("a", "from x.b import g as a, d"))
    211         self.assertFalse(self.find_binding("a", "from a import b as t"))
    212         self.assertFalse(self.find_binding("a", "from a.d import b as t"))
    213         self.assertFalse(self.find_binding("a", "from d.a import b as t"))
    214 
    215     def test_simple_import_with_package(self):
    216         self.assertTrue(self.find_binding("b", "import b"))
    217         self.assertTrue(self.find_binding("b", "import b, c, d"))
    218         self.assertFalse(self.find_binding("b", "import b", "b"))
    219         self.assertFalse(self.find_binding("b", "import b, c, d", "c"))
    220 
    221     def test_from_import_with_package(self):
    222         self.assertTrue(self.find_binding("a", "from x import a", "x"))
    223         self.assertTrue(self.find_binding("a", "from a import a", "a"))
    224         self.assertTrue(self.find_binding("a", "from x import *", "x"))
    225         self.assertTrue(self.find_binding("a", "from x import b, c, a, d", "x"))
    226         self.assertTrue(self.find_binding("a", "from x.b import a", "x.b"))
    227         self.assertTrue(self.find_binding("a", "from x.b import *", "x.b"))
    228         self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d", "x.b"))
    229         self.assertFalse(self.find_binding("a", "from a import b", "a"))
    230         self.assertFalse(self.find_binding("a", "from a.d import b", "a.d"))
    231         self.assertFalse(self.find_binding("a", "from d.a import b", "a.d"))
    232         self.assertFalse(self.find_binding("a", "from x.y import *", "a.b"))
    233 
    234     def test_import_as_with_package(self):
    235         self.assertFalse(self.find_binding("a", "import b.c as a", "b.c"))
    236         self.assertFalse(self.find_binding("a", "import a as f", "f"))
    237         self.assertFalse(self.find_binding("a", "import a as f", "a"))
    238 
    239     def test_from_import_as_with_package(self):
    240         # Because it would take a lot of special-case code in the fixers
    241         # to deal with from foo import bar as baz, we'll simply always
    242         # fail if there is an "from ... import ... as ..."
    243         self.assertFalse(self.find_binding("a", "from x import b as a", "x"))
    244         self.assertFalse(self.find_binding("a", "from x import g as a, d as b", "x"))
    245         self.assertFalse(self.find_binding("a", "from x.b import t as a", "x.b"))
    246         self.assertFalse(self.find_binding("a", "from x.b import g as a, d", "x.b"))
    247         self.assertFalse(self.find_binding("a", "from a import b as t", "a"))
    248         self.assertFalse(self.find_binding("a", "from a import b as t", "b"))
    249         self.assertFalse(self.find_binding("a", "from a import b as t", "t"))
    250 
    251     def test_function_def(self):
    252         self.assertTrue(self.find_binding("a", "def a(): pass"))
    253         self.assertTrue(self.find_binding("a", "def a(b, c, d): pass"))
    254         self.assertTrue(self.find_binding("a", "def a(): b = 7"))
    255         self.assertFalse(self.find_binding("a", "def d(b, (c, a), e): pass"))
    256         self.assertFalse(self.find_binding("a", "def d(a=7): pass"))
    257         self.assertFalse(self.find_binding("a", "def d(a): pass"))
    258         self.assertFalse(self.find_binding("a", "def d(): a = 7"))
    259 
    260         s = """
    261             def d():
    262                 def a():
    263                     pass"""
    264         self.assertFalse(self.find_binding("a", s))
    265 
    266     def test_class_def(self):
    267         self.assertTrue(self.find_binding("a", "class a: pass"))
    268         self.assertTrue(self.find_binding("a", "class a(): pass"))
    269         self.assertTrue(self.find_binding("a", "class a(b): pass"))
    270         self.assertTrue(self.find_binding("a", "class a(b, c=8): pass"))
    271         self.assertFalse(self.find_binding("a", "class d: pass"))
    272         self.assertFalse(self.find_binding("a", "class d(a): pass"))
    273         self.assertFalse(self.find_binding("a", "class d(b, a=7): pass"))
    274         self.assertFalse(self.find_binding("a", "class d(b, *a): pass"))
    275         self.assertFalse(self.find_binding("a", "class d(b, **a): pass"))
    276         self.assertFalse(self.find_binding("a", "class d: a = 7"))
    277 
    278         s = """
    279             class d():
    280                 class a():
    281                     pass"""
    282         self.assertFalse(self.find_binding("a", s))
    283 
    284     def test_for(self):
    285         self.assertTrue(self.find_binding("a", "for a in r: pass"))
    286         self.assertTrue(self.find_binding("a", "for a, b in r: pass"))
    287         self.assertTrue(self.find_binding("a", "for (a, b) in r: pass"))
    288         self.assertTrue(self.find_binding("a", "for c, (a,) in r: pass"))
    289         self.assertTrue(self.find_binding("a", "for c, (a, b) in r: pass"))
    290         self.assertTrue(self.find_binding("a", "for c in r: a = c"))
    291         self.assertFalse(self.find_binding("a", "for c in a: pass"))
    292 
    293     def test_for_nested(self):
    294         s = """
    295             for b in r:
    296                 for a in b:
    297                     pass"""
    298         self.assertTrue(self.find_binding("a", s))
    299 
    300         s = """
    301             for b in r:
    302                 for a, c in b:
    303                     pass"""
    304         self.assertTrue(self.find_binding("a", s))
    305 
    306         s = """
    307             for b in r:
    308                 for (a, c) in b:
    309                     pass"""
    310         self.assertTrue(self.find_binding("a", s))
    311 
    312         s = """
    313             for b in r:
    314                 for (a,) in b:
    315                     pass"""
    316         self.assertTrue(self.find_binding("a", s))
    317 
    318         s = """
    319             for b in r:
    320                 for c, (a, d) in b:
    321                     pass"""
    322         self.assertTrue(self.find_binding("a", s))
    323 
    324         s = """
    325             for b in r:
    326                 for c in b:
    327                     a = 7"""
    328         self.assertTrue(self.find_binding("a", s))
    329 
    330         s = """
    331             for b in r:
    332                 for c in b:
    333                     d = a"""
    334         self.assertFalse(self.find_binding("a", s))
    335 
    336         s = """
    337             for b in r:
    338                 for c in a:
    339                     d = 7"""
    340         self.assertFalse(self.find_binding("a", s))
    341 
    342     def test_if(self):
    343         self.assertTrue(self.find_binding("a", "if b in r: a = c"))
    344         self.assertFalse(self.find_binding("a", "if a in r: d = e"))
    345 
    346     def test_if_nested(self):
    347         s = """
    348             if b in r:
    349                 if c in d:
    350                     a = c"""
    351         self.assertTrue(self.find_binding("a", s))
    352 
    353         s = """
    354             if b in r:
    355                 if c in d:
    356                     c = a"""
    357         self.assertFalse(self.find_binding("a", s))
    358 
    359     def test_while(self):
    360         self.assertTrue(self.find_binding("a", "while b in r: a = c"))
    361         self.assertFalse(self.find_binding("a", "while a in r: d = e"))
    362 
    363     def test_while_nested(self):
    364         s = """
    365             while b in r:
    366                 while c in d:
    367                     a = c"""
    368         self.assertTrue(self.find_binding("a", s))
    369 
    370         s = """
    371             while b in r:
    372                 while c in d:
    373                     c = a"""
    374         self.assertFalse(self.find_binding("a", s))
    375 
    376     def test_try_except(self):
    377         s = """
    378             try:
    379                 a = 6
    380             except:
    381                 b = 8"""
    382         self.assertTrue(self.find_binding("a", s))
    383 
    384         s = """
    385             try:
    386                 b = 8
    387             except:
    388                 a = 6"""
    389         self.assertTrue(self.find_binding("a", s))
    390 
    391         s = """
    392             try:
    393                 b = 8
    394             except KeyError:
    395                 pass
    396             except:
    397                 a = 6"""
    398         self.assertTrue(self.find_binding("a", s))
    399 
    400         s = """
    401             try:
    402                 b = 8
    403             except:
    404                 b = 6"""
    405         self.assertFalse(self.find_binding("a", s))
    406 
    407     def test_try_except_nested(self):
    408         s = """
    409             try:
    410                 try:
    411                     a = 6
    412                 except:
    413                     pass
    414             except:
    415                 b = 8"""
    416         self.assertTrue(self.find_binding("a", s))
    417 
    418         s = """
    419             try:
    420                 b = 8
    421             except:
    422                 try:
    423                     a = 6
    424                 except:
    425                     pass"""
    426         self.assertTrue(self.find_binding("a", s))
    427 
    428         s = """
    429             try:
    430                 b = 8
    431             except:
    432                 try:
    433                     pass
    434                 except:
    435                     a = 6"""
    436         self.assertTrue(self.find_binding("a", s))
    437 
    438         s = """
    439             try:
    440                 try:
    441                     b = 8
    442                 except KeyError:
    443                     pass
    444                 except:
    445                     a = 6
    446             except:
    447                 pass"""
    448         self.assertTrue(self.find_binding("a", s))
    449 
    450         s = """
    451             try:
    452                 pass
    453             except:
    454                 try:
    455                     b = 8
    456                 except KeyError:
    457                     pass
    458                 except:
    459                     a = 6"""
    460         self.assertTrue(self.find_binding("a", s))
    461 
    462         s = """
    463             try:
    464                 b = 8
    465             except:
    466                 b = 6"""
    467         self.assertFalse(self.find_binding("a", s))
    468 
    469         s = """
    470             try:
    471                 try:
    472                     b = 8
    473                 except:
    474                     c = d
    475             except:
    476                 try:
    477                     b = 6
    478                 except:
    479                     t = 8
    480                 except:
    481                     o = y"""
    482         self.assertFalse(self.find_binding("a", s))
    483 
    484     def test_try_except_finally(self):
    485         s = """
    486             try:
    487                 c = 6
    488             except:
    489                 b = 8
    490             finally:
    491                 a = 9"""
    492         self.assertTrue(self.find_binding("a", s))
    493 
    494         s = """
    495             try:
    496                 b = 8
    497             finally:
    498                 a = 6"""
    499         self.assertTrue(self.find_binding("a", s))
    500 
    501         s = """
    502             try:
    503                 b = 8
    504             finally:
    505                 b = 6"""
    506         self.assertFalse(self.find_binding("a", s))
    507 
    508         s = """
    509             try:
    510                 b = 8
    511             except:
    512                 b = 9
    513             finally:
    514                 b = 6"""
    515         self.assertFalse(self.find_binding("a", s))
    516 
    517     def test_try_except_finally_nested(self):
    518         s = """
    519             try:
    520                 c = 6
    521             except:
    522                 b = 8
    523             finally:
    524                 try:
    525                     a = 9
    526                 except:
    527                     b = 9
    528                 finally:
    529                     c = 9"""
    530         self.assertTrue(self.find_binding("a", s))
    531 
    532         s = """
    533             try:
    534                 b = 8
    535             finally:
    536                 try:
    537                     pass
    538                 finally:
    539                     a = 6"""
    540         self.assertTrue(self.find_binding("a", s))
    541 
    542         s = """
    543             try:
    544                 b = 8
    545             finally:
    546                 try:
    547                     b = 6
    548                 finally:
    549                     b = 7"""
    550         self.assertFalse(self.find_binding("a", s))
    551 
    552 class Test_touch_import(support.TestCase):
    553 
    554     def test_after_docstring(self):
    555         node = parse('"""foo"""\nbar()')
    556         fixer_util.touch_import(None, "foo", node)
    557         self.assertEqual(str(node), '"""foo"""\nimport foo\nbar()\n\n')
    558 
    559     def test_after_imports(self):
    560         node = parse('"""foo"""\nimport bar\nbar()')
    561         fixer_util.touch_import(None, "foo", node)
    562         self.assertEqual(str(node), '"""foo"""\nimport bar\nimport foo\nbar()\n\n')
    563 
    564     def test_beginning(self):
    565         node = parse('bar()')
    566         fixer_util.touch_import(None, "foo", node)
    567         self.assertEqual(str(node), 'import foo\nbar()\n\n')
    568 
    569     def test_from_import(self):
    570         node = parse('bar()')
    571         fixer_util.touch_import("html", "escape", node)
    572         self.assertEqual(str(node), 'from html import escape\nbar()\n\n')
    573 
    574     def test_name_import(self):
    575         node = parse('bar()')
    576         fixer_util.touch_import(None, "cgi", node)
    577         self.assertEqual(str(node), 'import cgi\nbar()\n\n')
    578 
    579 class Test_find_indentation(support.TestCase):
    580 
    581     def test_nothing(self):
    582         fi = fixer_util.find_indentation
    583         node = parse("node()")
    584         self.assertEqual(fi(node), u"")
    585         node = parse("")
    586         self.assertEqual(fi(node), u"")
    587 
    588     def test_simple(self):
    589         fi = fixer_util.find_indentation
    590         node = parse("def f():\n    x()")
    591         self.assertEqual(fi(node), u"")
    592         self.assertEqual(fi(node.children[0].children[4].children[2]), u"    ")
    593         node = parse("def f():\n    x()\n    y()")
    594         self.assertEqual(fi(node.children[0].children[4].children[4]), u"    ")
    595