1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Live value resolution. 16 17 Live values are extracted from the known execution context. 18 19 Requires activity analysis annotations. 20 """ 21 22 from __future__ import absolute_import 23 from __future__ import division 24 from __future__ import print_function 25 26 import gast 27 28 from tensorflow.contrib.py2tf.pyct import anno 29 from tensorflow.contrib.py2tf.pyct import transformer 30 from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno 31 32 33 class LiveValueResolver(transformer.Base): 34 """Annotates nodes with live values.""" 35 36 def __init__(self, context, literals): 37 super(LiveValueResolver, self).__init__(context) 38 self.literals = literals 39 40 def visit_ClassDef(self, node): 41 self.generic_visit(node) 42 anno.setanno(node, 'live_val', self.context.namespace[node.name]) 43 return node 44 45 def visit_Name(self, node): 46 self.generic_visit(node) 47 if isinstance(node.ctx, gast.Load): 48 assert anno.hasanno(node, NodeAnno.IS_LOCAL), node 49 symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL) 50 assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node 51 symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY) 52 assert anno.hasanno(node, NodeAnno.IS_PARAM), node 53 symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM) 54 55 if not symbol_is_local and not symbol_is_param: 56 if node.id in self.literals: 57 anno.setanno(node, 'live_val', self.literals[node.id]) 58 # TODO(mdan): Could live values have FQNs? i.e. 'a'.join() 59 elif node.id in self.context.namespace: 60 obj = self.context.namespace[node.id] 61 anno.setanno(node, 'live_val', obj) 62 anno.setanno(node, 'fqn', (obj.__name__,)) 63 else: 64 pass 65 # TODO(mdan): Should we raise an error here? 66 # Can encounter this when: 67 # * a symbol truly lacks reference 68 # * a symbol is new, like the new name of a function we just renamed. 69 else: 70 pass 71 # TODO(mdan): Attempt to trace its value through the local chain. 72 # TODO(mdan): Use type annotations as fallback. 73 74 if not symbol_is_modified: 75 if node.id in self.context.arg_values: 76 obj = self.context.arg_values[node.id] 77 anno.setanno(node, 'live_val', obj) 78 anno.setanno(node, 'fqn', (obj.__class__.__name__,)) 79 return node 80 81 def visit_Attribute(self, node): 82 self.generic_visit(node) 83 if anno.hasanno(node.value, 'live_val'): 84 assert anno.hasanno(node.value, 'fqn') 85 parent_object = anno.getanno(node.value, 'live_val') 86 if not hasattr(parent_object, node.attr): 87 raise AttributeError('%s has no attribute %s' % (parent_object, 88 node.attr)) 89 anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) 90 anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) 91 # TODO(mdan): Investigate the role built-in annotations can play here. 92 elif anno.hasanno(node.value, 'type'): 93 parent_type = anno.getanno(node.value, 'type') 94 if hasattr(parent_type, node.attr): 95 # This should hold for static members like methods. 96 # This would not hold for dynamic members like function attributes. 97 # For the dynamic case, we simply leave the node without an annotation, 98 # and let downstream consumers figure out what to do. 99 anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) 100 anno.setanno(node, 'fqn', 101 anno.getanno(node.value, 'type_fqn') + (node.attr,)) 102 elif isinstance(node.value, gast.Name): 103 stem_name = node.value 104 # All nonlocal symbols should be fully resolved. 105 assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name 106 # TODO(mdan): Figure out what to do when calling attribute on local object 107 # Maybe just leave as-is? 108 return node 109 110 111 def resolve(node, context, literals): 112 return LiveValueResolver(context, literals).visit(node) 113