Home | History | Annotate | Download | only in static_analysis
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for live_values module."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.py2tf.pyct import anno
     22 from tensorflow.contrib.py2tf.pyct import context
     23 from tensorflow.contrib.py2tf.pyct import parser
     24 from tensorflow.contrib.py2tf.pyct import qual_names
     25 from tensorflow.contrib.py2tf.pyct.static_analysis import activity
     26 from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
     27 from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.platform import test
     30 
     31 
     32 class LiveValuesResolverTest(test.TestCase):
     33 
     34   def _parse_and_analyze(self,
     35                          test_fn,
     36                          namespace,
     37                          literals=None,
     38                          arg_types=None):
     39     literals = literals or {}
     40     arg_types = arg_types or {}
     41     node, source = parser.parse_entity(test_fn)
     42     ctx = context.EntityContext(
     43         namer=None,
     44         source_code=source,
     45         source_file=None,
     46         namespace=namespace,
     47         arg_values=None,
     48         arg_types=arg_types,
     49         recursive=True)
     50     node = qual_names.resolve(node)
     51     node = activity.resolve(node, ctx)
     52     node = live_values.resolve(node, ctx, literals)
     53     node = type_info.resolve(node, ctx)
     54     node = live_values.resolve(node, ctx, literals)
     55     return node
     56 
     57   def test_literals(self):
     58 
     59     def test_fn():
     60       return Foo  # pylint: disable=undefined-variable
     61 
     62     node = self._parse_and_analyze(test_fn, {}, {'Foo': 'bar'})
     63     retval_node = node.body[0].body[0].value
     64     self.assertEquals('bar', anno.getanno(retval_node, 'live_val'))
     65 
     66   def test_namespace(self):
     67 
     68     def foo():
     69       return 'bar'
     70 
     71     def test_fn():
     72       return foo()
     73 
     74     node = self._parse_and_analyze(test_fn, {'foo': foo})
     75     func_node = node.body[0].body[0].value.func
     76     self.assertEquals(foo, anno.getanno(func_node, 'live_val'))
     77     self.assertEquals(('foo',), anno.getanno(func_node, 'fqn'))
     78 
     79   def test_attribute_names(self):
     80 
     81     def test_fn():
     82       return constant_op.constant(0)
     83 
     84     node = self._parse_and_analyze(test_fn, {'constant_op': constant_op})
     85     func_node = node.body[0].body[0].value.func
     86     self.assertEquals(constant_op.constant, anno.getanno(func_node, 'live_val'))
     87     self.assertEquals((constant_op.__name__, 'constant'),
     88                       anno.getanno(func_node, 'fqn'))
     89 
     90   def test_attributes_with_type_hints(self):
     91 
     92     class TestClass(object):
     93 
     94       def member(self):
     95         pass
     96 
     97       def test_fn(self):
     98         return self.member()
     99 
    100     node = self._parse_and_analyze(
    101         TestClass.test_fn, {'constant_op': constant_op},
    102         arg_types={'self': (TestClass.__name__, TestClass)})
    103     func_node = node.body[0].body[0].value.func
    104     self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val'))
    105     self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn'))
    106 
    107 
    108 if __name__ == '__main__':
    109   test.main()
    110