Home | History | Annotate | Download | only in ops
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # =============================================================================
     15 
     16 """Functional operations."""
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 from tensorflow.core.framework import attr_value_pb2
     23 from tensorflow.python.eager import context
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import function
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.framework import tensor_shape
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import control_flow_ops
     31 from tensorflow.python.ops import gen_functional_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import tensor_array_ops
     34 from tensorflow.python.ops import variable_scope as vs
     35 # pylint: disable=unused-import
     36 from tensorflow.python.ops.gen_functional_ops import remote_call
     37 # pylint: enable=unused-import
     38 from tensorflow.python.ops.gen_functional_ops import symbolic_gradient
     39 from tensorflow.python.util import compat
     40 from tensorflow.python.util import function_utils
     41 from tensorflow.python.util import nest
     42 from tensorflow.python.util.tf_export import tf_export
     43 
     44 
     45 # TODO(yuanbyu, mrry): Handle stride to support sliding windows.
     46 @tf_export("foldl")
     47 def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
     48           swap_memory=False, name=None):
     49   """foldl on the list of tensors unpacked from `elems` on dimension 0.
     50 
     51   This foldl operator repeatedly applies the callable `fn` to a sequence
     52   of elements from first to last. The elements are made of the tensors
     53   unpacked from `elems` on dimension 0. The callable fn takes two tensors as
     54   arguments. The first argument is the accumulated value computed from the
     55   preceding invocation of fn. If `initializer` is None, `elems` must contain
     56   at least one element, and its first element is used as the initializer.
     57 
     58   Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
     59   of the result tensor is fn(initializer, values[0]).shape`.
     60 
     61   This method also allows multi-arity `elems` and output of `fn`.  If `elems`
     62   is a (possibly nested) list or tuple of tensors, then each of these tensors
     63   must have a matching first (unpack) dimension.  The signature of `fn` may
     64   match the structure of `elems`.  That is, if `elems` is
     65   `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
     66   `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
     67 
     68   Args:
     69     fn: The callable to be performed.
     70     elems: A tensor or (possibly nested) sequence of tensors, each of which
     71       will be unpacked along their first dimension.  The nested sequence
     72       of the resulting slices will be the first argument to `fn`.
     73     initializer: (optional) A tensor or (possibly nested) sequence of tensors,
     74       as the initial value for the accumulator.
     75     parallel_iterations: (optional) The number of iterations allowed to run
     76       in parallel.
     77     back_prop: (optional) True enables support for back propagation.
     78     swap_memory: (optional) True enables GPU-CPU memory swapping.
     79     name: (optional) Name prefix for the returned tensors.
     80 
     81   Returns:
     82     A tensor or (possibly nested) sequence of tensors, resulting from applying
     83     `fn` consecutively to the list of tensors unpacked from `elems`, from first
     84     to last.
     85 
     86   Raises:
     87     TypeError: if `fn` is not callable.
     88 
     89   Example:
     90     ```python
     91     elems = tf.constant([1, 2, 3, 4, 5, 6])
     92     sum = foldl(lambda a, x: a + x, elems)
     93     # sum == 21
     94     ```
     95   """
     96   if not callable(fn):
     97     raise TypeError("fn must be callable.")
     98 
     99   def create_ta(elem):
    100     return tensor_array_ops.TensorArray(
    101         dtype=elem.dtype, size=n, dynamic_size=False,
    102         infer_shape=True).unstack(elem)
    103 
    104   in_graph_mode = not context.executing_eagerly()
    105   with ops.name_scope(name, "foldl", [elems]):
    106     # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    107     # supported in Eager
    108     if in_graph_mode:
    109       # Any get_variable calls in fn will cache the first call locally
    110       # and not issue repeated network I/O requests for each iteration.
    111       varscope = vs.get_variable_scope()
    112       varscope_caching_device_was_none = False
    113       if varscope.caching_device is None:
    114         # TODO(ebrevdo): Change to using colocate_with here and in other
    115         # methods.
    116         varscope.set_caching_device(lambda op: op.device)
    117         varscope_caching_device_was_none = True
    118 
    119     # Convert elems to tensor array. n may be known statically.
    120     elems_flat = [
    121         ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
    122     ]
    123     n = (tensor_shape.dimension_value(elems_flat[0].shape[0])
    124          or array_ops.shape(elems_flat[0])[0])
    125 
    126     elems_ta = nest.map_structure(create_ta, elems)
    127 
    128     if initializer is None:
    129       a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
    130       i = constant_op.constant(1)
    131     else:
    132       a = initializer
    133       i = constant_op.constant(0)
    134 
    135     def compute(i, a):
    136       elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
    137       a = fn(a, elem_i)
    138       return [i + 1, a]
    139 
    140     _, r_a = control_flow_ops.while_loop(
    141         lambda i, a: i < n, compute, [i, a],
    142         parallel_iterations=parallel_iterations,
    143         back_prop=back_prop,
    144         swap_memory=swap_memory,
    145         maximum_iterations=n)
    146 
    147     # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    148     # supported in Eager
    149     if in_graph_mode and varscope_caching_device_was_none:
    150       varscope.set_caching_device(None)
    151 
    152     return r_a
    153 
    154 
    155 @tf_export("foldr")
    156 def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
    157           swap_memory=False, name=None):
    158   """foldr on the list of tensors unpacked from `elems` on dimension 0.
    159 
    160   This foldr operator repeatedly applies the callable `fn` to a sequence
    161   of elements from last to first. The elements are made of the tensors
    162   unpacked from `elems`. The callable fn takes two tensors as arguments.
    163   The first argument is the accumulated value computed from the preceding
    164   invocation of fn. If `initializer` is None, `elems` must contain at least
    165   one element, and its first element is used as the initializer.
    166 
    167   Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
    168   of the result tensor is `fn(initializer, values[0]).shape`.
    169 
    170   This method also allows multi-arity `elems` and output of `fn`.  If `elems`
    171   is a (possibly nested) list or tuple of tensors, then each of these tensors
    172   must have a matching first (unpack) dimension.  The signature of `fn` may
    173   match the structure of `elems`.  That is, if `elems` is
    174   `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
    175   `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
    176 
    177   Args:
    178     fn: The callable to be performed.
    179     elems: A tensor or (possibly nested) sequence of tensors, each of which
    180       will be unpacked along their first dimension.  The nested sequence
    181       of the resulting slices will be the first argument to `fn`.
    182     initializer: (optional) A tensor or (possibly nested) sequence of tensors,
    183       as the initial value for the accumulator.
    184     parallel_iterations: (optional) The number of iterations allowed to run
    185       in parallel.
    186     back_prop: (optional) True enables support for back propagation.
    187     swap_memory: (optional) True enables GPU-CPU memory swapping.
    188     name: (optional) Name prefix for the returned tensors.
    189 
    190   Returns:
    191     A tensor or (possibly nested) sequence of tensors, resulting from applying
    192     `fn` consecutively to the list of tensors unpacked from `elems`, from last
    193     to first.
    194 
    195   Raises:
    196     TypeError: if `fn` is not callable.
    197 
    198   Example:
    199     ```python
    200     elems = [1, 2, 3, 4, 5, 6]
    201     sum = foldr(lambda a, x: a + x, elems)
    202     # sum == 21
    203     ```
    204   """
    205   if not callable(fn):
    206     raise TypeError("fn must be callable.")
    207 
    208   def create_ta(elem):
    209     return tensor_array_ops.TensorArray(
    210         dtype=elem.dtype, size=n, dynamic_size=False,
    211         infer_shape=True).unstack(elem)
    212 
    213   in_graph_mode = not context.executing_eagerly()
    214   with ops.name_scope(name, "foldr", [elems]):
    215     # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    216     # supported in Eager
    217     if in_graph_mode:
    218       # Any get_variable calls in fn will cache the first call locally and not
    219       # issue repeated network I/O requests for each iteration.
    220       varscope = vs.get_variable_scope()
    221       varscope_caching_device_was_none = False
    222       if varscope.caching_device is None:
    223         # TODO(ebrevdo): Change to using colocate_with here and in other
    224         # methods.
    225         varscope.set_caching_device(lambda op: op.device)
    226         varscope_caching_device_was_none = True
    227 
    228     # Convert elems to tensor array. n may be known statically.
    229     elems_flat = [
    230         ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
    231     ]
    232     n = (tensor_shape.dimension_value(elems_flat[0].shape[0])
    233          or array_ops.shape(elems_flat[0])[0])
    234 
    235     elems_ta = nest.map_structure(create_ta, elems)
    236 
    237     if initializer is None:
    238       i = n - 1
    239       a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
    240     else:
    241       i = n
    242       a = initializer
    243 
    244     def compute(i, a):
    245       i -= 1
    246       elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
    247       a_out = fn(a, elem)
    248       return [i, a_out]
    249 
    250     _, r_a = control_flow_ops.while_loop(
    251         lambda i, a: i > 0,
    252         compute, [i, a],
    253         parallel_iterations=parallel_iterations,
    254         back_prop=back_prop,
    255         swap_memory=swap_memory,
    256         maximum_iterations=n)
    257 
    258     # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    259     # supported in Eager
    260     if in_graph_mode and varscope_caching_device_was_none:
    261       varscope.set_caching_device(None)
    262 
    263     return r_a
    264 
    265 
    266 @tf_export("scan")
    267 def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
    268          swap_memory=False, infer_shape=True, reverse=False, name=None):
    269   """scan on the list of tensors unpacked from `elems` on dimension 0.
    270 
    271   The simplest version of `scan` repeatedly applies the callable `fn` to a
    272   sequence of elements from first to last. The elements are made of the tensors
    273   unpacked from `elems` on dimension 0. The callable fn takes two tensors as
    274   arguments. The first argument is the accumulated value computed from the
    275   preceding invocation of fn. If `initializer` is None, `elems` must contain
    276   at least one element, and its first element is used as the initializer.
    277 
    278   Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
    279   of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
    280   If reverse=True, it's fn(initializer, values[-1]).shape.
    281 
    282   This method also allows multi-arity `elems` and accumulator.  If `elems`
    283   is a (possibly nested) list or tuple of tensors, then each of these tensors
    284   must have a matching first (unpack) dimension.  The second argument of
    285   `fn` must match the structure of `elems`.
    286 
    287   If no `initializer` is provided, the output structure and dtypes of `fn`
    288   are assumed to be the same as its input; and in this case, the first
    289   argument of `fn` must match the structure of `elems`.
    290 
    291   If an `initializer` is provided, then the output of `fn` must have the same
    292   structure as `initializer`; and the first argument of `fn` must match
    293   this structure.
    294 
    295   For example, if `elems` is `(t1, [t2, t3])` and `initializer` is
    296   `[i1, i2]` then an appropriate signature for `fn` in `python2` is:
    297   `fn = lambda (acc_p1, acc_p2), (t1, [t2, t3]):` and `fn` must return a list,
    298   `[acc_n1, acc_n2]`.  An alternative correct signature for `fn`, and the
    299    one that works in `python3`, is:
    300   `fn = lambda a, t:`, where `a` and `t` correspond to the input tuples.
    301 
    302   Args:
    303     fn: The callable to be performed.  It accepts two arguments.  The first
    304       will have the same structure as `initializer` if one is provided,
    305       otherwise it will have the same structure as `elems`.  The second
    306       will have the same (possibly nested) structure as `elems`.  Its output
    307       must have the same structure as `initializer` if one is provided,
    308       otherwise it must have the same structure as `elems`.
    309     elems: A tensor or (possibly nested) sequence of tensors, each of which
    310       will be unpacked along their first dimension.  The nested sequence
    311       of the resulting slices will be the first argument to `fn`.
    312     initializer: (optional) A tensor or (possibly nested) sequence of tensors,
    313       initial value for the accumulator, and the expected output type of `fn`.
    314     parallel_iterations: (optional) The number of iterations allowed to run
    315       in parallel.
    316     back_prop: (optional) True enables support for back propagation.
    317     swap_memory: (optional) True enables GPU-CPU memory swapping.
    318     infer_shape: (optional) False disables tests for consistent output shapes.
    319     reverse: (optional) True scans the tensor last to first (instead of first
    320       to last).
    321     name: (optional) Name prefix for the returned tensors.
    322 
    323   Returns:
    324     A tensor or (possibly nested) sequence of tensors.  Each tensor packs the
    325     results of applying `fn` to tensors unpacked from `elems` along the first
    326     dimension, and the previous accumulator value(s), from first to last (or
    327     last to first, if `reverse=True`).
    328 
    329   Raises:
    330     TypeError: if `fn` is not callable or the structure of the output of
    331       `fn` and `initializer` do not match.
    332     ValueError: if the lengths of the output of `fn` and `initializer`
    333       do not match.
    334 
    335   Examples:
    336     ```python
    337     elems = np.array([1, 2, 3, 4, 5, 6])
    338     sum = scan(lambda a, x: a + x, elems)
    339     # sum == [1, 3, 6, 10, 15, 21]
    340     sum = scan(lambda a, x: a + x, elems, reverse=True)
    341     # sum == [22, 21, 18, 15, 11, 6]
    342     ```
    343 
    344     ```python
    345     elems = np.array([1, 2, 3, 4, 5, 6])
    346     initializer = np.array(0)
    347     sum_one = scan(
    348         lambda a, x: x[0] - x[1] + a, (elems + 1, elems), initializer)
    349     # sum_one == [1, 2, 3, 4, 5, 6]
    350     ```
    351 
    352     ```python
    353     elems = np.array([1, 0, 0, 0, 0, 0])
    354     initializer = (np.array(0), np.array(1))
    355     fibonaccis = scan(lambda a, _: (a[1], a[0] + a[1]), elems, initializer)
    356     # fibonaccis == ([1, 1, 2, 3, 5, 8], [1, 2, 3, 5, 8, 13])
    357     ```
    358   """
    359   if not callable(fn):
    360     raise TypeError("fn must be callable.")
    361 
    362   input_is_sequence = nest.is_sequence(elems)
    363   input_flatten = lambda x: nest.flatten(x) if input_is_sequence else [x]
    364   def input_pack(x):
    365     return nest.pack_sequence_as(elems, x) if input_is_sequence else x[0]
    366 
    367   if initializer is None:
    368     output_is_sequence = input_is_sequence
    369     output_flatten = input_flatten
    370     output_pack = input_pack
    371   else:
    372     output_is_sequence = nest.is_sequence(initializer)
    373     output_flatten = lambda x: nest.flatten(x) if output_is_sequence else [x]
    374     def output_pack(x):
    375       return (nest.pack_sequence_as(initializer, x)
    376               if output_is_sequence else x[0])
    377 
    378   elems_flat = input_flatten(elems)
    379 
    380   in_graph_mode = not context.executing_eagerly()
    381   with ops.name_scope(name, "scan", elems_flat):
    382     # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    383     # supported in Eager
    384     if in_graph_mode:
    385       # Any get_variable calls in fn will cache the first call locally
    386       # and not issue repeated network I/O requests for each iteration.
    387       varscope = vs.get_variable_scope()
    388       varscope_caching_device_was_none = False
    389       if varscope.caching_device is None:
    390         # TODO(ebrevdo): Change to using colocate_with here and in other
    391         # methods.
    392         varscope.set_caching_device(lambda op: op.device)
    393         varscope_caching_device_was_none = True
    394 
    395     # Convert elems to tensor array.
    396     elems_flat = [
    397         ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]
    398 
    399     # Convert elems to tensor array. n may be known statically.
    400     n = tensor_shape.dimension_value(elems_flat[0].shape[0])
    401     if n is None:
    402       n = array_ops.shape(elems_flat[0])[0]
    403 
    404     # TensorArrays are always flat
    405     elems_ta = [
    406         tensor_array_ops.TensorArray(dtype=elem.dtype, size=n,
    407                                      dynamic_size=False,
    408                                      element_shape=elem.shape[1:],
    409                                      infer_shape=True)
    410         for elem in elems_flat]
    411     # Unpack elements
    412     elems_ta = [
    413         elem_ta.unstack(elem) for elem_ta, elem in zip(elems_ta, elems_flat)]
    414 
    415     if initializer is None:
    416       a_flat = [elem.read(n - 1 if reverse else 0) for elem in elems_ta]
    417       i = constant_op.constant(1)
    418     else:
    419       initializer_flat = output_flatten(initializer)
    420       a_flat = [ops.convert_to_tensor(init) for init in initializer_flat]
    421       i = constant_op.constant(0)
    422 
    423     # Create a tensor array to store the intermediate values.
    424     accs_ta = [
    425         tensor_array_ops.TensorArray(
    426             dtype=init.dtype, size=n,
    427             element_shape=init.shape if infer_shape else None,
    428             dynamic_size=False,
    429             infer_shape=infer_shape)
    430         for init in a_flat]
    431 
    432     if initializer is None:
    433       accs_ta = [acc_ta.write(n - 1 if reverse else 0, a)
    434                  for (acc_ta, a) in zip(accs_ta, a_flat)]
    435 
    436     def compute(i, a_flat, tas):
    437       """The loop body of scan.
    438 
    439       Args:
    440         i: the loop counter.
    441         a_flat: the accumulator value(s), flattened.
    442         tas: the output accumulator TensorArray(s), flattened.
    443 
    444       Returns:
    445         [i + 1, a_flat, tas]: the updated counter + new accumulator values +
    446           updated TensorArrays
    447 
    448       Raises:
    449         TypeError: if initializer and fn() output structure do not match
    450         ValueType: if initializer and fn() output lengths do not match
    451       """
    452       packed_elems = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
    453       packed_a = output_pack(a_flat)
    454       a_out = fn(packed_a, packed_elems)
    455       nest.assert_same_structure(
    456           elems if initializer is None else initializer, a_out)
    457       flat_a_out = output_flatten(a_out)
    458       tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_a_out)]
    459       if reverse:
    460         next_i = i - 1
    461       else:
    462         next_i = i + 1
    463       return (next_i, flat_a_out, tas)
    464 
    465     if reverse:
    466       initial_i = n - 1 - i
    467       condition = lambda i, _1, _2: i >= 0
    468     else:
    469       initial_i = i
    470       condition = lambda i, _1, _2: i < n
    471     _, _, r_a = control_flow_ops.while_loop(
    472         condition, compute, (initial_i, a_flat, accs_ta),
    473         parallel_iterations=parallel_iterations,
    474         back_prop=back_prop, swap_memory=swap_memory,
    475         maximum_iterations=n)
    476 
    477     results_flat = [r.stack() for r in r_a]
    478 
    479     n_static = tensor_shape.Dimension(tensor_shape.dimension_value(
    480         elems_flat[0].get_shape().with_rank_at_least(1)[0]))
    481     for elem in elems_flat[1:]:
    482       n_static.merge_with(tensor_shape.Dimension(tensor_shape.dimension_value(
    483           elem.get_shape().with_rank_at_least(1)[0])))
    484     for r in results_flat:
    485       r.set_shape(tensor_shape.TensorShape(n_static).concatenate(
    486           r.get_shape()[1:]))
    487 
    488     # TODO(akshayka): Remove the in_graph_mode check once caching devices are
    489     # supported in Eager
    490     if in_graph_mode and varscope_caching_device_was_none:
    491       varscope.set_caching_device(None)
    492 
    493     return output_pack(results_flat)
    494 
    495 
    496 # pylint: disable=invalid-name
    497 def If(cond, inputs, then_branch, else_branch, name=None):
    498   r"""output = Cond(inputs) ? then_branch(inputs) : else_branch(inputs).
    499 
    500   Args:
    501     cond: A `Tensor`. A scalar. If the scalar is not a boolean, the scalar is
    502       converted to a boolean according to the following rule: if the
    503       scalar is a numerical value, non-zero means True and zero means
    504       False; if the scalar is a string, non-empty means True and empty
    505       means False.
    506     inputs: A list of input tensors.
    507     then_branch: A function takes 'inputs' and returns a list of tensors,
    508         whose types are the same as what else_branch returns.
    509     else_branch: A function takes 'inputs' and returns a list of tensors.
    510         whose types are the same as what then_branch returns.
    511     name: A name for the operation (optional).
    512 
    513   Returns:
    514     A list of tensors returned by either then_branch(inputs)
    515     or else_branch(inputs).
    516   """
    517   # pylint: disable=protected-access
    518   return gen_functional_ops._if(
    519       cond,
    520       inputs, [_.type for _ in then_branch.definition.signature.output_arg],
    521       then_branch,
    522       else_branch,
    523       name=name)
    524 
    525 
    526 def Gradient(inputs, f, name=None):
    527   r"""Computes the gradient function for function f via backpropagation.
    528 
    529   Args:
    530     inputs: A list of tensors of size N + M.
    531     f: The function we want to compute the gradient for.
    532 
    533       The function 'f' must be a numerical function which takes N inputs and
    534       produces M outputs. Its gradient function 'g', which is  a function
    535       taking N + M inputs and produces N outputs.
    536 
    537       I.e. if we have
    538          (y1, y2, ..., yM) = f(x1, x2, ..., xN),
    539       then, g is
    540          (dL/dx1, dL/dx2, ..., dL/dxN) = g(x1, x2, ..., xN,
    541                                            dL/dy1, dL/dy2, ..., dL/dyM),
    542 
    543       where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
    544       loss function). dL/dxi is the partial derivative of L with respect
    545       to xi.
    546 
    547     name: A name for the operation (optional).
    548 
    549   Returns:
    550     A list of tensors of size N.
    551   """
    552   # TODO(zhifengc): Pretty-print the above spec in latex.
    553   # TODO(zhfiengc): Needs some math expert to say the comment above better.
    554   tlist = [_.type for _ in f.definition.signature.input_arg]
    555   return symbolic_gradient(input=inputs, Tout=tlist, f=f, name=name)
    556 
    557 
    558 def _LoopBodyCaptureWrapper(func):
    559   """Returns a wrapper for `func` that handles loop-carried captured inputs."""
    560 
    561   @function.Defun(
    562       *func.declared_input_types, func_name="%s_Wrapper" % func.name)
    563   def Wrapper(*args):
    564     """A wrapper that handles loop-carried captured inputs."""
    565     result = func(*args)
    566     extra_args = tuple(function.get_extra_args())
    567     # Nullary functions return an Operation. Normal functions can't do this
    568     # because their return values are converted to Tensors.
    569     if isinstance(result, ops.Operation):
    570       return extra_args
    571     # Unary functions return a single Tensor value.
    572     elif not isinstance(result, tuple):
    573       return (result,) + extra_args
    574     # N-ary functions return a tuple of Tensors.
    575     else:
    576       return result + extra_args
    577 
    578   return Wrapper
    579 
    580 
    581 # pylint: disable=invalid-name,protected-access
    582 def While(input_, cond, body, name=None, hostmem=None):
    583   r"""output = input; While (Cond(output)) { output = Body(output) }.
    584 
    585   Args:
    586     input_: A list of `Tensor` objects.
    587       A list of input tensors whose types are T.
    588     cond: . A function takes 'input' and returns a tensor.  If the tensor is
    589       a scalar of non-boolean, the scalar is converted to a boolean
    590       according to the following rule: if the scalar is a numerical
    591       value, non-zero means True and zero means False; if the scalar is
    592       a string, non-empty means True and empty means False. If the
    593       tensor is not a scalar, non-emptiness means True and False
    594       otherwise.
    595     body: . A function takes a list of tensors and returns another
    596       list tensors. Both lists have the same types as specified
    597       by T.
    598     name: A name for the operation (optional).
    599     hostmem: A list of integer. If i is in the list, input[i] is a
    600       host memory tensor.
    601 
    602   Raises:
    603     ValueError: if `cond` has implicitly captured inputs or if `cond` and `body`
    604       have different signatures.
    605 
    606   Returns:
    607     A list of `Tensor` objects. Has the same type as `input`.
    608     A list of output tensors whose types are T.
    609   """
    610   if cond.captured_inputs:
    611     raise ValueError("While op 'cond' argument must be a function "
    612                      "without implicitly captured inputs.")
    613 
    614   if cond.declared_input_types != body.declared_input_types:
    615     raise ValueError(
    616         "While op 'cond' and 'body' signatures do not match. %r vs %r" %
    617         (cond.declared_input_types, body.declared_input_types))
    618 
    619   if body.captured_inputs:
    620     cond_dtypes = list(
    621         body.declared_input_types) + [t.dtype for t in body.captured_inputs]
    622 
    623     @function.Defun(*cond_dtypes, func_name="%s_Wrapper" % cond.name)
    624     def CondWrapper(*args):
    625       """A wrapper that handles loop-carried captured inputs."""
    626       return cond(*args[:len(body.declared_input_types)])
    627 
    628     ret = gen_functional_ops._while(
    629         input_ + body.captured_inputs,
    630         CondWrapper,
    631         _LoopBodyCaptureWrapper(body),
    632         name=name)
    633     # Slice off the loop-carried captured inputs.
    634     ret = ret[:-len(body.captured_inputs)]
    635   else:
    636     ret = gen_functional_ops._while(input_, cond, body, name=name)
    637   if hostmem:
    638     input_attr = attr_value_pb2.AttrValue()
    639     input_attr.list.i.extend(hostmem)
    640     ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
    641 
    642     output_attr = attr_value_pb2.AttrValue()
    643     output_attr.list.i.extend(hostmem)
    644     ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
    645   return ret
    646 
    647 
    648 # b/36459430
    649 #
    650 # Ideally, we do not need this rewrite For loop into a While loop.
    651 # However, today, if a While runs on GPU and the condition returns a
    652 # boolean, the While kernel crashes. Even if we fix the crash, the
    653 # bool needs to be copied between GPU and CPU. So, a for loop is much
    654 # preferred when running on GPU.
    655 #
    656 # On the other hand, For op has no directly XLA kernel. So, when we run
    657 # a for loop, we need to rewrite it using a While op.
    658 #
    659 # It should be possible and probably better to write a XLA C++ kernel
    660 # implementing the logic in _ForUsingWhile.
    661 def _ForUsingWhile(start,
    662                    limit,
    663                    delta,
    664                    inputs,
    665                    forbody,
    666                    name=None,
    667                    hostmem=None):
    668   """Helper to implement a For loop using a While."""
    669   # To support negative delta (e.g., range(100, 0, -3)), we iterate
    670   # over the range(n) and use iter * delta + start as the real
    671   # iteration index. (e.g., for i in range(34): iter = i * (-3) +
    672   # 100).
    673   d = math_ops.abs(delta)
    674   # XLA on TPUs doesn't support integer division
    675   n = math_ops.cast(
    676       math_ops.cast((math_ops.abs(limit - start) + d - 1), dtypes.float32) /
    677       math_ops.cast(d, dtypes.float32), dtypes.int32)
    678 
    679   # Carried loop variables ("extra_args") are implicitly added to the input list
    680   # of the WhileBody function. WhileCond does not call forbody, and so does not
    681   # depend on any of forbody's extra_args. Since WhileCond and WhileBody
    682   # must have identical inputs, we have to augment the cond signature to take
    683   # the same types as the carried loop variables.
    684   body_sig = [dtypes.int32] * 4 + list(forbody.declared_input_types)[1:]
    685 
    686   cond_name = "%s_Cond" % forbody.name
    687 
    688   @function.Defun(*body_sig, func_name=cond_name)
    689   def WhileCond(i, n, *args):
    690     del args
    691     return i < n
    692 
    693   body_name = "%s_Body" % forbody.name
    694 
    695   @function.Defun(*body_sig, func_name=body_name)
    696   def WhileBody(i, n, start, delta, *args):
    697     """A While wrapper for forbody that handles loop-carried captured inputs."""
    698     for_result = forbody(start + i * delta, *args)
    699     # Nullary functions return an Operation. Normal functions can't do this
    700     # because their return values are converted to Tensors.
    701     if isinstance(for_result, ops.Operation):
    702       for_result = ()
    703     # Unary functions return a single Tensor value.
    704     elif isinstance(for_result, ops.Tensor):
    705       for_result = (for_result,)
    706     return (i + 1, n, start, delta) + tuple(for_result)
    707 
    708   if hostmem is not None:
    709     hostmem = [0, 1, 2, 3] + [(4 + _) for _ in hostmem]
    710   else:
    711     hostmem = [0, 1, 2, 3]
    712 
    713   results = While(
    714       input_=[0, n, start, delta] + inputs,
    715       cond=WhileCond,
    716       body=WhileBody,
    717       name=name,
    718       hostmem=hostmem)
    719   # Slice off the loop-carried captured inputs.
    720   return list(results[4:len(results)])
    721 
    722 
    723 def For(start,
    724         limit,
    725         delta,
    726         inputs,
    727         body,
    728         name=None,
    729         hostmem=None,
    730         rewrite_with_while=None):
    731   r"""out = input; for i in range(start, limit, delta) out = body(i, out).
    732 
    733   Args:
    734     start: A `Tensor` of type `int32`.
    735     limit: A `Tensor` of type `int32`.
    736     delta: A `Tensor` of type `int32`.
    737     inputs: A list of `Tensor` objects.
    738       A list of input tensors whose types are T.
    739     body: A function takes a list of tensors and returns another
    740       list of tensors. Both lists have the same types as (int32, T...).
    741     name: A name for the operation (optional).
    742     hostmem: A list of integer. If i is in the list, inputs[i] is a
    743       host memory tensor. In other words, (i+1)-th argument of the body
    744       function is expecting a host memory.
    745     rewrite_with_while: If True, using While op to implement the For.
    746 
    747   Returns:
    748     A list of `Tensor` objects. Has the same type as `input`.
    749     A list of output tensors whose types are T.
    750   """
    751   if rewrite_with_while:
    752     return _ForUsingWhile(start, limit, delta, inputs, body, name, hostmem)
    753   if body.captured_inputs:
    754     ret = gen_functional_ops._for(
    755         start,
    756         limit,
    757         delta,
    758         inputs + body.captured_inputs,
    759         _LoopBodyCaptureWrapper(body),
    760         name=name)
    761     # Slice off the loop-carried captured inputs.
    762     ret = ret[:-len(body.captured_inputs)]
    763   else:
    764     ret = gen_functional_ops._for(start, limit, delta, inputs, body, name=name)
    765   if hostmem:
    766     num_for_params = 3  # start/limit/delta
    767 
    768     input_attr = attr_value_pb2.AttrValue()
    769     input_attr.list.i.extend([num_for_params + i for i in hostmem])
    770     ret[0].op._set_attr("_input_hostmem", input_attr)  # pylint: disable=protected-access
    771 
    772     output_attr = attr_value_pb2.AttrValue()
    773     output_attr.list.i.extend(hostmem)
    774     ret[0].op._set_attr("_output_hostmem", output_attr)  # pylint: disable=protected-access
    775   return ret
    776 # pylint: enable=invalid-name,protected-access
    777 
    778 
    779 def partitioned_call(args, f, tout=None, executing_eagerly=None, config=None,
    780                      executor_type=None):
    781   """Executes a function while respecting device annotations.
    782 
    783   Currently, only those functions that execute within the same address space
    784   can be executed.
    785 
    786   Args:
    787     args: The arguments of the function, including captured inputs.
    788     f: The function to execute; an instance of `_DefinedFunction` or
    789       `_EagerDefinedFunction`.
    790     tout: a list containing the output dtypes enums; if `None`, inferred from
    791       the signature of `f`.
    792     executing_eagerly: (Optional) A boolean indicating whether the context is
    793       executing eagerly. If `None`, fetched from the global context.
    794     config: (Optional) A `tensorflow::ConfigProto` proto, serialized. If
    795       `None`, all optimizations are disabled. Currently only handled for eager
    796       defined functions.
    797     executor_type: (Optional) A string for the name of the executor to be used
    798       in the function call. If not set, or set to an empty string, the default
    799       tensorflow executor will be used.
    800 
    801   Returns:
    802     The list of `Tensor`s returned by invoking `f(args)`. If the function does
    803     not return anything, then returns `None` if eager execution is enabled, or
    804     the `Operation` if not.
    805   """
    806 
    807   if tout is None:
    808     tout = tuple(x.type for x in f.definition.signature.output_arg)
    809 
    810   if executing_eagerly is None:
    811     executing_eagerly = context.executing_eagerly()
    812 
    813   if config is None:
    814     config = function_utils.get_disabled_rewriter_config()
    815 
    816   if executor_type is None:
    817     executor_type = ""
    818 
    819   if executing_eagerly or len(tout):
    820     if f.stateful_ops:
    821       outputs = gen_functional_ops.stateful_partitioned_call(
    822           args=args, Tout=tout, f=f, config_proto=config,
    823           executor_type=executor_type)
    824     else:
    825       outputs = gen_functional_ops.partitioned_call(
    826           args=args, Tout=tout, f=f, config_proto=config,
    827           executor_type=executor_type)
    828     return outputs if outputs else None
    829 
    830   # The generated binding returns an empty list for functions that don't
    831   # return any Tensors, hence the need to use `create_op` directly.
    832   args = [ops.internal_convert_to_tensor(x) for x in args]
    833   tin_attr = attr_value_pb2.AttrValue(
    834       list=attr_value_pb2.AttrValue.ListValue(
    835           type=[x.dtype.as_datatype_enum for x in args]))
    836   tout_attr = attr_value_pb2.AttrValue(
    837       list=attr_value_pb2.AttrValue.ListValue(type=tout))
    838   func_attr = attr_value_pb2.AttrValue(
    839       func=attr_value_pb2.NameAttrList(name=f.name))
    840   executor_type_attr = attr_value_pb2.AttrValue(
    841       s=compat.as_bytes(executor_type))
    842 
    843   # When running in graph mode, the graph and function graphs are optimized
    844   # (i.e. run through grappler) per the session options, so we can disable any
    845   # eager-specific rewriting.
    846   config_proto = attr_value_pb2.AttrValue(
    847       s=function_utils.get_disabled_rewriter_config())
    848 
    849   graph = ops.get_default_graph()
    850   f.add_to_graph(graph)
    851   op_name = "StatefulPartitionedCall" if f.stateful_ops else "PartitionedCall"
    852   op = graph.create_op(
    853       op_name,
    854       args,
    855       tout,
    856       compute_shapes=False,
    857       name="PartitionedFunctionCall",
    858       attrs={
    859           "Tin": tin_attr,
    860           "Tout": tout_attr,
    861           "f": func_attr,
    862           "config_proto": config_proto,
    863           "executor_type": executor_type_attr,
    864       })
    865   outputs = op.outputs
    866   return outputs if outputs else op
    867