1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Adds guards against function calls with side effects. 16 17 Only standalone calls are guarded. 18 19 WARNING: This mechanism is incomplete. Particularly, it only guards the 20 arguments passed to functions, and does not account for indirectly modified 21 state. 22 23 Example: 24 y = tf.layers.dense(x) # Creates TF variable 'foo' 25 loss = loss(y) 26 opt.minimize(loss) # indirectly affects 'foo' 27 z = tf.get_variable('foo') # Indirectly affects `loss` and 'foo' 28 # Here, `loss` can be guarded. But `z` cannot. 29 30 # TODO(mdan): We should probably define a safe mode where we guard everything. 31 """ 32 33 from __future__ import absolute_import 34 from __future__ import division 35 from __future__ import print_function 36 37 import gast 38 39 from tensorflow.contrib.py2tf.pyct import anno 40 from tensorflow.contrib.py2tf.pyct import ast_util 41 from tensorflow.contrib.py2tf.pyct import qual_names 42 from tensorflow.contrib.py2tf.pyct import templates 43 from tensorflow.contrib.py2tf.pyct import transformer 44 from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno 45 46 47 class SymbolNamer(object): 48 """Describes the interface for SideEffectGuardTransformer's namer.""" 49 50 def new_symbol(self, name_root, reserved_locals): 51 """Generate a new unique function_name. 52 53 Args: 54 name_root: String, used as stem in the new name. 55 reserved_locals: Set(string), additional local symbols that are reserved. 56 Returns: 57 String. 58 """ 59 raise NotImplementedError() 60 61 62 class SideEffectGuardTransformer(transformer.Base): 63 """Adds control dependencies to functions with side effects.""" 64 65 def __init__(self, context): 66 super(SideEffectGuardTransformer, self).__init__(context) 67 68 # pylint:disable=invalid-name 69 70 def _visit_and_reindent(self, nodes): 71 new_nodes = [] 72 current_dest = new_nodes 73 alias_map = {} 74 reindent_requested = False 75 for n in nodes: 76 n = self.visit(n) 77 # NOTE: the order in which these statements execute is important; in 78 # particular, watch out for ending up with cycles in the AST. 79 if alias_map: 80 n = ast_util.rename_symbols(n, alias_map) 81 if isinstance(n, (list, tuple)): 82 current_dest.extend(n) 83 else: 84 current_dest.append(n) 85 if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER): 86 reindent_requested = True 87 new_dest, new_alias_map = anno.getanno( 88 n, anno.Basic.INDENT_BLOCK_REMAINDER) 89 anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER) 90 new_alias_map.update(alias_map) 91 alias_map = new_alias_map 92 current_dest = new_dest 93 if reindent_requested and not current_dest: 94 # TODO(mdan): There may still be something that could be done. 95 raise ValueError('Unable to insert statement into the computation flow: ' 96 'it is not followed by any computation which ' 97 'the statement could gate.') 98 return new_nodes 99 100 def visit_FunctionDef(self, node): 101 node.body = self._visit_and_reindent(node.body) 102 return node 103 104 def visit_With(self, node): 105 node.body = self._visit_and_reindent(node.body) 106 return node 107 108 def visit_If(self, node): 109 node.body = self._visit_and_reindent(node.body) 110 node.orelse = self._visit_and_reindent(node.orelse) 111 return node 112 113 def visit_While(self, node): 114 node.body = self._visit_and_reindent(node.body) 115 node.orelse = self._visit_and_reindent(node.orelse) 116 return node 117 118 def visit_Expr(self, node): 119 self.generic_visit(node) 120 if isinstance(node.value, gast.Call): 121 # Patterns of single function calls, like: 122 # opt.minimize(loss) 123 # or: 124 # tf.py_func(...) 125 126 # First, attempt to gate future evaluation of args. If that's not 127 # possible, gate all remaining statements (and that may fail too, see 128 # _visit_and_reindent. 129 args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE) 130 # NOTE: We can't guard object attributes because they may not be writable. 131 # In addition, avoid renaming well-known names. 132 # TODO(mdan): Move these names into config. 133 unguarded_names = (qual_names.QN('self'), qual_names.QN('tf')) 134 guarded_args = tuple(s for s in args_scope.used 135 if not s.is_composite() and s not in unguarded_names) 136 137 # TODO(mdan): Include all arguments which depended on guarded_args too. 138 # For example, the following will still cause a race: 139 # tf.assign(a, a + 1) 140 # b = a + 1 141 # tf.assign(a, a + 1) # Control deps here should include `b` 142 # c = b + 1 143 # Or maybe we should just raise an "unsafe assign" error? 144 145 if guarded_args: 146 # The aliases may need new names to avoid incorrectly making them local. 147 # TODO(mdan): This is brutal. It will even rename modules - any fix? 148 need_alias = tuple( 149 s for s in guarded_args if s not in args_scope.parent.modified) 150 aliased_new_names = tuple( 151 qual_names.QN( 152 self.context.namer.new_symbol( 153 s.ssf(), args_scope.parent.referenced)) for s in need_alias) 154 alias_map = dict(zip(need_alias, aliased_new_names)) 155 if len(guarded_args) == 1: 156 s, = guarded_args 157 aliased_guarded_args = alias_map.get(s, s) 158 else: 159 aliased_guarded_args = gast.Tuple( 160 [alias_map.get(s, s).ast() for s in guarded_args], None) 161 162 template = """ 163 with py2tf_utils.control_dependency_on_returns(call): 164 aliased_guarded_args = py2tf_utils.alias_tensors(guarded_args) 165 """ 166 control_deps_guard = templates.replace( 167 template, 168 call=node.value, 169 aliased_guarded_args=aliased_guarded_args, 170 guarded_args=guarded_args)[-1] 171 else: 172 alias_map = {} 173 174 template = """ 175 with py2tf_utils.control_dependency_on_returns(call): 176 pass 177 """ 178 control_deps_guard = templates.replace(template, call=node.value)[-1] 179 control_deps_guard.body = [] 180 181 node = control_deps_guard 182 anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER, 183 (node.body, alias_map)) 184 return node 185 186 # pylint:enable=invalid-name 187 188 189 def transform(node, context): 190 return SideEffectGuardTransformer(context).visit(node) 191