1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for type_info module.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.py2tf.pyct import anno 22 from tensorflow.contrib.py2tf.pyct import context 23 from tensorflow.contrib.py2tf.pyct import parser 24 from tensorflow.contrib.py2tf.pyct import qual_names 25 from tensorflow.contrib.py2tf.pyct.static_analysis import activity 26 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values 27 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info 28 from tensorflow.python.client import session 29 from tensorflow.python.platform import test 30 from tensorflow.python.training import training 31 32 33 class ScopeTest(test.TestCase): 34 35 def test_basic(self): 36 scope = type_info.Scope(None) 37 self.assertFalse(scope.hasval('foo')) 38 39 scope.setval('foo', 'bar') 40 self.assertTrue(scope.hasval('foo')) 41 42 self.assertFalse(scope.hasval('baz')) 43 44 def test_nesting(self): 45 scope = type_info.Scope(None) 46 scope.setval('foo', '') 47 48 child = type_info.Scope(scope) 49 self.assertTrue(child.hasval('foo')) 50 self.assertTrue(scope.hasval('foo')) 51 52 child.setval('bar', '') 53 self.assertTrue(child.hasval('bar')) 54 self.assertFalse(scope.hasval('bar')) 55 56 57 class TypeInfoResolverTest(test.TestCase): 58 59 def _parse_and_analyze(self, test_fn, namespace, arg_types=None): 60 node, source = parser.parse_entity(test_fn) 61 ctx = context.EntityContext( 62 namer=None, 63 source_code=source, 64 source_file=None, 65 namespace=namespace, 66 arg_values=None, 67 arg_types=arg_types, 68 recursive=True) 69 node = qual_names.resolve(node) 70 node = activity.resolve(node, ctx) 71 node = live_values.resolve(node, ctx, {}) 72 node = type_info.resolve(node, ctx) 73 node = live_values.resolve(node, ctx, {}) 74 return node 75 76 def test_constructor_detection(self): 77 78 def test_fn(): 79 opt = training.GradientDescentOptimizer(0.1) 80 return opt 81 82 node = self._parse_and_analyze(test_fn, {'training': training}) 83 call_node = node.body[0].body[0].value 84 self.assertEquals(training.GradientDescentOptimizer, 85 anno.getanno(call_node, 'type')) 86 self.assertEquals((training.__name__, 'GradientDescentOptimizer'), 87 anno.getanno(call_node, 'type_fqn')) 88 89 def test_class_members_of_detected_constructor(self): 90 91 def test_fn(): 92 opt = training.GradientDescentOptimizer(0.1) 93 opt.minimize(0) 94 95 node = self._parse_and_analyze(test_fn, {'training': training}) 96 method_call = node.body[0].body[1].value.func 97 self.assertEquals(training.GradientDescentOptimizer.minimize, 98 anno.getanno(method_call, 'live_val')) 99 100 def test_class_members_in_with_stmt(self): 101 102 def test_fn(x): 103 with session.Session() as sess: 104 sess.run(x) 105 106 node = self._parse_and_analyze(test_fn, {'session': session}) 107 constructor_call = node.body[0].body[0].items[0].context_expr 108 self.assertEquals(session.Session, anno.getanno(constructor_call, 'type')) 109 self.assertEquals((session.__name__, 'Session'), 110 anno.getanno(constructor_call, 'type_fqn')) 111 112 method_call = node.body[0].body[0].body[0].value.func 113 self.assertEquals(session.Session.run, anno.getanno(method_call, 114 'live_val')) 115 116 def test_constructor_data_dependent(self): 117 118 def test_fn(x): 119 if x > 0: 120 opt = training.GradientDescentOptimizer(0.1) 121 else: 122 opt = training.GradientDescentOptimizer(0.01) 123 opt.minimize(0) 124 125 node = self._parse_and_analyze(test_fn, {'training': training}) 126 method_call = node.body[0].body[1].value.func 127 self.assertFalse(anno.hasanno(method_call, 'live_val')) 128 129 def test_parameter_class_members(self): 130 131 def test_fn(opt): 132 opt.minimize(0) 133 134 node = self._parse_and_analyze(test_fn, {}) 135 method_call = node.body[0].body[0].value.func 136 self.assertFalse(anno.hasanno(method_call, 'live_val')) 137 138 def test_parameter_class_members_with_value_hints(self): 139 140 def test_fn(opt): 141 opt.minimize(0) 142 143 node = self._parse_and_analyze( 144 test_fn, {'training': training}, 145 arg_types={ 146 'opt': (training.GradientDescentOptimizer.__name__, 147 training.GradientDescentOptimizer) 148 }) 149 150 method_call = node.body[0].body[0].value.func 151 self.assertEquals(training.GradientDescentOptimizer.minimize, 152 anno.getanno(method_call, 'live_val')) 153 154 def test_function_variables(self): 155 156 def bar(): 157 pass 158 159 def test_fn(): 160 foo = bar 161 foo() 162 163 node = self._parse_and_analyze(test_fn, {'bar': bar}) 164 method_call = node.body[0].body[1].value.func 165 self.assertFalse(anno.hasanno(method_call, 'live_val')) 166 167 def test_nested_members(self): 168 169 def test_fn(): 170 foo = training.GradientDescentOptimizer(0.1) 171 foo.bar.baz() 172 173 node = self._parse_and_analyze(test_fn, {'training': training}) 174 method_call = node.body[0].body[1].value.func 175 self.assertFalse(anno.hasanno(method_call, 'live_val')) 176 177 178 if __name__ == '__main__': 179 test.main() 180