Home | History | Annotate | Download | only in converters
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Handles control flow statements: while, for, if."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import gast
     22 
     23 from tensorflow.python.autograph.core import converter
     24 from tensorflow.python.autograph.pyct import anno
     25 from tensorflow.python.autograph.pyct import ast_util
     26 from tensorflow.python.autograph.pyct import templates
     27 from tensorflow.python.autograph.pyct.static_analysis import annos
     28 
     29 
     30 class SymbolNamer(object):
     31   """Describes the interface for ControlFlowTransformer's namer."""
     32 
     33   def new_symbol(self, name_root, reserved_locals):
     34     """Generate a new unique symbol.
     35 
     36     Args:
     37       name_root: String, used as stem in the new name.
     38       reserved_locals: Set(string), additional local symbols that are reserved
     39           and which should not be used.
     40     Returns:
     41       String.
     42     """
     43     raise NotImplementedError()
     44 
     45 
     46 class ControlFlowTransformer(converter.Base):
     47   """Transforms control flow structures like loops an conditionals."""
     48 
     49   def _create_cond_branch(self, body_name, aliased_orig_names,
     50                           aliased_new_names, body, returns):
     51     if not returns:
     52       # TODO(b/110167197): Replace with a plain return.
     53       template = """
     54         return 1
     55       """
     56       return_stmt = templates.replace(template)
     57     elif len(returns) == 1:
     58       template = """
     59         return retval
     60       """
     61       return_stmt = templates.replace(template, retval=returns[0])
     62     else:
     63       template = """
     64         return (retvals,)
     65       """
     66       return_stmt = templates.replace(template, retvals=returns)
     67 
     68     if aliased_orig_names:
     69       template = """
     70         def body_name():
     71           aliased_new_names, = aliased_orig_names,
     72           body
     73           return_stmt
     74       """
     75       return templates.replace(
     76           template,
     77           body_name=body_name,
     78           body=body,
     79           aliased_orig_names=aliased_orig_names,
     80           aliased_new_names=aliased_new_names,
     81           return_stmt=return_stmt)
     82     else:
     83       template = """
     84         def body_name():
     85           body
     86           return_stmt
     87       """
     88       return templates.replace(
     89           template, body_name=body_name, body=body, return_stmt=return_stmt)
     90 
     91   def _create_cond_expr(self, results, test, body_name, orelse_name,
     92                         state_getter_name,
     93                         state_setter_name):
     94     if results is not None:
     95       template = """
     96         results = ag__.if_stmt(test, body_name, orelse_name,
     97                                state_getter_name, state_setter_name)
     98       """
     99       return templates.replace(
    100           template,
    101           test=test,
    102           results=results,
    103           body_name=body_name,
    104           orelse_name=orelse_name,
    105           state_getter_name=state_getter_name,
    106           state_setter_name=state_setter_name)
    107     else:
    108       template = """
    109         ag__.if_stmt(test, body_name, orelse_name, getter_name, setter_name)
    110       """
    111       return templates.replace(
    112           template,
    113           test=test,
    114           body_name=body_name,
    115           orelse_name=orelse_name,
    116           getter_name=state_getter_name,
    117           setter_name=state_setter_name)
    118 
    119   def _fmt_symbols(self, symbol_set):
    120     if not symbol_set:
    121       return 'no variables'
    122     return ', '.join(map(str, symbol_set))
    123 
    124   def _determine_aliased_symbols(self, scope, node_defined_in, block):
    125     if block:
    126       block_live_in = set(anno.getanno(block[0], anno.Static.LIVE_VARS_IN))
    127     else:
    128       block_live_in = set()
    129 
    130     # For the purpose of aliasing, composite symbols with live owners are live
    131     # as well. Otherwise this would leak tensors from the conditional's body.
    132     #
    133     # For example:
    134     #
    135     #   obj = some_obj
    136     #   if cond:
    137     #     obj.a = val
    138     #
    139     # Thanslating to the code below would be incorrect:
    140     #
    141     #   def true_fn():
    142     #     obj.a = val()  # Wrong! leaks ops owned by true_fn
    143     #     return obj.a
    144     for s in scope.modified:
    145       if s.is_composite():
    146         live_parents = block_live_in & s.owner_set
    147         if live_parents:
    148           block_live_in.add(s)
    149     return scope.modified & node_defined_in & block_live_in
    150 
    151   def _create_state_functions(self, composites,
    152                               state_getter_name, state_setter_name):
    153     if composites:
    154       composite_tuple = tuple(composites)
    155       template = """
    156         def state_getter_name():
    157           return composite_tuple,
    158         def state_setter_name(vals):
    159           composite_tuple, = vals
    160       """
    161       node = templates.replace(
    162           template,
    163           state_getter_name=state_getter_name,
    164           state_setter_name=state_setter_name,
    165           composite_tuple=composite_tuple)
    166     else:
    167       template = """
    168         def state_getter_name():
    169           return ()
    170         def state_setter_name(_):
    171           pass
    172         """
    173       node = templates.replace(
    174           template,
    175           state_getter_name=state_getter_name,
    176           state_setter_name=state_setter_name)
    177 
    178     return node
    179 
    180   def _create_undefined_assigns(self, undefined_symbols):
    181     assignments = []
    182     for s in undefined_symbols:
    183       template = '''
    184         var = ag__.Undefined(symbol_name)
    185       '''
    186       assignments += templates.replace(
    187           template,
    188           var=s,
    189           symbol_name=gast.Str(s.ssf()))
    190     return assignments
    191 
    192   def visit_If(self, node):
    193     body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    194     orelse_scope = anno.getanno(node, annos.NodeAnno.ORELSE_SCOPE)
    195     defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    196     live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
    197 
    198     # Note: this information needs to be extracted before the body conversion
    199     # that happens in the call to generic_visit below, because the conversion
    200     # generates nodes that lack static analysis annotations.
    201     need_alias_in_body = self._determine_aliased_symbols(
    202         body_scope, defined_in, node.body)
    203     need_alias_in_orelse = self._determine_aliased_symbols(
    204         orelse_scope, defined_in, node.orelse)
    205 
    206     node = self.generic_visit(node)
    207 
    208     modified_in_cond = body_scope.modified | orelse_scope.modified
    209     returned_from_cond = set()
    210     composites = set()
    211     for s in modified_in_cond:
    212       if s in live_out:
    213         returned_from_cond.add(s)
    214       if s.is_composite():
    215         # Special treatment for compound objects, always return them.
    216         # This allows special handling within the if_stmt itself.
    217         # For example, in TensorFlow we need to restore the state of composite
    218         # symbols to ensure that only effects from the executed branch are seen.
    219         returned_from_cond.add(s)
    220         composites.add(s)
    221 
    222     created_in_body = body_scope.modified & returned_from_cond - defined_in
    223     created_in_orelse = orelse_scope.modified & returned_from_cond - defined_in
    224 
    225     basic_created_in_body = tuple(
    226         s for s in created_in_body if not s.is_composite())
    227     basic_created_in_orelse = tuple(
    228         s for s in created_in_orelse if not s.is_composite())
    229 
    230     # These variables are defined only in a single branch. This is fine in
    231     # Python so we pass them through. Another backend, e.g. Tensorflow, may need
    232     # to handle these cases specially or throw an Error.
    233     possibly_undefined = (set(basic_created_in_body) ^
    234                           set(basic_created_in_orelse))
    235 
    236     # Alias the closure variables inside the conditional functions, to allow
    237     # the functions access to the respective variables.
    238     # We will alias variables independently for body and orelse scope,
    239     # because different branches might write different variables.
    240     aliased_body_orig_names = tuple(need_alias_in_body)
    241     aliased_orelse_orig_names = tuple(need_alias_in_orelse)
    242     aliased_body_new_names = tuple(
    243         self.ctx.namer.new_symbol(s.ssf(), body_scope.referenced)
    244         for s in aliased_body_orig_names)
    245     aliased_orelse_new_names = tuple(
    246         self.ctx.namer.new_symbol(s.ssf(), orelse_scope.referenced)
    247         for s in aliased_orelse_orig_names)
    248 
    249     alias_body_map = dict(zip(aliased_body_orig_names, aliased_body_new_names))
    250     alias_orelse_map = dict(
    251         zip(aliased_orelse_orig_names, aliased_orelse_new_names))
    252 
    253     node_body = ast_util.rename_symbols(node.body, alias_body_map)
    254     node_orelse = ast_util.rename_symbols(node.orelse, alias_orelse_map)
    255 
    256     cond_var_name = self.ctx.namer.new_symbol('cond', body_scope.referenced)
    257     body_name = self.ctx.namer.new_symbol('if_true', body_scope.referenced)
    258     orelse_name = self.ctx.namer.new_symbol('if_false', orelse_scope.referenced)
    259     all_referenced = body_scope.referenced | orelse_scope.referenced
    260     state_getter_name = self.ctx.namer.new_symbol('get_state', all_referenced)
    261     state_setter_name = self.ctx.namer.new_symbol('set_state', all_referenced)
    262 
    263     returned_from_cond = tuple(returned_from_cond)
    264     if returned_from_cond:
    265       if len(returned_from_cond) == 1:
    266         cond_results = returned_from_cond[0]
    267       else:
    268         cond_results = gast.Tuple([s.ast() for s in returned_from_cond], None)
    269 
    270       returned_from_body = tuple(
    271           alias_body_map[s] if s in need_alias_in_body else s
    272           for s in returned_from_cond)
    273       returned_from_orelse = tuple(
    274           alias_orelse_map[s] if s in need_alias_in_orelse else s
    275           for s in returned_from_cond)
    276 
    277     else:
    278       # When the cond would return no value, we leave the cond called without
    279       # results. That in turn should trigger the side effect guards. The
    280       # branch functions will return a dummy value that ensures cond
    281       # actually has some return value as well.
    282       cond_results = None
    283       # TODO(mdan): Replace with None once side_effect_guards is retired.
    284       returned_from_body = (templates.replace_as_expression(
    285           'ag__.match_staging_level(1, cond_var_name)',
    286           cond_var_name=cond_var_name),)
    287       returned_from_orelse = (templates.replace_as_expression(
    288           'ag__.match_staging_level(1, cond_var_name)',
    289           cond_var_name=cond_var_name),)
    290 
    291     cond_assign = self.create_assignment(cond_var_name, node.test)
    292     body_def = self._create_cond_branch(
    293         body_name,
    294         aliased_orig_names=aliased_body_orig_names,
    295         aliased_new_names=aliased_body_new_names,
    296         body=node_body,
    297         returns=returned_from_body)
    298     orelse_def = self._create_cond_branch(
    299         orelse_name,
    300         aliased_orig_names=aliased_orelse_orig_names,
    301         aliased_new_names=aliased_orelse_new_names,
    302         body=node_orelse,
    303         returns=returned_from_orelse)
    304     undefined_assigns = self._create_undefined_assigns(possibly_undefined)
    305     composite_defs = self._create_state_functions(
    306         composites, state_getter_name, state_setter_name)
    307 
    308     cond_expr = self._create_cond_expr(cond_results, cond_var_name, body_name,
    309                                        orelse_name, state_getter_name,
    310                                        state_setter_name)
    311 
    312     return (undefined_assigns + cond_assign + composite_defs + body_def +
    313             orelse_def + cond_expr)
    314 
    315   def _get_loop_state(self, node):
    316     body_scope = anno.getanno(node, annos.NodeAnno.BODY_SCOPE)
    317     defined_in = anno.getanno(node, anno.Static.DEFINED_VARS_IN)
    318     live_in = anno.getanno(node, anno.Static.LIVE_VARS_IN)
    319     live_out = anno.getanno(node, anno.Static.LIVE_VARS_OUT)
    320     reserved_symbols = body_scope.referenced
    321 
    322     loop_state = []
    323     for s in body_scope.modified:
    324 
    325       # Variables not live into or out of the loop are considered local to the
    326       # loop.
    327       if s not in live_in and s not in live_out:
    328         continue
    329 
    330       # Mutations made to objects created inside the loop will appear as writes
    331       # to composite symbols. Because these mutations appear as modifications
    332       # made to composite symbols, we check whether the composite's parent is
    333       # actually live into the loop.
    334       # Example:
    335       #   while cond:
    336       #     x = Foo()
    337       #     x.foo = 2 * x.foo  # x.foo is live into the loop, but x is not.
    338       if s.is_composite() and not all(p in live_in for p in s.support_set):
    339         continue
    340 
    341       loop_state.append(s)
    342     loop_state = frozenset(loop_state)
    343 
    344     # Variable that are used or defined inside the loop, but not defined
    345     # before entering the loop
    346     undefined_lives = loop_state - defined_in
    347 
    348     # Only simple variables must be defined. The composite ones will be
    349     # implicitly checked at runtime.
    350     possibly_undefs = {v for v in undefined_lives if v.is_simple()}
    351 
    352     return loop_state, reserved_symbols, possibly_undefs
    353 
    354   def _state_constructs(self, loop_state, reserved_symbols):
    355     loop_state = tuple(loop_state)
    356     state_ssf = [
    357         self.ctx.namer.new_symbol(s.ssf(), reserved_symbols) for s in loop_state
    358     ]
    359     ssf_map = {
    360         name: ssf
    361         for name, ssf in zip(loop_state, state_ssf)
    362         if str(name) != ssf
    363     }
    364 
    365     state_ast_tuple = gast.Tuple([n.ast() for n in loop_state], None)
    366 
    367     if len(loop_state) == 1:
    368       loop_state = loop_state[0]
    369       state_ssf = state_ssf[0]
    370 
    371     return loop_state, state_ssf, state_ast_tuple, ssf_map
    372 
    373   def visit_While(self, node):
    374     self.generic_visit(node)
    375 
    376     loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
    377 
    378     # Note: one might expect we can dispatch based on the loop condition.
    379     # But because that is dependent on the state, it cannot be evaluated ahead
    380     # of time - doing that would risk duplicating any effects the condition has.
    381     # Furthermore, we cannot evaluate slices and attributes, because they might
    382     # trigger __getitem__ or __getattribute__.
    383     #
    384     # A case where this fails includes ops with side effects on a stateful
    385     # resource captured in an object:
    386     #
    387     #   while self.v.read() > 0:
    388     #     self.v.assign(1)
    389     #
    390     # TODO(mdan): Handle the case above.
    391     cond_scope = anno.getanno(node, annos.NodeAnno.COND_SCOPE)
    392     cond_closure = set()
    393     for s in cond_scope.read:
    394       cond_closure |= s.support_set
    395 
    396     loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
    397         loop_state, reserved_symbols)
    398     node_body = ast_util.rename_symbols(node.body, ssf_map)
    399     test = ast_util.rename_symbols(node.test, ssf_map)
    400 
    401     if loop_state:
    402       template = """
    403         def test_name(state_ssf):
    404           return test
    405         def body_name(state_ssf):
    406           body
    407           return state_ssf,
    408         state_ast_tuple = ag__.while_stmt(
    409             test_name, body_name, (state,), (extra_deps,))
    410       """
    411       node = templates.replace(
    412           template,
    413           state=loop_state,
    414           state_ssf=state_ssf,
    415           state_ast_tuple=state_ast_tuple,
    416           test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
    417           test=test,
    418           body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
    419           body=node_body,
    420           extra_deps=tuple(s.ast() for s in cond_closure),
    421       )
    422     else:
    423       template = """
    424         def test_name():
    425           return test
    426         def body_name():
    427           body
    428           return ()
    429         ag__.while_stmt(test_name, body_name, (), (extra_deps,))
    430       """
    431       node = templates.replace(
    432           template,
    433           test_name=self.ctx.namer.new_symbol('loop_test', reserved_symbols),
    434           test=test,
    435           body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
    436           body=node_body,
    437           extra_deps=tuple(s.ast() for s in cond_closure),
    438       )
    439 
    440     undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    441     return undefined_assigns + node
    442 
    443   def _create_for_loop_early_stopping(self, loop_state, state_ssf,
    444                                       state_ast_tuple, original_node,
    445                                       extra_test_name, extra_test,
    446                                       body_name, loop_body):
    447     """Create node for for-loop with early stopping (e.g. break or return)."""
    448     template = """
    449       def extra_test_name(state_ssf):
    450         return extra_test_expr
    451       def body_name(loop_vars, state_ssf):
    452         # Workaround for PEP-3113
    453         iterate = loop_vars
    454         body
    455         return state_ssf,
    456       state_ast_tuple = ag__.for_stmt(
    457           iter_, extra_test_name, body_name, (state,))
    458     """
    459     return templates.replace(
    460         template,
    461         state=loop_state,
    462         state_ssf=state_ssf,
    463         state_ast_tuple=state_ast_tuple,
    464         iter_=original_node.iter,
    465         iterate=original_node.target,
    466         extra_test_name=extra_test_name,
    467         extra_test_expr=extra_test,
    468         body_name=body_name,
    469         body=loop_body)
    470 
    471   def _create_for_loop_with_state(self, loop_state, state_ssf, state_ast_tuple,
    472                                   original_node, body_name, loop_body):
    473     """Create node for for-loop with loop-carried state, no early stopping."""
    474     template = """
    475       def body_name(loop_vars, state_ssf):
    476         # Workaround for PEP-3113
    477         iterate = loop_vars
    478         body
    479         return state_ssf,
    480       state_ast_tuple = ag__.for_stmt(
    481           iter_, None, body_name, (state,))
    482     """
    483     return templates.replace(
    484         template,
    485         state=loop_state,
    486         state_ssf=state_ssf,
    487         state_ast_tuple=state_ast_tuple,
    488         iter_=original_node.iter,
    489         iterate=original_node.target,
    490         body_name=body_name,
    491         body=loop_body)
    492 
    493   def _create_for_loop_without_state(self, original_node, body_name, loop_body):
    494     """Create node for for-loop with loop-carried state, no early stopping."""
    495     template = """
    496       def body_name(loop_vars):
    497         # Workaround for PEP-3113
    498         iterate = loop_vars
    499         body
    500         return ()
    501       ag__.for_stmt(iter_, None, body_name, ())
    502     """
    503     return templates.replace(
    504         template,
    505         iter_=original_node.iter,
    506         iterate=original_node.target,
    507         body_name=body_name,
    508         body=loop_body)
    509 
    510   def visit_For(self, node):
    511     self.generic_visit(node)
    512 
    513     loop_state, reserved_symbols, possibly_undefs = self._get_loop_state(node)
    514     loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
    515         loop_state, reserved_symbols)
    516     node_body = ast_util.rename_symbols(node.body, ssf_map)
    517     body_name = self.ctx.namer.new_symbol('loop_body', reserved_symbols)
    518 
    519     has_extra_test = anno.hasanno(node, 'extra_test')
    520     if loop_state:
    521       if has_extra_test:
    522         # Loop with early stopping (e.g. break or return)
    523         extra_test = anno.getanno(node, 'extra_test')
    524         extra_test = ast_util.rename_symbols(extra_test, ssf_map)
    525         extra_test_name = self.ctx.namer.new_symbol('extra_test',
    526                                                     reserved_symbols)
    527         node = self._create_for_loop_early_stopping(
    528             loop_state, state_ssf, state_ast_tuple, node, extra_test_name,
    529             extra_test, body_name, node_body)
    530       else:
    531         # Loop with loop-carried state and no early stopping
    532         node = self._create_for_loop_with_state(
    533             loop_state, state_ssf, state_ast_tuple, node, body_name, node_body)
    534     else:
    535       # Loop with no loop-carried state and no early stopping
    536       assert not has_extra_test, ('Early stoppiong (e.g. break and/or return) '
    537                                   'should create state variables.')
    538       node = self._create_for_loop_without_state(node, body_name, node_body)
    539 
    540     undefined_assigns = self._create_undefined_assigns(possibly_undefs)
    541     return undefined_assigns + node
    542 
    543 
    544 def transform(node, ctx):
    545   node = ControlFlowTransformer(ctx).visit(node)
    546   return node
    547