Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Control Flow Operations.
     16 
     17 See the [autograph](https://www.tensorflow.org/guide/autographs) guide.
     18 """
     19 # pylint: disable=g-bad-name
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 import abc
     25 import collections
     26 import functools
     27 
     28 import six
     29 
     30 from tensorflow.core.framework import attr_value_pb2
     31 from tensorflow.core.protobuf import control_flow_pb2
     32 from tensorflow.python.eager import context
     33 from tensorflow.python.framework import composite_tensor
     34 from tensorflow.python.framework import constant_op
     35 from tensorflow.python.framework import dtypes
     36 from tensorflow.python.framework import errors
     37 from tensorflow.python.framework import ops
     38 from tensorflow.python.framework import tensor_shape
     39 from tensorflow.python.framework import tensor_util
     40 from tensorflow.python.ops import array_ops
     41 from tensorflow.python.ops import control_flow_util as util
     42 from tensorflow.python.ops import gen_array_ops
     43 from tensorflow.python.ops import gen_control_flow_ops
     44 from tensorflow.python.ops import gen_data_flow_ops
     45 from tensorflow.python.ops import gen_logging_ops
     46 from tensorflow.python.ops import gen_resource_variable_ops
     47 from tensorflow.python.ops import math_ops
     48 from tensorflow.python.ops import tensor_array_ops
     49 # go/tf-wildcard-import
     50 # pylint: disable=wildcard-import,undefined-variable
     51 from tensorflow.python.ops.gen_control_flow_ops import *
     52 # pylint: enable=wildcard-import
     53 from tensorflow.python.platform import tf_logging as logging
     54 from tensorflow.python.util import compat
     55 from tensorflow.python.util import deprecation
     56 from tensorflow.python.util import nest
     57 from tensorflow.python.util import tf_should_use
     58 from tensorflow.python.util.lazy_loader import LazyLoader
     59 from tensorflow.python.util.tf_export import tf_export
     60 
     61 # This is to avoid a circular dependency:
     62 # cond_v2 -> gradients_util -> control_flow_ops
     63 cond_v2 = LazyLoader("cond_v2", globals(),
     64                      "tensorflow.python.ops.cond_v2")
     65 
     66 # This is to avoid circular dependencies:
     67 # while_v2 -> control_flow_ops
     68 # while_v2 -> gradients_util -> control_flow_ops
     69 while_v2 = LazyLoader("while_v2", globals(),
     70                       "tensorflow.python.ops.while_v2")
     71 
     72 # We override the 'tuple' for a control flow op, so we keep python's
     73 # existing 'tuple' for later use in this module.
     74 _basetuple = tuple
     75 
     76 
     77 def _summarize_eager(tensor, summarize=None):
     78   """Returns a summarized string representation of eager `tensor`.
     79 
     80   Args:
     81     tensor: EagerTensor to summarize
     82     summarize: Include these many first elements of `array`
     83   """
     84   # Emulate the behavior of Tensor::SummarizeValue()
     85   if summarize is None:
     86     summarize = 3
     87   elif summarize < 0:
     88     summarize = array_ops.size(tensor)
     89   # reshape((-1,)) is the fastest way to get a flat array view
     90   if tensor._rank():  # pylint: disable=protected-access
     91     flat = tensor.numpy().reshape((-1,))
     92     lst = [str(x) for x in flat[:summarize]]
     93     if len(lst) < flat.size:
     94       lst.append("...")
     95   else:
     96     # tensor.numpy() returns a scalar for zero dimensional arrays
     97     if summarize != 0:
     98       lst = [str(tensor.numpy())]
     99     else:
    100       lst = []
    101 
    102   return ", ".join(lst)
    103 
    104 
    105 # pylint: disable=protected-access
    106 
    107 
    108 # Assert and Print are special symbols in python, so we must
    109 # use an upper-case version of them.
    110 @tf_export("debugging.Assert", "Assert")
    111 @tf_should_use.should_use_result
    112 def Assert(condition, data, summarize=None, name=None):
    113   """Asserts that the given condition is true.
    114 
    115   If `condition` evaluates to false, print the list of tensors in `data`.
    116   `summarize` determines how many entries of the tensors to print.
    117 
    118   NOTE: In graph mode, to ensure that Assert executes, one usually attaches
    119   a dependency:
    120 
    121   ```python
    122   # Ensure maximum element of x is smaller or equal to 1
    123   assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
    124   with tf.control_dependencies([assert_op]):
    125     ... code using x ...
    126   ```
    127 
    128   Args:
    129     condition: The condition to evaluate.
    130     data: The tensors to print out when condition is false.
    131     summarize: Print this many entries of each tensor.
    132     name: A name for this operation (optional).
    133 
    134   Returns:
    135     assert_op: An `Operation` that, when executed, raises a
    136     `tf.errors.InvalidArgumentError` if `condition` is not true.
    137     @compatibility{eager} returns None.
    138 
    139   Raises:
    140     @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
    141     is not true
    142   """
    143   if context.executing_eagerly():
    144     if not condition:
    145       xs = ops.convert_n_to_tensor(data)
    146       data_str = [_summarize_eager(x, summarize) for x in xs]
    147       raise errors.InvalidArgumentError(
    148           node_def=None,
    149           op=None,
    150           message="Expected '%s' to be true. Summarized data: %s" %
    151           (condition, "\n".join(data_str)))
    152     return
    153 
    154   with ops.name_scope(name, "Assert", [condition, data]) as name:
    155     xs = ops.convert_n_to_tensor(data)
    156     if all(x.dtype in {dtypes.string, dtypes.int32} for x in xs):
    157       # As a simple heuristic, we assume that string and int32 are
    158       # on host to avoid the need to use cond. If it is not case,
    159       # we will pay the price copying the tensor to host memory.
    160       return gen_logging_ops._assert(condition, data, summarize, name="Assert")
    161     else:
    162       condition = ops.convert_to_tensor(condition, name="Condition")
    163 
    164       def true_assert():
    165         return gen_logging_ops._assert(
    166             condition, data, summarize, name="Assert")
    167 
    168       guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
    169       if context.executing_eagerly():
    170         return
    171       return guarded_assert.op
    172 
    173 
    174 def _Identity(data, name=None):
    175   """Return a tensor with the same shape and contents as the input tensor.
    176 
    177   Args:
    178     data: A Tensor.
    179     name: A name for this operation (optional).
    180 
    181   Returns:
    182     A Tensor with the same type and value as the input Tensor.
    183   """
    184   data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
    185   if isinstance(data, ops.Tensor):
    186     if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    187       return gen_array_ops.ref_identity(data, name=name)
    188     else:
    189       return array_ops.identity(data, name=name)
    190   elif isinstance(data, composite_tensor.CompositeTensor):
    191     return nest.map_structure(_Identity, data, expand_composites=True)
    192   else:
    193     raise TypeError("Type %s not supported" % type(data))
    194 
    195 
    196 def _NextIteration(data, name=None):
    197   data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
    198   if isinstance(data, ops.Tensor):
    199     if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    200       return ref_next_iteration(data, name=name)
    201     else:
    202       return next_iteration(data, name=name)
    203   elif isinstance(data, composite_tensor.CompositeTensor):
    204     return nest.map_structure(_NextIteration, data, expand_composites=True)
    205   else:
    206     raise TypeError("Type %s not supported" % type(data))
    207 
    208 
    209 def _Enter(data,
    210            frame_name,
    211            is_constant=False,
    212            parallel_iterations=10,
    213            use_ref=True,
    214            use_input_shape=True,
    215            name=None):
    216   """Creates or finds a child frame, and makes `data` available to it.
    217 
    218   The unique `frame_name` is used by the `Executor` to identify frames. If
    219   `is_constant` is true, `data` is a constant in the child frame; otherwise
    220   it may be changed in the child frame. At most `parallel_iterations`
    221   iterations are run in parallel in the child frame.
    222 
    223   Args:
    224     data: The tensor to be made available to the child frame.
    225     frame_name: The name of the child frame.
    226     is_constant: If true, the output is constant within the child frame.
    227     parallel_iterations: The number of iterations allowed to run in parallel.
    228     use_ref: If true, use ref_enter if data is of ref type.
    229     use_input_shape: If true, set the result's shape based on data's shape.
    230     name: A name for this operation (optional).
    231 
    232   Returns:
    233     The same tensor as `data`.
    234   """
    235   data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
    236   if isinstance(data, ops.Tensor):
    237     if data.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
    238       result = gen_control_flow_ops.ref_enter(
    239           data, frame_name, is_constant, parallel_iterations, name=name)
    240     else:
    241       result = gen_control_flow_ops.enter(
    242           data, frame_name, is_constant, parallel_iterations, name=name)
    243     if use_input_shape:
    244       result.set_shape(data.get_shape())
    245     return result
    246   elif isinstance(data, composite_tensor.CompositeTensor):
    247     def enter_component(t):
    248       return _Enter(t, frame_name, is_constant, parallel_iterations,
    249                     use_ref, use_input_shape)
    250     return nest.map_structure(enter_component, data, expand_composites=True)
    251   else:
    252     raise TypeError("Type %s not supported" % type(data))
    253 
    254 
    255 def exit(data, name=None):  # pylint: disable=redefined-builtin
    256   """Exits the current frame to its parent frame.
    257 
    258   Exit makes its input `data` available to the parent frame.
    259 
    260   Args:
    261     data: The tensor to be made available to the parent frame.
    262     name: A name for this operation (optional).
    263 
    264   Returns:
    265     The same tensor as `data`.
    266   """
    267   data = ops.internal_convert_to_tensor_or_composite(data, as_ref=True)
    268   if isinstance(data, ops.Tensor):
    269     if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    270       return gen_control_flow_ops.ref_exit(data, name)
    271     else:
    272       return gen_control_flow_ops._exit(data, name)
    273   elif isinstance(data, composite_tensor.CompositeTensor):
    274     return nest.map_structure(exit, data, expand_composites=True)
    275   else:
    276     raise TypeError("Type %s not supported" % type(data))
    277 
    278 
    279 def switch(data, pred, dtype=None, name=None):
    280   """Forwards `data` to an output determined by `pred`.
    281 
    282   If `pred` is false, the `data` input is forwarded to the first output.
    283   Otherwise, the data goes to the second output.
    284 
    285   This op handles `Tensor`s and `IndexedSlices`.
    286 
    287   Args:
    288     data: The tensor to be forwarded to the appropriate output.
    289     pred: A scalar that specifies which output port will receive data.
    290     dtype: Optional element type for the returned tensor. If missing, the type
    291       is inferred from the type of `value`.
    292     name: A name for this operation (optional).
    293 
    294   Returns:
    295     `(output_false, output_true)`: If `pred` is true, data will be forwarded
    296     to `output_true`, otherwise it goes to `output_false`.
    297   """
    298   with ops.name_scope(name, "Switch", [data, pred]) as name:
    299     data = ops.internal_convert_to_tensor_or_composite(
    300         data, dtype=dtype, name="data", as_ref=True)
    301     pred = ops.convert_to_tensor(pred, name="pred")
    302     if isinstance(data, ops.Tensor):
    303       return gen_control_flow_ops.switch(data, pred, name=name)
    304     else:
    305       if not isinstance(data, composite_tensor.CompositeTensor):
    306         raise TypeError("Type %s not supported" % type(data))
    307       tensors = nest.flatten(data, expand_composites=True)
    308       mapped = [gen_control_flow_ops.switch(tensor, pred) for tensor in tensors]
    309       mapped_f, mapped_t = zip(*mapped)
    310       return (nest.pack_sequence_as(data, mapped_f, expand_composites=True),
    311               nest.pack_sequence_as(data, mapped_t, expand_composites=True))
    312 
    313 
    314 def _SwitchRefOrTensor(data, pred, name="Switch"):
    315   """Forwards `data` to an output determined by `pred`.
    316 
    317   If `pred` is false, the `data` input is forwarded to the first output.
    318   Otherwise, the data goes to the second output.
    319 
    320   This op handles `Tensor`s and `IndexedSlices`.
    321 
    322   Args:
    323     data: The tensor to be forwarded to the appropriate output.
    324     pred: A scalar that specifies which output port will receive data.
    325     name: A name for this operation (optional).
    326 
    327   Returns:
    328     `(output_false, output_true)`: If `pred` is true, data will be forwarded to
    329     `output_true`, otherwise it goes to `output_false`.
    330 
    331   Raises:
    332     TypeError: if data is not a Tensor or IndexedSlices
    333   """
    334   data = ops.convert_to_tensor_or_composite(data, name="data")
    335   # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
    336   # addresses the following scenario.
    337   #
    338   # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
    339   #
    340   # 1. The update op is created inside a `with ops.colocate(var):` block
    341   #
    342   # 2. Some tensor `data` is captured and a switch is created in a
    343   #    `with ops.colocate_with(data):` block.
    344   #
    345   # with ops.colocate_with(var):
    346   #  with ops.colocate_with(data):
    347   #    op = ...
    348   #
    349   # var and data may be pinned to different devices, so we want to ops
    350   # created within ops.colocate_with(data) to ignore the existing stack.
    351   with ops.colocate_with(data, ignore_existing=True):
    352     if isinstance(data, ops.Tensor):
    353       if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    354         return ref_switch(data, pred, name=name)
    355     return switch(data, pred, name=name)
    356 
    357 
    358 def merge(inputs, name=None):
    359   """Returns the value of an available element of `inputs`.
    360 
    361   This op tests each of the tensors in `inputs` in turn to determine if any of
    362   them is available. If it finds an available tensor, it returns it and its
    363   index in `inputs`.
    364 
    365   It is an error if more than one tensor in `inputs` is available. If no tensor
    366   in `inputs` is available, the returned tensor and index are not set.
    367 
    368   This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
    369   `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
    370   before merging.
    371 
    372   Args:
    373     inputs: The input tensors, at most one of which is available.
    374     name: A name for this operation (optional).
    375 
    376   Returns:
    377     A tuple containing the chosen input tensor and its index in `inputs`.
    378 
    379   Raises:
    380     ValueError: If any of the inputs is None, or inputs are IndexedSlices and
    381       some but not all have a dense_shape property.
    382   """
    383   if any(inp is None for inp in inputs):
    384     raise ValueError("At least one of the merge inputs is None: %s" % inputs)
    385   with ops.name_scope(name, "Merge", inputs) as name:
    386     inputs = [
    387         ops.internal_convert_to_tensor_or_composite(inp, as_ref=True)
    388         for inp in inputs
    389     ]
    390     if all(isinstance(v, ops.Tensor) for v in inputs):
    391       if all(v.dtype._is_ref_dtype for v in inputs):  # pylint: disable=protected-access
    392         return gen_control_flow_ops.ref_merge(inputs, name)
    393       else:
    394         return gen_control_flow_ops.merge(inputs, name)
    395     else:
    396       # If there is a mix of tensors and indexed slices, then convert the
    397       # tensors to indexed slices.
    398       if all(isinstance(v, (ops.IndexedSlices, ops.Tensor)) for v in inputs):
    399         inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
    400 
    401       for v in inputs:
    402         if not isinstance(v, composite_tensor.CompositeTensor):
    403           raise TypeError("Type %s not supported" % type(v))
    404 
    405       for v in inputs[1:]:
    406         nest.assert_same_structure(inputs[0], v, expand_composites=True)
    407 
    408       flat_inputs = [nest.flatten(v, expand_composites=True) for v in inputs]
    409       merged_results = [gen_control_flow_ops.merge(component)
    410                         for component in zip(*flat_inputs)]
    411       flat_merged = [tensor for (tensor, _) in merged_results]
    412       chosen_index = merged_results[0][1]
    413       merged_inputs = nest.pack_sequence_as(inputs[0], flat_merged,
    414                                             expand_composites=True)
    415       return (merged_inputs, chosen_index)
    416 
    417 
    418 # pylint: enable=protected-access
    419 
    420 
    421 def _convert_tensorarray_to_flow(tensor_or_tensor_array):
    422   if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
    423     return tensor_or_tensor_array.flow
    424   else:
    425     return tensor_or_tensor_array
    426 
    427 
    428 def _make_tensor_array(ta, t_or_flow):
    429   # pylint: disable=protected-access
    430   new_ta = tensor_array_ops.TensorArray(
    431       dtype=ta.dtype,
    432       handle=ta.handle,
    433       flow=t_or_flow,
    434       infer_shape=ta._infer_shape,
    435       colocate_with_first_write_call=ta._colocate_with_first_write_call)
    436   new_ta._colocate_with = ta._colocate_with
    437   new_ta._element_shape = ta._element_shape
    438   # pylint: enable=protected-access
    439   return new_ta
    440 
    441 
    442 def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
    443   if len(tensors_or_tensorarrays) != len(tensors_or_flows):
    444     raise ValueError(
    445         "Lengths of original Tensor list and new list do not match: %d vs. %d" %
    446         (len(tensors_or_tensorarrays), len(tensors_or_flows)))
    447   return [
    448       _make_tensor_array(ta, t_or_flow) if isinstance(
    449           ta, tensor_array_ops.TensorArray) else t_or_flow
    450       for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
    451   ]
    452 
    453 
    454 def _ShapeLessThanOrEqual(shape1, shape2):
    455   if shape2.dims is None:
    456     return True
    457   if shape1.ndims != shape2.ndims:
    458     return False
    459   for dim1, dim2 in zip(shape1.dims, shape2.dims):
    460     if dim2.value is not None and dim1.value != dim2.value:
    461       return False
    462   return True
    463 
    464 
    465 def _get_shape_invariant(var, shape=None):
    466   """Returns a shape invariant for the given variable.
    467 
    468   If `var` is a `CompositeTensor`, then this uses
    469   `_shape_invariant_to_components()` to get shape invariants for the
    470   component tensors.
    471 
    472   Args:
    473     var: The tensor whose shape is described.
    474     shape: The shape invariant for the tensor.  If not specified, then a default
    475       shape invariant for `var` is returned.
    476 
    477   Returns:
    478     The shape invariant for `var` (if it is a `Tensor`), or the shape invariants
    479     for the components that comprise `var` (if it is a `CompositeTensor`).
    480   """
    481   if isinstance(var, composite_tensor.CompositeTensor):
    482     return var._shape_invariant_to_components(shape)  # pylint: disable=protected-access
    483   elif shape is None:
    484     return var.shape
    485   else:
    486     return shape
    487 
    488 
    489 def _SetShapeInvariants(input_vars, enter_vars, shapes):
    490   """Set the shapes of the tensors in `enter_vars` to `shapes`.
    491 
    492   Args:
    493     input_vars: A list of tensors that are inputs to `enter_vars`.
    494     enter_vars: A list of tensors whose shapes will be set.
    495     shapes: A (possibly nested) list of shapes.
    496 
    497   Raises:
    498     ValueError: If any tensor in `enter_vars` has a less specific shape
    499       than its corresponding shape in `shapes`.
    500   """
    501   if shapes is None:
    502     return
    503   flat_shapes = nest.flatten(shapes)
    504   if not all(isinstance(s, tensor_shape.TensorShape) for s in flat_shapes):
    505     raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
    506   # Check that the shapes of the inputs are less than the shape invariants,
    507   # and set the shapes of `enter_vars` to the shape invariants.
    508   for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
    509     if isinstance(var, ops.Tensor):
    510       if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
    511         raise ValueError(
    512             "The shape invariant specified for %s is not compatible with "
    513             "the initial shape of the loop variable. It enters the loop "
    514             "with shape %s, but the specified shape invariant is %s." %
    515             (inp.name, inp.get_shape(), shape))
    516       var.set_shape(shape)
    517     else:
    518       raise TypeError("Type %s not supported" % type(var))
    519 
    520 
    521 def _EnforceShapeInvariant(merge_var, next_var):
    522   """Check if the shapes of the loops variables are invariants.
    523 
    524   Args:
    525     merge_var: The list of tensors representing the initial values of the loop
    526       variables.
    527     next_var: The list of tensors representing the values of the loop variables
    528       after one loop iteration.
    529 
    530   Raises:
    531     ValueError: If any tensor in `merge_var` has a more specific shape than
    532       its correspnding tensor in `next_var`.
    533   """
    534   if isinstance(merge_var, ops.Tensor):
    535     m_shape = merge_var.get_shape()
    536     n_shape = next_var.get_shape()
    537     if not _ShapeLessThanOrEqual(n_shape, m_shape):
    538       enter = merge_var.op.inputs[0].op
    539       assert util.IsLoopEnter(enter)
    540       input_t = enter.inputs[0]
    541       raise ValueError(
    542           "Input tensor '%s' enters the loop with shape %s, but has shape %s "
    543           "after one iteration. To allow the shape to vary across iterations, "
    544           "use the `shape_invariants` argument of tf.while_loop to specify a "
    545           "less-specific shape." % (input_t.name, input_t.shape, n_shape))
    546   else:
    547     raise TypeError("Type %s not supported" % type(merge_var))
    548 
    549 
    550 def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
    551   """Add NextIteration and back edge from v to m."""
    552   if isinstance(m, ops.Tensor):
    553     v = ops.convert_to_tensor(v)
    554     v = _NextIteration(v)
    555     if enforce_shape_invariant:
    556       # Make sure the shapes of loop outputs are correct. We do this before
    557       # calling _update_input, which will raise a less-helpful error message if
    558       # the types don't match.
    559       # TODO(skyewm): call this for other cases below (needs testing)
    560       _EnforceShapeInvariant(m, v)
    561     m.op._update_input(1, v)  # pylint: disable=protected-access
    562   elif isinstance(m, composite_tensor.CompositeTensor):
    563     # pylint: disable=protected-access
    564     def update_component(m_component, v_component):
    565       m_component.op._update_input(1, v_component)
    566     if isinstance(m, ops.IndexedSlices):
    567       v = math_ops._as_indexed_slices(v, optimize=False)
    568     # pylint: enable=protected-access
    569     v = _NextIteration(v)
    570     return nest.map_structure(update_component, m, v, expand_composites=True)
    571   else:
    572     raise TypeError("Type %s not supported" % type(m))
    573   return v
    574 
    575 
    576 def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
    577   """Calculate a max_size for use by stack ops inside an XLA while_loop.
    578 
    579   Args:
    580     value: The value inside the while_loop forward context.  Used for printing
    581       error messages.
    582     while_ctxt: The forward context inside which value resides.  This does not
    583       always match the value's immediate context, as `value` may be inside e.g.
    584       a cond context inside the while_loop.
    585 
    586   Returns:
    587     A tensor containing the `max_size` to feed to a Stack initializer.
    588 
    589   Raises:
    590     ValueError: If `value` is nested inside a `while_loop` that either
    591       lacks a `maximum_iterations` parameter, or the `maximum_iterations`
    592       parameter:
    593 
    594         - is inside a `while_loop` that is a parent of the calling context, and
    595         - cannot be evaluated at graph build time to a constant.
    596   """
    597   value_name = value.name
    598   # curr_ctxt is the context that tf.gradients was called in.
    599   curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
    600 
    601   curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
    602   max_size = constant_op.constant(1)
    603 
    604   # Loop through all containing while contexts between value and the
    605   # current context, multiplying together each context's
    606   # max_iterations to get the maximum stack size.
    607   while while_ctxt not in (None, curr_ctxt):
    608     max_iter = while_ctxt.maximum_iterations
    609     if max_iter is None:
    610       raise ValueError(
    611           "Cannot create a gradient accumulator for tensor '%s' inside "
    612           "XLA while_loop because maximum_iterations was not passed to "
    613           "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
    614 
    615     # pylint: disable=protected-access
    616     max_iter_ctxt = max_iter.op._get_control_flow_context()
    617     # pylint: enable=protected-access
    618 
    619     # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
    620     if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
    621       max_size *= max_iter
    622     else:
    623       # We cannot use max_iter because it's defined in a nested while
    624       # or cond context, so will fail if we try to use it as input to
    625       # any ops in curr_ctxt (e.g. max_size or the final accumulator
    626       # stack). Attempt to get a constant value out to use instead.
    627       const_max_iter = tensor_util.constant_value(max_iter)
    628       if const_max_iter is None:
    629         raise ValueError(
    630             "Cannot create a gradient accumulator for tensor '%s' inside XLA "
    631             "while_loop. maximum_iterations tensor '%s' for while_loop context "
    632             "'%s' must be statically known (e.g. a constant value or known "
    633             "shape dimension), or be defined at or outside the while loop "
    634             "context '%s' (currently defined in '%s')." %
    635             (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
    636              max_iter_ctxt.name))
    637       max_size *= const_max_iter
    638 
    639     # Find the next outer WhileContext (or stop if we reach the
    640     # tf.gradient's context).
    641     while_ctxt = util.GetContainingWhileContext(
    642         while_ctxt.outer_context, stop_ctxt=curr_ctxt)
    643 
    644   return max_size
    645 
    646 
    647 class GradLoopState(object):
    648   """The state used for constructing the gradient graph for a while loop.
    649 
    650   We create a GradLoopState for each while loop in forward and its
    651   corresponding while loop in backprop. This gives us access to both
    652   the forward and the backprop WhileContexts.
    653 
    654   During the construction of gradient graph, any time when we detect
    655   a forward value that is needed for backprop, we create a history
    656   accumulator and add it to `history_map`. Any time when we backprop
    657   a loop switch op (in _SwitchGrad), we add the grad merge op in
    658   `switch_map`.
    659   """
    660 
    661   def __init__(self, forward_ctxt, outer_grad_state):
    662     # The grad loop state for the outer while loop.
    663     self._outer_grad_state = None
    664 
    665     # The while loop context for forward.
    666     self._forward_context = None
    667 
    668     # The loop counter added by AddForwardLoopCounter. It is the value
    669     # of the loop counter for the next iteration.
    670     self._forward_index = None
    671 
    672     # A sync op for forward.
    673     self._forward_sync = None
    674 
    675     # The while loop context for backprop.
    676     self._grad_context = None
    677 
    678     # The loop counter added by AddBackpropLoopCounter. It is the value
    679     # of the loop counter for the current iteration.
    680     self._grad_index = None
    681 
    682     # A sync op for backprop.
    683     self._grad_sync = None
    684 
    685     # Information needed by backprop.
    686     self._history_map = {}
    687     self._switch_map = {}
    688     self._unused_exits = []
    689     self._deferred_exits = []
    690     self._forward_loop_exits = list(forward_ctxt.loop_exits)
    691     self._pending_exits_count = len(forward_ctxt.loop_exits)
    692 
    693     self._outer_grad_state = outer_grad_state
    694     if outer_grad_state:
    695       outer_forward_ctxt = outer_grad_state.forward_context
    696     else:
    697       if not hasattr(forward_ctxt, "outer_context"):
    698         raise ValueError("Failed to call gradients on a while loop without"
    699                          "properly serializing graph via MetaGraphDef")
    700       outer_forward_ctxt = forward_ctxt.outer_context
    701 
    702     # Add the forward loop counter.
    703     with forward_ctxt._graph.as_default():  # pylint: disable=protected-access
    704       if outer_forward_ctxt:
    705         outer_forward_ctxt.Enter()
    706       cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
    707       if outer_forward_ctxt:
    708         outer_forward_ctxt.Exit()
    709     self._forward_context = forward_ctxt
    710     self._forward_index = forward_index
    711 
    712     # Add the backprop WhileContext, and the backprop loop counter.
    713     if outer_grad_state:
    714       # This is a nested loop. Remember the iteration counts for each
    715       # execution of this inner loop.
    716       outer_forward_ctxt.AddName(cnt.name)
    717       history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
    718 
    719       outer_grad_ctxt = outer_grad_state.grad_context
    720       outer_grad_ctxt.Enter()
    721       self._grad_context = WhileContext(
    722           maximum_iterations=forward_ctxt.maximum_iterations,
    723           parallel_iterations=forward_ctxt.parallel_iterations,
    724           back_prop=forward_ctxt.back_prop,
    725           swap_memory=forward_ctxt.swap_memory,
    726           name=forward_ctxt.name,
    727           grad_state=self)
    728       real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
    729       self._grad_index = self._grad_context.AddBackpropLoopCounter(
    730           real_cnt, outer_grad_state)
    731       outer_grad_ctxt.Exit()
    732     else:
    733       if outer_forward_ctxt:
    734         outer_forward_ctxt.Enter()
    735       self._grad_context = WhileContext(
    736           maximum_iterations=forward_ctxt.maximum_iterations,
    737           parallel_iterations=forward_ctxt.parallel_iterations,
    738           back_prop=forward_ctxt.back_prop,
    739           swap_memory=forward_ctxt.swap_memory,
    740           name=forward_ctxt.name,
    741           grad_state=self)
    742       self._grad_index = self._grad_context.AddBackpropLoopCounter(
    743           cnt, outer_grad_state)
    744       if outer_forward_ctxt:
    745         outer_forward_ctxt.Exit()
    746 
    747   @property
    748   def outer_grad_state(self):
    749     """The grad loop state for outer loop."""
    750     return self._outer_grad_state
    751 
    752   @property
    753   def forward_context(self):
    754     """The while loop context for forward."""
    755     return self._forward_context
    756 
    757   @property
    758   def forward_index(self):
    759     """The loop index of forward loop."""
    760     return self._forward_index
    761 
    762   @property
    763   def forward_sync(self):
    764     """A control trigger node for synchronization in the forward loop.
    765 
    766     One main use is to keep the push ops of a stack executed in the
    767     iteration order.
    768     """
    769     if self._forward_sync is None:
    770       with ops.control_dependencies(None):
    771         self._forward_sync = control_trigger(name="f_sync")
    772       self._forward_sync._set_control_flow_context(self._forward_context)
    773       self._forward_index.op._add_control_input(self._forward_sync)
    774     return self._forward_sync
    775 
    776   @property
    777   def grad_context(self):
    778     """The corresponding WhileContext for gradient."""
    779     return self._grad_context
    780 
    781   @property
    782   def grad_index(self):
    783     """The loop index of backprop loop."""
    784     return self._grad_index
    785 
    786   @property
    787   def grad_sync(self):
    788     """A control trigger node for synchronization in the grad loop.
    789 
    790     One main use is to keep the pop ops of a stack executed in the
    791     iteration order.
    792     """
    793     if self._grad_sync is None:
    794       with ops.control_dependencies(None):
    795         self._grad_sync = control_trigger(name="b_sync")
    796       self._grad_sync._set_control_flow_context(self._grad_context)
    797       self._grad_index.op._add_control_input(self._grad_sync)
    798       if self._grad_context.outer_context:
    799         self._grad_context.outer_context.AddInnerOp(self._grad_sync)
    800     return self._grad_sync
    801 
    802   @property
    803   def history_map(self):
    804     """The map that records all the tensors needed for backprop."""
    805     return self._history_map
    806 
    807   @property
    808   def switch_map(self):
    809     """The map that records all the Switch ops for the while loop."""
    810     return self._switch_map
    811 
    812   @property
    813   def unused_exits(self):
    814     """The list of "unused" exits."""
    815     return self._unused_exits
    816 
    817   @property
    818   def deferred_exits(self):
    819     """The list of "deferred" exits."""
    820     return self._deferred_exits
    821 
    822   @property
    823   def forward_loop_exits(self):
    824     """The list of exits of the forward loop."""
    825     return self._forward_loop_exits
    826 
    827   @property
    828   def pending_exits_count(self):
    829     """The number of exits we expect to see but haven't."""
    830     return self._pending_exits_count
    831 
    832   @pending_exits_count.setter
    833   def pending_exits_count(self, cnt):
    834     """Set the pending count to cnt."""
    835     self._pending_exits_count = cnt
    836 
    837   def AddForwardAccumulator(self, value, dead_branch=False):
    838     """Add an accumulator for each forward tensor that is needed in backprop.
    839 
    840     This is added to the forward loop at the first time when a tensor
    841     in the forward loop is used by backprop gradient computation loop.
    842     We create an accumulator that accumulates the value of tensor at each
    843     iteration. Called in the control flow context where gradients() is called.
    844 
    845     The pseudocode is:
    846     ```
    847       acc = stack();
    848       while (_pivot) {
    849         acc = stack_push(acc, value);
    850       }
    851     ```
    852 
    853     We make sure that the stack push op in one iteration is executed before
    854     next iteration. This is achieved by adding a control edge from
    855     `forward_index.op.inputs[0].op` to the push op, and another control
    856     edge from the push op to either `forward_index.op` or `forward_sync`.
    857 
    858     Args:
    859       value: The source tensor in forward that is to be accumulated.
    860       dead_branch: True iff the tensor is on a dead branch of a cond.
    861 
    862     Returns:
    863       The stack that contains the accumulated history of the tensor.
    864 
    865     Raises:
    866       TypeError: For internal errors involving the value condition context.
    867       ValueError: If `value` is inside a XLA scope and a valid max size
    868         for the stack can't be found.
    869     """
    870     # curr_ctxt is the context that tf.gradients was called in.
    871     with self._forward_index.graph.as_default():
    872       curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
    873       with ops.control_dependencies(None):
    874         if curr_ctxt:
    875           curr_ctxt.Enter()
    876         with ops.colocate_with(value):
    877           # We only need to pass maximum_iterations to the stack if
    878           # we're inside an XLA context.
    879           if not util.IsInXLAContext(value.op):
    880             max_size = constant_op.constant(-1, dtypes.int32)
    881           else:
    882             max_size = GetMaxSizeFromNestedMaximumIterations(
    883                 value, self.forward_context)
    884           acc = gen_data_flow_ops.stack_v2(
    885               max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
    886         if curr_ctxt:
    887           curr_ctxt.Exit()
    888 
    889         # Make acc available in the forward context.
    890         enter_acc = self.forward_context.AddValue(acc)
    891 
    892         # Add the stack_push op in the context of value.op.
    893         swap_enabled = self.forward_context.swap_memory
    894         value_ctxt = util.GetOutputContext(value.op)
    895         if value_ctxt == self.forward_context:
    896           # value is not nested in the forward context.
    897           self.forward_context.Enter()
    898           push = gen_data_flow_ops.stack_push_v2(
    899               enter_acc, value, swap_memory=swap_enabled)
    900           self.forward_context.Exit()
    901           # Protect stack push and order it before forward_index.
    902           self.forward_index.op._add_control_input(push.op)
    903         else:
    904           # value is in a cond context within the forward context.
    905           if not isinstance(value_ctxt, CondContext):
    906             raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
    907           if dead_branch:
    908             # The special case for creating a zero tensor for a dead
    909             # branch of a switch. See ControlFlowState.ZerosLike().
    910             value_ctxt.outer_context.Enter()
    911             push = gen_data_flow_ops.stack_push_v2(
    912                 enter_acc, value, swap_memory=swap_enabled)
    913             value_ctxt.outer_context.Exit()
    914             push.op._set_control_flow_context(value_ctxt)
    915           else:
    916             value_ctxt.Enter()
    917             push = gen_data_flow_ops.stack_push_v2(
    918                 enter_acc, value, swap_memory=swap_enabled)
    919             value_ctxt.Exit()
    920           # Protect stack push and order it before forward_sync.
    921           self.forward_sync._add_control_input(push.op)
    922         # Order stack push after the successor of forward_index
    923         add_op = self.forward_index.op.inputs[0].op
    924         push.op._add_control_input(add_op)
    925         return acc
    926 
    927   def AddBackpropAccumulatedValue(self, history_value, value,
    928                                   dead_branch=False):
    929     """Add the getter for an accumulated value in the grad context.
    930 
    931     This is added to the backprop loop. Called in the grad context to
    932     get the value of an accumulated value. The stack pop op must be guarded
    933     by the pred of the controlling cond.
    934 
    935     Args:
    936       history_value: The history (a stack) of a value.
    937       value: The value that is pushed onto the stack.
    938       dead_branch: True iff the tensor is on a dead branch of a cond.
    939 
    940     Returns:
    941       The current value (the top of the stack).
    942     """
    943     history_ctxt = history_value.op._get_control_flow_context()
    944     # Find the cond context that controls history_value if any.
    945     cond_ctxt = None
    946     value_ctxt = value.op._get_control_flow_context()
    947     while value_ctxt and value_ctxt != history_ctxt:
    948       if isinstance(value_ctxt, CondContext):
    949         cond_ctxt = value_ctxt
    950         break
    951       value_ctxt = value_ctxt.outer_context
    952     with ops.control_dependencies(None):
    953       self.grad_context.Enter()
    954       if cond_ctxt:
    955         # Guard stack pop with a switch if it is controlled by a cond.
    956         grad_state = self
    957         pred = None
    958         while pred is None and grad_state:
    959           pred = grad_state.history_map.get(cond_ctxt.pred.name)
    960           grad_state = grad_state.outer_grad_state
    961         if pred is None:
    962           pred = cond_ctxt.pred
    963         branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
    964         history_value = _SwitchRefOrTensor(history_value, pred)[branch]
    965       pop = gen_data_flow_ops.stack_pop_v2(history_value,
    966                                            value.dtype.base_dtype)
    967       pop.set_shape(value.get_shape())
    968       self.grad_context.Exit()
    969     parallel_iterations = self.grad_context.parallel_iterations
    970     if parallel_iterations > 1:
    971       # All pops are ordered after pivot_for_body and before grad_sync.
    972       self.grad_sync._add_control_input(pop.op)
    973     return pop
    974 
    975   def GetRealValue(self, value):
    976     """Get the real value of `value`.
    977 
    978     If backprop "uses" a value produced by forward inference, an accumulator
    979     is added in the forward loop to accumulate its values.  We use the
    980     accumulated value. This method must be called in the grad loop context.
    981     `value` must be in forward and needed for backprop.
    982 
    983     Args:
    984       value: A tensor to be captured.
    985 
    986     Returns:
    987       The same tensor obtained from the saved history.
    988     """
    989     assert value.op.type not in ["Variable", "VariableV2"]
    990     real_value = self._history_map.get(value.name)
    991     if real_value is None:
    992       cur_value = value
    993       cur_grad_state = self
    994       while True:
    995         enter_op = util.GetLoopConstantEnter(cur_value)
    996         if enter_op:
    997           # Special case: cur_value comes from a constant Enter node.
    998           cur_value = enter_op.inputs[0]
    999           cur_grad_state = cur_grad_state.outer_grad_state
   1000           if cur_grad_state is None:
   1001             # We are now outside all nested loops for this gradient(),
   1002             # so `value` is a loop invariant and there is no need to
   1003             # save the history of value. Just make cur_value to enter
   1004             # the right control flow context.
   1005             real_value = self._grad_context.AddValue(cur_value)
   1006             break
   1007         elif constant_op.is_constant(cur_value):
   1008           # If the value to be forwarded is a constant, clone the constant in
   1009           # the gradient loop rather than using a stack.
   1010           # TODO(phawkins): consider hoisting the constant out of the loop
   1011           # instead.
   1012           real_value = constant_op.constant(
   1013               tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
   1014           break
   1015         else:
   1016           # Record the history of this value in forward_ctxt.
   1017           self._grad_context.Exit()
   1018           history_value = cur_grad_state.AddForwardAccumulator(cur_value)
   1019           self._grad_context.Enter()
   1020           break
   1021 
   1022       if real_value is None:
   1023         # Add the stack pop op in the grad context.
   1024         real_value = cur_grad_state.AddBackpropAccumulatedValue(
   1025             history_value, cur_value)
   1026         if cur_grad_state != self:
   1027           real_value = self._grad_context.AddValue(real_value)
   1028       self._history_map[value.name] = real_value
   1029     return real_value
   1030 
   1031 
   1032 def _GetWhileContext(op):
   1033   """Get the WhileContext to which this op belongs."""
   1034   ctxt = op._get_control_flow_context()
   1035   if ctxt:
   1036     ctxt = ctxt.GetWhileContext()
   1037   return ctxt
   1038 
   1039 
   1040 class ControlFlowState(object):
   1041   """Maintain the mapping from the loops to their grad states."""
   1042 
   1043   def __init__(self):
   1044     self._map = {}  # maps forward loop context to GradLoopState
   1045 
   1046   def GetGradState(self, op, before):
   1047     """Return the grad state for this op if it's in a forward loop context."""
   1048     if before and util.IsLoopExit(op):
   1049       forward_ctxt = op._get_control_flow_context()
   1050       forward_ctxt = forward_ctxt.outer_context
   1051       if forward_ctxt:
   1052         forward_ctxt = forward_ctxt.GetWhileContext()
   1053     else:
   1054       forward_ctxt = _GetWhileContext(op)
   1055     if forward_ctxt:
   1056       return self._map.get(forward_ctxt)
   1057     return None
   1058 
   1059   def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
   1060     """Process all the "unused" loop exits.
   1061 
   1062     The "unused" exits of the loops are added to `unused_exits`. An exit is
   1063     unused if its pending_count is 0. If there is an exit with real gradient,
   1064     all these deferred exits will enter the backprop loop with zero gradient.
   1065     Otherwise, they will enter the backprop loop with None. As an example,
   1066     people often write:
   1067 
   1068     ```python
   1069     v1, _ = tf.while_loop(p, b, [x1, x2])
   1070     result = gradients(v1, x1)
   1071     ```
   1072 
   1073     The exit node for x2 is not included by the betweenness analysis. But we
   1074     need to backprop x2 if x2 is involved in computing v1.
   1075 
   1076     Args:
   1077       pending_count: The number of backprop inputs for every op.
   1078       to_ops_set: The set of ops for ys in gradients(ys, xs)
   1079 
   1080     Returns:
   1081       The set of unused loop exits that we know at this point we need
   1082       to backprop.
   1083     """
   1084     loop_exits = []
   1085     for grad_state in self._map.values():
   1086       for y in grad_state.forward_loop_exits:
   1087         if pending_count[y.op] == 0:
   1088           grad_state.pending_exits_count -= 1
   1089           if y.op not in to_ops_set:
   1090             grad_state.unused_exits.append(y)
   1091           if grad_state.pending_exits_count == 0:
   1092             loop_exits.extend(grad_state.unused_exits)
   1093       # Need to include Enters in backprop for higher-order gradients.
   1094       for y in grad_state.forward_context.loop_enters:
   1095         if pending_count[y.op] == 0:
   1096           pending_count[y.op] = 1
   1097     return loop_exits
   1098 
   1099   def EnterGradWhileContext(self, op, before):
   1100     """Enter the WhileContext for gradient computation."""
   1101     grad_state = self.GetGradState(op, before)
   1102     if grad_state:
   1103       grad_state.grad_context.Enter()
   1104 
   1105   def ExitGradWhileContext(self, op, before):
   1106     """Exit the WhileContext for gradient computation."""
   1107     grad_state = self.GetGradState(op, before)
   1108     if grad_state:
   1109       grad_state.grad_context.Exit()
   1110 
   1111   def AddWhileContext(self, op, between_op_list, between_ops):
   1112     """Add the grad state for the while loop that op belongs to.
   1113 
   1114     Note that op is an Exit, and this method must be called in
   1115     the control flow context where gradients() is called.
   1116 
   1117     Note that this method modifies `between_op_list` and `between_ops`.
   1118     """
   1119     forward_ctxt = _GetWhileContext(op)
   1120     grad_state = self._map.get(forward_ctxt)
   1121     if grad_state is None:
   1122       # This is a new while loop so create a grad state for it.
   1123       outer_forward_ctxt = forward_ctxt.outer_context
   1124       if outer_forward_ctxt:
   1125         outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
   1126       outer_grad_state = None
   1127       if outer_forward_ctxt:
   1128         outer_grad_state = self._map.get(outer_forward_ctxt)
   1129       grad_state = GradLoopState(forward_ctxt, outer_grad_state)
   1130       self._map[forward_ctxt] = grad_state
   1131 
   1132       # We need to include all exits of a loop for backprop.
   1133       for loop_exit in grad_state.forward_loop_exits:
   1134         if loop_exit.op not in between_ops:
   1135           between_ops.add(loop_exit.op)
   1136           between_op_list.append(loop_exit.op)
   1137 
   1138   def ZerosLikeForExit(self, val):
   1139     """Create zeros_like gradient for a loop exit.
   1140 
   1141     If the result of a loop variable is not used but is involved in
   1142     computing the result of some needed loop variable, we create a
   1143     zero-valued tensor that is fed as gradient for the Exit node of that
   1144     loop variable. Note that val.op is an Exit, and this method must be
   1145     called in the control flow context where gradients() is called.
   1146 
   1147     Args:
   1148       val: The output tensor of an Exit op.
   1149 
   1150     Returns:
   1151       A zero tensor of the same shape of val.
   1152     """
   1153     val_shape = val.get_shape()
   1154     forward_ctxt = val.op._get_control_flow_context()
   1155     outer_forward_ctxt = forward_ctxt.outer_context
   1156     if outer_forward_ctxt:
   1157       outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
   1158     outer_grad_state = None
   1159     if outer_forward_ctxt:
   1160       outer_grad_state = self._map.get(outer_forward_ctxt)
   1161     if outer_grad_state:
   1162       # This is a nested loop.
   1163       if val_shape.is_fully_defined():
   1164         # If the shape is known statically, just create a zero tensor
   1165         # with the right shape in the right context.
   1166         outer_grad_state.grad_context.Enter()
   1167         result = array_ops.zeros(val_shape.dims, val.dtype)
   1168         outer_grad_state.grad_context.Exit()
   1169       else:
   1170         # Only the shape of value is needed for backprop.
   1171         forward_ctxt.outer_context.Enter()
   1172         shape = array_ops.shape_internal(val, optimize=False)
   1173         forward_ctxt.outer_context.Exit()
   1174         # Save the shape to a stack.
   1175         history_shape = outer_grad_state.AddForwardAccumulator(shape)
   1176         # Get the shape back from the stack.
   1177         outer_grad_ctxt = outer_grad_state.grad_context
   1178         outer_grad_ctxt.Enter()
   1179         real_shape = outer_grad_state.AddBackpropAccumulatedValue(
   1180             history_shape, shape)
   1181         result = array_ops.zeros(real_shape, val.dtype)
   1182         outer_grad_ctxt.Exit()
   1183     else:
   1184       # This is not a nested loop.
   1185       if val_shape.is_fully_defined():
   1186         # If the shape is known statically, just create a zero tensor
   1187         # with the right shape.
   1188         result = array_ops.zeros(val_shape.dims, val.dtype)
   1189       else:
   1190         result = array_ops.zeros_like(val, optimize=False)
   1191     return result
   1192 
   1193   def ZerosLike(self, op, index):
   1194     """Create zeros_like for the specified output of an op.
   1195 
   1196     If op is in a while loop that is part of gradients(), this method
   1197     must be called in its grad loop context.
   1198 
   1199     Args:
   1200       op: A tensorflow operation.
   1201       index: the index for a specific output of the op.
   1202 
   1203     Returns:
   1204       A zero tensor of the same shape of op.outputs[index].
   1205     """
   1206     if util.IsLoopSwitch(op):
   1207       return None
   1208     if op.graph._building_function:  # pylint: disable=protected-access
   1209       # The optimization here is tricky to apply to functions
   1210       return array_ops.zeros_like(op.outputs[index])
   1211     dead_branch = util.IsSwitch(op)
   1212     forward_ctxt = _GetWhileContext(op)
   1213     grad_state = self._map.get(forward_ctxt)
   1214     if grad_state is None:
   1215       # op is not in a while loop that is part of gradients().
   1216       return ZerosLikeOutsideLoop(op, index)
   1217     op_ctxt = op._get_control_flow_context()
   1218     val = ops.convert_to_tensor(op.outputs[index], name="tensor")
   1219     shape = val.get_shape()
   1220     if shape.is_fully_defined():
   1221       # If the shape is known statically, just create a zero tensor with
   1222       # the right shape in the grad loop context.
   1223       result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
   1224       if dead_branch:
   1225         # op is a cond switch. Guard the zero tensor with a switch.
   1226         pred = grad_state.history_map.get(op_ctxt.pred.name)
   1227         branch = op_ctxt.branch
   1228         result = _SwitchRefOrTensor(result, pred)[1 - branch]
   1229     else:
   1230       # Unknown shape so keep a history of the shape at runtime.
   1231       if dead_branch:
   1232         # Need to add a special switch to guard the value.
   1233         pred = op_ctxt.pred
   1234         branch = op_ctxt.branch
   1235         op_ctxt.outer_context.Enter()
   1236         val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
   1237         zeros_shape = array_ops.shape_internal(val, optimize=False)
   1238         op_ctxt.outer_context.Exit()
   1239         val.op._set_control_flow_context(op_ctxt)
   1240         zeros_shape.op._set_control_flow_context(op_ctxt)
   1241       else:
   1242         op_ctxt.Enter()
   1243         zeros_shape = array_ops.shape_internal(val, optimize=False)
   1244         op_ctxt.Exit()
   1245 
   1246       # Add forward accumulator for shape.
   1247       grad_state.grad_context.Exit()
   1248       history_zeros_shape = grad_state.AddForwardAccumulator(
   1249           zeros_shape, dead_branch=dead_branch)
   1250       grad_state.grad_context.Enter()
   1251 
   1252       # Create a zero tensor with the right shape.
   1253       shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
   1254                                                      zeros_shape, dead_branch)
   1255       result = array_ops.zeros(shape, val.dtype)
   1256     return result
   1257 
   1258   def PostProcessing(self):
   1259     """Perform postprocessing at the end of gradients().
   1260 
   1261     We have created the gradient graph at this point. So this function
   1262     can be used to perform any postprocessing on the gradient graph.
   1263     We currently perform the following postprocessing:
   1264       1. Patch the gradient graph if the output of a loop variable
   1265          doesn't depend on its input.
   1266     """
   1267     for _, grad_state in self._map.items():
   1268       for _, b_merge in grad_state.switch_map.items():
   1269         if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
   1270           # The value of this loop variable at iteration i+1 doesn't
   1271           # depend on its value at iteration i. So use zeros as the
   1272           # gradients for all iterations > 0.
   1273           dtype = b_merge.op.inputs[0].dtype
   1274           shape = b_merge.op.inputs[0].get_shape()
   1275           # pylint: disable=protected-access
   1276           if shape.is_fully_defined():
   1277             grad_state.grad_context.Enter()
   1278             # Create a zeros and use it for iterations > 0.
   1279             grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
   1280             next_grad_val = _NextIteration(grad_val)
   1281             grad_state.grad_context.Exit()
   1282           else:
   1283             # Create a zeros in the outer grad context.
   1284             outer_grad_ctxt = grad_state.grad_context.outer_context
   1285             if outer_grad_ctxt:
   1286               outer_grad_ctxt.Enter()
   1287             enter_grad_op = b_merge.op.inputs[0].op
   1288             enter_grad = enter_grad_op.inputs[0]
   1289             grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
   1290             grad_val = array_ops.zeros(grad_shape)
   1291             if outer_grad_ctxt:
   1292               outer_grad_ctxt.Exit()
   1293             # Use the zeros for iterations > 0.
   1294             grad_state.grad_context.Enter()
   1295             next_grad_val = _NextIteration(grad_val)
   1296             grad_state.grad_context.Exit()
   1297           b_merge.op._update_input(1, next_grad_val)
   1298           # pylint: enable=protected-access
   1299 
   1300 
   1301 def MaybeCreateControlFlowState(between_op_list, between_ops,
   1302                                 colocate_gradients_with_ops):
   1303   """Create the state for all the while loops involved in one gradients().
   1304 
   1305   We create a ControlFlowState when there are while loops involved in
   1306   gradients(). In gradients(), control flow logic is only invoked when
   1307   the ControlFlowState is not None.
   1308 
   1309   Note that this method modifies `between_op_list` and `between_ops`.
   1310   """
   1311   loop_state = None
   1312   for op in between_op_list:
   1313     if util.IsLoopExit(op):
   1314       if loop_state is None:
   1315         loop_state = ControlFlowState()
   1316       if colocate_gradients_with_ops:
   1317         with ops.colocate_with(op):
   1318           loop_state.AddWhileContext(op, between_op_list, between_ops)
   1319       else:
   1320         loop_state.AddWhileContext(op, between_op_list, between_ops)
   1321   return loop_state
   1322 
   1323 
   1324 def ZerosLikeOutsideLoop(op, index):
   1325   """Create zeros_like for the specified output of an op."""
   1326   val = op.outputs[index]
   1327   if not util.IsSwitch(op):
   1328     if val.dtype == dtypes.resource:
   1329       return array_ops.zeros(gen_resource_variable_ops.variable_shape(val))
   1330     return array_ops.zeros_like(val, optimize=False)
   1331   else:
   1332     op_ctxt = op._get_control_flow_context()
   1333     if op_ctxt:
   1334       # We are in a cond context. Use a switch to create zeros only when needed.
   1335       pred = op_ctxt.pred
   1336       branch = op_ctxt.branch
   1337       switch_val = switch(op.inputs[0], pred)[1 - branch]
   1338       # A op is created along the branch taken as control dependencies are on
   1339       # the whole op and not on the tensor output.
   1340       pivot = array_ops.identity(switch_val)
   1341       if val.dtype == dtypes.resource:
   1342         with ops.control_dependencies([pivot]):
   1343           return array_ops.zeros(
   1344               gen_resource_variable_ops.variable_shape(switch_val))
   1345       zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
   1346       # Ensure ops created within array_ops.zeros are dominated by switch in
   1347       # cond context.
   1348       with ops.control_dependencies([pivot]):
   1349         return array_ops.zeros(zeros_shape, dtype=val.dtype)
   1350     else:
   1351       return array_ops.zeros_like(val, optimize=False)
   1352 
   1353 
   1354 @six.add_metaclass(abc.ABCMeta)
   1355 class ControlFlowContext(object):
   1356   """The base class for control flow context.
   1357 
   1358   The usage pattern is a sequence of (Enter, Exit) followed by a final
   1359   ExitResult.
   1360 
   1361   We maintain the following state for control flow contexts during graph
   1362   construction:
   1363    1. graph has _control_flow_context: the current context used to
   1364       construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
   1365    2. op has _control_flow_context: the context to which the op belongs.
   1366       Set at the time the op is created. Immutable.
   1367    3. A ControlFlowContext has _outer_context: the context in which this
   1368       context is created. Set at the time a context is created. Immutable.
   1369    4. A ControlFlowContext has _context_stack.
   1370       Pushed and popped by ctxt.Enter() and ctxt.Exit()
   1371   """
   1372 
   1373   def __init__(self, values_def=None, import_scope=None):
   1374     self._nested_contexts = []
   1375     self._outer_context = ops.get_default_graph()._get_control_flow_context()
   1376     if self._outer_context:
   1377       self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
   1378     self._context_stack = []
   1379     if values_def:
   1380       self._init_values_from_proto(values_def, import_scope=import_scope)
   1381     else:
   1382       # The names of tensors that have been already seen in this context.
   1383       self._values = set()
   1384       # The keys are the names of tensors referenced by but external to this
   1385       # context. Each value is the Tensor that should be used by this context to
   1386       # access the key value (e.g. a switch output guarding a cond input value).
   1387       self._external_values = {}
   1388 
   1389   def _init_values_from_proto(self, values_def, import_scope=None):
   1390     """Initializes values and external_values from `ValuesDef` protocol buffer.
   1391 
   1392     Args:
   1393       values_def: `ValuesDef` protocol buffer.
   1394       import_scope: Optional `string`. Name scope to add.
   1395     """
   1396     assert isinstance(values_def, control_flow_pb2.ValuesDef)
   1397     self._values = set(
   1398         ops.prepend_name_scope(value, import_scope)
   1399         for value in values_def.values)
   1400     g = ops.get_default_graph()
   1401     self._external_values = {}
   1402     for k, v in values_def.external_values.items():
   1403       k = ops.prepend_name_scope(k, import_scope)
   1404       self._external_values[k] = g.as_graph_element(
   1405           ops.prepend_name_scope(v, import_scope))
   1406     op_names = set([
   1407         op.split(":")[0]
   1408         for op in self._values - set(self._external_values.keys())
   1409     ])
   1410     for op in op_names:
   1411       # pylint: disable=protected-access
   1412       g.as_graph_element(op)._set_control_flow_context(self)
   1413       # pylint: enable=protected-access
   1414 
   1415   @property
   1416   def name(self):
   1417     return self._name
   1418 
   1419   @property
   1420   def outer_context(self):
   1421     """Return the context containing this context."""
   1422     return self._outer_context
   1423 
   1424   @property
   1425   def grad_state(self):
   1426     raise NotImplementedError("Abstract method")
   1427 
   1428   @property
   1429   def back_prop(self):
   1430     raise NotImplementedError("Abstract method")
   1431 
   1432   @abc.abstractmethod
   1433   def to_control_flow_context_def(self, context_def, export_scope=None):
   1434     """Serializes this into `context_def`.
   1435 
   1436     Args:
   1437       context_def: a `ControlFlowContextDef` protocol buffer.
   1438       export_scope: Optional `string`. Name scope to remove.
   1439     """
   1440     raise NotImplementedError("Abstract method")
   1441 
   1442   def _to_values_def(self, export_scope=None):
   1443     """Converts the values to a `ValuesDef` protocol buffer.
   1444 
   1445     Args:
   1446       export_scope: Optional `string`. Name scope to remove.
   1447 
   1448     Returns:
   1449       A `ValuesDef` protocol buffer.
   1450     """
   1451     values_def = control_flow_pb2.ValuesDef()
   1452     values_def.values.extend(
   1453         [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
   1454     for k, v in self._external_values.items():
   1455       k = ops.strip_name_scope(k, export_scope)
   1456       values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
   1457     return values_def
   1458 
   1459   def AddName(self, name):
   1460     self._values.add(name)
   1461 
   1462   # pylint: disable=protected-access
   1463   def Enter(self):
   1464     """Enter this control flow context."""
   1465     graph = ops.get_default_graph()
   1466     self._context_stack.append(graph._get_control_flow_context())
   1467     graph._set_control_flow_context(self)
   1468 
   1469   def Exit(self):
   1470     """Exit this control flow context."""
   1471     graph = ops.get_default_graph()
   1472     last_context = self._context_stack.pop()
   1473     graph._set_control_flow_context(last_context)
   1474 
   1475   def EnterGradientColocation(self, op, gradient_uid):
   1476     """Start building a gradient colocated with an op."""
   1477     if self._outer_context:
   1478       self._outer_context.EnterGradientColocation(op, gradient_uid)
   1479 
   1480   def ExitGradientColocation(self, op, gradient_uid):
   1481     """Start building a gradient colocated with an op."""
   1482     if self._outer_context:
   1483       self._outer_context.ExitGradientColocation(op, gradient_uid)
   1484 
   1485   def ExitResult(self, result):
   1486     """Make a list of tensors available in the outer context."""
   1487     if self._outer_context:
   1488       nest.map_structure(lambda x: self._outer_context.AddName(x.name), result,
   1489                          expand_composites=True)
   1490 
   1491   def GetWhileContext(self):
   1492     """Return the while context containing this context."""
   1493     if self._outer_context:
   1494       return self._outer_context.GetWhileContext()
   1495     return None
   1496 
   1497   def _IsInOuterContext(self, op):
   1498     op_ctxt = util.GetOutputContext(op)
   1499     outer_ctxt = self.outer_context
   1500     while outer_ctxt != op_ctxt:
   1501       if outer_ctxt is None:
   1502         return False
   1503       outer_ctxt = outer_ctxt.outer_context
   1504     return True
   1505 
   1506   def _RemoveExternalControlEdges(self, op):
   1507     """Remove any external control dependency on this op."""
   1508     while_ctxt = self.GetWhileContext()
   1509     # A control input of `op` is internal if it is in the same while
   1510     # loop context as the enclosing while loop context of self.
   1511     if while_ctxt is None:
   1512       internal_control_inputs = op.control_inputs
   1513     else:
   1514       internal_control_inputs = []
   1515       for x in op.control_inputs:
   1516         ctxt = util.GetOutputContext(x)
   1517         if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
   1518           internal_control_inputs.append(x)
   1519     external_control_inputs = []
   1520     if len(internal_control_inputs) != len(op.control_inputs):
   1521       external_control_inputs = list(
   1522           set(op.control_inputs) - set(internal_control_inputs))
   1523       op._remove_all_control_inputs()
   1524       op._add_control_inputs(internal_control_inputs)
   1525     return internal_control_inputs, external_control_inputs
   1526 
   1527   # pylint: enable=protected-access
   1528 
   1529   def AddInnerOp(self, op):
   1530     """Notifies a scope about an operator added to an inner scope."""
   1531     if self._outer_context:
   1532       self._outer_context.AddInnerOp(op)
   1533 
   1534   def GetControlPivot(self):
   1535     """Returns the pivot node for this context, or None."""
   1536     return None
   1537 
   1538   def IsWhileContext(self):
   1539     return False
   1540 
   1541   def IsCondContext(self):
   1542     return False
   1543 
   1544   def IsXLAContext(self):
   1545     return False
   1546 
   1547   def __str__(self):
   1548     return self.name
   1549 
   1550 
   1551 class CondContext(ControlFlowContext):
   1552   """The context for the conditional construct."""
   1553 
   1554   def __init__(self,
   1555                pred=None,
   1556                pivot=None,
   1557                branch=None,
   1558                name="cond_text",
   1559                context_def=None,
   1560                import_scope=None):
   1561     """Creates a `CondContext`.
   1562 
   1563     Args:
   1564       pred: The `boolean` tensor for the conditional predicate.
   1565       pivot: The predicate tensor in this branch.
   1566       branch: 0 or 1 representing this branch.
   1567       name: Name of the `CondContext` python object.
   1568       context_def: Optional `ContextDef` protocol buffer to initialize the
   1569         `CondContext` object from.
   1570       import_scope: Optional `string`. Name scope to add. Only used when
   1571         initialing from protocol buffer.
   1572     """
   1573     self._name = ops.get_default_graph().unique_name(name)
   1574 
   1575     if context_def:
   1576       self._init_from_proto(context_def, import_scope=import_scope)
   1577     else:
   1578       # Initializes the default fields.
   1579       ControlFlowContext.__init__(self)
   1580       self._pred = pred  # The boolean tensor for the cond predicate
   1581       self._pivot = pivot  # The predicate tensor in this branch
   1582       self._branch = branch  # 0 or 1 representing this branch
   1583 
   1584       # Values considered to have been already seen in this context. pred is not
   1585       # included in this context.
   1586       self._values.add(pred.name)
   1587       self._external_values[pred.name] = pred
   1588       self._values.add(pivot.name)
   1589       pivot.op._set_control_flow_context(self)  # pylint: disable=protected-access
   1590 
   1591   def _init_from_proto(self, context_def, import_scope=None):
   1592     """Creates a new `CondContext` from protocol buffer.
   1593 
   1594     Args:
   1595       context_def: `CondContextDef` protocol buffer.
   1596       import_scope: Optional `string`. Name scope to add.
   1597     """
   1598     assert isinstance(context_def, control_flow_pb2.CondContextDef)
   1599     # Create from context_def.
   1600     g = ops.get_default_graph()
   1601     self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
   1602     self._pred = g.as_graph_element(
   1603         ops.prepend_name_scope(context_def.pred_name, import_scope))
   1604     self._pivot = g.as_graph_element(
   1605         ops.prepend_name_scope(context_def.pivot_name, import_scope))
   1606     self._branch = context_def.branch
   1607     super(CondContext, self).__init__(
   1608         values_def=context_def.values_def, import_scope=import_scope)
   1609 
   1610   @property
   1611   def pred(self):
   1612     return self._pred
   1613 
   1614   @property
   1615   def pivot(self):
   1616     return self._pivot
   1617 
   1618   @property
   1619   def branch(self):
   1620     return self._branch
   1621 
   1622   @property
   1623   def grad_state(self):
   1624     if self.GetWhileContext():
   1625       return self.GetWhileContext().grad_state
   1626     return None
   1627 
   1628   @property
   1629   def back_prop(self):
   1630     if self.GetWhileContext():
   1631       self.GetWhileContext().back_prop
   1632     return False
   1633 
   1634   def GetControlPivot(self):
   1635     return self._pivot
   1636 
   1637   def to_proto(self, export_scope=None):
   1638     """Converts a `CondContext` to a `CondContextDef` protocol buffer.
   1639 
   1640     Args:
   1641       export_scope: Optional `string`. Name scope to remove.
   1642 
   1643     Returns:
   1644       A `CondContextDef` protocol buffer.
   1645     """
   1646     if (export_scope is None or self.name.startswith(export_scope)):
   1647       context_def = control_flow_pb2.CondContextDef()
   1648       context_def.context_name = ops.strip_name_scope(self.name, export_scope)
   1649       context_def.pred_name = ops.strip_name_scope(self._pred.name,
   1650                                                    export_scope)
   1651       context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
   1652                                                     export_scope)
   1653       context_def.branch = self._branch
   1654       context_def.values_def.MergeFrom(
   1655           super(CondContext, self)._to_values_def(export_scope))
   1656       for nested in self._nested_contexts:
   1657         nested_def = context_def.nested_contexts.add()
   1658         nested.to_control_flow_context_def(nested_def)
   1659 
   1660       return context_def
   1661     else:
   1662       return None
   1663 
   1664   @staticmethod
   1665   def from_proto(context_def, import_scope=None):
   1666     """Returns a `CondContext` object created from `context_def`."""
   1667     ret = CondContext(context_def=context_def, import_scope=import_scope)
   1668 
   1669     ret.Enter()
   1670     for nested_def in context_def.nested_contexts:
   1671       from_control_flow_context_def(nested_def, import_scope=import_scope)
   1672     ret.Exit()
   1673     return ret
   1674 
   1675   def to_control_flow_context_def(self, context_def, export_scope=None):
   1676     context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
   1677 
   1678   def AddValue(self, val):
   1679     """Add `val` to the current context and its outer context recursively."""
   1680     if val.name in self._values:
   1681       # Use the real value if it comes from outer context. This is needed in
   1682       # particular for nested conds.
   1683       result = self._external_values.get(val.name)
   1684       result = val if result is None else result
   1685     else:
   1686       result = val
   1687       self._values.add(val.name)
   1688       if self._outer_context:
   1689         result = self._outer_context.AddValue(val)
   1690         self._values.add(result.name)
   1691         self._external_values[result.name] = result
   1692       with ops.control_dependencies(None):
   1693         result = _SwitchRefOrTensor(result, self._pred)[self._branch]
   1694         if self._outer_context:
   1695           self._outer_context.AddInnerOp(result.op)
   1696 
   1697       result.op.graph.prevent_fetching(result.op)
   1698       # pylint: disable=protected-access
   1699       result.op._set_control_flow_context(self)
   1700       # pylint: enable=protected-access
   1701 
   1702       # Mark Switch output as seen by this context and any outer contexts,
   1703       # just like what we do for normal op outputs in _AddOpInternal() below.
   1704       ctxt = self
   1705       while ctxt is not None:
   1706         # pylint: disable=protected-access
   1707         ctxt._values.add(result.name)
   1708         ctxt = ctxt._outer_context
   1709         # pylint: enable=protected-access
   1710 
   1711       self._external_values[val.name] = result
   1712     return result
   1713 
   1714   def AddOp(self, op):
   1715     self._AddOpInternal(op)
   1716 
   1717   def _AddOpInternal(self, op):
   1718     """Add `op` to the current context."""
   1719     if not op.inputs:
   1720       # If we're in a while loop, remove any control inputs from outside the
   1721       # loop.
   1722       self._RemoveExternalControlEdges(op)
   1723 
   1724       if not any(
   1725           util.OpInContext(input_op, self) for input_op in op.control_inputs):
   1726         # pylint: disable=protected-access
   1727         op._add_control_input(self._pivot.op)
   1728         # pylint: enable=protected-access
   1729     else:
   1730       # Make each input to 'op' available in this CondContext. If an input is
   1731       # already part of this context there's nothing to do, but if it's
   1732       # external, AddValue() will handle adding the appropriate Switch node and
   1733       # other bookkeeping.
   1734       for index in range(len(op.inputs)):
   1735         x = op.inputs[index]
   1736         if op.type == "Merge" and x.op.type == "NextIteration":
   1737           # Edge case: if we're importing a while loop inside this CondContext,
   1738           # AddValue() will not correctly handle the NextIteration inputs to
   1739           # Merge node. The problem is that the NextIteration should also be
   1740           # part of this context, but if we're importing it won't have been
   1741           # processed and added to the context yet, so AddValue() will try to
   1742           # add a Switch which results in an invalid graph. Instead, we use the
   1743           # NextIteration input as-is here, and it will eventually be added to
   1744           # the context via AddOp().
   1745           real_x = x
   1746         else:
   1747           real_x = self.AddValue(x)
   1748         if real_x != x:
   1749           # pylint: disable=protected-access
   1750           op._update_input(index, real_x)
   1751           # pylint: enable=protected-access
   1752       # Remove any external control dependency on this op.
   1753       self._RemoveExternalControlEdges(op)
   1754       # pylint: disable=protected-access
   1755       if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
   1756         op._add_control_input(self._pivot.op)
   1757       # pylint: enable=protected-access
   1758 
   1759     # Mark op's outputs as seen by this context and any outer contexts.
   1760     output_names = [x.name for x in op.outputs]
   1761     ctxt = self
   1762     while ctxt is not None:
   1763       # pylint: disable=protected-access
   1764       ctxt._values.update(output_names)
   1765       ctxt = ctxt._outer_context
   1766       # pylint: enable=protected-access
   1767 
   1768     if self._outer_context or not util.IsLoopExit(op):
   1769       op.graph.prevent_fetching(op)
   1770 
   1771     if self._outer_context:
   1772       self._outer_context.AddInnerOp(op)
   1773 
   1774   def _ProcessOutputTensor(self, val):
   1775     """Process an output tensor of a conditional branch."""
   1776     real_val = val
   1777     if val.name not in self._values:
   1778       # Handle the special case of lambda: x
   1779       self._values.add(val.name)
   1780       if self._outer_context:
   1781         real_val = self._outer_context.AddValue(val)
   1782         self._values.add(real_val.name)
   1783         self._external_values[real_val.name] = real_val
   1784       real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
   1785       self._external_values[val.name] = real_val
   1786     else:
   1787       external_val = self._external_values.get(val.name)
   1788       if external_val is not None:
   1789         real_val = external_val
   1790     return real_val
   1791 
   1792   def _BuildCondTensor(self, v):
   1793     if isinstance(v, ops.Operation):
   1794       # Use pivot as the proxy for this op.
   1795       return with_dependencies([v], self._pivot)
   1796     else:
   1797       v = nest.map_structure(_convert_tensorarray_to_flow, v,
   1798                              expand_composites=True)
   1799       return self._ProcessOutputTensor(ops.convert_to_tensor(v))
   1800 
   1801   def BuildCondBranch(self, fn):
   1802     """Add the subgraph defined by fn() to the graph."""
   1803     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1804     original_result = fn()
   1805     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1806     if len(post_summaries) > len(pre_summaries):
   1807       new_summaries = post_summaries[len(pre_summaries):]
   1808       summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1809       summary_ref[:] = pre_summaries
   1810       with ops.control_dependencies(new_summaries):
   1811         if original_result is None:
   1812           return no_op(), None
   1813         else:
   1814           original_result = nest.map_structure(array_ops.identity,
   1815                                                original_result,
   1816                                                expand_composites=True)
   1817     if original_result is None:
   1818       return None, None
   1819 
   1820     result = nest.map_structure(self._BuildCondTensor, original_result,
   1821                                 expand_composites=True)
   1822     if not isinstance(result, (list, _basetuple)):
   1823       result = [result]
   1824     return original_result, result
   1825 
   1826   def IsCondContext(self):
   1827     return True
   1828 
   1829 
   1830 def _UnpackIfSingleton(res):
   1831   if isinstance(res, (list, _basetuple)) and len(res) == 1:
   1832     return res[0]
   1833   else:
   1834     return res
   1835 
   1836 
   1837 # pylint: disable=redefined-outer-name
   1838 # pylint: disable=g-doc-args
   1839 @tf_export(v1=["cond"])
   1840 @deprecation.deprecated_args(
   1841     None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
   1842     "fn1", "fn2")
   1843 def cond(pred,
   1844          true_fn=None,
   1845          false_fn=None,
   1846          strict=False,
   1847          name=None,
   1848          fn1=None,
   1849          fn2=None):
   1850   """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
   1851 
   1852   `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
   1853   `false_fn` must have the same non-zero number and type of outputs.
   1854 
   1855   **WARNING**: Any Tensors or Operations created outside of `true_fn` and
   1856   `false_fn` will be executed regardless of which branch is selected at runtime.
   1857 
   1858   Although this behavior is consistent with the dataflow model of TensorFlow,
   1859   it has frequently surprised users who expected a lazier semantics.
   1860   Consider the following simple program:
   1861 
   1862   ```python
   1863   z = tf.multiply(a, b)
   1864   result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
   1865   ```
   1866 
   1867   If `x < y`, the `tf.add` operation will be executed and `tf.square`
   1868   operation will not be executed. Since `z` is needed for at least one
   1869   branch of the `cond`, the `tf.multiply` operation is always executed,
   1870   unconditionally.
   1871 
   1872   Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
   1873   call to `cond`, and not at all during `Session.run()`). `cond`
   1874   stitches together the graph fragments created during the `true_fn` and
   1875   `false_fn` calls with some additional graph nodes to ensure that the right
   1876   branch gets executed depending on the value of `pred`.
   1877 
   1878   `tf.cond` supports nested structures as implemented in
   1879   `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
   1880   same (possibly nested) value structure of lists, tuples, and/or named tuples.
   1881   Singleton lists and tuples form the only exceptions to this: when returned by
   1882   `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
   1883   This behavior is disabled by passing `strict=True`.
   1884 
   1885   Args:
   1886     pred: A scalar determining whether to return the result of `true_fn` or
   1887       `false_fn`.
   1888     true_fn: The callable to be performed if pred is true.
   1889     false_fn: The callable to be performed if pred is false.
   1890     strict: A boolean that enables/disables 'strict' mode; see above.
   1891     name: Optional name prefix for the returned tensors.
   1892 
   1893   Returns:
   1894     Tensors returned by the call to either `true_fn` or `false_fn`. If the
   1895     callables return a singleton list, the element is extracted from the list.
   1896 
   1897   Raises:
   1898     TypeError: if `true_fn` or `false_fn` is not callable.
   1899     ValueError: if `true_fn` and `false_fn` do not return the same number of
   1900       tensors, or return tensors of different types.
   1901 
   1902   Example:
   1903 
   1904   ```python
   1905   x = tf.constant(2)
   1906   y = tf.constant(5)
   1907   def f1(): return tf.multiply(x, 17)
   1908   def f2(): return tf.add(y, 23)
   1909   r = tf.cond(tf.less(x, y), f1, f2)
   1910   # r is set to f1().
   1911   # Operations in f2 (e.g., tf.add) are not executed.
   1912   ```
   1913 
   1914   """
   1915   # Always enable control flow v2 if building a function, regardless of toggle.
   1916   if (util.EnableControlFlowV2(ops.get_default_graph()) and
   1917       not context.executing_eagerly()):
   1918     return cond_v2.cond_v2(pred, true_fn, false_fn, name)
   1919 
   1920   # We needed to make true_fn/false_fn keyword arguments for
   1921   # backwards-compatibility. This check exists so that we can convert back to
   1922   # having them be positional arguments.
   1923   # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
   1924   # `fn1` and `fn2` are deleted.
   1925   if fn1 is not None:
   1926     if true_fn is not None:
   1927       raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
   1928     true_fn = fn1
   1929   elif true_fn is None:
   1930     raise TypeError("cond(): true_fn argument required")
   1931   if fn2 is not None:
   1932     if false_fn is not None:
   1933       raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
   1934     false_fn = fn2
   1935   elif false_fn is None:
   1936     raise TypeError("cond(): false_fn argument required")
   1937 
   1938   if not callable(true_fn):
   1939     raise TypeError("true_fn must be callable.")
   1940   if not callable(false_fn):
   1941     raise TypeError("false_fn must be callable.")
   1942 
   1943   with ops.name_scope(name, "cond", [pred]):
   1944     if context.executing_eagerly():
   1945       if pred:
   1946         return _UnpackIfSingleton(true_fn())
   1947       return _UnpackIfSingleton(false_fn())
   1948 
   1949     # Add the Switch to the graph.
   1950     if isinstance(pred, bool):
   1951       raise TypeError("pred must not be a Python bool")
   1952     p_2, p_1 = switch(pred, pred)
   1953     pivot_1 = array_ops.identity(p_1, name="switch_t")
   1954     pivot_2 = array_ops.identity(p_2, name="switch_f")
   1955     pred = array_ops.identity(pred, name="pred_id")
   1956     # Disable the fetching of tensors that are only on one branch of cond.
   1957     for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
   1958       tensor.op.graph.prevent_fetching(tensor.op)
   1959 
   1960     # Build the graph for the true branch in a new context.
   1961     context_t = CondContext(pred, pivot_1, branch=1)
   1962     try:
   1963       context_t.Enter()
   1964       orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
   1965       if orig_res_t is None:
   1966         raise ValueError("true_fn must have a return value.")
   1967       context_t.ExitResult(res_t)
   1968     finally:
   1969       context_t.Exit()
   1970 
   1971     # Build the graph for the false branch in a new context.
   1972     context_f = CondContext(pred, pivot_2, branch=0)
   1973     try:
   1974       context_f.Enter()
   1975       orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
   1976       if orig_res_f is None:
   1977         raise ValueError("false_fn must have a return value.")
   1978       context_f.ExitResult(res_f)
   1979     finally:
   1980       context_f.Exit()
   1981 
   1982     if not strict:
   1983       orig_res_t = _UnpackIfSingleton(orig_res_t)
   1984       orig_res_f = _UnpackIfSingleton(orig_res_f)
   1985 
   1986     # Check that the return values of the two branches have the same structure.
   1987     try:
   1988       nest.assert_same_structure(orig_res_t, orig_res_f,
   1989                                  expand_composites=True)
   1990     except TypeError as e:
   1991       raise TypeError(
   1992           "Incompatible return types of true_fn and false_fn: {}".format(e))
   1993     except ValueError as e:
   1994       raise ValueError(
   1995           "Incompatible return values of true_fn and false_fn: {}".format(e))
   1996 
   1997     # Add the final merge to the graph.
   1998     if not res_t:
   1999       raise ValueError("true_fn and false_fn must return at least one result.")
   2000 
   2001     res_t_flat = nest.flatten(res_t, expand_composites=True)
   2002     res_f_flat = nest.flatten(res_f, expand_composites=True)
   2003 
   2004     for i, (x, y) in enumerate(zip(res_t_flat, res_f_flat)):
   2005       assert isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)
   2006       if x.dtype.base_dtype != y.dtype.base_dtype:
   2007         _cast_indexed_slice_indices(res_t, res_t_flat, res_f_flat)
   2008         if res_t_flat[i].dtype.base_dtype != res_f_flat[i].dtype.base_dtype:
   2009           raise ValueError(
   2010               "Outputs of true_fn and false_fn must have the same type: "
   2011               "%s, %s" % (x.dtype.name, y.dtype.name))
   2012 
   2013     merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
   2014     merges = _convert_flows_to_tensorarrays(
   2015         nest.flatten(orig_res_t, expand_composites=True), merges)
   2016 
   2017     # Only add non-nested conds to the collection. Any nested control flow will
   2018     # be encapsulated in the root context.
   2019     assert context_t.outer_context == context_f.outer_context
   2020     if context_t.outer_context is None:
   2021       ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
   2022       ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
   2023 
   2024     merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges,
   2025                                    expand_composites=True)
   2026 
   2027     # Singleton lists and tuples are automatically unpacked if strict == False.
   2028     if not strict:
   2029       merges = _UnpackIfSingleton(merges)
   2030     return merges
   2031 
   2032 
   2033 def _cast_indexed_slice_indices(structure, flat_a, flat_b):
   2034   """Cast IndexedSlice.indices from int32 to int64 where necessary.
   2035 
   2036   For each `IndexedSlices` in the nested structure `structure`, find its
   2037   indices `Tensor` in the corresponding flattened lists `flat_a` and `flat_b`
   2038   (where composites have been expanded); and if those indices tensors have
   2039   different dtypes (i.e., if one is int64 but the other is int32), then cast
   2040   them to both be int64.
   2041 
   2042   Args:
   2043     structure: The nested structure that was flattened.
   2044     flat_a: A flattened list of `Tensors` whose structure matches
   2045         `structure`.  Will be modified in place to cast `IndexedSlices`
   2046         indices tensors to int64, where necessary.
   2047     flat_a: A flattened list of `Tensors` whose structure matches
   2048         `structure`.  Will be modified in place to cast `IndexedSlices`
   2049         indices tensors to int64, where necessary.
   2050   """
   2051   # Find the locations (in flat_a and flat_b) of the IndexedSlices'
   2052   # indices tensors.
   2053   indexed_slice_indices = []
   2054   current_index = 0
   2055   for item in nest.flatten(structure, expand_composites=False):
   2056     if isinstance(item, ops.IndexedSlices):
   2057       # indices is the second component of the composite tensor.
   2058       indexed_slice_indices.append(current_index + 1)
   2059     if nest.is_sequence_or_composite(item):
   2060       current_index += len(nest.flatten(item, expand_composites=True))
   2061     else:
   2062       current_index += 1
   2063   assert current_index == len(flat_a)
   2064 
   2065   for index in indexed_slice_indices:
   2066     assert flat_a[index].dtype in (dtypes.int32, dtypes.int64)
   2067     assert flat_b[index].dtype in (dtypes.int32, dtypes.int64)
   2068     if flat_a[index].dtype != flat_b[index].dtype:
   2069       if flat_b[index].dtype == dtypes.int32:
   2070         flat_b[index] = math_ops.cast(flat_b[index], dtypes.int64)
   2071       else:
   2072         flat_a[index] = math_ops.cast(flat_a[index], dtypes.int64)
   2073 
   2074 
   2075 # pylint: enable=g-doc-args
   2076 # pylint: enable=redefined-outer-name
   2077 
   2078 
   2079 @tf_export("cond", v1=[])
   2080 def cond_for_tf_v2(pred,
   2081                    true_fn=None,
   2082                    false_fn=None,
   2083                    name=None):
   2084   """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
   2085 
   2086   `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
   2087   `false_fn` must have the same non-zero number and type of outputs.
   2088 
   2089   **WARNING**: Any Tensors or Operations created outside of `true_fn` and
   2090   `false_fn` will be executed regardless of which branch is selected at runtime.
   2091 
   2092   Although this behavior is consistent with the dataflow model of TensorFlow,
   2093   it has frequently surprised users who expected a lazier semantics.
   2094   Consider the following simple program:
   2095 
   2096   ```python
   2097   z = tf.multiply(a, b)
   2098   result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
   2099   ```
   2100 
   2101   If `x < y`, the `tf.add` operation will be executed and `tf.square`
   2102   operation will not be executed. Since `z` is needed for at least one
   2103   branch of the `cond`, the `tf.multiply` operation is always executed,
   2104   unconditionally.
   2105 
   2106   Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
   2107   call to `cond`, and not at all during `Session.run()`). `cond`
   2108   stitches together the graph fragments created during the `true_fn` and
   2109   `false_fn` calls with some additional graph nodes to ensure that the right
   2110   branch gets executed depending on the value of `pred`.
   2111 
   2112   `tf.cond` supports nested structures as implemented in
   2113   `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
   2114   same (possibly nested) value structure of lists, tuples, and/or named tuples.
   2115   Singleton lists and tuples form the only exceptions to this: when returned by
   2116   `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
   2117 
   2118   Args:
   2119     pred: A scalar determining whether to return the result of `true_fn` or
   2120       `false_fn`.
   2121     true_fn: The callable to be performed if pred is true.
   2122     false_fn: The callable to be performed if pred is false.
   2123     name: Optional name prefix for the returned tensors.
   2124 
   2125   Returns:
   2126     Tensors returned by the call to either `true_fn` or `false_fn`. If the
   2127     callables return a singleton list, the element is extracted from the list.
   2128 
   2129   Raises:
   2130     TypeError: if `true_fn` or `false_fn` is not callable.
   2131     ValueError: if `true_fn` and `false_fn` do not return the same number of
   2132       tensors, or return tensors of different types.
   2133 
   2134   Example:
   2135 
   2136   ```python
   2137   x = tf.constant(2)
   2138   y = tf.constant(5)
   2139   def f1(): return tf.multiply(x, 17)
   2140   def f2(): return tf.add(y, 23)
   2141   r = tf.cond(tf.less(x, y), f1, f2)
   2142   # r is set to f1().
   2143   # Operations in f2 (e.g., tf.add) are not executed.
   2144   ```
   2145 
   2146   """
   2147   return cond(pred, true_fn=true_fn, false_fn=false_fn, strict=True, name=name)
   2148 
   2149 
   2150 def _resource_safe_shape(t):
   2151   """Returns the shape of t or the variable it points to."""
   2152   if t.dtype == dtypes.resource:
   2153     while t.op.inputs:
   2154       t = t.op.inputs[0]
   2155     return tensor_shape.TensorShape(t.op.get_attr("shape"))
   2156   return array_ops.shape_internal(t, optimize=False)
   2157 
   2158 
   2159 # TODO(yuanbyu): Consider having a unified notion of context for
   2160 # not only conditionals and loops but also control dependency and
   2161 # subgraphs.
   2162 class WhileContext(ControlFlowContext):
   2163   """The context for the loop construct."""
   2164 
   2165   def __init__(self,
   2166                maximum_iterations=None,
   2167                parallel_iterations=10,
   2168                back_prop=True,
   2169                swap_memory=False,
   2170                name="while_context",
   2171                grad_state=None,
   2172                context_def=None,
   2173                import_scope=None):
   2174     """"Creates a `WhileContext`.
   2175 
   2176     Args:
   2177       maximum_iterations: Optional upper bound on number of loop iterations.
   2178       parallel_iterations: The number of iterations allowed to run in parallel.
   2179       back_prop: Whether backprop is enabled for this while loop.
   2180       swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
   2181       name: Optional name prefix for the returned tensors.
   2182       grad_state: The gradient loop state.
   2183       context_def: Optional `WhileContextDef` protocol buffer to initialize the
   2184         `Whilecontext` python object from.
   2185       import_scope: Optional `string`. Name scope to add. Only used when
   2186         initialing from protocol buffer.
   2187     """
   2188     if context_def:
   2189       self._init_from_proto(context_def, import_scope=import_scope)
   2190     else:
   2191       ControlFlowContext.__init__(self)
   2192       self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
   2193                            swap_memory, name)
   2194     # The gradient loop state.
   2195     self._grad_state = grad_state
   2196 
   2197   def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
   2198                       swap_memory, name):
   2199     """Creates a new `WhileContext` from arguments.
   2200 
   2201     Args:
   2202       maximum_iterations: Optional upper bound on number of loop iterations.
   2203       parallel_iterations: The number of iterations allowed to run in parallel.
   2204       back_prop: Whether backprop is enabled for this while loop.
   2205       swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
   2206       name: Optional name prefix for the returned tensors.
   2207 
   2208     Raises:
   2209       ValueError: If `parallel_iterations` has invalid value.
   2210     """
   2211     if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
   2212       raise ValueError("`parallel_iterations` must be a positive integer: "
   2213                        "%s" % parallel_iterations)
   2214     self._name = ops.get_default_graph().unique_name(name)
   2215     self._maximum_iterations = maximum_iterations
   2216     self._parallel_iterations = parallel_iterations
   2217     self._back_prop = back_prop
   2218     self._swap_memory = swap_memory
   2219     # We use this node to control constants created by the pred lambda.
   2220     self._pivot_for_pred = None
   2221     # We use this node to control constants created by the body lambda.
   2222     self._pivot_for_body = None
   2223     # The boolean tensor for loop termination condition. Used in code
   2224     # generation for gradient computation
   2225     self._pivot = None
   2226     # The list of exit tensors for loop variables.
   2227     self._loop_exits = []
   2228     # The list of enter tensors for loop variables.
   2229     self._loop_enters = []
   2230     self._graph = ops.get_default_graph()
   2231 
   2232   def _init_from_proto(self, context_def, import_scope=None):
   2233     """Creates a new `WhileContext` from protocol buffer.
   2234 
   2235     Args:
   2236       context_def: `WhileContextDef` protocol buffer.
   2237       import_scope: Optional `string`. Name scope to add.
   2238     """
   2239     assert isinstance(context_def, control_flow_pb2.WhileContextDef)
   2240     # Create from context_def.
   2241     g = ops.get_default_graph()
   2242     self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
   2243     if context_def.maximum_iterations_name:
   2244       self._maximum_iterations = g.as_graph_element(
   2245           ops.prepend_name_scope(context_def.maximum_iterations_name,
   2246                                  import_scope))
   2247     else:
   2248       self._maximum_iterations = None
   2249     self._parallel_iterations = context_def.parallel_iterations
   2250     self._back_prop = context_def.back_prop
   2251     self._swap_memory = context_def.swap_memory
   2252     self._pivot_for_pred = g.as_graph_element(
   2253         ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
   2254     # We use this node to control constants created by the body lambda.
   2255     self._pivot_for_body = g.as_graph_element(
   2256         ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
   2257     # The boolean tensor for loop termination condition. Used in code
   2258     # generation for gradient computation.
   2259     self._pivot = g.as_graph_element(
   2260         ops.prepend_name_scope(context_def.pivot_name, import_scope))
   2261     # The list of exit tensors for loop variables.
   2262     self._loop_exits = [
   2263         g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
   2264         for exit_name in context_def.loop_exit_names
   2265     ]
   2266     # The list of enter tensors for loop variables.
   2267     self._loop_enters = [
   2268         g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
   2269         for enter_name in context_def.loop_enter_names
   2270     ]
   2271     super(WhileContext, self).__init__(
   2272         values_def=context_def.values_def, import_scope=import_scope)
   2273 
   2274     # import_scope causes self.name to be different from the original serialized
   2275     # context's name. Rewrite "frame_name" attrs with the new name.
   2276     if import_scope:
   2277       for tensor_name in self._values:
   2278         op = g.as_graph_element(tensor_name).op
   2279         if util.IsLoopEnter(op):
   2280           # pylint: disable=protected-access
   2281           op._set_attr("frame_name",
   2282                        attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
   2283           # pylint: enable=protected-access
   2284     self._graph = ops.get_default_graph()
   2285 
   2286   @property
   2287   def maximum_iterations(self):
   2288     """The maximum number of iterations that will be executed."""
   2289     return self._maximum_iterations
   2290 
   2291   @property
   2292   def parallel_iterations(self):
   2293     """The number of iterations allowed to run in parallel."""
   2294     return self._parallel_iterations
   2295 
   2296   @property
   2297   def back_prop(self):
   2298     """True iff backprop is enabled for this while loop."""
   2299     return self._back_prop
   2300 
   2301   @property
   2302   def swap_memory(self):
   2303     """True iff GPU-CPU memory swap is enabled for this while loop."""
   2304     return self._swap_memory
   2305 
   2306   @property
   2307   def pivot(self):
   2308     """The boolean tensor representing the loop termination condition."""
   2309     return self._pivot
   2310 
   2311   @property
   2312   def loop_enters(self):
   2313     """The list of enter tensors for loop variables."""
   2314     return self._loop_enters
   2315 
   2316   @property
   2317   def loop_exits(self):
   2318     """The list of exit tensors for loop variables."""
   2319     return self._loop_exits
   2320 
   2321   @property
   2322   def grad_state(self):
   2323     """The gradient loop state."""
   2324     return self._grad_state
   2325 
   2326   def to_proto(self, export_scope=None):
   2327     """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
   2328 
   2329     Args:
   2330       export_scope: Optional `string`. Name scope to remove.
   2331 
   2332     Returns:
   2333       A `WhileContextDef` protocol buffer.
   2334     """
   2335     if (export_scope is None or self.name.startswith(export_scope)):
   2336       context_def = control_flow_pb2.WhileContextDef()
   2337       context_def.context_name = ops.strip_name_scope(self.name, export_scope)
   2338       context_def.parallel_iterations = self._parallel_iterations
   2339       if self._maximum_iterations is not None:
   2340         context_def.maximum_iterations_name = ops.strip_name_scope(
   2341             self._maximum_iterations.name, export_scope)
   2342       context_def.back_prop = self._back_prop
   2343       context_def.swap_memory = self._swap_memory
   2344       context_def.pivot_for_pred_name = ops.strip_name_scope(
   2345           self._pivot_for_pred.name, export_scope)
   2346       context_def.pivot_for_body_name = ops.strip_name_scope(
   2347           self._pivot_for_body.name, export_scope)
   2348       context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
   2349                                                     export_scope)
   2350       context_def.loop_exit_names.extend([
   2351           ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
   2352       ])
   2353       context_def.loop_enter_names.extend([
   2354           ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
   2355       ])
   2356       context_def.values_def.MergeFrom(
   2357           super(WhileContext, self)._to_values_def(export_scope=export_scope))
   2358       for nested in self._nested_contexts:
   2359         nested_def = context_def.nested_contexts.add()
   2360         nested.to_control_flow_context_def(nested_def)
   2361 
   2362       return context_def
   2363     else:
   2364       return None
   2365 
   2366   def to_control_flow_context_def(self, context_def, export_scope=None):
   2367     context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
   2368 
   2369   @staticmethod
   2370   def from_proto(context_def, import_scope=None):
   2371     """Returns a `WhileContext` object created from `context_def`.
   2372 
   2373     Args:
   2374       context_def: A `WhileContextDef` protocol buffer.
   2375       import_scope: Optional `string`. Name scope to add.
   2376 
   2377     Returns:
   2378       A `WhileContext` Python object.
   2379     """
   2380     ret = WhileContext(context_def=context_def, import_scope=import_scope)
   2381     ret.Enter()
   2382     for nested_def in context_def.nested_contexts:
   2383       from_control_flow_context_def(nested_def, import_scope=import_scope)
   2384     ret.Exit()
   2385     return ret
   2386 
   2387   def GetWhileContext(self):
   2388     return self
   2389 
   2390   def GetControlPivot(self):
   2391     if self._pivot_for_body is not None:
   2392       return self._pivot_for_body
   2393     return self._pivot_for_pred
   2394 
   2395   def AddValue(self, val):
   2396     """Add `val` to the current context and its outer context recursively."""
   2397     result = val
   2398     new_value = val.name not in self._values
   2399     # Don't treat ops in this context as new values. Usually all known values
   2400     # are in self._values, except when we're importing a while loop inside this
   2401     # WhileContext. Since there's a cycle in this case, `val` may be part of the
   2402     # imported while loop but not yet processed by this context and added to
   2403     # self._values in _AddOpInternal. We only want to process external input
   2404     # tensors to the while loop here.
   2405     new_value &= val.op._control_flow_context is not self  # pylint: disable=protected-access
   2406     if new_value:
   2407       self._values.add(val.name)
   2408 
   2409       # If we are in a grad context and val is from its forward context,
   2410       # use GetRealValue(), which adds the logic to save the history of
   2411       # val in forward.
   2412       grad_ctxt = ops.get_default_graph()._get_control_flow_context()
   2413       if grad_ctxt:
   2414         grad_ctxt = grad_ctxt.GetWhileContext()
   2415         if grad_ctxt.grad_state:
   2416           forward_ctxt = _GetWhileContext(val.op)
   2417           if util.IsLoopExit(val.op):
   2418             forward_ctxt = forward_ctxt.outer_context
   2419             if forward_ctxt:
   2420               forward_ctxt = forward_ctxt.GetWhileContext()
   2421           if forward_ctxt == grad_ctxt.grad_state.forward_context:
   2422             real_val = grad_ctxt.grad_state.GetRealValue(val)
   2423             self._external_values[val.name] = real_val
   2424             return real_val
   2425 
   2426       if self._outer_context is not None:
   2427         result = self._outer_context.AddValue(val)
   2428       # Create an Enter to make `result` known to this loop context.
   2429       with ops.control_dependencies(None):
   2430         enter = _Enter(
   2431             result,
   2432             self._name,
   2433             is_constant=True,
   2434             parallel_iterations=self._parallel_iterations)
   2435         enter.graph.prevent_feeding(enter)
   2436         if self._outer_context:
   2437           self._outer_context.AddInnerOp(enter.op)
   2438       # Fix the control inputs and control flow context of these enter ops.
   2439       self._FixControlInputsAndContext([enter])
   2440 
   2441       # Add `enter` in this context.
   2442       self._values.add(enter.name)
   2443       self._external_values[val.name] = enter
   2444       result = enter
   2445     else:
   2446       actual_val = self._external_values.get(val.name)
   2447       if actual_val is not None:
   2448         result = actual_val
   2449     return result
   2450 
   2451   def AddOp(self, op):
   2452     """Add `op` to the current context."""
   2453     # For a reduction op, if op is in a grad context and its input is from
   2454     # its forward context, moving op to the forward context means we would
   2455     # store the tensor after the reduction as opposed to the tensor before
   2456     # reduction, and therefore could significantly reduce memory consumption.
   2457     # For now, we do this only for a few ops.
   2458     if op.type in {"Shape", "Size", "Rank"}:
   2459       grad_ctxt = ops.get_default_graph()._get_control_flow_context()
   2460       if grad_ctxt:
   2461         grad_ctxt = grad_ctxt.GetWhileContext()
   2462         if grad_ctxt.grad_state:
   2463           op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op)
   2464           if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
   2465             op_input_ctxt = op.inputs[0].op._get_control_flow_context()
   2466             op._set_control_flow_context(op_input_ctxt)
   2467             op_input_ctxt._AddOpInternal(op)
   2468             return
   2469     self._AddOpInternal(op)
   2470 
   2471   def _AddOpInternal(self, op):
   2472     """Add `op` to the current context.
   2473 
   2474     We move any external control dependencies of the op to the loop pivot, to
   2475     ensure they get executed.
   2476     """
   2477     if not op.inputs:
   2478       # Remove any external control dependency on this op
   2479       control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
   2480       # Add a control edge from the control pivot to this op.
   2481       if not control_inputs:
   2482         # pylint: disable=protected-access
   2483         op._add_control_input(self.GetControlPivot().op)
   2484         # pylint: enable=protected-access
   2485       for x in op.outputs:
   2486         self._values.add(x.name)
   2487     else:
   2488       for index in range(len(op.inputs)):
   2489         x = op.inputs[index]
   2490         real_x = self.AddValue(x)
   2491         if real_x != x:
   2492           op._update_input(index, real_x)  # pylint: disable=protected-access
   2493       # Remove any external control dependency on this op.
   2494       _, external_inputs = self._RemoveExternalControlEdges(op)
   2495       # Add a control dependency to prevent loop invariants from
   2496       # enabling ops that should not be executed.
   2497       self._MaybeAddControlDependency(op)
   2498       for x in op.outputs:
   2499         self._values.add(x.name)
   2500     if external_inputs:
   2501       # Use an identity to pull control inputs as data inputs. Note that we
   2502       # ignore ops which don't have outputs. TODO(apassos): fix that
   2503       with ops.control_dependencies(None):
   2504         self.Enter()
   2505         external_inputs = [
   2506             array_ops.identity(x.outputs[0]).op
   2507             for x in external_inputs
   2508             if x.outputs
   2509         ]
   2510         self.Exit()
   2511       op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
   2512     if self._outer_context or not util.IsLoopExit(op):
   2513       op.graph.prevent_fetching(op)
   2514       for x in op.outputs:
   2515         op.graph.prevent_feeding(x)
   2516 
   2517     if self._outer_context:
   2518       self._outer_context.AddInnerOp(op)
   2519 
   2520   def _MaybeAddControlDependency(self, op):
   2521     """Add a control input to the op if it only depends on loop invariants."""
   2522 
   2523     def _IsOpFree(op):
   2524       """Determines if `op` needs a control dependency."""
   2525       if op.control_inputs:
   2526         return False
   2527       # pylint: disable=protected-access
   2528       if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
   2529         return True
   2530       # pylint: enable=protected-access
   2531       for x in op.inputs:
   2532         if not util.IsLoopConstantEnter(x.op):
   2533           return False
   2534       return True
   2535 
   2536     if _IsOpFree(op):
   2537       # pylint: disable=protected-access
   2538       op._add_control_input(self.GetControlPivot().op)
   2539       # pylint: enable=protected-access
   2540 
   2541   def AddForwardLoopCounter(self, outer_grad_state):
   2542     """Adds a loop that counts the number of iterations.
   2543 
   2544     This is added to the forward loop at the time when we start to
   2545     create the loop for backprop gradient computation. Called in
   2546     the outer context of this forward context.
   2547 
   2548     The pseudocode is:
   2549       `n = 0; while (_pivot) { n++; }`
   2550 
   2551     Note that a control dependency is added to `n` to ensure the correct
   2552     execution order of stack push ops.
   2553 
   2554     Args:
   2555       outer_grad_state: The outer grad state. None if not nested.
   2556 
   2557     Returns:
   2558       The number of iterations taken by the forward loop and the loop index.
   2559     """
   2560     n = constant_op.constant(0, name="f_count")
   2561     if outer_grad_state is not None:
   2562       # Force the stack pushes of i-th execution of an inner loop to be ordered
   2563       # before the pushes of (i+1)-th execution of the same inner loop.
   2564       outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
   2565       n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
   2566 
   2567     self.Enter()
   2568     self.AddName(n.name)
   2569     enter_n = _Enter(
   2570         n,
   2571         self._name,
   2572         is_constant=False,
   2573         parallel_iterations=self._parallel_iterations,
   2574         name="f_count")
   2575     self.loop_enters.append(enter_n)
   2576 
   2577     merge_n = merge([enter_n, enter_n])[0]
   2578     switch_n = switch(merge_n, self._pivot)
   2579 
   2580     index = math_ops.add(switch_n[1], 1)
   2581     next_n = _NextIteration(index)
   2582     merge_n.op._update_input(1, next_n)
   2583 
   2584     total_iterations = exit(switch_n[0], name="f_count")
   2585     self.loop_exits.append(total_iterations)
   2586     self.ExitResult([total_iterations])
   2587     self.Exit()
   2588     return total_iterations, next_n
   2589 
   2590   def AddBackpropLoopCounter(self, count, outer_grad_state):
   2591     """Add the backprop loop that controls the iterations.
   2592 
   2593     This is added to the backprop loop. It is used to control the loop
   2594     termination of the backprop loop. Called in the outer context of
   2595     this grad context.
   2596 
   2597     The pseudocode is:
   2598       `n = count; while (n >= 1) { n--; }`
   2599 
   2600     Note that a control dependency is added to `final_zero` to ensure the
   2601     correct execution order of stack pop ops.
   2602 
   2603     Args:
   2604       count: The number of iterations for backprop.
   2605       outer_grad_state: The outer grad state. None if not nested.
   2606 
   2607     Returns:
   2608       The loop index.
   2609     """
   2610     in_separate_functions = count.graph is not ops.get_default_graph()
   2611     if in_separate_functions:
   2612       # Brings the count into this graph
   2613       count = array_ops.identity(count)
   2614     else:
   2615       # TODO(apassos) XLA expects this constant to be created outside the loop,
   2616       # so doing that for now.
   2617       one = constant_op.constant(1, name="b_count")
   2618 
   2619     self.Enter()
   2620     self.AddName(count.name)
   2621     enter_count = _Enter(
   2622         count,
   2623         self._name,
   2624         is_constant=False,
   2625         parallel_iterations=self._parallel_iterations,
   2626         name="b_count")
   2627     self.loop_enters.append(enter_count)
   2628 
   2629     merge_count = merge([enter_count, enter_count])[0]
   2630     self._pivot_for_pred = merge_count
   2631 
   2632     if in_separate_functions:
   2633       one = constant_op.constant(1, name="b_count")
   2634     pred = math_ops.greater_equal(merge_count, one)
   2635     self._pivot = loop_cond(pred, name="b_count")
   2636     switch_count = switch(merge_count, self._pivot)
   2637 
   2638     index = math_ops.subtract(switch_count[1], one)
   2639     self._pivot_for_body = index
   2640     next_count = _NextIteration(index)
   2641     merge_count.op._update_input(1, next_count)
   2642 
   2643     final_zero = exit(switch_count[0], name="b_count")
   2644     self.loop_exits.append(final_zero)
   2645     if outer_grad_state is not None:
   2646       # Force the stack pops of i-th execution of an inner loop to be ordered
   2647       # before the pops of (i+1)-th execution of the same inner loop.
   2648       # pylint: disable=protected-access
   2649       outer_grad_state.grad_sync._add_control_input(final_zero.op)
   2650       # pylint: enable=protected-access
   2651 
   2652     self.ExitResult([final_zero])
   2653     self.Exit()
   2654     return next_count
   2655 
   2656   def AddBackpropAccumulator(self, op, grad):
   2657     """Add an accumulation loop for every loop invariant.
   2658 
   2659     This is added to the backprop loop. It is used to accumulate partial
   2660     gradients within each loop iteration. Called when in the gradient while
   2661     context.
   2662 
   2663     The pseudocode is:
   2664       ```
   2665       acc = 0.0;
   2666       while (_pivot) {
   2667         acc += grad;
   2668       }
   2669       ```
   2670 
   2671     Args:
   2672       op: The Enter op for a loop invariant.
   2673       grad: The partial gradient of an iteration for a loop invariant.
   2674 
   2675     Returns:
   2676       The gradient for a loop invariant.
   2677     """
   2678     self.Exit()
   2679     # Create a zeros tensor with the right shape for acc. If we don't
   2680     # know the full shape statically, we will have to get the shape
   2681     # dynamically from the forward inference. Getting the shape right
   2682     # for the zeros is only needed for the base case when the loop exits
   2683     # without running any iterations.
   2684     shape = grad.get_shape()
   2685     if shape.is_fully_defined():
   2686       if self.outer_context:
   2687         self.outer_context.Enter()
   2688       acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
   2689       if self.outer_context:
   2690         self.outer_context.Exit()
   2691     else:
   2692       value = op.inputs[0]
   2693       if (isinstance(self.outer_context, WhileContext) and
   2694           self.outer_context.grad_state is not None):
   2695         # We are in a nested while loop.
   2696         forward_ctxt = self.grad_state.forward_context
   2697         forward_ctxt.outer_context.Enter()
   2698         zeros_shape = array_ops.shape_internal(value, optimize=False)
   2699         forward_ctxt.outer_context.Exit()
   2700         outer_grad_state = self.grad_state.outer_grad_state
   2701         history_zeros_shape = outer_grad_state.AddForwardAccumulator(
   2702             zeros_shape)
   2703         self.outer_context.Enter()
   2704         real_shape = outer_grad_state.AddBackpropAccumulatedValue(
   2705             history_zeros_shape, zeros_shape)
   2706         acc = array_ops.zeros(real_shape, grad.dtype)
   2707         self.outer_context.Exit()
   2708       else:
   2709         if self.outer_context:
   2710           self.outer_context.Enter()
   2711         zeros_shape = array_ops.shape_internal(value, optimize=False)
   2712         acc = array_ops.zeros(zeros_shape, grad.dtype)
   2713         if self.outer_context:
   2714           self.outer_context.Exit()
   2715 
   2716     self.Enter()
   2717     self.AddName(acc.name)
   2718     enter_acc = _Enter(
   2719         acc,
   2720         self._name,
   2721         is_constant=False,
   2722         parallel_iterations=self._parallel_iterations,
   2723         name="b_acc")
   2724     self.loop_enters.append(enter_acc)
   2725 
   2726     merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
   2727     switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
   2728 
   2729     add_acc = math_ops.add(switch_acc_true, grad)
   2730     next_acc = _NextIteration(add_acc)
   2731     merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
   2732 
   2733     result_acc = exit(switch_acc_false, name="b_acc")
   2734     self.loop_exits.append(result_acc)
   2735     self.ExitResult([result_acc])
   2736     return result_acc
   2737 
   2738   def AddBackpropIndexedSlicesAccumulator(self, op, grad):
   2739     """This is used for accumulating gradients that are IndexedSlices.
   2740 
   2741     This is essentially the equivalent of AddBackpropAccumulator but optimized
   2742     for things like updating embeddings from within a while loop.
   2743 
   2744     Args:
   2745       op: The Enter op for a loop invariant.
   2746       grad: The partial gradients represented as an IndexedSlices.
   2747 
   2748     Returns:
   2749       The accumulated IndexedSlices gradient of the loop invariant.
   2750     """
   2751     values = grad.values
   2752     indices = grad.indices
   2753     dense_shape = grad.dense_shape
   2754 
   2755     self.Exit()
   2756     if self.outer_context:
   2757       self.outer_context.Enter()
   2758     if values.get_shape().is_fully_defined():
   2759       values_shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
   2760                                               values.get_shape().dims[1:])
   2761       if self.outer_context:
   2762         self.outer_context.Enter()
   2763       values_acc = constant_op.constant(
   2764           0, values.dtype, shape=values_shape, name="b_acc")
   2765       if self.outer_context:
   2766         self.outer_context.Exit()
   2767     else:
   2768       values_shape = _resource_safe_shape(op.inputs[0])[1:]
   2769       values_shape = array_ops.concat([[1], values_shape], 0)
   2770       values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
   2771     indices_acc = constant_op.constant([0], indices.dtype)
   2772     shape_acc = None
   2773     if dense_shape is not None:
   2774       if dense_shape.get_shape().is_fully_defined():
   2775         if self.outer_context:
   2776           self.outer_context.Enter()
   2777         shape_acc = constant_op.constant(
   2778             0, dense_shape.dtype, shape=dense_shape.get_shape())
   2779         if self.outer_context:
   2780           self.outer_context.Exit()
   2781       else:
   2782         shape_acc = array_ops.zeros_like(
   2783             array_ops.shape_internal(
   2784                 op.inputs[0], optimize=False, out_type=dense_shape.dtype),
   2785             optimize=False)
   2786 
   2787     if self.outer_context:
   2788       self.outer_context.Exit()
   2789 
   2790     self.Enter()
   2791     self.AddName(values_acc.name)
   2792     self.AddName(indices_acc.name)
   2793     init_acc = [indices_acc, values_acc]
   2794     if shape_acc is not None:
   2795       self.AddName(shape_acc.name)
   2796       init_acc.append(shape_acc)
   2797 
   2798     # Set use_input_shape=False since the accumulator tensors will grow in
   2799     # size. If use_input_shape=True, the _update_input call below will result in
   2800     # incompatible shapes.
   2801     enter_acc = [
   2802         _Enter(
   2803             x,
   2804             self._name,
   2805             is_constant=False,
   2806             parallel_iterations=self._parallel_iterations,
   2807             use_input_shape=False,
   2808             name="b_acc") for x in init_acc
   2809     ]
   2810     # Manually set appropriate partial shapes.
   2811     enter_acc[0].set_shape([None])
   2812     if values_acc.shape.dims is not None:
   2813       enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
   2814     self.loop_enters.extend(enter_acc)
   2815 
   2816     merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
   2817     switch_acc = [switch(x, self._pivot) for x in merge_acc]
   2818 
   2819     # The actual accumulation.
   2820     acc_indexed_slices = [
   2821         array_ops.concat([xa[1], xv], 0)
   2822         for xa, xv in zip(switch_acc[:2], [indices, values])
   2823     ]
   2824     if shape_acc is not None:
   2825       # For the shape we just keep the maximum
   2826       acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
   2827 
   2828     next_acc = [_NextIteration(x) for x in acc_indexed_slices]
   2829     for xm, xn in zip(merge_acc, next_acc):
   2830       xm.op._update_input(1, xn)  # pylint: disable=protected-access
   2831 
   2832     exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
   2833     self.loop_exits.extend(exit_acc)
   2834 
   2835     self.ExitResult(exit_acc)
   2836     return ops.IndexedSlices(
   2837         indices=exit_acc[0],
   2838         values=exit_acc[1],
   2839         dense_shape=exit_acc[2] if shape_acc is not None else None)
   2840 
   2841   def _InitializeValues(self, values):
   2842     """Makes the values known to this context."""
   2843     self._values = set()
   2844     for x in values:
   2845       if isinstance(x, ops.Tensor):
   2846         self._values.add(x.name)
   2847       else:
   2848         raise TypeError("Type %s not supported" % type(x))
   2849 
   2850   def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
   2851                  shape_invariants):
   2852     """Core: Add the loop termination condition and body to the graph."""
   2853     flat_loop_vars = nest.flatten(original_loop_vars, expand_composites=True)
   2854 
   2855     # Let the context know the loop variables so the loop variables
   2856     # would be added in the outer contexts properly.
   2857     self._InitializeValues(loop_vars)
   2858     real_vars = loop_vars
   2859     if self._outer_context:
   2860       real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
   2861     with ops.control_dependencies(None):
   2862       enter_vars = [
   2863           _Enter(
   2864               x,
   2865               self._name,
   2866               is_constant=False,
   2867               parallel_iterations=self._parallel_iterations,
   2868               use_input_shape=(shape_invariants is None)) for x in real_vars
   2869       ]
   2870       for x in enter_vars:
   2871         x.graph.prevent_feeding(x)
   2872         if self._outer_context:
   2873           self._outer_context.AddInnerOp(x.op)
   2874 
   2875     # Finds the closest enclosing non-None control pivot.
   2876     outer_context = self._outer_context
   2877     control_pivot = None
   2878     while outer_context is not None and control_pivot is None:
   2879       control_pivot = outer_context.GetControlPivot()
   2880       # pylint: disable=protected-access
   2881       outer_context = outer_context._outer_context
   2882       # pylint: enable=protected-access
   2883 
   2884     if control_pivot is not None:
   2885       for var in enter_vars:
   2886         if util.IsLoopConstantEnter(var.op.inputs[0].op):
   2887           # pylint: disable=protected-access
   2888           var.op._add_control_input(control_pivot.op)
   2889           # pylint: enable=protected-access
   2890     _SetShapeInvariants(real_vars, enter_vars, shape_invariants)
   2891 
   2892     # Fix the control inputs and control flow context of these enter ops.
   2893     self._FixControlInputsAndContext(enter_vars)
   2894     self._InitializeValues(enter_vars)
   2895     self._loop_enters = enter_vars
   2896 
   2897     merge_vars = [merge([x, x])[0] for x in enter_vars]
   2898     self._pivot_for_pred = merge_vars[0]
   2899 
   2900     # Build the graph for pred.
   2901     merge_vars_with_tensor_arrays = (
   2902         _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
   2903     packed_vars = nest.pack_sequence_as(
   2904         structure=original_loop_vars,
   2905         flat_sequence=merge_vars_with_tensor_arrays,
   2906         expand_composites=True)
   2907     c = ops.convert_to_tensor(pred(*packed_vars))
   2908     self._pivot = loop_cond(c, name="LoopCond")
   2909     switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
   2910 
   2911     # Build the graph for body.
   2912     vars_for_body = [_Identity(x[1]) for x in switch_vars]
   2913     self._pivot_for_body = vars_for_body[0]
   2914     # Convert TensorArray flow variables inside the context back into
   2915     # their associated TensorArrays for calling the body.
   2916     vars_for_body_with_tensor_arrays = (
   2917         _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
   2918     packed_vars_for_body = nest.pack_sequence_as(
   2919         structure=original_loop_vars,
   2920         flat_sequence=vars_for_body_with_tensor_arrays,
   2921         expand_composites=True)
   2922     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   2923     body_result = body(*packed_vars_for_body)
   2924     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   2925     if not nest.is_sequence_or_composite(body_result):
   2926       body_result = [body_result]
   2927     if len(post_summaries) > len(pre_summaries):
   2928       new_summaries = post_summaries[len(pre_summaries):]
   2929       summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   2930       summary_ref[:] = pre_summaries
   2931       with ops.control_dependencies(new_summaries):
   2932 
   2933         def map_fn(x):
   2934           # TODO(apassos) figure out how to trigger with tensor arrays as well
   2935           if isinstance(x, tensor_array_ops.TensorArray):
   2936             return x
   2937           return array_ops.identity(x)
   2938 
   2939         body_result = nest.map_structure(map_fn, body_result,
   2940                                          expand_composites=True)
   2941 
   2942     # Compare the structure types of input and output of body.
   2943     # For backwards compatibility, the first layer is forced to a list
   2944     # during this comparison, because inputs are typically lists and
   2945     # outputs of the body are typically tuples.
   2946     nest.assert_same_structure(list(packed_vars_for_body), list(body_result),
   2947                                expand_composites=True)
   2948 
   2949     # Store body_result to keep track of TensorArrays returned by body
   2950     original_body_result = body_result
   2951     # Convert TensorArrays returned by body into their flow variables
   2952     result = nest.map_structure(
   2953         _convert_tensorarray_to_flow,
   2954         nest.flatten(body_result, expand_composites=True),
   2955         expand_composites=True)
   2956     result = ops.convert_n_to_tensor_or_composite(result)
   2957 
   2958     # Add NextIteration and the back edges to complete the loop.
   2959     if len(merge_vars) != len(result):
   2960       raise ValueError("Number of inputs and outputs of body must match "
   2961                        "loop_vars: %d, %d" % (len(merge_vars), len(result)))
   2962     next_vars = []
   2963     for m, v in zip(merge_vars, result):
   2964       next_vars.append(_AddNextAndBackEdge(m, v))
   2965 
   2966     # Add the exit ops.
   2967     exit_vars = [exit(x[0]) for x in switch_vars]
   2968     self._loop_exits = exit_vars
   2969 
   2970     # Exit the loop.
   2971     self.ExitResult(exit_vars)
   2972 
   2973     return original_body_result, exit_vars
   2974 
   2975   def BuildLoop(self, pred, body, loop_vars, shape_invariants,
   2976                 return_same_structure):
   2977     """Add the loop termination condition and body to the graph."""
   2978 
   2979     # Keep original_loop_vars to identify which are TensorArrays
   2980     original_loop_vars = loop_vars
   2981     # Convert TensorArrays to their flow variables
   2982     loop_vars = nest.map_structure(
   2983         _convert_tensorarray_to_flow,
   2984         nest.flatten(loop_vars, expand_composites=False),
   2985         expand_composites=True)
   2986     loop_vars = ops.convert_n_to_tensor_or_composite(loop_vars)
   2987     if shape_invariants is None:
   2988       shape_invariants = nest.map_structure(
   2989           _get_shape_invariant, loop_vars, expand_composites=False)
   2990     loop_vars = nest.flatten(loop_vars, expand_composites=True)
   2991     try:
   2992       self.Enter()
   2993       # _BuildLoop calls _update_input in several places. _mutation_lock()
   2994       # ensures a Session.run call cannot occur between creating and mutating
   2995       # new ops.
   2996       with ops.get_default_graph()._mutation_lock():  # pylint: disable=protected-access
   2997         original_body_result, exit_vars = self._BuildLoop(
   2998             pred, body, original_loop_vars, loop_vars, shape_invariants)
   2999     finally:
   3000       self.Exit()
   3001 
   3002     flat_result = nest.flatten(original_body_result, expand_composites=True)
   3003     # Convert TensorArray flow variables outside the context back into
   3004     # their associated TensorArrays for returning to caller.
   3005     exit_vars_with_tensor_arrays = (
   3006         _convert_flows_to_tensorarrays(flat_result, exit_vars))
   3007     packed_exit_vars = nest.pack_sequence_as(
   3008         structure=original_body_result,
   3009         flat_sequence=exit_vars_with_tensor_arrays,
   3010         expand_composites=True)
   3011 
   3012     if return_same_structure:
   3013       return packed_exit_vars
   3014     else:
   3015       return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
   3016 
   3017   def _FixControlInputsAndContext(self, enters):
   3018     graph = ops.get_default_graph()
   3019     # pylint: disable=protected-access
   3020     for e in enters:
   3021       if isinstance(e, ops.Tensor):
   3022         xs = [e]
   3023       else:
   3024         raise TypeError("Type %s not supported" % type(e))
   3025       for x in xs:
   3026         inp_op = x.op.inputs[0].op
   3027         control_inputs = graph._control_dependencies_for_inputs([inp_op])
   3028         outer_control_inputs = [
   3029             op for op in control_inputs if self._IsInOuterContext(op)
   3030         ]
   3031         x.op._set_control_flow_context(self)
   3032         x.op._add_control_inputs(outer_control_inputs)
   3033         graph._record_op_seen_by_control_dependencies(x.op)
   3034     # pylint: enable=protected-access
   3035 
   3036   def IsWhileContext(self):
   3037     return True
   3038 
   3039 
   3040 # pylint: disable=redefined-outer-name
   3041 @tf_export("while_loop", v1=[])
   3042 def while_loop_v2(cond,
   3043                   body,
   3044                   loop_vars,
   3045                   shape_invariants=None,
   3046                   parallel_iterations=10,
   3047                   back_prop=True,
   3048                   swap_memory=False,
   3049                   maximum_iterations=None,
   3050                   name=None):
   3051   """Repeat `body` while the condition `cond` is true.
   3052 
   3053   `cond` is a callable returning a boolean scalar tensor. `body` is a callable
   3054   returning a (possibly nested) tuple, namedtuple or list of tensors of the same
   3055   arity (length and structure) and types as `loop_vars`. `loop_vars` is a
   3056   (possibly nested) tuple, namedtuple or list of tensors that is passed to both
   3057   `cond` and `body`. `cond` and `body` both take as many arguments as there are
   3058   `loop_vars`.
   3059 
   3060   In addition to regular Tensors or IndexedSlices, the body may accept and
   3061   return TensorArray objects.  The flows of the TensorArray objects will
   3062   be appropriately forwarded between loops and during gradient calculations.
   3063 
   3064   Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
   3065   call to `while_loop`, and not at all during `Session.run()`). `while_loop`
   3066   stitches together the graph fragments created during the `cond` and `body`
   3067   calls with some additional graph nodes to create the graph flow that
   3068   repeats `body` until `cond` returns false.
   3069 
   3070   For correctness, `tf.while_loop()` strictly enforces shape invariants for
   3071   the loop variables. A shape invariant is a (possibly partial) shape that
   3072   is unchanged across the iterations of the loop. An error will be raised
   3073   if the shape of a loop variable after an iteration is determined to be more
   3074   general than or incompatible with its shape invariant. For example, a shape
   3075   of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
   3076   compatible with [11, 17]. By default (if the argument `shape_invariants` is
   3077   not specified), it is assumed that the initial shape of each tensor in
   3078   `loop_vars` is the same in every iteration. The `shape_invariants` argument
   3079   allows the caller to specify a less specific shape invariant for each loop
   3080   variable, which is needed if the shape varies between iterations. The
   3081   `tf.Tensor.set_shape`
   3082   function may also be used in the `body` function to indicate that
   3083   the output loop variable has a particular shape. The shape invariant for
   3084   SparseTensor and IndexedSlices are treated specially as follows:
   3085 
   3086   a) If a loop variable is a SparseTensor, the shape invariant must be
   3087   TensorShape([r]) where r is the rank of the dense tensor represented
   3088   by the sparse tensor. It means the shapes of the three tensors of the
   3089   SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
   3090   is the shape of the SparseTensor.dense_shape property. It must be the shape of
   3091   a vector.
   3092 
   3093   b) If a loop variable is an IndexedSlices, the shape invariant must be
   3094   a shape invariant of the values tensor of the IndexedSlices. It means
   3095   the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
   3096   [shape.ndims]).
   3097 
   3098   `while_loop` implements non-strict semantics, enabling multiple iterations
   3099   to run in parallel. The maximum number of parallel iterations can be
   3100   controlled by `parallel_iterations`, which gives users some control over
   3101   memory consumption and execution order. For correct programs, `while_loop`
   3102   should return the same result for any parallel_iterations > 0.
   3103 
   3104   For training, TensorFlow stores the tensors that are produced in the
   3105   forward inference and are needed in back propagation. These tensors are a
   3106   main source of memory consumption and often cause OOM errors when training
   3107   on GPUs. When the flag swap_memory is true, we swap out these tensors from
   3108   GPU to CPU. This for example allows us to train RNN models with very long
   3109   sequences and large batches.
   3110 
   3111   Args:
   3112     cond: A callable that represents the termination condition of the loop.
   3113     body: A callable that represents the loop body.
   3114     loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
   3115       `Tensor`, and `TensorArray` objects.
   3116     shape_invariants: The shape invariants for the loop variables.
   3117     parallel_iterations: The number of iterations allowed to run in parallel. It
   3118       must be a positive integer.
   3119     back_prop: Whether backprop is enabled for this while loop.
   3120     swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
   3121     maximum_iterations: Optional maximum number of iterations of the while loop
   3122       to run.  If provided, the `cond` output is AND-ed with an additional
   3123       condition ensuring the number of iterations executed is no greater than
   3124       `maximum_iterations`.
   3125     name: Optional name prefix for the returned tensors.
   3126 
   3127   Returns:
   3128     The output tensors for the loop variables after the loop. The return value
   3129       has the same structure as `loop_vars`.
   3130 
   3131   Raises:
   3132     TypeError: if `cond` or `body` is not callable.
   3133     ValueError: if `loop_vars` is empty.
   3134 
   3135   Example:
   3136 
   3137   ```python
   3138   i = tf.constant(0)
   3139   c = lambda i: tf.less(i, 10)
   3140   b = lambda i: tf.add(i, 1)
   3141   r = tf.while_loop(c, b, [i])
   3142   ```
   3143 
   3144   Example with nesting and a namedtuple:
   3145 
   3146   ```python
   3147   import collections
   3148   Pair = collections.namedtuple('Pair', 'j, k')
   3149   ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
   3150   c = lambda i, p: i < 10
   3151   b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
   3152   ijk_final = tf.while_loop(c, b, ijk_0)
   3153   ```
   3154 
   3155   Example using shape_invariants:
   3156 
   3157   ```python
   3158   i0 = tf.constant(0)
   3159   m0 = tf.ones([2, 2])
   3160   c = lambda i, m: i < 10
   3161   b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
   3162   tf.while_loop(
   3163       c, b, loop_vars=[i0, m0],
   3164       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
   3165   ```
   3166 
   3167   Example which demonstrates non-strict semantics: In the following
   3168   example, the final value of the counter `i` does not depend on `x`. So
   3169   the `while_loop` can increment the counter parallel to updates of `x`.
   3170   However, because the loop counter at one loop iteration depends
   3171   on the value at the previous iteration, the loop counter itself cannot
   3172   be incremented in parallel. Hence if we just want the final value of the
   3173   counter (which we print on the line `print(sess.run(i))`), then
   3174   `x` will never be incremented, but the counter will be updated on a
   3175   single thread. Conversely, if we want the value of the output (which we
   3176   print on the line `print(sess.run(out).shape)`), then the counter may be
   3177   incremented on its own thread, while `x` can be incremented in
   3178   parallel on a separate thread. In the extreme case, it is conceivable
   3179   that the thread incrementing the counter runs until completion before
   3180   `x` is incremented even a single time. The only thing that can never
   3181   happen is that the thread updating `x` can never get ahead of the
   3182   counter thread because the thread incrementing `x` depends on the value
   3183   of the counter.
   3184 
   3185   ```python
   3186   import tensorflow as tf
   3187 
   3188   n = 10000
   3189   x = tf.constant(list(range(n)))
   3190   c = lambda i, x: i < n
   3191   b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
   3192   i, out = tf.while_loop(c, b, (0, x))
   3193   with tf.Session() as sess:
   3194       print(sess.run(i))  # prints [0] ... [9999]
   3195 
   3196       # The following line may increment the counter and x in parallel.
   3197       # The counter thread may get ahead of the other thread, but not the
   3198       # other way around. So you may see things like
   3199       # [9996] x:[9987]
   3200       # meaning that the counter thread is on iteration 9996,
   3201       # while the other thread is on iteration 9987
   3202       print(sess.run(out).shape)
   3203   ```
   3204 
   3205   """
   3206   return while_loop(
   3207       cond=cond,
   3208       body=body,
   3209       loop_vars=loop_vars,
   3210       shape_invariants=shape_invariants,
   3211       parallel_iterations=parallel_iterations,
   3212       back_prop=back_prop,
   3213       swap_memory=swap_memory,
   3214       name=name,
   3215       maximum_iterations=maximum_iterations,
   3216       return_same_structure=True)
   3217 
   3218 
   3219 # pylint: disable=redefined-outer-name
   3220 @tf_export(v1=["while_loop"])
   3221 def while_loop(cond,
   3222                body,
   3223                loop_vars,
   3224                shape_invariants=None,
   3225                parallel_iterations=10,
   3226                back_prop=True,
   3227                swap_memory=False,
   3228                name=None,
   3229                maximum_iterations=None,
   3230                return_same_structure=False):
   3231   """Repeat `body` while the condition `cond` is true.
   3232 
   3233   `cond` is a callable returning a boolean scalar tensor. `body` is a callable
   3234   returning a (possibly nested) tuple, namedtuple or list of tensors of the same
   3235   arity (length and structure) and types as `loop_vars`. `loop_vars` is a
   3236   (possibly nested) tuple, namedtuple or list of tensors that is passed to both
   3237   `cond` and `body`. `cond` and `body` both take as many arguments as there are
   3238   `loop_vars`.
   3239 
   3240   In addition to regular Tensors or IndexedSlices, the body may accept and
   3241   return TensorArray objects.  The flows of the TensorArray objects will
   3242   be appropriately forwarded between loops and during gradient calculations.
   3243 
   3244   Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
   3245   call to `while_loop`, and not at all during `Session.run()`). `while_loop`
   3246   stitches together the graph fragments created during the `cond` and `body`
   3247   calls with some additional graph nodes to create the graph flow that
   3248   repeats `body` until `cond` returns false.
   3249 
   3250   For correctness, `tf.while_loop()` strictly enforces shape invariants for
   3251   the loop variables. A shape invariant is a (possibly partial) shape that
   3252   is unchanged across the iterations of the loop. An error will be raised
   3253   if the shape of a loop variable after an iteration is determined to be more
   3254   general than or incompatible with its shape invariant. For example, a shape
   3255   of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
   3256   compatible with [11, 17]. By default (if the argument `shape_invariants` is
   3257   not specified), it is assumed that the initial shape of each tensor in
   3258   `loop_vars` is the same in every iteration. The `shape_invariants` argument
   3259   allows the caller to specify a less specific shape invariant for each loop
   3260   variable, which is needed if the shape varies between iterations. The
   3261   `tf.Tensor.set_shape`
   3262   function may also be used in the `body` function to indicate that
   3263   the output loop variable has a particular shape. The shape invariant for
   3264   SparseTensor and IndexedSlices are treated specially as follows:
   3265 
   3266   a) If a loop variable is a SparseTensor, the shape invariant must be
   3267   TensorShape([r]) where r is the rank of the dense tensor represented
   3268   by the sparse tensor. It means the shapes of the three tensors of the
   3269   SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
   3270   is the shape of the SparseTensor.dense_shape property. It must be the shape of
   3271   a vector.
   3272 
   3273   b) If a loop variable is an IndexedSlices, the shape invariant must be
   3274   a shape invariant of the values tensor of the IndexedSlices. It means
   3275   the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
   3276   [shape.ndims]).
   3277 
   3278   `while_loop` implements non-strict semantics, enabling multiple iterations
   3279   to run in parallel. The maximum number of parallel iterations can be
   3280   controlled by `parallel_iterations`, which gives users some control over
   3281   memory consumption and execution order. For correct programs, `while_loop`
   3282   should return the same result for any parallel_iterations > 0.
   3283 
   3284   For training, TensorFlow stores the tensors that are produced in the
   3285   forward inference and are needed in back propagation. These tensors are a
   3286   main source of memory consumption and often cause OOM errors when training
   3287   on GPUs. When the flag swap_memory is true, we swap out these tensors from
   3288   GPU to CPU. This for example allows us to train RNN models with very long
   3289   sequences and large batches.
   3290 
   3291   Args:
   3292     cond: A callable that represents the termination condition of the loop.
   3293     body: A callable that represents the loop body.
   3294     loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
   3295       `Tensor`, and `TensorArray` objects.
   3296     shape_invariants: The shape invariants for the loop variables.
   3297     parallel_iterations: The number of iterations allowed to run in parallel. It
   3298       must be a positive integer.
   3299     back_prop: Whether backprop is enabled for this while loop.
   3300     swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
   3301     name: Optional name prefix for the returned tensors.
   3302     maximum_iterations: Optional maximum number of iterations of the while loop
   3303       to run.  If provided, the `cond` output is AND-ed with an additional
   3304       condition ensuring the number of iterations executed is no greater than
   3305       `maximum_iterations`.
   3306     return_same_structure: If True, output has same structure as `loop_vars`. If
   3307       eager execution is enabled, this is ignored (and always treated as True).
   3308 
   3309   Returns:
   3310     The output tensors for the loop variables after the loop.
   3311      If `return_same_structure` is True, the return value has the same
   3312      structure as `loop_vars`.
   3313      If `return_same_structure` is False, the return value is a Tensor,
   3314      TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
   3315      otherwise.
   3316 
   3317   Raises:
   3318     TypeError: if `cond` or `body` is not callable.
   3319     ValueError: if `loop_vars` is empty.
   3320 
   3321   Example:
   3322 
   3323   ```python
   3324   i = tf.constant(0)
   3325   c = lambda i: tf.less(i, 10)
   3326   b = lambda i: tf.add(i, 1)
   3327   r = tf.while_loop(c, b, [i])
   3328   ```
   3329 
   3330   Example with nesting and a namedtuple:
   3331 
   3332   ```python
   3333   import collections
   3334   Pair = collections.namedtuple('Pair', 'j, k')
   3335   ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
   3336   c = lambda i, p: i < 10
   3337   b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
   3338   ijk_final = tf.while_loop(c, b, ijk_0)
   3339   ```
   3340 
   3341   Example using shape_invariants:
   3342 
   3343   ```python
   3344   i0 = tf.constant(0)
   3345   m0 = tf.ones([2, 2])
   3346   c = lambda i, m: i < 10
   3347   b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
   3348   tf.while_loop(
   3349       c, b, loop_vars=[i0, m0],
   3350       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
   3351   ```
   3352 
   3353   Example which demonstrates non-strict semantics: In the following
   3354   example, the final value of the counter `i` does not depend on `x`. So
   3355   the `while_loop` can increment the counter parallel to updates of `x`.
   3356   However, because the loop counter at one loop iteration depends
   3357   on the value at the previous iteration, the loop counter itself cannot
   3358   be incremented in parallel. Hence if we just want the final value of the
   3359   counter (which we print on the line `print(sess.run(i))`), then
   3360   `x` will never be incremented, but the counter will be updated on a
   3361   single thread. Conversely, if we want the value of the output (which we
   3362   print on the line `print(sess.run(out).shape)`), then the counter may be
   3363   incremented on its own thread, while `x` can be incremented in
   3364   parallel on a separate thread. In the extreme case, it is conceivable
   3365   that the thread incrementing the counter runs until completion before
   3366   `x` is incremented even a single time. The only thing that can never
   3367   happen is that the thread updating `x` can never get ahead of the
   3368   counter thread because the thread incrementing `x` depends on the value
   3369   of the counter.
   3370 
   3371   ```python
   3372   import tensorflow as tf
   3373 
   3374   n = 10000
   3375   x = tf.constant(list(range(n)))
   3376   c = lambda i, x: i < n
   3377   b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
   3378   i, out = tf.while_loop(c, b, (0, x))
   3379   with tf.Session() as sess:
   3380       print(sess.run(i))  # prints [0] ... [9999]
   3381 
   3382       # The following line may increment the counter and x in parallel.
   3383       # The counter thread may get ahead of the other thread, but not the
   3384       # other way around. So you may see things like
   3385       # [9996] x:[9987]
   3386       # meaning that the counter thread is on iteration 9996,
   3387       # while the other thread is on iteration 9987
   3388       print(sess.run(out).shape)
   3389   ```
   3390 
   3391   """
   3392   # Always enable control flow v2 if building a function, regardless of toggle.
   3393   if (util.EnableControlFlowV2(ops.get_default_graph()) and
   3394       not context.executing_eagerly()):
   3395     return while_v2.while_loop(
   3396         cond,
   3397         body,
   3398         loop_vars,
   3399         shape_invariants=shape_invariants,
   3400         parallel_iterations=parallel_iterations,
   3401         maximum_iterations=maximum_iterations,
   3402         name=name,
   3403         return_same_structure=return_same_structure)
   3404 
   3405   with ops.name_scope(name, "while", loop_vars):
   3406     if not loop_vars:
   3407       raise ValueError("No loop variables provided")
   3408     if not callable(cond):
   3409       raise TypeError("cond must be callable.")
   3410     if not callable(body):
   3411       raise TypeError("body must be callable.")
   3412     if parallel_iterations < 1:
   3413       raise TypeError("parallel_iterations must be a positive integer.")
   3414 
   3415     if maximum_iterations is not None:
   3416       maximum_iterations = ops.convert_to_tensor(
   3417           maximum_iterations, name="maximum_iterations")
   3418       if maximum_iterations.shape.ndims != 0:
   3419         raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
   3420                          maximum_iterations.shape)
   3421 
   3422       counter = constant_op.constant(
   3423           0, dtype=maximum_iterations.dtype, name="iteration_counter")
   3424       orig_cond = cond
   3425       orig_body = body
   3426       if len(loop_vars) == 1:
   3427         loop_vars = (counter, loop_vars[0])
   3428         cond = lambda i, lv: (  # pylint: disable=g-long-lambda
   3429             math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
   3430         body = lambda i, lv: (i + 1, orig_body(lv))
   3431       else:
   3432         loop_vars = (counter, loop_vars)
   3433         cond = lambda i, lv: (  # pylint: disable=g-long-lambda
   3434             math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
   3435         body = lambda i, lv: (i + 1, orig_body(*lv))
   3436 
   3437     if context.executing_eagerly():
   3438       try_to_pack = len(loop_vars) == 1
   3439       packed = False  # whether the body result was packed into a 1-item tuple
   3440 
   3441       while cond(*loop_vars):
   3442         loop_vars = body(*loop_vars)
   3443         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   3444           packed = True
   3445           loop_vars = (loop_vars,)
   3446 
   3447       def convert(x):
   3448         if isinstance(x, tensor_array_ops.TensorArray):
   3449           return x
   3450         return ops.convert_to_tensor(x)
   3451       loop_vars = nest.map_structure(convert, loop_vars)
   3452       if maximum_iterations is not None:
   3453         return loop_vars[1]
   3454       else:
   3455         return loop_vars[0] if packed else loop_vars
   3456 
   3457     if shape_invariants is not None:
   3458       if maximum_iterations is not None:
   3459         shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
   3460 
   3461       nest.assert_same_structure(loop_vars, shape_invariants,
   3462                                  expand_composites=False)
   3463       shape_invariants = nest.map_structure(
   3464           _get_shape_invariant, loop_vars, shape_invariants,
   3465           expand_composites=False)
   3466 
   3467     loop_context = WhileContext(
   3468         maximum_iterations=maximum_iterations,
   3469         parallel_iterations=parallel_iterations,
   3470         back_prop=back_prop,
   3471         swap_memory=swap_memory)
   3472     # Only add non-nested loops to the collection. Any nested control flow will
   3473     # be encapsulated in the root context.
   3474     if loop_context.outer_context is None:
   3475       ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
   3476     result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
   3477                                     return_same_structure)
   3478     if maximum_iterations is not None:
   3479       return result[1]
   3480     else:
   3481       return result
   3482 
   3483 
   3484 # pylint: enable=redefined-outer-name
   3485 
   3486 
   3487 def _AsTensorList(x, p):
   3488   """Return x as a list of Tensors or IndexedSlices.
   3489 
   3490   For entries of `x` that are Operations, this returns an Identity of `p`
   3491   with a dependency on the operation.
   3492 
   3493   Args:
   3494     x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
   3495     p: A Tensor to return for entries in `x` that are Operations.
   3496 
   3497   Returns:
   3498     A list of Tensors or IndexedSlices.
   3499   """
   3500   if not isinstance(x, (list, _basetuple)):
   3501     x = [x]
   3502 
   3503   l = []
   3504   for v in x:
   3505     if isinstance(v, ops.Operation):
   3506       v = with_dependencies([v], p)
   3507     v = ops.convert_to_tensor_or_composite(v)
   3508     if isinstance(v, ops.Tensor):
   3509       l.append(array_ops.identity(v))
   3510     else:
   3511       l.append(
   3512           ops.IndexedSlices(
   3513               array_ops.identity(v.values), array_ops.identity(v.indices)))
   3514   return l
   3515 
   3516 
   3517 def _CheckResults(a, b):
   3518   assert len(a) == len(b), (
   3519       "Values returned by a() and b() must have the same length.")
   3520   for x, y in zip(a, b):
   3521     assert x.dtype == y.dtype, (
   3522         "Values returned by a() [%s] and b() [%s] must have "
   3523         "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
   3524 
   3525 
   3526 def with_dependencies(dependencies, output_tensor, name=None):
   3527   """Produces the content of `output_tensor` only after `dependencies`.
   3528 
   3529   In some cases, a user may want the output of an operation to be
   3530   consumed externally only after some other dependencies have run
   3531   first. This function ensures returns `output_tensor`, but only after all
   3532   operations in `dependencies` have run. Note that this means that there is
   3533   no guarantee that `output_tensor` will be evaluated after any `dependencies`
   3534   have run.
   3535 
   3536   See also `tf.tuple` and `tf.group`.
   3537 
   3538   Args:
   3539     dependencies: Iterable of operations to run before this op finishes.
   3540     output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
   3541     name: (Optional) A name for this operation.
   3542 
   3543   Returns:
   3544     Same as `output_tensor`.
   3545 
   3546   Raises:
   3547     TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
   3548   """
   3549   if context.executing_eagerly():
   3550     return output_tensor
   3551   with ops.name_scope(name, "control_dependency",
   3552                       list(dependencies) + [output_tensor]) as name:
   3553     with ops.colocate_with(output_tensor):
   3554       with ops.control_dependencies(dependencies):
   3555         output_tensor = ops.convert_to_tensor_or_composite(output_tensor)
   3556         if isinstance(output_tensor, ops.Tensor):
   3557           return _Identity(output_tensor, name=name)
   3558         else:
   3559           return ops.IndexedSlices(
   3560               _Identity(output_tensor.values, name=name), output_tensor.indices,
   3561               output_tensor.dense_shape)
   3562 
   3563 
   3564 def _GroupControlDeps(dev, deps, name=None):
   3565   with ops.control_dependencies(deps):
   3566     if dev is None:
   3567       return no_op(name=name)
   3568     else:
   3569       with ops.device(dev):
   3570         return no_op(name=name)
   3571 
   3572 
   3573 # TODO(touts): Accept "inputs" as a list.
   3574 @tf_export("group")
   3575 def group(*inputs, **kwargs):
   3576   """Create an op that groups multiple operations.
   3577 
   3578   When this op finishes, all ops in `inputs` have finished. This op has no
   3579   output.
   3580 
   3581   See also `tf.tuple` and
   3582   `tf.control_dependencies`.
   3583 
   3584   Args:
   3585     *inputs: Zero or more tensors to group.
   3586     name: A name for this operation (optional).
   3587 
   3588   Returns:
   3589     An Operation that executes all its inputs.
   3590 
   3591   Raises:
   3592     ValueError: If an unknown keyword argument is provided.
   3593   """
   3594   if context.executing_eagerly():
   3595     return None
   3596   name = kwargs.pop("name", None)
   3597   if kwargs:
   3598     raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
   3599   with ops.name_scope(name, "group_deps", inputs) as name:
   3600     # Grouping no inputs means do nothing
   3601     if not inputs:
   3602       return no_op(name=name)
   3603 
   3604     # Sorts *inputs according to their devices.
   3605     ops_on_device = {}  # device -> operations specified on the device.
   3606     for inp in nest.flatten(inputs, expand_composites=True):
   3607       if not hasattr(inp, "device"):
   3608         raise TypeError("Expected tf.group() expected Tensor arguments not "
   3609                         "'%s' with type '%s'" % (inp, type(inp)))
   3610       dev = inp.device
   3611       if dev in ops_on_device:
   3612         ops_on_device[dev].append(inp)
   3613       else:
   3614         ops_on_device[dev] = [inp]
   3615     if len(ops_on_device) == 1:
   3616       # 1-level tree. The root node is the returned NoOp node.
   3617       (dev, deps), = ops_on_device.items()
   3618       return _GroupControlDeps(dev, deps, name=name)
   3619 
   3620     # 2-level tree. The root node is the returned NoOp node.
   3621     # deps contains 1 NoOp node for each device.
   3622     deps = []
   3623 
   3624     def device_key(dev):
   3625       """A sort key that allows None to be compared to strings."""
   3626       return "" if dev is None else dev
   3627 
   3628     for dev in sorted(ops_on_device, key=device_key):
   3629       deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
   3630 
   3631     with ops.control_dependencies(deps):
   3632       return no_op(name=name)
   3633 
   3634 
   3635 @tf_export("tuple", v1=[])
   3636 def tuple_v2(tensors, control_inputs=None, name=None):
   3637   """Group tensors together.
   3638 
   3639   This creates a tuple of tensors with the same values as the `tensors`
   3640   argument, except that the value of each tensor is only returned after the
   3641   values of all tensors have been computed.
   3642 
   3643   `control_inputs` contains additional ops that have to finish before this op
   3644   finishes, but whose outputs are not returned.
   3645 
   3646   This can be used as a "join" mechanism for parallel computations: all the
   3647   argument tensors can be computed in parallel, but the values of any tensor
   3648   returned by `tuple` are only available after all the parallel computations
   3649   are done.
   3650 
   3651   See also `tf.group` and
   3652   `tf.control_dependencies`.
   3653 
   3654   Args:
   3655     tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
   3656     control_inputs: List of additional ops to finish before returning.
   3657     name: (optional) A name to use as a `name_scope` for the operation.
   3658 
   3659   Returns:
   3660     Same as `tensors`.
   3661 
   3662   Raises:
   3663     ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
   3664     TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
   3665       objects.
   3666 
   3667   """
   3668   return tuple(tensors=tensors, name=name, control_inputs=control_inputs)  # pylint: disable=redefined-builtin
   3669 
   3670 
   3671 @tf_export(v1=["tuple"])
   3672 def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
   3673   """Group tensors together.
   3674 
   3675   This creates a tuple of tensors with the same values as the `tensors`
   3676   argument, except that the value of each tensor is only returned after the
   3677   values of all tensors have been computed.
   3678 
   3679   `control_inputs` contains additional ops that have to finish before this op
   3680   finishes, but whose outputs are not returned.
   3681 
   3682   This can be used as a "join" mechanism for parallel computations: all the
   3683   argument tensors can be computed in parallel, but the values of any tensor
   3684   returned by `tuple` are only available after all the parallel computations
   3685   are done.
   3686 
   3687   See also `tf.group` and
   3688   `tf.control_dependencies`.
   3689 
   3690   Args:
   3691     tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
   3692     name: (optional) A name to use as a `name_scope` for the operation.
   3693     control_inputs: List of additional ops to finish before returning.
   3694 
   3695   Returns:
   3696     Same as `tensors`.
   3697 
   3698   Raises:
   3699     ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
   3700     TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
   3701       objects.
   3702 
   3703   """
   3704   if context.executing_eagerly():
   3705     return tensors
   3706   with ops.name_scope(name, "tuple", tensors) as name:
   3707     tensors = [
   3708         t if (isinstance(t, ops.Operation) or tensor_util.is_tensor(t) or
   3709               t is None) else ops.convert_to_tensor(t) for t in tensors
   3710     ]
   3711     gating_ops = [
   3712         t if isinstance(t, ops.Operation) else t.op
   3713         for t in tensors
   3714         if t is not None
   3715     ]
   3716     if control_inputs:
   3717       for c in control_inputs:
   3718         if isinstance(c, ops.Tensor):
   3719           c = c.op
   3720         elif not isinstance(c, ops.Operation):
   3721           raise TypeError("Control input must be Operation or Tensor: %s" % c)
   3722         gating_ops.append(c)
   3723     # Note that in order to ensure ordering in the pbtxt, we must take care to
   3724     # ensure the order here.
   3725     gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
   3726     if not gating_ops:
   3727       raise ValueError("Must have at least one Tensor: %s" % tensors)
   3728     gate = group(*gating_ops)
   3729     tpl = []
   3730     for t in tensors:
   3731       if tensor_util.is_tensor(t):
   3732         tpl.append(with_dependencies([gate], t))
   3733       elif isinstance(t, ops.Operation):
   3734         with ops.control_dependencies([gate]):
   3735           tpl.append(group(t))
   3736       else:
   3737         tpl.append(None)
   3738     return tpl
   3739 
   3740 
   3741 def _assert_at_most_n_true(predicates, n, msg):
   3742   """Returns an Assert op that checks that at most n predicates are True.
   3743 
   3744   Args:
   3745     predicates: list of bool scalar tensors.
   3746     n: maximum number of true predicates allowed.
   3747     msg: Error message.
   3748   """
   3749   preds_c = array_ops.stack(predicates, name="preds_c")
   3750   num_true_conditions = math_ops.reduce_sum(
   3751       math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
   3752   condition = math_ops.less_equal(num_true_conditions,
   3753                                   constant_op.constant(n, name="n_true_conds"))
   3754   preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
   3755   error_msg = [
   3756       "%s: more than %d conditions (%s) evaluated as True:" %
   3757       (msg, n, preds_names), preds_c
   3758   ]
   3759   return Assert(condition, data=error_msg, summarize=len(predicates))
   3760 
   3761 
   3762 def _case_create_default_action(predicates, actions):
   3763   """Creates default action for a list of actions and their predicates.
   3764 
   3765   It uses the input actions to select an arbitrary as default and makes sure
   3766   that corresponding predicates have valid values.
   3767 
   3768   Args:
   3769     predicates: a list of bool scalar tensors
   3770     actions: a list of callable objects which return tensors.
   3771 
   3772   Returns:
   3773     a callable
   3774   """
   3775   k = len(predicates) - 1  # could pick any
   3776   predicate, action = predicates[k], actions[k]
   3777   other_predicates, other_actions = predicates[:k], actions[:k]
   3778 
   3779   def default_action():
   3780     others_msg = ("Implementation error: "
   3781                   "selected default action #%d was called, but some of other "
   3782                   "predicates are True: " % k)
   3783     default_msg = ("Input error: "
   3784                    "None of conditions evaluated as True:",
   3785                    array_ops.stack(predicates, name="preds_c"))
   3786     with ops.control_dependencies([
   3787         _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
   3788         Assert(predicate, data=default_msg)
   3789     ]):
   3790       return action()
   3791 
   3792   return default_action, other_predicates, other_actions
   3793 
   3794 
   3795 def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name,
   3796                                        allow_python_preds):
   3797   """Verifies input arguments for the case function.
   3798 
   3799   Args:
   3800     pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
   3801       callable which returns a list of tensors.
   3802     exclusive: True iff at most one predicate is allowed to evaluate to `True`.
   3803     name: A name for the case operation.
   3804     allow_python_preds: if true, pred_fn_pairs may contain Python bools in
   3805       addition to boolean Tensors
   3806 
   3807   Raises:
   3808     TypeError: If `pred_fn_pairs` is not a list/dictionary.
   3809     TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
   3810     TypeError: If `fns[i]` is not callable for any i, or `default` is not
   3811                callable.
   3812 
   3813   Returns:
   3814     a tuple <list of scalar bool tensors, list of callables>.
   3815   """
   3816   if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
   3817     raise TypeError("fns must be a list, tuple, or dict")
   3818 
   3819   if isinstance(pred_fn_pairs, collections.OrderedDict):
   3820     pred_fn_pairs = pred_fn_pairs.items()
   3821   elif isinstance(pred_fn_pairs, dict):
   3822     if context.executing_eagerly():
   3823       # No name to sort on in eager mode. Use dictionary traversal order,
   3824       # which is nondeterministic in versions of Python < 3.6
   3825       if not exclusive:
   3826         raise ValueError("Unordered dictionaries are not supported for the "
   3827                          "`pred_fn_pairs` argument when `exclusive=False` and "
   3828                          "eager mode is enabled.")
   3829       pred_fn_pairs = list(pred_fn_pairs.items())
   3830     else:
   3831       pred_fn_pairs = sorted(
   3832           pred_fn_pairs.items(), key=lambda item: item[0].name)
   3833       if not exclusive:
   3834         logging.warn(
   3835             "%s: An unordered dictionary of predicate/fn pairs was "
   3836             "provided, but exclusive=False. The order of conditional "
   3837             "tests is deterministic but not guaranteed.", name)
   3838   for pred_fn_pair in pred_fn_pairs:
   3839     if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
   3840       raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
   3841     pred, fn = pred_fn_pair
   3842 
   3843     if isinstance(pred, ops.Tensor):
   3844       if pred.dtype != dtypes.bool:
   3845         raise TypeError("pred must be Tensor of type bool: %s" % pred.name)
   3846     elif not allow_python_preds:
   3847       raise TypeError("pred must be a Tensor, got: %s" % pred)
   3848     elif not isinstance(pred, bool):
   3849       raise TypeError("pred must be a Tensor or bool, got: %s" % pred)
   3850 
   3851     if not callable(fn):
   3852       raise TypeError("fn for pred %s must be callable." % pred.name)
   3853 
   3854   predicates, actions = zip(*pred_fn_pairs)
   3855   return predicates, actions
   3856 
   3857 
   3858 def _case_helper(cond_fn,
   3859                  pred_fn_pairs,
   3860                  default,
   3861                  exclusive,
   3862                  name,
   3863                  allow_python_preds=False,
   3864                  **cond_kwargs):
   3865   """Implementation of case that allows for different cond functions.
   3866 
   3867   Args:
   3868     cond_fn: method that has signature and semantics of `cond` above.
   3869     pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor, and a
   3870       callable which returns a list of tensors.
   3871     default: Optional callable that returns a list of tensors.
   3872     exclusive: True iff at most one predicate is allowed to evaluate to `True`.
   3873     name: A name for this operation (optional).
   3874     allow_python_preds: if true, pred_fn_pairs may contain Python bools in
   3875       addition to boolean Tensors
   3876     **cond_kwargs: keyword arguments that will be passed to `cond_fn`.
   3877 
   3878   Returns:
   3879     The tensors returned by the first pair whose predicate evaluated to True, or
   3880     those returned by `default` if none does.
   3881 
   3882   Raises:
   3883     TypeError: If `pred_fn_pairs` is not a list/dictionary.
   3884     TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
   3885     TypeError: If `fns[i]` is not callable for any i, or `default` is not
   3886                callable.
   3887   """
   3888   predicates, actions = _case_verify_and_canonicalize_args(
   3889       pred_fn_pairs, exclusive, name, allow_python_preds)
   3890   with ops.name_scope(name, "case", [predicates]):
   3891     if default is None:
   3892       default, predicates, actions = _case_create_default_action(
   3893           predicates, actions)
   3894     fn = default
   3895     # To eval conditions in direct order we create nested conditions in reverse:
   3896     #   cond_fn(c[0], true_fn=.., false_fn=cond_fn(c[1], ...))
   3897     for predicate, action in reversed(list(zip(predicates, actions))):
   3898       fn = functools.partial(
   3899           cond_fn, predicate, true_fn=action, false_fn=fn, **cond_kwargs)
   3900     if exclusive:
   3901       with ops.control_dependencies([
   3902           _assert_at_most_n_true(
   3903               predicates, n=1, msg="Input error: exclusive=True")
   3904       ]):
   3905         return fn()
   3906     else:
   3907       return fn()
   3908 
   3909 
   3910 @tf_export("case")
   3911 def case(pred_fn_pairs,
   3912          default=None,
   3913          exclusive=False,
   3914          strict=False,
   3915          name="case"):
   3916   """Create a case operation.
   3917 
   3918   The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
   3919   Each pair contains a boolean scalar tensor and a python callable that
   3920   creates the tensors to be returned if the boolean evaluates to True.
   3921   `default` is a callable generating a list of tensors. All the callables
   3922   in `pred_fn_pairs` as well as `default` (if provided) should return the same
   3923   number and types of tensors.
   3924 
   3925   If `exclusive==True`, all predicates are evaluated, and an exception is
   3926   thrown if more than one of the predicates evaluates to `True`.
   3927   If `exclusive==False`, execution stops at the first predicate which
   3928   evaluates to True, and the tensors generated by the corresponding function
   3929   are returned immediately. If none of the predicates evaluate to True, this
   3930   operation returns the tensors generated by `default`.
   3931 
   3932   `tf.case` supports nested structures as implemented in
   3933   `tf.contrib.framework.nest`. All of the callables must return the same
   3934   (possibly nested) value structure of lists, tuples, and/or named tuples.
   3935   Singleton lists and tuples form the only exceptions to this: when returned by
   3936   a callable, they are implicitly unpacked to single values. This
   3937   behavior is disabled by passing `strict=True`.
   3938 
   3939   If an unordered dictionary is used for `pred_fn_pairs`, the order of the
   3940   conditional tests is not guaranteed. However, the order is guaranteed to be
   3941   deterministic, so that variables created in conditional branches are created
   3942   in fixed order across runs.
   3943 
   3944   @compatibility{eager}
   3945   Unordered dictionaries are not supported in eager mode when `exclusive=False`.
   3946   Use a list of tuples instead.
   3947   @end_compatibility
   3948 
   3949 
   3950   **Example 1:**
   3951 
   3952   Pseudocode:
   3953 
   3954   ```
   3955   if (x < y) return 17;
   3956   else return 23;
   3957   ```
   3958 
   3959   Expressions:
   3960 
   3961   ```python
   3962   f1 = lambda: tf.constant(17)
   3963   f2 = lambda: tf.constant(23)
   3964   r = tf.case([(tf.less(x, y), f1)], default=f2)
   3965   ```
   3966 
   3967   **Example 2:**
   3968 
   3969   Pseudocode:
   3970 
   3971   ```
   3972   if (x < y && x > z) raise OpError("Only one predicate may evaluate to True");
   3973   if (x < y) return 17;
   3974   else if (x > z) return 23;
   3975   else return -1;
   3976   ```
   3977 
   3978   Expressions:
   3979 
   3980   ```python
   3981   def f1(): return tf.constant(17)
   3982   def f2(): return tf.constant(23)
   3983   def f3(): return tf.constant(-1)
   3984   r = tf.case({tf.less(x, y): f1, tf.greater(x, z): f2},
   3985            default=f3, exclusive=True)
   3986   ```
   3987 
   3988   Args:
   3989     pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
   3990       callable which returns a list of tensors.
   3991     default: Optional callable that returns a list of tensors.
   3992     exclusive: True iff at most one predicate is allowed to evaluate to `True`.
   3993     strict: A boolean that enables/disables 'strict' mode; see above.
   3994     name: A name for this operation (optional).
   3995 
   3996   Returns:
   3997     The tensors returned by the first pair whose predicate evaluated to True, or
   3998     those returned by `default` if none does.
   3999 
   4000   Raises:
   4001     TypeError: If `pred_fn_pairs` is not a list/dictionary.
   4002     TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
   4003     TypeError: If `fns[i]` is not callable for any i, or `default` is not
   4004                callable.
   4005   """
   4006   return _case_helper(
   4007       cond,
   4008       pred_fn_pairs,
   4009       default,
   4010       exclusive,
   4011       name,
   4012       allow_python_preds=False,
   4013       strict=strict)
   4014 
   4015 
   4016 class XLAControlFlowContext(ControlFlowContext):
   4017   """Base class for XLA and TPU control flow contexts."""
   4018 
   4019   def __init__(self):
   4020     super(XLAControlFlowContext, self).__init__()
   4021     self._name = "XLAControlFlowContext"
   4022 
   4023   def to_control_flow_context_def(self, context_def, export_scope=None):
   4024     # pylint: disable=useless-super-delegation
   4025     # NOTE(slebedev): the method is required by `ControlFlowContext`.
   4026     super(XLAControlFlowContext, self).to_control_flow_context_def(
   4027         context_def, export_scope)
   4028 
   4029   def IsXLAContext(self):
   4030     return True
   4031 
   4032   def AddOp(self, _):
   4033     pass
   4034 
   4035   def AddValue(self, x):
   4036     return x
   4037 
   4038 
   4039 def from_control_flow_context_def(context_def, import_scope=None):
   4040   """Deserializes `context_def` into the appropriate ControlFlowContext.
   4041 
   4042   Args:
   4043     context_def: ControlFlowContextDef proto
   4044     import_scope: Optional `string`. Name scope to add.
   4045 
   4046   Returns:
   4047     A ControlFlowContext subclass
   4048   """
   4049   if context_def.HasField("cond_ctxt"):
   4050     return CondContext.from_proto(
   4051         context_def.cond_ctxt, import_scope=import_scope)
   4052   if context_def.HasField("while_ctxt"):
   4053     return WhileContext.from_proto(
   4054         context_def.while_ctxt, import_scope=import_scope)
   4055   raise NotImplementedError("Unknown ControlFlowContextDef field: %s" %
   4056                             context_def.WhichOneof("ctxt"))
   4057 
   4058 
   4059 ops.register_proto_function(
   4060     ops.GraphKeys.COND_CONTEXT,
   4061     proto_type=control_flow_pb2.CondContextDef,
   4062     to_proto=CondContext.to_proto,
   4063     from_proto=CondContext.from_proto)
   4064 
   4065 ops.register_proto_function(
   4066     ops.GraphKeys.WHILE_CONTEXT,
   4067     proto_type=control_flow_pb2.WhileContextDef,
   4068     to_proto=WhileContext.to_proto,
   4069     from_proto=WhileContext.from_proto)
   4070