1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Canonicalizes continue statements by de-sugaring into a control boolean.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.contrib.py2tf.pyct import anno 22 from tensorflow.contrib.py2tf.pyct import templates 23 from tensorflow.contrib.py2tf.pyct import transformer 24 from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno 25 26 27 class ContinueCanonicalizationTransformer(transformer.Base): 28 """Canonicalizes continue statements into additional conditionals.""" 29 30 def __init__(self, context): 31 super(ContinueCanonicalizationTransformer, self).__init__(context) 32 # This is a stack structure, to correctly process nested loops. 33 self.continuation_uses = [] 34 35 def _create_continuation_check(self): 36 template = """ 37 if not var_name: 38 pass 39 """ 40 cond, = templates.replace(template, var_name=self.continuation_uses[-1][1]) 41 cond.body = [] 42 return cond 43 44 def _create_continuation_trigger(self): 45 template = """ 46 var_name = True 47 """ 48 assign, = templates.replace( 49 template, var_name=self.continuation_uses[-1][1]) 50 return assign 51 52 def _create_continuation_init(self): 53 template = """ 54 var_name = False 55 """ 56 assign, = templates.replace( 57 template, var_name=self.continuation_uses[-1][1]) 58 return assign 59 60 def _visit_and_reindent_if_necessary(self, nodes): 61 reorganized_nodes = [] 62 current_dest = reorganized_nodes 63 continue_used_in_block = False 64 for i, n in enumerate(nodes): 65 # TODO(mdan): This could be optimized if control structures are simple. 66 self.continuation_uses[-1][0] = False 67 n = self.visit(n) 68 current_dest.append(n) 69 if self.continuation_uses[-1][0]: 70 continue_used_in_block = True 71 if i < len(nodes) - 1: # Last statement in block needs no protection. 72 cond = self._create_continuation_check() 73 current_dest.append(cond) 74 current_dest = cond.body 75 self.continuation_uses[-1][0] = continue_used_in_block 76 return reorganized_nodes 77 78 def _process_loop_block(self, block, scope): 79 cont_var = self.context.namer.new_symbol('cont_requested', scope.referenced) 80 self.continuation_uses.append([False, cont_var]) 81 block = self._visit_and_reindent_if_necessary(block) 82 if self.continuation_uses[-1][0]: 83 block.insert(0, self._create_continuation_init()) 84 self.continuation_uses.pop() 85 return block 86 87 def visit_While(self, node): 88 self.generic_visit(node.test) 89 node.body = self._process_loop_block(node.body, 90 anno.getanno(node, 91 NodeAnno.BODY_SCOPE)) 92 for n in node.orelse: 93 self.generic_visit(n) 94 return node 95 96 def visit_For(self, node): 97 self.generic_visit(node.target) 98 self.generic_visit(node.iter) 99 node.body = self._process_loop_block(node.body, 100 anno.getanno(node, 101 NodeAnno.BODY_SCOPE)) 102 for n in node.orelse: 103 self.generic_visit(n) 104 return node 105 106 def visit_If(self, node): 107 if self.continuation_uses: 108 self.generic_visit(node.test) 109 node.body = self._visit_and_reindent_if_necessary(node.body) 110 continue_used_in_body = self.continuation_uses[-1][0] 111 node.orelse = self._visit_and_reindent_if_necessary(node.orelse) 112 self.continuation_uses[-1][0] = ( 113 continue_used_in_body or self.continuation_uses[-1][0]) 114 else: 115 node = self.generic_visit(node) 116 return node 117 118 def visit_Continue(self, node): 119 self.continuation_uses[-1][0] = True 120 return self._create_continuation_trigger() 121 122 def visit_Break(self, node): 123 assert False, 'break statement should be desugared at this point' 124 125 126 def transform(node, namer): 127 return ContinueCanonicalizationTransformer(namer).visit(node) 128