Home | History | Annotate | Download | only in converters
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Canonicalizes continue statements by de-sugaring into a control boolean."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.contrib.py2tf.pyct import anno
     22 from tensorflow.contrib.py2tf.pyct import templates
     23 from tensorflow.contrib.py2tf.pyct import transformer
     24 from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno
     25 
     26 
     27 class ContinueCanonicalizationTransformer(transformer.Base):
     28   """Canonicalizes continue statements into additional conditionals."""
     29 
     30   def __init__(self, context):
     31     super(ContinueCanonicalizationTransformer, self).__init__(context)
     32     # This is a stack structure, to correctly process nested loops.
     33     self.continuation_uses = []
     34 
     35   def _create_continuation_check(self):
     36     template = """
     37       if not var_name:
     38         pass
     39     """
     40     cond, = templates.replace(template, var_name=self.continuation_uses[-1][1])
     41     cond.body = []
     42     return cond
     43 
     44   def _create_continuation_trigger(self):
     45     template = """
     46       var_name = True
     47     """
     48     assign, = templates.replace(
     49         template, var_name=self.continuation_uses[-1][1])
     50     return assign
     51 
     52   def _create_continuation_init(self):
     53     template = """
     54       var_name = False
     55     """
     56     assign, = templates.replace(
     57         template, var_name=self.continuation_uses[-1][1])
     58     return assign
     59 
     60   def _visit_and_reindent_if_necessary(self, nodes):
     61     reorganized_nodes = []
     62     current_dest = reorganized_nodes
     63     continue_used_in_block = False
     64     for i, n in enumerate(nodes):
     65       # TODO(mdan): This could be optimized if control structures are simple.
     66       self.continuation_uses[-1][0] = False
     67       n = self.visit(n)
     68       current_dest.append(n)
     69       if self.continuation_uses[-1][0]:
     70         continue_used_in_block = True
     71         if i < len(nodes) - 1:  # Last statement in block needs no protection.
     72           cond = self._create_continuation_check()
     73           current_dest.append(cond)
     74           current_dest = cond.body
     75     self.continuation_uses[-1][0] = continue_used_in_block
     76     return reorganized_nodes
     77 
     78   def _process_loop_block(self, block, scope):
     79     cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced)
     80     self.continuation_uses.append([False, cont_var])
     81     block = self._visit_and_reindent_if_necessary(block)
     82     if self.continuation_uses[-1][0]:
     83       block.insert(0, self._create_continuation_init())
     84     self.continuation_uses.pop()
     85     return block
     86 
     87   def visit_While(self, node):
     88     self.generic_visit(node.test)
     89     node.body = self._process_loop_block(node.body,
     90                                          anno.getanno(node,
     91                                                       NodeAnno.BODY_SCOPE))
     92     for n in node.orelse:
     93       self.generic_visit(n)
     94     return node
     95 
     96   def visit_For(self, node):
     97     self.generic_visit(node.target)
     98     self.generic_visit(node.iter)
     99     node.body = self._process_loop_block(node.body,
    100                                          anno.getanno(node,
    101                                                       NodeAnno.BODY_SCOPE))
    102     for n in node.orelse:
    103       self.generic_visit(n)
    104     return node
    105 
    106   def visit_If(self, node):
    107     if self.continuation_uses:
    108       self.generic_visit(node.test)
    109       node.body = self._visit_and_reindent_if_necessary(node.body)
    110       continue_used_in_body = self.continuation_uses[-1][0]
    111       node.orelse = self._visit_and_reindent_if_necessary(node.orelse)
    112       self.continuation_uses[-1][0] = (
    113           continue_used_in_body or self.continuation_uses[-1][0])
    114     else:
    115       node = self.generic_visit(node)
    116     return node
    117 
    118   def visit_Continue(self, node):
    119     self.continuation_uses[-1][0] = True
    120     return self._create_continuation_trigger()
    121 
    122   def visit_Break(self, node):
    123     assert False, 'break statement should be desugared at this point'
    124 
    125 
    126 def transform(node, namer):
    127   return ContinueCanonicalizationTransformer(namer).visit(node)
    128