Home | History | Annotate | Download | only in tests
      1 """ Test suite for the code in fixer_util """
      2 
      3 # Testing imports
      4 from . import support
      5 
      6 # Local imports
      7 from lib2to3.pytree import Node, Leaf
      8 from lib2to3 import fixer_util
      9 from lib2to3.fixer_util import Attr, Name, Call, Comma
     10 from lib2to3.pgen2 import token
     11 
     12 def parse(code, strip_levels=0):
     13     # The topmost node is file_input, which we don't care about.
     14     # The next-topmost node is a *_stmt node, which we also don't care about
     15     tree = support.parse_string(code)
     16     for i in range(strip_levels):
     17         tree = tree.children[0]
     18     tree.parent = None
     19     return tree
     20 
     21 class MacroTestCase(support.TestCase):
     22     def assertStr(self, node, string):
     23         if isinstance(node, (tuple, list)):
     24             node = Node(fixer_util.syms.simple_stmt, node)
     25         self.assertEqual(str(node), string)
     26 
     27 
     28 class Test_is_tuple(support.TestCase):
     29     def is_tuple(self, string):
     30         return fixer_util.is_tuple(parse(string, strip_levels=2))
     31 
     32     def test_valid(self):
     33         self.assertTrue(self.is_tuple("(a, b)"))
     34         self.assertTrue(self.is_tuple("(a, (b, c))"))
     35         self.assertTrue(self.is_tuple("((a, (b, c)),)"))
     36         self.assertTrue(self.is_tuple("(a,)"))
     37         self.assertTrue(self.is_tuple("()"))
     38 
     39     def test_invalid(self):
     40         self.assertFalse(self.is_tuple("(a)"))
     41         self.assertFalse(self.is_tuple("('foo') % (b, c)"))
     42 
     43 
     44 class Test_is_list(support.TestCase):
     45     def is_list(self, string):
     46         return fixer_util.is_list(parse(string, strip_levels=2))
     47 
     48     def test_valid(self):
     49         self.assertTrue(self.is_list("[]"))
     50         self.assertTrue(self.is_list("[a]"))
     51         self.assertTrue(self.is_list("[a, b]"))
     52         self.assertTrue(self.is_list("[a, [b, c]]"))
     53         self.assertTrue(self.is_list("[[a, [b, c]],]"))
     54 
     55     def test_invalid(self):
     56         self.assertFalse(self.is_list("[]+[]"))
     57 
     58 
     59 class Test_Attr(MacroTestCase):
     60     def test(self):
     61         call = parse("foo()", strip_levels=2)
     62 
     63         self.assertStr(Attr(Name("a"), Name("b")), "a.b")
     64         self.assertStr(Attr(call, Name("b")), "foo().b")
     65 
     66     def test_returns(self):
     67         attr = Attr(Name("a"), Name("b"))
     68         self.assertEqual(type(attr), list)
     69 
     70 
     71 class Test_Name(MacroTestCase):
     72     def test(self):
     73         self.assertStr(Name("a"), "a")
     74         self.assertStr(Name("foo.foo().bar"), "foo.foo().bar")
     75         self.assertStr(Name("a", prefix="b"), "ba")
     76 
     77 
     78 class Test_Call(MacroTestCase):
     79     def _Call(self, name, args=None, prefix=None):
     80         """Help the next test"""
     81         children = []
     82         if isinstance(args, list):
     83             for arg in args:
     84                 children.append(arg)
     85                 children.append(Comma())
     86             children.pop()
     87         return Call(Name(name), children, prefix)
     88 
     89     def test(self):
     90         kids = [None,
     91                 [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2),
     92                  Leaf(token.NUMBER, 3)],
     93                 [Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3),
     94                  Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)],
     95                 [Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")]
     96                 ]
     97         self.assertStr(self._Call("A"), "A()")
     98         self.assertStr(self._Call("b", kids[1]), "b(1,2,3)")
     99         self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)")
    100         self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)")
    101 
    102 
    103 class Test_does_tree_import(support.TestCase):
    104     def _find_bind_rec(self, name, node):
    105         # Search a tree for a binding -- used to find the starting
    106         # point for these tests.
    107         c = fixer_util.find_binding(name, node)
    108         if c: return c
    109         for child in node.children:
    110             c = self._find_bind_rec(name, child)
    111             if c: return c
    112 
    113     def does_tree_import(self, package, name, string):
    114         node = parse(string)
    115         # Find the binding of start -- that's what we'll go from
    116         node = self._find_bind_rec('start', node)
    117         return fixer_util.does_tree_import(package, name, node)
    118 
    119     def try_with(self, string):
    120         failing_tests = (("a", "a", "from a import b"),
    121                          ("a.d", "a", "from a.d import b"),
    122                          ("d.a", "a", "from d.a import b"),
    123                          (None, "a", "import b"),
    124                          (None, "a", "import b, c, d"))
    125         for package, name, import_ in failing_tests:
    126             n = self.does_tree_import(package, name, import_ + "\n" + string)
    127             self.assertFalse(n)
    128             n = self.does_tree_import(package, name, string + "\n" + import_)
    129             self.assertFalse(n)
    130 
    131         passing_tests = (("a", "a", "from a import a"),
    132                          ("x", "a", "from x import a"),
    133                          ("x", "a", "from x import b, c, a, d"),
    134                          ("x.b", "a", "from x.b import a"),
    135                          ("x.b", "a", "from x.b import b, c, a, d"),
    136                          (None, "a", "import a"),
    137                          (None, "a", "import b, c, a, d"))
    138         for package, name, import_ in passing_tests:
    139             n = self.does_tree_import(package, name, import_ + "\n" + string)
    140             self.assertTrue(n)
    141             n = self.does_tree_import(package, name, string + "\n" + import_)
    142             self.assertTrue(n)
    143 
    144     def test_in_function(self):
    145         self.try_with("def foo():\n\tbar.baz()\n\tstart=3")
    146 
    147 class Test_find_binding(support.TestCase):
    148     def find_binding(self, name, string, package=None):
    149         return fixer_util.find_binding(name, parse(string), package)
    150 
    151     def test_simple_assignment(self):
    152         self.assertTrue(self.find_binding("a", "a = b"))
    153         self.assertTrue(self.find_binding("a", "a = [b, c, d]"))
    154         self.assertTrue(self.find_binding("a", "a = foo()"))
    155         self.assertTrue(self.find_binding("a", "a = foo().foo.foo[6][foo]"))
    156         self.assertFalse(self.find_binding("a", "foo = a"))
    157         self.assertFalse(self.find_binding("a", "foo = (a, b, c)"))
    158 
    159     def test_tuple_assignment(self):
    160         self.assertTrue(self.find_binding("a", "(a,) = b"))
    161         self.assertTrue(self.find_binding("a", "(a, b, c) = [b, c, d]"))
    162         self.assertTrue(self.find_binding("a", "(c, (d, a), b) = foo()"))
    163         self.assertTrue(self.find_binding("a", "(a, b) = foo().foo[6][foo]"))
    164         self.assertFalse(self.find_binding("a", "(foo, b) = (b, a)"))
    165         self.assertFalse(self.find_binding("a", "(foo, (b, c)) = (a, b, c)"))
    166 
    167     def test_list_assignment(self):
    168         self.assertTrue(self.find_binding("a", "[a] = b"))
    169         self.assertTrue(self.find_binding("a", "[a, b, c] = [b, c, d]"))
    170         self.assertTrue(self.find_binding("a", "[c, [d, a], b] = foo()"))
    171         self.assertTrue(self.find_binding("a", "[a, b] = foo().foo[a][foo]"))
    172         self.assertFalse(self.find_binding("a", "[foo, b] = (b, a)"))
    173         self.assertFalse(self.find_binding("a", "[foo, [b, c]] = (a, b, c)"))
    174 
    175     def test_invalid_assignments(self):
    176         self.assertFalse(self.find_binding("a", "foo.a = 5"))
    177         self.assertFalse(self.find_binding("a", "foo[a] = 5"))
    178         self.assertFalse(self.find_binding("a", "foo(a) = 5"))
    179         self.assertFalse(self.find_binding("a", "foo(a, b) = 5"))
    180 
    181     def test_simple_import(self):
    182         self.assertTrue(self.find_binding("a", "import a"))
    183         self.assertTrue(self.find_binding("a", "import b, c, a, d"))
    184         self.assertFalse(self.find_binding("a", "import b"))
    185         self.assertFalse(self.find_binding("a", "import b, c, d"))
    186 
    187     def test_from_import(self):
    188         self.assertTrue(self.find_binding("a", "from x import a"))
    189         self.assertTrue(self.find_binding("a", "from a import a"))
    190         self.assertTrue(self.find_binding("a", "from x import b, c, a, d"))
    191         self.assertTrue(self.find_binding("a", "from x.b import a"))
    192         self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d"))
    193         self.assertFalse(self.find_binding("a", "from a import b"))
    194         self.assertFalse(self.find_binding("a", "from a.d import b"))
    195         self.assertFalse(self.find_binding("a", "from d.a import b"))
    196 
    197     def test_import_as(self):
    198         self.assertTrue(self.find_binding("a", "import b as a"))
    199         self.assertTrue(self.find_binding("a", "import b as a, c, a as f, d"))
    200         self.assertFalse(self.find_binding("a", "import a as f"))
    201         self.assertFalse(self.find_binding("a", "import b, c as f, d as e"))
    202 
    203     def test_from_import_as(self):
    204         self.assertTrue(self.find_binding("a", "from x import b as a"))
    205         self.assertTrue(self.find_binding("a", "from x import g as a, d as b"))
    206         self.assertTrue(self.find_binding("a", "from x.b import t as a"))
    207         self.assertTrue(self.find_binding("a", "from x.b import g as a, d"))
    208         self.assertFalse(self.find_binding("a", "from a import b as t"))
    209         self.assertFalse(self.find_binding("a", "from a.d import b as t"))
    210         self.assertFalse(self.find_binding("a", "from d.a import b as t"))
    211 
    212     def test_simple_import_with_package(self):
    213         self.assertTrue(self.find_binding("b", "import b"))
    214         self.assertTrue(self.find_binding("b", "import b, c, d"))
    215         self.assertFalse(self.find_binding("b", "import b", "b"))
    216         self.assertFalse(self.find_binding("b", "import b, c, d", "c"))
    217 
    218     def test_from_import_with_package(self):
    219         self.assertTrue(self.find_binding("a", "from x import a", "x"))
    220         self.assertTrue(self.find_binding("a", "from a import a", "a"))
    221         self.assertTrue(self.find_binding("a", "from x import *", "x"))
    222         self.assertTrue(self.find_binding("a", "from x import b, c, a, d", "x"))
    223         self.assertTrue(self.find_binding("a", "from x.b import a", "x.b"))
    224         self.assertTrue(self.find_binding("a", "from x.b import *", "x.b"))
    225         self.assertTrue(self.find_binding("a", "from x.b import b, c, a, d", "x.b"))
    226         self.assertFalse(self.find_binding("a", "from a import b", "a"))
    227         self.assertFalse(self.find_binding("a", "from a.d import b", "a.d"))
    228         self.assertFalse(self.find_binding("a", "from d.a import b", "a.d"))
    229         self.assertFalse(self.find_binding("a", "from x.y import *", "a.b"))
    230 
    231     def test_import_as_with_package(self):
    232         self.assertFalse(self.find_binding("a", "import b.c as a", "b.c"))
    233         self.assertFalse(self.find_binding("a", "import a as f", "f"))
    234         self.assertFalse(self.find_binding("a", "import a as f", "a"))
    235 
    236     def test_from_import_as_with_package(self):
    237         # Because it would take a lot of special-case code in the fixers
    238         # to deal with from foo import bar as baz, we'll simply always
    239         # fail if there is an "from ... import ... as ..."
    240         self.assertFalse(self.find_binding("a", "from x import b as a", "x"))
    241         self.assertFalse(self.find_binding("a", "from x import g as a, d as b", "x"))
    242         self.assertFalse(self.find_binding("a", "from x.b import t as a", "x.b"))
    243         self.assertFalse(self.find_binding("a", "from x.b import g as a, d", "x.b"))
    244         self.assertFalse(self.find_binding("a", "from a import b as t", "a"))
    245         self.assertFalse(self.find_binding("a", "from a import b as t", "b"))
    246         self.assertFalse(self.find_binding("a", "from a import b as t", "t"))
    247 
    248     def test_function_def(self):
    249         self.assertTrue(self.find_binding("a", "def a(): pass"))
    250         self.assertTrue(self.find_binding("a", "def a(b, c, d): pass"))
    251         self.assertTrue(self.find_binding("a", "def a(): b = 7"))
    252         self.assertFalse(self.find_binding("a", "def d(b, (c, a), e): pass"))
    253         self.assertFalse(self.find_binding("a", "def d(a=7): pass"))
    254         self.assertFalse(self.find_binding("a", "def d(a): pass"))
    255         self.assertFalse(self.find_binding("a", "def d(): a = 7"))
    256 
    257         s = """
    258             def d():
    259                 def a():
    260                     pass"""
    261         self.assertFalse(self.find_binding("a", s))
    262 
    263     def test_class_def(self):
    264         self.assertTrue(self.find_binding("a", "class a: pass"))
    265         self.assertTrue(self.find_binding("a", "class a(): pass"))
    266         self.assertTrue(self.find_binding("a", "class a(b): pass"))
    267         self.assertTrue(self.find_binding("a", "class a(b, c=8): pass"))
    268         self.assertFalse(self.find_binding("a", "class d: pass"))
    269         self.assertFalse(self.find_binding("a", "class d(a): pass"))
    270         self.assertFalse(self.find_binding("a", "class d(b, a=7): pass"))
    271         self.assertFalse(self.find_binding("a", "class d(b, *a): pass"))
    272         self.assertFalse(self.find_binding("a", "class d(b, **a): pass"))
    273         self.assertFalse(self.find_binding("a", "class d: a = 7"))
    274 
    275         s = """
    276             class d():
    277                 class a():
    278                     pass"""
    279         self.assertFalse(self.find_binding("a", s))
    280 
    281     def test_for(self):
    282         self.assertTrue(self.find_binding("a", "for a in r: pass"))
    283         self.assertTrue(self.find_binding("a", "for a, b in r: pass"))
    284         self.assertTrue(self.find_binding("a", "for (a, b) in r: pass"))
    285         self.assertTrue(self.find_binding("a", "for c, (a,) in r: pass"))
    286         self.assertTrue(self.find_binding("a", "for c, (a, b) in r: pass"))
    287         self.assertTrue(self.find_binding("a", "for c in r: a = c"))
    288         self.assertFalse(self.find_binding("a", "for c in a: pass"))
    289 
    290     def test_for_nested(self):
    291         s = """
    292             for b in r:
    293                 for a in b:
    294                     pass"""
    295         self.assertTrue(self.find_binding("a", s))
    296 
    297         s = """
    298             for b in r:
    299                 for a, c in b:
    300                     pass"""
    301         self.assertTrue(self.find_binding("a", s))
    302 
    303         s = """
    304             for b in r:
    305                 for (a, c) in b:
    306                     pass"""
    307         self.assertTrue(self.find_binding("a", s))
    308 
    309         s = """
    310             for b in r:
    311                 for (a,) in b:
    312                     pass"""
    313         self.assertTrue(self.find_binding("a", s))
    314 
    315         s = """
    316             for b in r:
    317                 for c, (a, d) in b:
    318                     pass"""
    319         self.assertTrue(self.find_binding("a", s))
    320 
    321         s = """
    322             for b in r:
    323                 for c in b:
    324                     a = 7"""
    325         self.assertTrue(self.find_binding("a", s))
    326 
    327         s = """
    328             for b in r:
    329                 for c in b:
    330                     d = a"""
    331         self.assertFalse(self.find_binding("a", s))
    332 
    333         s = """
    334             for b in r:
    335                 for c in a:
    336                     d = 7"""
    337         self.assertFalse(self.find_binding("a", s))
    338 
    339     def test_if(self):
    340         self.assertTrue(self.find_binding("a", "if b in r: a = c"))
    341         self.assertFalse(self.find_binding("a", "if a in r: d = e"))
    342 
    343     def test_if_nested(self):
    344         s = """
    345             if b in r:
    346                 if c in d:
    347                     a = c"""
    348         self.assertTrue(self.find_binding("a", s))
    349 
    350         s = """
    351             if b in r:
    352                 if c in d:
    353                     c = a"""
    354         self.assertFalse(self.find_binding("a", s))
    355 
    356     def test_while(self):
    357         self.assertTrue(self.find_binding("a", "while b in r: a = c"))
    358         self.assertFalse(self.find_binding("a", "while a in r: d = e"))
    359 
    360     def test_while_nested(self):
    361         s = """
    362             while b in r:
    363                 while c in d:
    364                     a = c"""
    365         self.assertTrue(self.find_binding("a", s))
    366 
    367         s = """
    368             while b in r:
    369                 while c in d:
    370                     c = a"""
    371         self.assertFalse(self.find_binding("a", s))
    372 
    373     def test_try_except(self):
    374         s = """
    375             try:
    376                 a = 6
    377             except:
    378                 b = 8"""
    379         self.assertTrue(self.find_binding("a", s))
    380 
    381         s = """
    382             try:
    383                 b = 8
    384             except:
    385                 a = 6"""
    386         self.assertTrue(self.find_binding("a", s))
    387 
    388         s = """
    389             try:
    390                 b = 8
    391             except KeyError:
    392                 pass
    393             except:
    394                 a = 6"""
    395         self.assertTrue(self.find_binding("a", s))
    396 
    397         s = """
    398             try:
    399                 b = 8
    400             except:
    401                 b = 6"""
    402         self.assertFalse(self.find_binding("a", s))
    403 
    404     def test_try_except_nested(self):
    405         s = """
    406             try:
    407                 try:
    408                     a = 6
    409                 except:
    410                     pass
    411             except:
    412                 b = 8"""
    413         self.assertTrue(self.find_binding("a", s))
    414 
    415         s = """
    416             try:
    417                 b = 8
    418             except:
    419                 try:
    420                     a = 6
    421                 except:
    422                     pass"""
    423         self.assertTrue(self.find_binding("a", s))
    424 
    425         s = """
    426             try:
    427                 b = 8
    428             except:
    429                 try:
    430                     pass
    431                 except:
    432                     a = 6"""
    433         self.assertTrue(self.find_binding("a", s))
    434 
    435         s = """
    436             try:
    437                 try:
    438                     b = 8
    439                 except KeyError:
    440                     pass
    441                 except:
    442                     a = 6
    443             except:
    444                 pass"""
    445         self.assertTrue(self.find_binding("a", s))
    446 
    447         s = """
    448             try:
    449                 pass
    450             except:
    451                 try:
    452                     b = 8
    453                 except KeyError:
    454                     pass
    455                 except:
    456                     a = 6"""
    457         self.assertTrue(self.find_binding("a", s))
    458 
    459         s = """
    460             try:
    461                 b = 8
    462             except:
    463                 b = 6"""
    464         self.assertFalse(self.find_binding("a", s))
    465 
    466         s = """
    467             try:
    468                 try:
    469                     b = 8
    470                 except:
    471                     c = d
    472             except:
    473                 try:
    474                     b = 6
    475                 except:
    476                     t = 8
    477                 except:
    478                     o = y"""
    479         self.assertFalse(self.find_binding("a", s))
    480 
    481     def test_try_except_finally(self):
    482         s = """
    483             try:
    484                 c = 6
    485             except:
    486                 b = 8
    487             finally:
    488                 a = 9"""
    489         self.assertTrue(self.find_binding("a", s))
    490 
    491         s = """
    492             try:
    493                 b = 8
    494             finally:
    495                 a = 6"""
    496         self.assertTrue(self.find_binding("a", s))
    497 
    498         s = """
    499             try:
    500                 b = 8
    501             finally:
    502                 b = 6"""
    503         self.assertFalse(self.find_binding("a", s))
    504 
    505         s = """
    506             try:
    507                 b = 8
    508             except:
    509                 b = 9
    510             finally:
    511                 b = 6"""
    512         self.assertFalse(self.find_binding("a", s))
    513 
    514     def test_try_except_finally_nested(self):
    515         s = """
    516             try:
    517                 c = 6
    518             except:
    519                 b = 8
    520             finally:
    521                 try:
    522                     a = 9
    523                 except:
    524                     b = 9
    525                 finally:
    526                     c = 9"""
    527         self.assertTrue(self.find_binding("a", s))
    528 
    529         s = """
    530             try:
    531                 b = 8
    532             finally:
    533                 try:
    534                     pass
    535                 finally:
    536                     a = 6"""
    537         self.assertTrue(self.find_binding("a", s))
    538 
    539         s = """
    540             try:
    541                 b = 8
    542             finally:
    543                 try:
    544                     b = 6
    545                 finally:
    546                     b = 7"""
    547         self.assertFalse(self.find_binding("a", s))
    548 
    549 class Test_touch_import(support.TestCase):
    550 
    551     def test_after_docstring(self):
    552         node = parse('"""foo"""\nbar()')
    553         fixer_util.touch_import(None, "foo", node)
    554         self.assertEqual(str(node), '"""foo"""\nimport foo\nbar()\n\n')
    555 
    556     def test_after_imports(self):
    557         node = parse('"""foo"""\nimport bar\nbar()')
    558         fixer_util.touch_import(None, "foo", node)
    559         self.assertEqual(str(node), '"""foo"""\nimport bar\nimport foo\nbar()\n\n')
    560 
    561     def test_beginning(self):
    562         node = parse('bar()')
    563         fixer_util.touch_import(None, "foo", node)
    564         self.assertEqual(str(node), 'import foo\nbar()\n\n')
    565 
    566     def test_from_import(self):
    567         node = parse('bar()')
    568         fixer_util.touch_import("html", "escape", node)
    569         self.assertEqual(str(node), 'from html import escape\nbar()\n\n')
    570 
    571     def test_name_import(self):
    572         node = parse('bar()')
    573         fixer_util.touch_import(None, "cgi", node)
    574         self.assertEqual(str(node), 'import cgi\nbar()\n\n')
    575 
    576 class Test_find_indentation(support.TestCase):
    577 
    578     def test_nothing(self):
    579         fi = fixer_util.find_indentation
    580         node = parse("node()")
    581         self.assertEqual(fi(node), "")
    582         node = parse("")
    583         self.assertEqual(fi(node), "")
    584 
    585     def test_simple(self):
    586         fi = fixer_util.find_indentation
    587         node = parse("def f():\n    x()")
    588         self.assertEqual(fi(node), "")
    589         self.assertEqual(fi(node.children[0].children[4].children[2]), "    ")
    590         node = parse("def f():\n    x()\n    y()")
    591         self.assertEqual(fi(node.children[0].children[4].children[4]), "    ")
    592