Home | History | Annotate | Download | only in pyct
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for qual_names module."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import textwrap
     22 
     23 from tensorflow.contrib.py2tf.pyct import anno
     24 from tensorflow.contrib.py2tf.pyct import parser
     25 from tensorflow.contrib.py2tf.pyct import qual_names
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class QNTest(test.TestCase):
     30 
     31   def test_basic(self):
     32     a = qual_names.QN('a')
     33     self.assertEqual(a.qn, ('a',))
     34     self.assertEqual(str(a), 'a')
     35     self.assertEqual(a.ssf(), 'a')
     36     self.assertEqual(a.ast().id, 'a')
     37     self.assertFalse(a.is_composite())
     38     with self.assertRaises(ValueError):
     39       _ = a.parent
     40 
     41     a_b = qual_names.QN(a, 'b')
     42     self.assertEqual(a_b.qn, ('a', 'b'))
     43     self.assertEqual(str(a_b), 'a.b')
     44     self.assertEqual(a_b.ssf(), 'a_b')
     45     self.assertEqual(a_b.ast().value.id, 'a')
     46     self.assertEqual(a_b.ast().attr, 'b')
     47     self.assertTrue(a_b.is_composite())
     48     self.assertEqual(a_b.parent.qn, ('a',))
     49 
     50     a2 = qual_names.QN(a)
     51     self.assertEqual(a2.qn, ('a',))
     52     with self.assertRaises(ValueError):
     53       _ = a.parent
     54 
     55     a_b2 = qual_names.QN(a_b)
     56     self.assertEqual(a_b2.qn, ('a', 'b'))
     57     self.assertEqual(a_b2.parent.qn, ('a',))
     58 
     59     self.assertTrue(a2 == a)
     60     self.assertFalse(a2 is a)
     61 
     62     self.assertTrue(a_b.parent == a)
     63     self.assertTrue(a_b2.parent == a)
     64 
     65     self.assertTrue(a_b2 == a_b)
     66     self.assertFalse(a_b2 is a_b)
     67     self.assertFalse(a_b2 == a)
     68 
     69     with self.assertRaises(ValueError):
     70       qual_names.QN('a', 'b')
     71 
     72   def test_hashable(self):
     73     d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'}
     74 
     75     self.assertEqual(d[qual_names.QN('a')], 'a')
     76     self.assertEqual(d[qual_names.QN('b')], 'b')
     77     self.assertTrue(qual_names.QN('c') not in d)
     78 
     79 
     80 class QNResolverTest(test.TestCase):
     81 
     82   def assertQNStringIs(self, node, qn_str):
     83     self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str)
     84 
     85   def test_resolve(self):
     86     samples = """
     87       a
     88       a.b
     89       (c, d.e)
     90       [f, (g.h.i)]
     91       j(k, l)
     92     """
     93     nodes = qual_names.resolve(parser.parse_str(textwrap.dedent(samples)))
     94     nodes = tuple(n.value for n in nodes.body)
     95 
     96     self.assertQNStringIs(nodes[0], 'a')
     97     self.assertQNStringIs(nodes[1], 'a.b')
     98     self.assertQNStringIs(nodes[2].elts[0], 'c')
     99     self.assertQNStringIs(nodes[2].elts[1], 'd.e')
    100     self.assertQNStringIs(nodes[3].elts[0], 'f')
    101     self.assertQNStringIs(nodes[3].elts[1], 'g.h.i')
    102     self.assertQNStringIs(nodes[4].func, 'j')
    103     self.assertQNStringIs(nodes[4].args[0], 'k')
    104     self.assertQNStringIs(nodes[4].args[1], 'l')
    105 
    106 
    107 if __name__ == '__main__':
    108   test.main()
    109