Home | History | Annotate | Download | only in yapftests
      1 # Copyright 2015 Google Inc. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 """Tests for yapf.pytree_visitor."""
     15 
     16 import unittest
     17 
     18 from yapf.yapflib import py3compat
     19 from yapf.yapflib import pytree_utils
     20 from yapf.yapflib import pytree_visitor
     21 
     22 
     23 class _NodeNameCollector(pytree_visitor.PyTreeVisitor):
     24   """A tree visitor that collects the names of all tree nodes into a list.
     25 
     26   Attributes:
     27     all_node_names: collected list of the names, available when the traversal
     28       is over.
     29     name_node_values: collects a list of NAME leaves (in addition to those going
     30       into all_node_names).
     31   """
     32 
     33   def __init__(self):
     34     self.all_node_names = []
     35     self.name_node_values = []
     36 
     37   def DefaultNodeVisit(self, node):
     38     self.all_node_names.append(pytree_utils.NodeName(node))
     39     super(_NodeNameCollector, self).DefaultNodeVisit(node)
     40 
     41   def DefaultLeafVisit(self, leaf):
     42     self.all_node_names.append(pytree_utils.NodeName(leaf))
     43 
     44   def Visit_NAME(self, leaf):
     45     self.name_node_values.append(leaf.value)
     46     self.DefaultLeafVisit(leaf)
     47 
     48 
     49 _VISITOR_TEST_SIMPLE_CODE = r"""
     50 foo = bar
     51 baz = x
     52 """
     53 
     54 _VISITOR_TEST_NESTED_CODE = r"""
     55 if x:
     56   if y:
     57     return z
     58 """
     59 
     60 
     61 class PytreeVisitorTest(unittest.TestCase):
     62 
     63   def testCollectAllNodeNamesSimpleCode(self):
     64     tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
     65     collector = _NodeNameCollector()
     66     collector.Visit(tree)
     67     expected_names = [
     68         'file_input',
     69         'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
     70         'simple_stmt', 'expr_stmt', 'NAME', 'EQUAL', 'NAME', 'NEWLINE',
     71         'ENDMARKER',
     72     ]  # yapf: disable
     73     self.assertEqual(expected_names, collector.all_node_names)
     74 
     75     expected_name_node_values = ['foo', 'bar', 'baz', 'x']
     76     self.assertEqual(expected_name_node_values, collector.name_node_values)
     77 
     78   def testCollectAllNodeNamesNestedCode(self):
     79     tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_NESTED_CODE)
     80     collector = _NodeNameCollector()
     81     collector.Visit(tree)
     82     expected_names = [
     83         'file_input',
     84         'if_stmt', 'NAME', 'NAME', 'COLON',
     85         'suite', 'NEWLINE',
     86         'INDENT', 'if_stmt', 'NAME', 'NAME', 'COLON', 'suite', 'NEWLINE',
     87         'INDENT', 'simple_stmt', 'return_stmt', 'NAME', 'NAME', 'NEWLINE',
     88         'DEDENT', 'DEDENT', 'ENDMARKER',
     89     ]  # yapf: disable
     90     self.assertEqual(expected_names, collector.all_node_names)
     91 
     92     expected_name_node_values = ['if', 'x', 'if', 'y', 'return', 'z']
     93     self.assertEqual(expected_name_node_values, collector.name_node_values)
     94 
     95   def testDumper(self):
     96     # PyTreeDumper is mainly a debugging utility, so only do basic sanity
     97     # checking.
     98     tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
     99     stream = py3compat.StringIO()
    100     pytree_visitor.PyTreeDumper(target_stream=stream).Visit(tree)
    101 
    102     dump_output = stream.getvalue()
    103     self.assertIn('file_input [3 children]', dump_output)
    104     self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
    105     self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
    106 
    107   def testDumpPyTree(self):
    108     # Similar sanity checking for the convenience wrapper DumpPyTree
    109     tree = pytree_utils.ParseCodeToTree(_VISITOR_TEST_SIMPLE_CODE)
    110     stream = py3compat.StringIO()
    111     pytree_visitor.DumpPyTree(tree, target_stream=stream)
    112 
    113     dump_output = stream.getvalue()
    114     self.assertIn('file_input [3 children]', dump_output)
    115     self.assertIn("NAME(Leaf(NAME, 'foo'))", dump_output)
    116     self.assertIn("EQUAL(Leaf(EQUAL, '='))", dump_output)
    117 
    118 
    119 if __name__ == '__main__':
    120   unittest.main()
    121