1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utilities for type-dependent behavior used in py2tf-generated code.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import six 22 23 from tensorflow.contrib.py2tf.utils.type_check import is_tensor 24 from tensorflow.python.ops import control_flow_ops 25 26 27 def run_cond(condition, true_fn, false_fn): 28 """Type-dependent functional conditional. 29 30 Args: 31 condition: A Tensor or Python bool. 32 true_fn: A Python callable implementing the true branch of the conditional. 33 false_fn: A Python callable implementing the false branch of the 34 conditional. 35 36 Returns: 37 result: The result of calling the appropriate branch. If condition is a 38 Tensor, tf.cond will be used. Otherwise, a standard Python if statement will 39 be ran. 40 """ 41 if is_tensor(condition): 42 return control_flow_ops.cond(condition, true_fn, false_fn) 43 else: 44 return py_cond(condition, true_fn, false_fn) 45 46 47 def py_cond(condition, true_fn, false_fn): 48 if condition: 49 return true_fn() 50 else: 51 return false_fn() 52 53 54 def run_while(cond_fn, body_fn, init_args): 55 """Type-dependent functional while loop. 56 57 Args: 58 cond_fn: A Python callable implementing the stop conditions of the loop. 59 body_fn: A Python callable implementing the body of the loop. 60 init_args: The initial values of the arguments that will be passed to both 61 cond_fn and body_fn. 62 63 Returns: 64 result: A list of values with the same shape and type as init_args. If any 65 of the init_args, or any variables closed-over in cond_fn are Tensors, 66 tf.while_loop will be used, otherwise a Python while loop will be ran. 67 68 Raises: 69 ValueError: if init_args is not a tuple or list with one or more elements. 70 """ 71 if not isinstance(init_args, (tuple, list)) or not init_args: 72 raise ValueError( 73 'init_args must be a non-empty list or tuple, found %s' % init_args) 74 75 # TODO(alexbw): statically determine all active variables in cond_fn, 76 # and pass them directly 77 closure_vars = tuple( 78 [c.cell_contents for c in six.get_function_closure(cond_fn) or []]) 79 possibly_tensors = tuple(init_args) + closure_vars 80 if is_tensor(*possibly_tensors): 81 return control_flow_ops.while_loop(cond_fn, body_fn, init_args) 82 else: 83 return py_while_loop(cond_fn, body_fn, init_args) 84 85 86 def py_while_loop(cond_fn, body_fn, init_args): 87 state = init_args 88 while cond_fn(*state): 89 state = body_fn(*state) 90 return state 91