1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Handles control flow statements: while, for, if.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import gast 22 23 from tensorflow.python.autograph.core import converter 24 from tensorflow.python.autograph.pyct import anno 25 from tensorflow.python.autograph.pyct import ast_util 26 from tensorflow.python.autograph.pyct import templates 27 from tensorflow.python.autograph.pyct.static_analysis import annos 28 29 30 class SymbolNamer(object): 31 """Describes the interface for ControlFlowTransformer's namer.""" 32 33 def new_symbol(self, name_root, reserved_locals): 34 """Generate a new unique symbol. 35 36 Args: 37 name_root: String, used as stem in the new name. 38 reserved_locals: Set(string), additional local symbols that are reserved 39 and which should not be used. 40 Returns: 41 String. 42 """ 43 raise NotImplementedError() 44 45 46 class ControlFlowTransformer(converter.Base): 47 """Transforms control flow structures like loops an conditionals.""" 48 49 def _create_cond_branch(self, body_name, aliased_orig_names, 50 aliased_new_names, body, returns): 51 if not returns: 52 # TODO(b/110167197): Replace with a plain return. 53 template = """ 54 return 1 55 """ 56 return_stmt = templates.replace(template) 57 elif len(returns) == 1: 58 template = """ 59 return retval 60 """ 61 return_stmt = templates.replace(template, retval=returns[0]) 62 else: 63 template = """ 64 return (retvals,) 65 """ 66 return_stmt = templates.replace(template, retvals=returns) 67 68 if aliased_orig_names: 69 template = """ 70 def body_name(): 71 aliased_new_names, = aliased_orig_names, 72 body 73 return_stmt 74 """ 75 return templates.replace( 76 template, 77 body_name=body_name, 78 body=body, 79 aliased_orig_names=aliased_orig_names, 80 aliased_new_names=aliased_new_names, 81 return_stmt=return_stmt) 82 else: 83 template = """ 84 def body_name(): 85 body 86 return_stmt 87 """ 88 return templates.replace( 89 template, body_name=body_name, body=body, return_stmt=return_stmt) 90 91 def _create_cond_expr(self, results, test, body_name, orelse_name, 92 state_getter_name, 93 state_setter_name): 94 if results is not None: 95 template = """ 96 results = ag__.if_stmt(test, body_name, orelse_name, 97 state_getter_name, state_setter_name) 98 """ 99 return templates.replace( 100 template, 101 test=test, 102 results=results, 103 body_name=body_name, 104 orelse_name=orelse_name, 105 state_getter_name=state_getter_name, 106 state_setter_name=state_setter_name) 107 else: 108 template = """ 109 ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name) 110 """ 111 return templates.replace( 112 template, 113 test=test, 114 body_name=body_name, 115 orelse_name=orelse_name, 116 getter_name=state_getter_name, 117 setter_name=state_setter_name) 118 119 def _fmt_symbols(self, symbol_set): 120 if not symbol_set: 121 return 'no variables' 122 return ', '.join(map(str, symbol_set)) 123 124 def _determine_aliased_symbols(self, scope, node_defined_in, block): 125 if block: 126 block_live_in = set(anno.getanno(block[0], anno.Static.LIVE_VARS_IN)) 127 else: 128 block_live_in = set() 129 130 # For the purpose of aliasing, composite symbols with live owners are live 131 # as well. Otherwise this would leak tensors from the conditional's body. 132 # 133 # For example: 134 # 135 # obj = some_obj 136 # if cond: 137 # obj.a = val 138 # 139 # Thanslating to the code below would be incorrect: 140 # 141 # def true_fn(): 142 # obj.a = val() # Wrong! leaks ops owned by true_fn 143 # return obj.a 144 for s in scope.modified: 145 if s.is_composite(): 146 live_parents = block_live_in & s.owner_set 147 if live_parents: 148 block_live_in.add(s) 149 return scope.modified & node_defined_in & block_live_in 150 151 def _create_state_functions(self, composites, 152 state_getter_name, state_setter_name): 153 if composites: 154 composite_tuple = tuple(composites) 155 template = """ 156 def state_getter_name(): 157 return composite_tuple, 158 def state_setter_name(vals): 159 composite_tuple, = vals 160 """ 161 node = templates.replace( 162 template, 163 state_getter_name=state_getter_name, 164 state_setter_name=state_setter_name, 165 composite_tuple=composite_tuple) 166 else: 167 template = """ 168 def state_getter_name(): 169 return () 170 def state_setter_name(_): 171 pass 172 """ 173 node = templates.replace( 174 template, 175 state_getter_name=state_getter_name, 176 state_setter_name=state_setter_name) 177 178 return node 179 180 def _create_undefined_assigns(self, undefined_symbols): 181 assignments = [] 182 for s in undefined_symbols: 183 template = ''' 184 var = ag__.Undefined(symbol_name) 185 ''' 186 assignments += templates.replace( 187 template, 188 var=s, 189 symbol_name=gast.Str(s.ssf())) 190 return assignments 191 192 def visit_If(self, node): 193 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 194 orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE) 195 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 196 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 197 198 # Note: this information needs to be extracted before the body conversion 199 # that happens in the call to generic_visit below, because the conversion 200 # generates nodes that lack static analysis annotations. 201 need_alias_in_body = self._determine_aliased_symbols( 202 body_scope, defined_in, node.body) 203 need_alias_in_orelse = self._determine_aliased_symbols( 204 orelse_scope, defined_in, node.orelse) 205 206 node = self.generic_visit(node) 207 208 modified_in_cond = body_scope.modified | orelse_scope.modified 209 returned_from_cond = set() 210 composites = set() 211 for s in modified_in_cond: 212 if s in live_out: 213 returned_from_cond.add(s) 214 if s.is_composite(): 215 # Special treatment for compound objects, always return them. 216 # This allows special handling within the if_stmt itself. 217 # For example, in TensorFlow we need to restore the state of composite 218 # symbols to ensure that only effects from the executed branch are seen. 219 returned_from_cond.add(s) 220 composites.add(s) 221 222 created_in_body = body_scope.modified & returned_from_cond - defined_in 223 created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in 224 225 basic_created_in_body = tuple( 226 s for s in created_in_body if not s.is_composite()) 227 basic_created_in_orelse = tuple( 228 s for s in created_in_orelse if not s.is_composite()) 229 230 # These variables are defined only in a single branch. This is fine in 231 # Python so we pass them through. Another backend, e.g. Tensorflow, may need 232 # to handle these cases specially or throw an Error. 233 possibly_undefined = (set(basic_created_in_body) ^ 234 set(basic_created_in_orelse)) 235 236 # Alias the closure variables inside the conditional functions, to allow 237 # the functions access to the respective variables. 238 # We will alias variables independently for body and orelse scope, 239 # because different branches might write different variables. 240 aliased_body_orig_names = tuple(need_alias_in_body) 241 aliased_orelse_orig_names = tuple(need_alias_in_orelse) 242 aliased_body_new_names = tuple( 243 self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced) 244 for s in aliased_body_orig_names) 245 aliased_orelse_new_names = tuple( 246 self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced) 247 for s in aliased_orelse_orig_names) 248 249 alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names)) 250 alias_orelse_map = dict( 251 zip(aliased_orelse_orig_names, aliased_orelse_new_names)) 252 253 node_body = ast_util.rename_symbols(node.body, alias_body_map) 254 node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map) 255 256 cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced) 257 body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced) 258 orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced) 259 all_referenced = body_scope.referenced | orelse_scope.referenced 260 state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced) 261 state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced) 262 263 returned_from_cond = tuple(returned_from_cond) 264 if returned_from_cond: 265 if len(returned_from_cond) == 1: 266 cond_results = returned_from_cond[0] 267 else: 268 cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None) 269 270 returned_from_body = tuple( 271 alias_body_map[s] if s in need_alias_in_body else s 272 for s in returned_from_cond) 273 returned_from_orelse = tuple( 274 alias_orelse_map[s] if s in need_alias_in_orelse else s 275 for s in returned_from_cond) 276 277 else: 278 # When the cond would return no value, we leave the cond called without 279 # results. That in turn should trigger the side effect guards. The 280 # branch functions will return a dummy value that ensures cond 281 # actually has some return value as well. 282 cond_results = None 283 # TODO(mdan): Replace with None once side_effect_guards is retired. 284 returned_from_body = (templates.replace_as_expression( 285 'ag__.match_staging_level(1, cond_var_name)', 286 cond_var_name=cond_var_name),) 287 returned_from_orelse = (templates.replace_as_expression( 288 'ag__.match_staging_level(1, cond_var_name)', 289 cond_var_name=cond_var_name),) 290 291 cond_assign = self.create_assignment(cond_var_name, node.test) 292 body_def = self._create_cond_branch( 293 body_name, 294 aliased_orig_names=aliased_body_orig_names, 295 aliased_new_names=aliased_body_new_names, 296 body=node_body, 297 returns=returned_from_body) 298 orelse_def = self._create_cond_branch( 299 orelse_name, 300 aliased_orig_names=aliased_orelse_orig_names, 301 aliased_new_names=aliased_orelse_new_names, 302 body=node_orelse, 303 returns=returned_from_orelse) 304 undefined_assigns = self._create_undefined_assigns(possibly_undefined) 305 composite_defs = self._create_state_functions( 306 composites, state_getter_name, state_setter_name) 307 308 cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name, 309 orelse_name, state_getter_name, 310 state_setter_name) 311 312 return (undefined_assigns + cond_assign + composite_defs + body_def + 313 orelse_def + cond_expr) 314 315 def _get_loop_state(self, node): 316 body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE) 317 defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN) 318 live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN) 319 live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT) 320 reserved_symbols = body_scope.referenced 321 322 loop_state = [] 323 for s in body_scope.modified: 324 325 # Variables not live into or out of the loop are considered local to the 326 # loop. 327 if s not in live_in and s not in live_out: 328 continue 329 330 # Mutations made to objects created inside the loop will appear as writes 331 # to composite symbols. Because these mutations appear as modifications 332 # made to composite symbols, we check whether the composite's parent is 333 # actually live into the loop. 334 # Example: 335 # while cond: 336 # x = Foo() 337 # x.foo = 2 * x.foo # x.foo is live into the loop, but x is not. 338 if s.is_composite() and not all(p in live_in for p in s.support_set): 339 continue 340 341 loop_state.append(s) 342 loop_state = frozenset(loop_state) 343 344 # Variable that are used or defined inside the loop, but not defined 345 # before entering the loop 346 undefined_lives = loop_state - defined_in 347 348 # Only simple variables must be defined. The composite ones will be 349 # implicitly checked at runtime. 350 possibly_undefs = {v for v in undefined_lives if v.is_simple()} 351 352 return loop_state, reserved_symbols, possibly_undefs 353 354 def _state_constructs(self, loop_state, reserved_symbols): 355 loop_state = tuple(loop_state) 356 state_ssf = [ 357 self.ctx.namer.new_symbol(s.ssf(), reserved_symbols) for s in loop_state 358 ] 359 ssf_map = { 360 name: ssf 361 for name, ssf in zip(loop_state, state_ssf) 362 if str(name) != ssf 363 } 364 365 state_ast_tuple = gast.Tuple([n.ast() for n in loop_state], None) 366 367 if len(loop_state) == 1: 368 loop_state = loop_state[0] 369 state_ssf = state_ssf[0] 370 371 return loop_state, state_ssf, state_ast_tuple, ssf_map 372 373 def visit_While(self, node): 374 self.generic_visit(node) 375 376 loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node) 377 378 # Note: one might expect we can dispatch based on the loop condition. 379 # But because that is dependent on the state, it cannot be evaluated ahead 380 # of time - doing that would risk duplicating any effects the condition has. 381 # Furthermore, we cannot evaluate slices and attributes, because they might 382 # trigger __getitem__ or __getattribute__. 383 # 384 # A case where this fails includes ops with side effects on a stateful 385 # resource captured in an object: 386 # 387 # while self.v.read() > 0: 388 # self.v.assign(1) 389 # 390 # TODO(mdan): Handle the case above. 391 cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE) 392 cond_closure = set() 393 for s in cond_scope.read: 394 cond_closure |= s.support_set 395 396 loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( 397 loop_state, reserved_symbols) 398 node_body = ast_util.rename_symbols(node.body, ssf_map) 399 test = ast_util.rename_symbols(node.test, ssf_map) 400 401 if loop_state: 402 template = """ 403 def test_name(state_ssf): 404 return test 405 def body_name(state_ssf): 406 body 407 return state_ssf, 408 state_ast_tuple = ag__.while_stmt( 409 test_name, body_name, (state,), (extra_deps,)) 410 """ 411 node = templates.replace( 412 template, 413 state=loop_state, 414 state_ssf=state_ssf, 415 state_ast_tuple=state_ast_tuple, 416 test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), 417 test=test, 418 body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), 419 body=node_body, 420 extra_deps=tuple(s.ast() for s in cond_closure), 421 ) 422 else: 423 template = """ 424 def test_name(): 425 return test 426 def body_name(): 427 body 428 return () 429 ag__.while_stmt(test_name, body_name, (), (extra_deps,)) 430 """ 431 node = templates.replace( 432 template, 433 test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols), 434 test=test, 435 body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols), 436 body=node_body, 437 extra_deps=tuple(s.ast() for s in cond_closure), 438 ) 439 440 undefined_assigns = self._create_undefined_assigns(possibly_undefs) 441 return undefined_assigns + node 442 443 def _create_for_loop_early_stopping(self, loop_state, state_ssf, 444 state_ast_tuple, original_node, 445 extra_test_name, extra_test, 446 body_name, loop_body): 447 """Create node for for-loop with early stopping (e.g. break or return).""" 448 template = """ 449 def extra_test_name(state_ssf): 450 return extra_test_expr 451 def body_name(loop_vars, state_ssf): 452 # Workaround for PEP-3113 453 iterate = loop_vars 454 body 455 return state_ssf, 456 state_ast_tuple = ag__.for_stmt( 457 iter_, extra_test_name, body_name, (state,)) 458 """ 459 return templates.replace( 460 template, 461 state=loop_state, 462 state_ssf=state_ssf, 463 state_ast_tuple=state_ast_tuple, 464 iter_=original_node.iter, 465 iterate=original_node.target, 466 extra_test_name=extra_test_name, 467 extra_test_expr=extra_test, 468 body_name=body_name, 469 body=loop_body) 470 471 def _create_for_loop_with_state(self, loop_state, state_ssf, state_ast_tuple, 472 original_node, body_name, loop_body): 473 """Create node for for-loop with loop-carried state, no early stopping.""" 474 template = """ 475 def body_name(loop_vars, state_ssf): 476 # Workaround for PEP-3113 477 iterate = loop_vars 478 body 479 return state_ssf, 480 state_ast_tuple = ag__.for_stmt( 481 iter_, None, body_name, (state,)) 482 """ 483 return templates.replace( 484 template, 485 state=loop_state, 486 state_ssf=state_ssf, 487 state_ast_tuple=state_ast_tuple, 488 iter_=original_node.iter, 489 iterate=original_node.target, 490 body_name=body_name, 491 body=loop_body) 492 493 def _create_for_loop_without_state(self, original_node, body_name, loop_body): 494 """Create node for for-loop with loop-carried state, no early stopping.""" 495 template = """ 496 def body_name(loop_vars): 497 # Workaround for PEP-3113 498 iterate = loop_vars 499 body 500 return () 501 ag__.for_stmt(iter_, None, body_name, ()) 502 """ 503 return templates.replace( 504 template, 505 iter_=original_node.iter, 506 iterate=original_node.target, 507 body_name=body_name, 508 body=loop_body) 509 510 def visit_For(self, node): 511 self.generic_visit(node) 512 513 loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node) 514 loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs( 515 loop_state, reserved_symbols) 516 node_body = ast_util.rename_symbols(node.body, ssf_map) 517 body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols) 518 519 has_extra_test = anno.hasanno(node, 'extra_test') 520 if loop_state: 521 if has_extra_test: 522 # Loop with early stopping (e.g. break or return) 523 extra_test = anno.getanno(node, 'extra_test') 524 extra_test = ast_util.rename_symbols(extra_test, ssf_map) 525 extra_test_name = self.ctx.namer.new_symbol('extra_test', 526 reserved_symbols) 527 node = self._create_for_loop_early_stopping( 528 loop_state, state_ssf, state_ast_tuple, node, extra_test_name, 529 extra_test, body_name, node_body) 530 else: 531 # Loop with loop-carried state and no early stopping 532 node = self._create_for_loop_with_state( 533 loop_state, state_ssf, state_ast_tuple, node, body_name, node_body) 534 else: 535 # Loop with no loop-carried state and no early stopping 536 assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) ' 537 'should create state variables.') 538 node = self._create_for_loop_without_state(node, body_name, node_body) 539 540 undefined_assigns = self._create_undefined_assigns(possibly_undefs) 541 return undefined_assigns + node 542 543 544 def transform(node, ctx): 545 node = ControlFlowTransformer(ctx).visit(node) 546 return node 547