Home | History | Annotate | Download | only in yapftests
      1 # Copyright 2015 Google Inc. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Tests for yapf.split_penalty."""
     15 
     16 import sys
     17 import textwrap
     18 import unittest
     19 
     20 from lib2to3 import pytree
     21 
     22 from yapf.yapflib import pytree_utils
     23 from yapf.yapflib import pytree_visitor
     24 from yapf.yapflib import split_penalty
     25 
     26 UNBREAKABLE = split_penalty.UNBREAKABLE
     27 VERY_STRONGLY_CONNECTED = split_penalty.VERY_STRONGLY_CONNECTED
     28 DOTTED_NAME = split_penalty.DOTTED_NAME
     29 STRONGLY_CONNECTED = split_penalty.STRONGLY_CONNECTED
     30 
     31 
     32 class SplitPenaltyTest(unittest.TestCase):
     33 
     34   def _ParseAndComputePenalties(self, code, dumptree=False):
     35     """Parses the code and computes split penalties.
     36 
     37     Arguments:
     38       code: code to parse as a string
     39       dumptree: if True, the parsed pytree (after penalty assignment) is dumped
     40         to stderr. Useful for debugging.
     41 
     42     Returns:
     43       Parse tree.
     44     """
     45     tree = pytree_utils.ParseCodeToTree(code)
     46     split_penalty.ComputeSplitPenalties(tree)
     47     if dumptree:
     48       pytree_visitor.DumpPyTree(tree, target_stream=sys.stderr)
     49     return tree
     50 
     51   def _CheckPenalties(self, tree, list_of_expected):
     52     """Check that the tokens in the tree have the correct penalties.
     53 
     54     Args:
     55       tree: the pytree.
     56       list_of_expected: list of (name, penalty) pairs. Non-semantic tokens are
     57         filtered out from the expected values.
     58     """
     59 
     60     def FlattenRec(tree):
     61       if pytree_utils.NodeName(tree) in pytree_utils.NONSEMANTIC_TOKENS:
     62         return []
     63       if isinstance(tree, pytree.Leaf):
     64         return [(tree.value,
     65                  pytree_utils.GetNodeAnnotation(
     66                      tree, pytree_utils.Annotation.SPLIT_PENALTY))]
     67       nodes = []
     68       for node in tree.children:
     69         nodes += FlattenRec(node)
     70       return nodes
     71 
     72     self.assertEqual(list_of_expected, FlattenRec(tree))
     73 
     74   def testUnbreakable(self):
     75     # Test function definitions.
     76     code = textwrap.dedent(r"""
     77       def foo(x):
     78         pass
     79       """)
     80     tree = self._ParseAndComputePenalties(code)
     81     self._CheckPenalties(tree, [
     82         ('def', None),
     83         ('foo', UNBREAKABLE),
     84         ('(', UNBREAKABLE),
     85         ('x', None),
     86         (')', STRONGLY_CONNECTED),
     87         (':', UNBREAKABLE),
     88         ('pass', None),
     89     ])
     90 
     91     # Test function definition with trailing comment.
     92     code = textwrap.dedent(r"""
     93       def foo(x):  # trailing comment
     94         pass
     95       """)
     96     tree = self._ParseAndComputePenalties(code)
     97     self._CheckPenalties(tree, [
     98         ('def', None),
     99         ('foo', UNBREAKABLE),
    100         ('(', UNBREAKABLE),
    101         ('x', None),
    102         (')', STRONGLY_CONNECTED),
    103         (':', UNBREAKABLE),
    104         ('pass', None),
    105     ])
    106 
    107     # Test class definitions.
    108     code = textwrap.dedent(r"""
    109       class A:
    110         pass
    111       class B(A):
    112         pass
    113       """)
    114     tree = self._ParseAndComputePenalties(code)
    115     self._CheckPenalties(tree, [
    116         ('class', None),
    117         ('A', UNBREAKABLE),
    118         (':', UNBREAKABLE),
    119         ('pass', None),
    120         ('class', None),
    121         ('B', UNBREAKABLE),
    122         ('(', UNBREAKABLE),
    123         ('A', None),
    124         (')', None),
    125         (':', UNBREAKABLE),
    126         ('pass', None),
    127     ])
    128 
    129     # Test lambda definitions.
    130     code = textwrap.dedent(r"""
    131       lambda a, b: None
    132       """)
    133     tree = self._ParseAndComputePenalties(code)
    134     self._CheckPenalties(tree, [
    135         ('lambda', None),
    136         ('a', UNBREAKABLE),
    137         (',', UNBREAKABLE),
    138         ('b', UNBREAKABLE),
    139         (':', UNBREAKABLE),
    140         ('None', UNBREAKABLE),
    141     ])
    142 
    143     # Test dotted names.
    144     code = textwrap.dedent(r"""
    145       import a.b.c
    146       """)
    147     tree = self._ParseAndComputePenalties(code)
    148     self._CheckPenalties(tree, [
    149         ('import', None),
    150         ('a', None),
    151         ('.', UNBREAKABLE),
    152         ('b', UNBREAKABLE),
    153         ('.', UNBREAKABLE),
    154         ('c', UNBREAKABLE),
    155     ])
    156 
    157   def testStronglyConnected(self):
    158     # Test dictionary keys.
    159     code = textwrap.dedent(r"""
    160       a = {
    161           'x': 42,
    162           y(lambda a: 23): 37,
    163       }
    164       """)
    165     tree = self._ParseAndComputePenalties(code)
    166     self._CheckPenalties(tree, [
    167         ('a', None),
    168         ('=', None),
    169         ('{', None),
    170         ("'x'", None),
    171         (':', STRONGLY_CONNECTED),
    172         ('42', None),
    173         (',', None),
    174         ('y', None),
    175         ('(', UNBREAKABLE),
    176         ('lambda', STRONGLY_CONNECTED),
    177         ('a', UNBREAKABLE),
    178         (':', UNBREAKABLE),
    179         ('23', UNBREAKABLE),
    180         (')', VERY_STRONGLY_CONNECTED),
    181         (':', STRONGLY_CONNECTED),
    182         ('37', None),
    183         (',', None),
    184         ('}', None),
    185     ])
    186 
    187     # Test list comprehension.
    188     code = textwrap.dedent(r"""
    189       [a for a in foo if a.x == 37]
    190       """)
    191     tree = self._ParseAndComputePenalties(code)
    192     self._CheckPenalties(tree, [
    193         ('[', None),
    194         ('a', None),
    195         ('for', 0),
    196         ('a', STRONGLY_CONNECTED),
    197         ('in', STRONGLY_CONNECTED),
    198         ('foo', STRONGLY_CONNECTED),
    199         ('if', 0),
    200         ('a', STRONGLY_CONNECTED),
    201         ('.', UNBREAKABLE),
    202         ('x', DOTTED_NAME),
    203         ('==', STRONGLY_CONNECTED),
    204         ('37', STRONGLY_CONNECTED),
    205         (']', None),
    206     ])
    207 
    208   def testFuncCalls(self):
    209     code = 'foo(1, 2, 3)\n'
    210     tree = self._ParseAndComputePenalties(code)
    211     self._CheckPenalties(tree, [
    212         ('foo', None),
    213         ('(', UNBREAKABLE),
    214         ('1', None),
    215         (',', UNBREAKABLE),
    216         ('2', None),
    217         (',', UNBREAKABLE),
    218         ('3', None),
    219         (')', VERY_STRONGLY_CONNECTED),
    220     ])
    221 
    222     # Now a method call, which has more than one trailer
    223     code = 'foo.bar.baz(1, 2, 3)\n'
    224     tree = self._ParseAndComputePenalties(code)
    225     self._CheckPenalties(tree, [
    226         ('foo', None),
    227         ('.', UNBREAKABLE),
    228         ('bar', DOTTED_NAME),
    229         ('.', STRONGLY_CONNECTED),
    230         ('baz', DOTTED_NAME),
    231         ('(', STRONGLY_CONNECTED),
    232         ('1', None),
    233         (',', UNBREAKABLE),
    234         ('2', None),
    235         (',', UNBREAKABLE),
    236         ('3', None),
    237         (')', VERY_STRONGLY_CONNECTED),
    238     ])
    239 
    240     # Test single generator argument.
    241     code = 'max(i for i in xrange(10))\n'
    242     tree = self._ParseAndComputePenalties(code)
    243     self._CheckPenalties(tree, [
    244         ('max', None),
    245         ('(', UNBREAKABLE),
    246         ('i', 0),
    247         ('for', 0),
    248         ('i', STRONGLY_CONNECTED),
    249         ('in', STRONGLY_CONNECTED),
    250         ('xrange', STRONGLY_CONNECTED),
    251         ('(', UNBREAKABLE),
    252         ('10', STRONGLY_CONNECTED),
    253         (')', VERY_STRONGLY_CONNECTED),
    254         (')', VERY_STRONGLY_CONNECTED),
    255     ])
    256 
    257 
    258 if __name__ == '__main__':
    259   unittest.main()
    260