1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for live_values module.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.py2tf.pyct import anno 22 from tensorflow.contrib.py2tf.pyct import context 23 from tensorflow.contrib.py2tf.pyct import parser 24 from tensorflow.contrib.py2tf.pyct import qual_names 25 from tensorflow.contrib.py2tf.pyct.static_analysis import activity 26 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values 27 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.platform import test 30 31 32 class LiveValuesResolverTest(test.TestCase): 33 34 def _parse_and_analyze(self, 35 test_fn, 36 namespace, 37 literals=None, 38 arg_types=None): 39 literals = literals or {} 40 arg_types = arg_types or {} 41 node, source = parser.parse_entity(test_fn) 42 ctx = context.EntityContext( 43 namer=None, 44 source_code=source, 45 source_file=None, 46 namespace=namespace, 47 arg_values=None, 48 arg_types=arg_types, 49 recursive=True) 50 node = qual_names.resolve(node) 51 node = activity.resolve(node, ctx) 52 node = live_values.resolve(node, ctx, literals) 53 node = type_info.resolve(node, ctx) 54 node = live_values.resolve(node, ctx, literals) 55 return node 56 57 def test_literals(self): 58 59 def test_fn(): 60 return Foo # pylint: disable=undefined-variable 61 62 node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'}) 63 retval_node = node.body[0].body[0].value 64 self.assertEquals('bar', anno.getanno(retval_node, 'live_val')) 65 66 def test_namespace(self): 67 68 def foo(): 69 return 'bar' 70 71 def test_fn(): 72 return foo() 73 74 node = self._parse_and_analyze(test_fn, {'foo': foo}) 75 func_node = node.body[0].body[0].value.func 76 self.assertEquals(foo, anno.getanno(func_node, 'live_val')) 77 self.assertEquals(('foo',), anno.getanno(func_node, 'fqn')) 78 79 def test_attribute_names(self): 80 81 def test_fn(): 82 return constant_op.constant(0) 83 84 node = self._parse_and_analyze(test_fn, {'constant_op': constant_op}) 85 func_node = node.body[0].body[0].value.func 86 self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val')) 87 self.assertEquals((constant_op.__name__, 'constant'), 88 anno.getanno(func_node, 'fqn')) 89 90 def test_attributes_with_type_hints(self): 91 92 class TestClass(object): 93 94 def member(self): 95 pass 96 97 def test_fn(self): 98 return self.member() 99 100 node = self._parse_and_analyze( 101 TestClass.test_fn, {'constant_op': constant_op}, 102 arg_types={'self': (TestClass.__name__, TestClass)}) 103 func_node = node.body[0].body[0].value.func 104 self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) 105 self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) 106 107 108 if __name__ == '__main__': 109 test.main() 110