1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for qual_names module.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import textwrap 22 23 from tensorflow.contrib.py2tf.pyct import anno 24 from tensorflow.contrib.py2tf.pyct import parser 25 from tensorflow.contrib.py2tf.pyct import qual_names 26 from tensorflow.python.platform import test 27 28 29 class QNTest(test.TestCase): 30 31 def test_basic(self): 32 a = qual_names.QN('a') 33 self.assertEqual(a.qn, ('a',)) 34 self.assertEqual(str(a), 'a') 35 self.assertEqual(a.ssf(), 'a') 36 self.assertEqual(a.ast().id, 'a') 37 self.assertFalse(a.is_composite()) 38 with self.assertRaises(ValueError): 39 _ = a.parent 40 41 a_b = qual_names.QN(a, 'b') 42 self.assertEqual(a_b.qn, ('a', 'b')) 43 self.assertEqual(str(a_b), 'a.b') 44 self.assertEqual(a_b.ssf(), 'a_b') 45 self.assertEqual(a_b.ast().value.id, 'a') 46 self.assertEqual(a_b.ast().attr, 'b') 47 self.assertTrue(a_b.is_composite()) 48 self.assertEqual(a_b.parent.qn, ('a',)) 49 50 a2 = qual_names.QN(a) 51 self.assertEqual(a2.qn, ('a',)) 52 with self.assertRaises(ValueError): 53 _ = a.parent 54 55 a_b2 = qual_names.QN(a_b) 56 self.assertEqual(a_b2.qn, ('a', 'b')) 57 self.assertEqual(a_b2.parent.qn, ('a',)) 58 59 self.assertTrue(a2 == a) 60 self.assertFalse(a2 is a) 61 62 self.assertTrue(a_b.parent == a) 63 self.assertTrue(a_b2.parent == a) 64 65 self.assertTrue(a_b2 == a_b) 66 self.assertFalse(a_b2 is a_b) 67 self.assertFalse(a_b2 == a) 68 69 with self.assertRaises(ValueError): 70 qual_names.QN('a', 'b') 71 72 def test_hashable(self): 73 d = {qual_names.QN('a'): 'a', qual_names.QN('b'): 'b'} 74 75 self.assertEqual(d[qual_names.QN('a')], 'a') 76 self.assertEqual(d[qual_names.QN('b')], 'b') 77 self.assertTrue(qual_names.QN('c') not in d) 78 79 80 class QNResolverTest(test.TestCase): 81 82 def assertQNStringIs(self, node, qn_str): 83 self.assertEqual(str(anno.getanno(node, anno.Basic.QN)), qn_str) 84 85 def test_resolve(self): 86 samples = """ 87 a 88 a.b 89 (c, d.e) 90 [f, (g.h.i)] 91 j(k, l) 92 """ 93 nodes = qual_names.resolve(parser.parse_str(textwrap.dedent(samples))) 94 nodes = tuple(n.value for n in nodes.body) 95 96 self.assertQNStringIs(nodes[0], 'a') 97 self.assertQNStringIs(nodes[1], 'a.b') 98 self.assertQNStringIs(nodes[2].elts[0], 'c') 99 self.assertQNStringIs(nodes[2].elts[1], 'd.e') 100 self.assertQNStringIs(nodes[3].elts[0], 'f') 101 self.assertQNStringIs(nodes[3].elts[1], 'g.h.i') 102 self.assertQNStringIs(nodes[4].func, 'j') 103 self.assertQNStringIs(nodes[4].args[0], 'k') 104 self.assertQNStringIs(nodes[4].args[1], 'l') 105 106 107 if __name__ == '__main__': 108 test.main() 109