Home | History | Annotate | Download | only in utils
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utilities for type-dependent behavior used in py2tf-generated code."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import six
     22 
     23 from tensorflow.contrib.py2tf.utils.type_check import is_tensor
     24 from tensorflow.python.ops import control_flow_ops
     25 
     26 
     27 def run_cond(condition, true_fn, false_fn):
     28   """Type-dependent functional conditional.
     29 
     30   Args:
     31     condition: A Tensor or Python bool.
     32     true_fn: A Python callable implementing the true branch of the conditional.
     33     false_fn: A Python callable implementing the false branch of the
     34       conditional.
     35 
     36   Returns:
     37     result: The result of calling the appropriate branch. If condition is a
     38     Tensor, tf.cond will be used. Otherwise, a standard Python if statement will
     39     be ran.
     40   """
     41   if is_tensor(condition):
     42     return control_flow_ops.cond(condition, true_fn, false_fn)
     43   else:
     44     return py_cond(condition, true_fn, false_fn)
     45 
     46 
     47 def py_cond(condition, true_fn, false_fn):
     48   if condition:
     49     return true_fn()
     50   else:
     51     return false_fn()
     52 
     53 
     54 def run_while(cond_fn, body_fn, init_args):
     55   """Type-dependent functional while loop.
     56 
     57   Args:
     58     cond_fn: A Python callable implementing the stop conditions of the loop.
     59     body_fn: A Python callable implementing the body of the loop.
     60     init_args: The initial values of the arguments that will be passed to both
     61       cond_fn and body_fn.
     62 
     63   Returns:
     64     result: A list of values with the same shape and type as init_args. If any
     65     of the init_args, or any variables closed-over in cond_fn are Tensors,
     66     tf.while_loop will be used, otherwise a Python while loop will be ran.
     67 
     68   Raises:
     69     ValueError: if init_args is not a tuple or list with one or more elements.
     70   """
     71   if not isinstance(init_args, (tuple, list)) or not init_args:
     72     raise ValueError(
     73         'init_args must be a non-empty list or tuple, found %s' % init_args)
     74 
     75   # TODO(alexbw): statically determine all active variables in cond_fn,
     76   # and pass them directly
     77   closure_vars = tuple(
     78       [c.cell_contents for c in six.get_function_closure(cond_fn) or []])
     79   possibly_tensors = tuple(init_args) + closure_vars
     80   if is_tensor(*possibly_tensors):
     81     return control_flow_ops.while_loop(cond_fn, body_fn, init_args)
     82   else:
     83     return py_while_loop(cond_fn, body_fn, init_args)
     84 
     85 
     86 def py_while_loop(cond_fn, body_fn, init_args):
     87   state = init_args
     88   while cond_fn(*state):
     89     state = body_fn(*state)
     90   return state
     91