Home | History | Annotate | Download | only in static_analysis
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for type_info module."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.py2tf.pyct import anno
     22 from tensorflow.contrib.py2tf.pyct import context
     23 from tensorflow.contrib.py2tf.pyct import parser
     24 from tensorflow.contrib.py2tf.pyct import qual_names
     25 from tensorflow.contrib.py2tf.pyct.static_analysis import activity
     26 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
     27 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
     28 from tensorflow.python.client import session
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.training import training
     31 
     32 
     33 class ScopeTest(test.TestCase):
     34 
     35   def test_basic(self):
     36     scope = type_info.Scope(None)
     37     self.assertFalse(scope.hasval('foo'))
     38 
     39     scope.setval('foo', 'bar')
     40     self.assertTrue(scope.hasval('foo'))
     41 
     42     self.assertFalse(scope.hasval('baz'))
     43 
     44   def test_nesting(self):
     45     scope = type_info.Scope(None)
     46     scope.setval('foo', '')
     47 
     48     child = type_info.Scope(scope)
     49     self.assertTrue(child.hasval('foo'))
     50     self.assertTrue(scope.hasval('foo'))
     51 
     52     child.setval('bar', '')
     53     self.assertTrue(child.hasval('bar'))
     54     self.assertFalse(scope.hasval('bar'))
     55 
     56 
     57 class TypeInfoResolverTest(test.TestCase):
     58 
     59   def _parse_and_analyze(self, test_fn, namespace, arg_types=None):
     60     node, source = parser.parse_entity(test_fn)
     61     ctx = context.EntityContext(
     62         namer=None,
     63         source_code=source,
     64         source_file=None,
     65         namespace=namespace,
     66         arg_values=None,
     67         arg_types=arg_types,
     68         recursive=True)
     69     node = qual_names.resolve(node)
     70     node = activity.resolve(node, ctx)
     71     node = live_values.resolve(node, ctx, {})
     72     node = type_info.resolve(node, ctx)
     73     node = live_values.resolve(node, ctx, {})
     74     return node
     75 
     76   def test_constructor_detection(self):
     77 
     78     def test_fn():
     79       opt = training.GradientDescentOptimizer(0.1)
     80       return opt
     81 
     82     node = self._parse_and_analyze(test_fn, {'training': training})
     83     call_node = node.body[0].body[0].value
     84     self.assertEquals(training.GradientDescentOptimizer,
     85                       anno.getanno(call_node, 'type'))
     86     self.assertEquals((training.__name__, 'GradientDescentOptimizer'),
     87                       anno.getanno(call_node, 'type_fqn'))
     88 
     89   def test_class_members_of_detected_constructor(self):
     90 
     91     def test_fn():
     92       opt = training.GradientDescentOptimizer(0.1)
     93       opt.minimize(0)
     94 
     95     node = self._parse_and_analyze(test_fn, {'training': training})
     96     method_call = node.body[0].body[1].value.func
     97     self.assertEquals(training.GradientDescentOptimizer.minimize,
     98                       anno.getanno(method_call, 'live_val'))
     99 
    100   def test_class_members_in_with_stmt(self):
    101 
    102     def test_fn(x):
    103       with session.Session() as sess:
    104         sess.run(x)
    105 
    106     node = self._parse_and_analyze(test_fn, {'session': session})
    107     constructor_call = node.body[0].body[0].items[0].context_expr
    108     self.assertEquals(session.Session, anno.getanno(constructor_call, 'type'))
    109     self.assertEquals((session.__name__, 'Session'),
    110                       anno.getanno(constructor_call, 'type_fqn'))
    111 
    112     method_call = node.body[0].body[0].body[0].value.func
    113     self.assertEquals(session.Session.run, anno.getanno(method_call,
    114                                                         'live_val'))
    115 
    116   def test_constructor_data_dependent(self):
    117 
    118     def test_fn(x):
    119       if x > 0:
    120         opt = training.GradientDescentOptimizer(0.1)
    121       else:
    122         opt = training.GradientDescentOptimizer(0.01)
    123       opt.minimize(0)
    124 
    125     node = self._parse_and_analyze(test_fn, {'training': training})
    126     method_call = node.body[0].body[1].value.func
    127     self.assertFalse(anno.hasanno(method_call, 'live_val'))
    128 
    129   def test_parameter_class_members(self):
    130 
    131     def test_fn(opt):
    132       opt.minimize(0)
    133 
    134     node = self._parse_and_analyze(test_fn, {})
    135     method_call = node.body[0].body[0].value.func
    136     self.assertFalse(anno.hasanno(method_call, 'live_val'))
    137 
    138   def test_parameter_class_members_with_value_hints(self):
    139 
    140     def test_fn(opt):
    141       opt.minimize(0)
    142 
    143     node = self._parse_and_analyze(
    144         test_fn, {'training': training},
    145         arg_types={
    146             'opt': (training.GradientDescentOptimizer.__name__,
    147                     training.GradientDescentOptimizer)
    148         })
    149 
    150     method_call = node.body[0].body[0].value.func
    151     self.assertEquals(training.GradientDescentOptimizer.minimize,
    152                       anno.getanno(method_call, 'live_val'))
    153 
    154   def test_function_variables(self):
    155 
    156     def bar():
    157       pass
    158 
    159     def test_fn():
    160       foo = bar
    161       foo()
    162 
    163     node = self._parse_and_analyze(test_fn, {'bar': bar})
    164     method_call = node.body[0].body[1].value.func
    165     self.assertFalse(anno.hasanno(method_call, 'live_val'))
    166 
    167   def test_nested_members(self):
    168 
    169     def test_fn():
    170       foo = training.GradientDescentOptimizer(0.1)
    171       foo.bar.baz()
    172 
    173     node = self._parse_and_analyze(test_fn, {'training': training})
    174     method_call = node.body[0].body[1].value.func
    175     self.assertFalse(anno.hasanno(method_call, 'live_val'))
    176 
    177 
    178 if __name__ == '__main__':
    179   test.main()
    180