Home | History | Annotate | Download | only in converters
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Adds guards against function calls with side effects.
     16 
     17 Only standalone calls are guarded.
     18 
     19 WARNING: This mechanism is incomplete. Particularly, it only guards the
     20 arguments passed to functions, and does not account for indirectly modified
     21 state.
     22 
     23 Example:
     24   y = tf.layers.dense(x)       # Creates TF variable 'foo'
     25   loss = loss(y)
     26   opt.minimize(loss)           # indirectly affects 'foo'
     27   z = tf.get_variable('foo')   # Indirectly affects `loss` and 'foo'
     28   # Here, `loss` can be guarded. But `z` cannot.
     29 
     30 # TODO(mdan): We should probably define a safe mode where we guard everything.
     31 """
     32 
     33 from __future__ import absolute_import
     34 from __future__ import division
     35 from __future__ import print_function
     36 
     37 import gast
     38 
     39 from tensorflow.contrib.py2tf.pyct import anno
     40 from tensorflow.contrib.py2tf.pyct import ast_util
     41 from tensorflow.contrib.py2tf.pyct import qual_names
     42 from tensorflow.contrib.py2tf.pyct import templates
     43 from tensorflow.contrib.py2tf.pyct import transformer
     44 from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
     45 
     46 
     47 class SymbolNamer(object):
     48   """Describes the interface for SideEffectGuardTransformer's namer."""
     49 
     50   def new_symbol(self, name_root, reserved_locals):
     51     """Generate a new unique function_name.
     52 
     53     Args:
     54       name_root: String, used as stem in the new name.
     55       reserved_locals: Set(string), additional local symbols that are reserved.
     56     Returns:
     57       String.
     58     """
     59     raise NotImplementedError()
     60 
     61 
     62 class SideEffectGuardTransformer(transformer.Base):
     63   """Adds control dependencies to functions with side effects."""
     64 
     65   def __init__(self, context):
     66     super(SideEffectGuardTransformer, self).__init__(context)
     67 
     68   # pylint:disable=invalid-name
     69 
     70   def _visit_and_reindent(self, nodes):
     71     new_nodes = []
     72     current_dest = new_nodes
     73     alias_map = {}
     74     reindent_requested = False
     75     for n in nodes:
     76       n = self.visit(n)
     77       # NOTE: the order in which these statements execute is important; in
     78       # particular, watch out for ending up with cycles in the AST.
     79       if alias_map:
     80         n = ast_util.rename_symbols(n, alias_map)
     81       if isinstance(n, (list, tuple)):
     82         current_dest.extend(n)
     83       else:
     84         current_dest.append(n)
     85       if anno.hasanno(n, anno.Basic.INDENT_BLOCK_REMAINDER):
     86         reindent_requested = True
     87         new_dest, new_alias_map = anno.getanno(
     88             n, anno.Basic.INDENT_BLOCK_REMAINDER)
     89         anno.delanno(n, anno.Basic.INDENT_BLOCK_REMAINDER)
     90         new_alias_map.update(alias_map)
     91         alias_map = new_alias_map
     92         current_dest = new_dest
     93     if reindent_requested and not current_dest:
     94       # TODO(mdan): There may still be something that could be done.
     95       raise ValueError('Unable to insert statement into the computation flow: '
     96                        'it is not followed by any computation which '
     97                        'the statement could gate.')
     98     return new_nodes
     99 
    100   def visit_FunctionDef(self, node):
    101     node.body = self._visit_and_reindent(node.body)
    102     return node
    103 
    104   def visit_With(self, node):
    105     node.body = self._visit_and_reindent(node.body)
    106     return node
    107 
    108   def visit_If(self, node):
    109     node.body = self._visit_and_reindent(node.body)
    110     node.orelse = self._visit_and_reindent(node.orelse)
    111     return node
    112 
    113   def visit_While(self, node):
    114     node.body = self._visit_and_reindent(node.body)
    115     node.orelse = self._visit_and_reindent(node.orelse)
    116     return node
    117 
    118   def visit_Expr(self, node):
    119     self.generic_visit(node)
    120     if isinstance(node.value, gast.Call):
    121       # Patterns of single function calls, like:
    122       #   opt.minimize(loss)
    123       # or:
    124       #   tf.py_func(...)
    125 
    126       # First, attempt to gate future evaluation of args. If that's not
    127       # possible, gate all remaining statements (and that may fail too, see
    128       # _visit_and_reindent.
    129       args_scope = anno.getanno(node.value, NodeAnno.ARGS_SCOPE)
    130       # NOTE: We can't guard object attributes because they may not be writable.
    131       # In addition, avoid renaming well-known names.
    132       # TODO(mdan): Move these names into config.
    133       unguarded_names = (qual_names.QN('self'), qual_names.QN('tf'))
    134       guarded_args = tuple(s for s in args_scope.used
    135                            if not s.is_composite() and s not in unguarded_names)
    136 
    137       # TODO(mdan): Include all arguments which depended on guarded_args too.
    138       # For example, the following will still cause a race:
    139       #   tf.assign(a, a + 1)
    140       #   b = a + 1
    141       #   tf.assign(a, a + 1)  # Control deps here should include `b`
    142       #   c = b + 1
    143       # Or maybe we should just raise an "unsafe assign" error?
    144 
    145       if guarded_args:
    146         # The aliases may need new names to avoid incorrectly making them local.
    147         # TODO(mdan): This is brutal. It will even rename modules - any fix?
    148         need_alias = tuple(
    149             s for s in guarded_args if s not in args_scope.parent.modified)
    150         aliased_new_names = tuple(
    151             qual_names.QN(
    152                 self.context.namer.new_symbol(
    153                     s.ssf(), args_scope.parent.referenced)) for s in need_alias)
    154         alias_map = dict(zip(need_alias, aliased_new_names))
    155         if len(guarded_args) == 1:
    156           s, = guarded_args
    157           aliased_guarded_args = alias_map.get(s, s)
    158         else:
    159           aliased_guarded_args = gast.Tuple(
    160               [alias_map.get(s, s).ast() for s in guarded_args], None)
    161 
    162         template = """
    163           with py2tf_utils.control_dependency_on_returns(call):
    164             aliased_guarded_args = py2tf_utils.alias_tensors(guarded_args)
    165         """
    166         control_deps_guard = templates.replace(
    167             template,
    168             call=node.value,
    169             aliased_guarded_args=aliased_guarded_args,
    170             guarded_args=guarded_args)[-1]
    171       else:
    172         alias_map = {}
    173 
    174         template = """
    175           with py2tf_utils.control_dependency_on_returns(call):
    176             pass
    177         """
    178         control_deps_guard = templates.replace(template, call=node.value)[-1]
    179         control_deps_guard.body = []
    180 
    181       node = control_deps_guard
    182       anno.setanno(node, anno.Basic.INDENT_BLOCK_REMAINDER,
    183                    (node.body, alias_map))
    184     return node
    185 
    186   # pylint:enable=invalid-name
    187 
    188 
    189 def transform(node, context):
    190   return SideEffectGuardTransformer(context).visit(node)
    191