Home | History | Annotate | Download | only in static_analysis
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Live value resolution.
     16 
     17 Live values are extracted from the known execution context.
     18 
     19 Requires activity analysis annotations.
     20 """
     21 
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import gast
     27 
     28 from tensorflow.contrib.py2tf.pyct import anno
     29 from tensorflow.contrib.py2tf.pyct import transformer
     30 from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
     31 
     32 
     33 class LiveValueResolver(transformer.Base):
     34   """Annotates nodes with live values."""
     35 
     36   def __init__(self, context, literals):
     37     super(LiveValueResolver, self).__init__(context)
     38     self.literals = literals
     39 
     40   def visit_ClassDef(self, node):
     41     self.generic_visit(node)
     42     anno.setanno(node, 'live_val', self.context.namespace[node.name])
     43     return node
     44 
     45   def visit_Name(self, node):
     46     self.generic_visit(node)
     47     if isinstance(node.ctx, gast.Load):
     48       assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
     49       symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
     50       assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
     51       symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY)
     52       assert anno.hasanno(node, NodeAnno.IS_PARAM), node
     53       symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)
     54 
     55       if not symbol_is_local and not symbol_is_param:
     56         if node.id in self.literals:
     57           anno.setanno(node, 'live_val', self.literals[node.id])
     58           # TODO(mdan): Could live values have FQNs? i.e. 'a'.join()
     59         elif node.id in self.context.namespace:
     60           obj = self.context.namespace[node.id]
     61           anno.setanno(node, 'live_val', obj)
     62           anno.setanno(node, 'fqn', (obj.__name__,))
     63         else:
     64           pass
     65           # TODO(mdan): Should we raise an error here?
     66           # Can encounter this when:
     67           #  * a symbol truly lacks reference
     68           #  * a symbol is new, like the new name of a function we just renamed.
     69       else:
     70         pass
     71         # TODO(mdan): Attempt to trace its value through the local chain.
     72         # TODO(mdan): Use type annotations as fallback.
     73 
     74       if not symbol_is_modified:
     75         if node.id in self.context.arg_values:
     76           obj = self.context.arg_values[node.id]
     77           anno.setanno(node, 'live_val', obj)
     78           anno.setanno(node, 'fqn', (obj.__class__.__name__,))
     79     return node
     80 
     81   def visit_Attribute(self, node):
     82     self.generic_visit(node)
     83     if anno.hasanno(node.value, 'live_val'):
     84       assert anno.hasanno(node.value, 'fqn')
     85       parent_object = anno.getanno(node.value, 'live_val')
     86       if not hasattr(parent_object, node.attr):
     87         raise AttributeError('%s has no attribute %s' % (parent_object,
     88                                                          node.attr))
     89       anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
     90       anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
     91     # TODO(mdan): Investigate the role built-in annotations can play here.
     92     elif anno.hasanno(node.value, 'type'):
     93       parent_type = anno.getanno(node.value, 'type')
     94       if hasattr(parent_type, node.attr):
     95         # This should hold for static members like methods.
     96         # This would not hold for dynamic members like function attributes.
     97         # For the dynamic case, we simply leave the node without an annotation,
     98         # and let downstream consumers figure out what to do.
     99         anno.setanno(node, 'live_val', getattr(parent_type, node.attr))
    100         anno.setanno(node, 'fqn',
    101                      anno.getanno(node.value, 'type_fqn') + (node.attr,))
    102     elif isinstance(node.value, gast.Name):
    103       stem_name = node.value
    104       # All nonlocal symbols should be fully resolved.
    105       assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
    106       # TODO(mdan): Figure out what to do when calling attribute on local object
    107       # Maybe just leave as-is?
    108     return node
    109 
    110 
    111 def resolve(node, context, literals):
    112   return LiveValueResolver(context, literals).visit(node)
    113