Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Control Flow Operations.
     16 
     17 See the @{$python/control_flow_ops} guide.
     18 
     19 @@identity
     20 @@identity_n
     21 @@tuple
     22 @@group
     23 @@no_op
     24 @@count_up_to
     25 @@cond
     26 @@smart_cond
     27 @@case
     28 @@while_loop
     29 @@logical_and
     30 @@logical_not
     31 @@logical_or
     32 @@logical_xor
     33 @@equal
     34 @@not_equal
     35 @@less
     36 @@less_equal
     37 @@greater
     38 @@greater_equal
     39 @@where
     40 @@is_finite
     41 @@is_inf
     42 @@is_nan
     43 @@verify_tensor_all_finite
     44 @@check_numerics
     45 @@add_check_numerics_ops
     46 @@Assert
     47 @@Print
     48 """
     49 # pylint: disable=g-bad-name
     50 from __future__ import absolute_import
     51 from __future__ import division
     52 from __future__ import print_function
     53 
     54 import abc
     55 import collections
     56 import functools
     57 
     58 import six
     59 
     60 from tensorflow.core.framework import attr_value_pb2
     61 from tensorflow.core.protobuf import control_flow_pb2
     62 from tensorflow.python.eager import context
     63 from tensorflow.python.framework import constant_op
     64 from tensorflow.python.framework import dtypes
     65 from tensorflow.python.framework import errors
     66 from tensorflow.python.framework import ops
     67 from tensorflow.python.framework import sparse_tensor
     68 from tensorflow.python.framework import tensor_shape
     69 from tensorflow.python.framework import tensor_util
     70 from tensorflow.python.ops import array_ops
     71 from tensorflow.python.ops import control_flow_util as util
     72 from tensorflow.python.ops import gen_array_ops
     73 from tensorflow.python.ops import gen_control_flow_ops
     74 from tensorflow.python.ops import gen_data_flow_ops
     75 from tensorflow.python.ops import gen_logging_ops
     76 from tensorflow.python.ops import math_ops
     77 from tensorflow.python.ops import tensor_array_ops
     78 # go/tf-wildcard-import
     79 # pylint: disable=wildcard-import,undefined-variable
     80 from tensorflow.python.ops.gen_control_flow_ops import *
     81 # pylint: enable=wildcard-import
     82 from tensorflow.python.platform import tf_logging as logging
     83 from tensorflow.python.util import compat
     84 from tensorflow.python.util import deprecation
     85 from tensorflow.python.util import nest
     86 from tensorflow.python.util import tf_should_use
     87 from tensorflow.python.util.tf_export import tf_export
     88 
     89 # We override the 'tuple' for a control flow op, so we keep python's
     90 # existing 'tuple' for later use in this module.
     91 _basetuple = tuple
     92 
     93 
     94 def _summarize_eager(tensor, summarize=None):
     95   """Returns a summarized string representation of eager `tensor`.
     96 
     97   Args:
     98     tensor: EagerTensor to summarize
     99     summarize: Include these many first elements of `array`
    100   """
    101   # reshape((-1,)) is the fastest way to get a flat array view
    102   if tensor._rank():  # pylint: disable=protected-access
    103     flat = tensor.numpy().reshape((-1,))
    104     lst = [str(x) for x in flat[:summarize]]
    105     if len(lst) < flat.size:
    106       lst.append("...")
    107   else:
    108     # tensor.numpy() returns a scalar for zero dimensional arrays
    109     if summarize != 0:
    110       lst = [str(tensor.numpy())]
    111     else:
    112       lst = []
    113 
    114   return ", ".join(lst)
    115 
    116 
    117 # pylint: disable=protected-access
    118 
    119 
    120 # Assert and Print are special symbols in python, so we must
    121 # use an upper-case version of them.
    122 @tf_export("Assert")
    123 @tf_should_use.should_use_result
    124 def Assert(condition, data, summarize=None, name=None):
    125   """Asserts that the given condition is true.
    126 
    127   If `condition` evaluates to false, print the list of tensors in `data`.
    128   `summarize` determines how many entries of the tensors to print.
    129 
    130   NOTE: In graph mode, to ensure that Assert executes, one usually attaches
    131   a dependency:
    132 
    133   ```python
    134   # Ensure maximum element of x is smaller or equal to 1
    135   assert_op = tf.Assert(tf.less_equal(tf.reduce_max(x), 1.), [x])
    136   with tf.control_dependencies([assert_op]):
    137     ... code using x ...
    138   ```
    139 
    140   Args:
    141     condition: The condition to evaluate.
    142     data: The tensors to print out when condition is false.
    143     summarize: Print this many entries of each tensor.
    144     name: A name for this operation (optional).
    145 
    146   Returns:
    147     assert_op: An `Operation` that, when executed, raises a
    148     `tf.errors.InvalidArgumentError` if `condition` is not true.
    149     @compatibility{eager} returns None.
    150 
    151   Raises:
    152     @compatibility{eager} `tf.errors.InvalidArgumentError` if `condition`
    153     is not true
    154   """
    155   if context.in_eager_mode():
    156     if not condition:
    157       xs = ops.convert_n_to_tensor(data)
    158       data_str = [_summarize_eager(x, summarize) for x in xs]
    159       raise errors.InvalidArgumentError(
    160           node_def=None,
    161           op=None,
    162           message="Expected '%s' to be true. Summarized data: %s" %
    163           (condition, "\n".join(data_str)))
    164     return
    165 
    166   with ops.name_scope(name, "Assert", [condition, data]) as name:
    167     xs = ops.convert_n_to_tensor(data)
    168     if all([x.dtype in {dtypes.string, dtypes.int32} for x in xs]):
    169       # As a simple heuristic, we assume that string and int32 are
    170       # on host to avoid the need to use cond. If it is not case,
    171       # we will pay the price copying the tensor to host memory.
    172       return gen_logging_ops._assert(condition, data, summarize, name="Assert")
    173     else:
    174       condition = ops.convert_to_tensor(condition, name="Condition")
    175 
    176       def true_assert():
    177         return gen_logging_ops._assert(
    178             condition, data, summarize, name="Assert")
    179 
    180       guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard")
    181       return guarded_assert.op
    182 
    183 
    184 def _Identity(data, name=None):
    185   """Return a tensor with the same shape and contents as the input tensor.
    186 
    187   Args:
    188     data: A Tensor.
    189     name: A name for this operation (optional).
    190 
    191   Returns:
    192     A Tensor with the same type and value as the input Tensor.
    193   """
    194   data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
    195   if isinstance(data, ops.Tensor):
    196     if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    197       return gen_array_ops._ref_identity(data, name=name)
    198     else:
    199       return array_ops.identity(data, name=name)
    200   else:
    201     if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
    202       raise TypeError("Type %s not supported" % type(data))
    203     values = _Identity(data.values, name=name)
    204     indices = array_ops.identity(data.indices, name="indices")
    205     if isinstance(data, ops.IndexedSlices):
    206       dense_shape = data.dense_shape
    207       if dense_shape is not None:
    208         dense_shape = array_ops.identity(dense_shape, name="dense_shape")
    209       return ops.IndexedSlices(values, indices, dense_shape)
    210     else:
    211       dense_shape = array_ops.identity(data.dense_shape, name="dense_shape")
    212       return sparse_tensor.SparseTensor(indices, values, dense_shape)
    213 
    214 
    215 def _NextIteration(data, name=None):
    216   data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
    217   if isinstance(data, ops.Tensor):
    218     if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    219       return ref_next_iteration(data, name=name)
    220     else:
    221       return next_iteration(data, name=name)
    222   else:
    223     if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
    224       raise TypeError("Type %s not supported" % type(data))
    225     values = _NextIteration(data.values, name=name)
    226     indices = next_iteration(data.indices, name="indices")
    227     if isinstance(data, ops.IndexedSlices):
    228       dense_shape = data.dense_shape
    229       if dense_shape is not None:
    230         dense_shape = next_iteration(dense_shape, name="dense_shape")
    231       return ops.IndexedSlices(values, indices, dense_shape)
    232     else:
    233       dense_shape = next_iteration(data.dense_shape, name="dense_shape")
    234       return sparse_tensor.SparseTensor(indices, values, dense_shape)
    235 
    236 
    237 def _Enter(data,
    238            frame_name,
    239            is_constant=False,
    240            parallel_iterations=10,
    241            use_ref=True,
    242            use_input_shape=True,
    243            name=None):
    244   """Creates or finds a child frame, and makes `data` available to it.
    245 
    246   The unique `frame_name` is used by the `Executor` to identify frames. If
    247   `is_constant` is true, `data` is a constant in the child frame; otherwise
    248   it may be changed in the child frame. At most `parallel_iterations`
    249   iterations are run in parallel in the child frame.
    250 
    251   Args:
    252     data: The tensor to be made available to the child frame.
    253     frame_name: The name of the child frame.
    254     is_constant: If true, the output is constant within the child frame.
    255     parallel_iterations: The number of iterations allowed to run in parallel.
    256     use_ref: If true, use ref_enter if data is of ref type.
    257     name: A name for this operation (optional).
    258 
    259   Returns:
    260     The same tensor as `data`.
    261   """
    262   data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
    263   if isinstance(data, ops.Tensor):
    264     if data.dtype._is_ref_dtype and use_ref:  # pylint: disable=protected-access
    265       result = gen_control_flow_ops._ref_enter(
    266           data, frame_name, is_constant, parallel_iterations, name=name)
    267     else:
    268       result = gen_control_flow_ops._enter(
    269           data, frame_name, is_constant, parallel_iterations, name=name)
    270     if use_input_shape:
    271       result.set_shape(data.get_shape())
    272     return result
    273   else:
    274     if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
    275       raise TypeError("Type %s not supported" % type(data))
    276     values = _Enter(
    277         data.values,
    278         frame_name,
    279         is_constant,
    280         parallel_iterations=parallel_iterations,
    281         use_input_shape=use_input_shape,
    282         name=name)
    283     indices = gen_control_flow_ops._enter(
    284         data.indices,
    285         frame_name,
    286         is_constant,
    287         parallel_iterations,
    288         name="indices")
    289     if use_input_shape:
    290       indices.set_shape(data.indices.get_shape())
    291     if isinstance(data, ops.IndexedSlices):
    292       dense_shape = data.dense_shape
    293       if dense_shape is not None:
    294         dense_shape = gen_control_flow_ops._enter(
    295             dense_shape,
    296             frame_name,
    297             is_constant,
    298             parallel_iterations,
    299             name="dense_shape")
    300         if use_input_shape:
    301           dense_shape.set_shape(data.dense_shape.get_shape())
    302       return ops.IndexedSlices(values, indices, dense_shape)
    303     else:
    304       dense_shape = gen_control_flow_ops._enter(
    305           data.dense_shape,
    306           frame_name,
    307           is_constant,
    308           parallel_iterations,
    309           name="dense_shape")
    310       if use_input_shape:
    311         dense_shape.set_shape(data.dense_shape.get_shape())
    312       return sparse_tensor.SparseTensor(indices, values, dense_shape)
    313 
    314 
    315 def exit(data, name=None):  # pylint: disable=redefined-builtin
    316   """Exits the current frame to its parent frame.
    317 
    318   Exit makes its input `data` available to the parent frame.
    319 
    320   Args:
    321     data: The tensor to be made available to the parent frame.
    322     name: A name for this operation (optional).
    323 
    324   Returns:
    325     The same tensor as `data`.
    326   """
    327   data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
    328   if isinstance(data, ops.Tensor):
    329     if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    330       return gen_control_flow_ops._ref_exit(data, name)
    331     else:
    332       return gen_control_flow_ops._exit(data, name)
    333   else:
    334     if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
    335       raise TypeError("Type %s not supported" % type(data))
    336     values = exit(data.values, name=name)
    337     indices = gen_control_flow_ops._exit(data.indices, name="indices")
    338     if isinstance(data, ops.IndexedSlices):
    339       dense_shape = data.dense_shape
    340       if dense_shape is not None:
    341         dense_shape = gen_control_flow_ops._exit(dense_shape, name)
    342       return ops.IndexedSlices(values, indices, dense_shape)
    343     else:
    344       dense_shape = gen_control_flow_ops._exit(data.dense_shape, name)
    345       return sparse_tensor.SparseTensor(indices, values, dense_shape)
    346 
    347 
    348 def switch(data, pred, dtype=None, name=None):
    349   """Forwards `data` to an output determined by `pred`.
    350 
    351   If `pred` is false, the `data` input is forwarded to the first output.
    352   Otherwise, the data goes to the second output.
    353 
    354   This op handles `Tensor`s and `IndexedSlices`.
    355 
    356   Args:
    357     data: The tensor to be forwarded to the appropriate output.
    358     pred: A scalar that specifies which output port will receive data.
    359     dtype: Optional element type for the returned tensor. If missing,
    360            the type is inferred from the type of `value`.
    361     name: A name for this operation (optional).
    362 
    363   Returns:
    364     `(output_false, output_true)`: If `pred` is true, data will be forwarded
    365     to `output_true`, otherwise it goes to `output_false`.
    366   """
    367   with ops.name_scope(name, "Switch", [data, pred]) as name:
    368     data = ops.internal_convert_to_tensor_or_indexed_slices(
    369         data, dtype=dtype, name="data", as_ref=True)
    370     pred = ops.convert_to_tensor(pred, name="pred")
    371     if isinstance(data, ops.Tensor):
    372       return gen_control_flow_ops._switch(data, pred, name=name)
    373     else:
    374       if not isinstance(data, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
    375         raise TypeError("Type %s not supported" % type(data))
    376       val, ind = data.values, data.indices
    377       val_f, val_t = gen_control_flow_ops._switch(val, pred, name=name)
    378       ind_f, ind_t = gen_control_flow_ops._switch(ind, pred, name="indices")
    379       if isinstance(data, ops.IndexedSlices):
    380         dense_shape = data.dense_shape
    381         if dense_shape is not None:
    382           dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
    383               dense_shape, pred, name="dense_shape")
    384         else:
    385           dense_shape_f, dense_shape_t = None, None
    386         return (ops.IndexedSlices(val_f, ind_f, dense_shape_f),
    387                 ops.IndexedSlices(val_t, ind_t, dense_shape_t))
    388       else:
    389         dense_shape = data.dense_shape
    390         dense_shape_f, dense_shape_t = gen_control_flow_ops._switch(
    391             data.dense_shape, pred, name="dense_shape")
    392         return (sparse_tensor.SparseTensor(ind_f, val_f, dense_shape_f),
    393                 sparse_tensor.SparseTensor(ind_t, val_t, dense_shape_t))
    394 
    395 
    396 def _SwitchRefOrTensor(data, pred, name="Switch"):
    397   """Forwards `data` to an output determined by `pred`.
    398 
    399   If `pred` is false, the `data` input is forwarded to the first output.
    400   Otherwise, the data goes to the second output.
    401 
    402   This op handles `Tensor`s and `IndexedSlices`.
    403 
    404   Args:
    405     data: The tensor to be forwarded to the appropriate output.
    406     pred: A scalar that specifies which output port will receive data.
    407     name: A name for this operation (optional).
    408 
    409   Returns:
    410     `(output_false, output_true)`: If `pred` is true, data will be forwarded to
    411     `output_true`, otherwise it goes to `output_false`.
    412 
    413   Raises:
    414     TypeError: if data is not a Tensor or IndexedSlices
    415   """
    416   data = ops.convert_to_tensor_or_indexed_slices(data, name="data")
    417   # NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below
    418   # addresses the following scenario.
    419   #
    420   # Assume you execute Optimizer.apply_gradients() in a branch of a cond().
    421   #
    422   # 1. The update op is created inside a `with ops.colocate(var):` block
    423   #
    424   # 2. Some tensor `data` is captured and a switch is created in a
    425   #    `with ops.colocate_with(data):` block.
    426   #
    427   # with ops.colocate_with(var):
    428   #  with ops.colocate_with(data):
    429   #    op = ...
    430   #
    431   # var and data may be pinned to different devices, so we want to ops
    432   # created within ops.colocate_with(data) to ignore the existing stack.
    433   with ops.colocate_with(data, ignore_existing=True):
    434     if isinstance(data, ops.Tensor):
    435       if data.dtype._is_ref_dtype:  # pylint: disable=protected-access
    436         return ref_switch(data, pred, name=name)
    437     return switch(data, pred, name=name)
    438 
    439 
    440 def merge(inputs, name=None):
    441   """Returns the value of an available element of `inputs`.
    442 
    443   This op tests each of the tensors in `inputs` in turn to determine if any of
    444   them is available. If it finds an available tensor, it returns it and its
    445   index in `inputs`.
    446 
    447   It is an error if more than one tensor in `inputs` is available. If no tensor
    448   in `inputs` is available, the returned tensor and index are not set.
    449 
    450   This op handles both `Tensor`s and `IndexedSlices`. If inputs has a mix of
    451   `Tensor`s and `IndexedSlices`, all inputs are converted to IndexedSlices
    452   before merging.
    453 
    454   Args:
    455     inputs: The input tensors, at most one of which is available.
    456     name: A name for this operation (optional).
    457 
    458   Returns:
    459     A tuple containing the chosen input tensor and its index in `inputs`.
    460 
    461   Raises:
    462     ValueError: If any of the inputs is None, or inputs are IndexedSlices and
    463       some but not all have a dense_shape property.
    464   """
    465   if any([inp is None for inp in inputs]):
    466     raise ValueError("At least one of the merge inputs is None: %s" % inputs)
    467   with ops.name_scope(name, "Merge", inputs) as name:
    468     inputs = [
    469         ops.internal_convert_to_tensor_or_indexed_slices(inp, as_ref=True)
    470         for inp in inputs
    471     ]
    472     if all([isinstance(v, ops.Tensor) for v in inputs]):
    473       if all([v.dtype._is_ref_dtype for v in inputs]):  # pylint: disable=protected-access
    474         return gen_control_flow_ops._ref_merge(inputs, name)
    475       else:
    476         return gen_control_flow_ops._merge(inputs, name)
    477     elif all([isinstance(v, sparse_tensor.SparseTensor) for v in inputs]):
    478       # Only handle the case when all inputs are SparseTensor.
    479       values, _ = merge([inp.values for inp in inputs], name=name)
    480       indices, chosen_index = gen_control_flow_ops._merge(
    481           [inp.indices for inp in inputs], name="indices")
    482       dense_shape, _ = gen_control_flow_ops._merge(
    483           [inp.dense_shape for inp in inputs], name="dense_shape")
    484       return (sparse_tensor.SparseTensor(indices, values, dense_shape),
    485               chosen_index)
    486     else:
    487       # For now convert all the inputs as IndexedSlices.
    488       inputs = math_ops._as_indexed_slices_list(inputs, optimize=False)
    489       values, _ = merge([inp.values for inp in inputs], name=name)
    490       indices, chosen_index = gen_control_flow_ops._merge(
    491           [inp.indices for inp in inputs], name="indices")
    492       if any(inp.dense_shape is not None for inp in inputs):
    493         if any(inp.dense_shape is None for inp in inputs):
    494           raise ValueError("Either all merged IndexedSlices must have a "
    495                            "dense_shape, or none must have a dense_shape.")
    496         dense_shape, _ = gen_control_flow_ops._merge(
    497             [inp.dense_shape for inp in inputs], name="dense_shape")
    498       else:
    499         dense_shape = None
    500       return ops.IndexedSlices(values, indices, dense_shape), chosen_index
    501 
    502 
    503 # pylint: enable=protected-access
    504 
    505 
    506 def _convert_tensorarray_to_flow(tensor_or_tensor_array):
    507   if isinstance(tensor_or_tensor_array, tensor_array_ops.TensorArray):
    508     return tensor_or_tensor_array.flow
    509   else:
    510     return tensor_or_tensor_array
    511 
    512 
    513 def _make_tensor_array(ta, t_or_flow):
    514   # pylint: disable=protected-access
    515   new_ta = tensor_array_ops.TensorArray(
    516       dtype=ta.dtype,
    517       handle=ta.handle,
    518       flow=t_or_flow,
    519       infer_shape=ta._infer_shape,
    520       colocate_with_first_write_call=ta._colocate_with_first_write_call)
    521   new_ta._colocate_with = ta._colocate_with
    522   new_ta._element_shape = ta._element_shape
    523   # pylint: enable=protected-access
    524   return new_ta
    525 
    526 
    527 def _convert_flows_to_tensorarrays(tensors_or_tensorarrays, tensors_or_flows):
    528   if len(tensors_or_tensorarrays) != len(tensors_or_flows):
    529     raise ValueError(
    530         "Lengths of original Tensor list and new list do not match: %d vs. %d" %
    531         (len(tensors_or_tensorarrays), len(tensors_or_flows)))
    532   return [
    533       _make_tensor_array(ta, t_or_flow)
    534       if isinstance(ta, tensor_array_ops.TensorArray) else t_or_flow
    535       for (ta, t_or_flow) in zip(tensors_or_tensorarrays, tensors_or_flows)
    536   ]
    537 
    538 
    539 def _ShapeLessThanOrEqual(shape1, shape2):
    540   if shape2.dims is None:
    541     return True
    542   if shape1.ndims != shape2.ndims:
    543     return False
    544   for dim1, dim2 in zip(shape1.dims, shape2.dims):
    545     if dim2.value is not None and dim1.value != dim2.value:
    546       return False
    547   return True
    548 
    549 
    550 def _SetShapeInvariants(input_vars, enter_vars, shapes):
    551   """Set the shapes of the tensors in `enter_vars` to `shapes`.
    552 
    553   Args:
    554     input_vars: A list of tensors that are inputs to `enter_vars`.
    555     enter_vars: A list of tensors whose shapes will be set.
    556     shapes: A (possibly nested) list of shapes.
    557 
    558   Raises:
    559     ValueError: If any tensor in `enter_vars` has a less specific shape
    560       than its corresponding shape in `shapes`.
    561   """
    562   if shapes is None:
    563     return
    564   flat_shapes = nest.flatten(shapes)
    565   if not all([isinstance(s, tensor_shape.TensorShape) for s in flat_shapes]):
    566     raise ValueError("`shapes` must be a (possibly nested) list of shapes.")
    567   # Check that the shapes of the inputs are less than the shape invariants,
    568   # and set the shapes of `enter_vars` to the shape invariants.
    569   for inp, var, shape in zip(input_vars, enter_vars, flat_shapes):
    570     if isinstance(var, ops.Tensor):
    571       if not _ShapeLessThanOrEqual(inp.get_shape(), shape):
    572         raise ValueError(
    573             "The shape invariant specified for %s is not compatible with "
    574             "the initial shape of the loop variable. It enters the loop "
    575             "with shape %s, but the specified shape invariant is %s." %
    576             (inp.name, inp.get_shape(), shape))
    577       var.set_shape(shape)
    578     else:
    579       if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
    580         raise TypeError("Type %s not supported" % type(var))
    581       if isinstance(var, ops.IndexedSlices):
    582         if not _ShapeLessThanOrEqual(inp.values.get_shape(), shape):
    583           raise ValueError(
    584               "The shape invariant specified for %s is not compatible with "
    585               "the initial shape of the values tensor of this IndexedSlices. "
    586               "It enters the loop with shape %s, but the specified shape "
    587               "invariant is %s." % (inp.values.name, inp.values.get_shape(),
    588                                     shape))
    589         var.values.set_shape(shape)
    590         var.indices.set_shape(tensor_shape.TensorShape([shape[0]]))
    591         if var.dense_shape is not None:
    592           var.dense_shape.set_shape(tensor_shape.TensorShape([shape.ndims]))
    593       else:
    594         if not _ShapeLessThanOrEqual(inp.dense_shape.get_shape(), shape):
    595           raise ValueError(
    596               "The shape invariant specified for %s is not compatible with "
    597               "the initial shape of the shape tensor of this SparseTensor. "
    598               "It enters the loop with shape %s, but the specified shape "
    599               "invariant is %s." % (inp.dense_shape.name,
    600                                     inp.dense_shape.get_shape(), shape))
    601         var.values.set_shape(tensor_shape.TensorShape([None]))
    602         var.indices.set_shape(tensor_shape.TensorShape([None, shape.ndims]))
    603         var.dense_shape.set_shape(shape)
    604 
    605 
    606 def _EnforceShapeInvariant(merge_var, next_var):
    607   """Check if the shapes of the loops variables are invariants.
    608 
    609   Args:
    610     merge_vars: The list of tensors representing the initial values of the
    611       loop variables.
    612     next_vars: The list of tensors representing the values of the loop
    613       variables after one loop iteration.
    614 
    615   Raises:
    616     ValueError: If any tensor in `merge_vars` has a more specific shape than
    617       its correspnding tensor in `next_var`.
    618   """
    619   if isinstance(merge_var, ops.Tensor):
    620     m_shape = merge_var.get_shape()
    621     n_shape = next_var.get_shape()
    622     if not _ShapeLessThanOrEqual(n_shape, m_shape):
    623       # TODO(skyewm): get original loop input that caused the shape error and
    624       # report its name instead of the merge node's.
    625       raise ValueError(
    626           "The shape for %s is not an invariant for the loop. It enters "
    627           "the loop with shape %s, but has shape %s after one iteration. "
    628           "Provide shape invariants using either the `shape_invariants` "
    629           "argument of tf.while_loop or set_shape() on the loop variables." %
    630           (merge_var.name, m_shape, n_shape))
    631   else:
    632     if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
    633       raise TypeError("Type %s not supported" % type(var))
    634     if isinstance(var, ops.IndexedSlices):
    635       m_values_shape = merge_var.values.get_shape()
    636       m_indices_shape = merge_var.indices.get_shape()
    637       m_shape_shape = tensor_shape.TensorShape(None)
    638       if merge_var.dense_shape is not None:
    639         m_shape_shape = merge_var.dense_shape.get_shape()
    640       n_values_shape = next_var.values.get_shape()
    641       n_indices_shape = next_var.indices.get_shape()
    642       n_shape_shape = tensor_shape.TensorShape(None)
    643       if next_var.dense_shape is not None:
    644         n_shape_shape = next_var.dense_shape.get_shape()
    645       if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or
    646           not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape)):
    647         if not _ShapeLessThanOrEqual(n_values_shape, m_values_shape):
    648           raise ValueError(
    649               "The shape for %s is not an invariant for the loop. It enters "
    650               "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
    651               "after one iteration. Provide shape invariants using either the "
    652               "`shape_invariants` argument of tf.while_loop or set_shape() "
    653               "on the loop variables." %
    654               (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
    655                n_values_shape, n_indices_shape, n_shape_shape))
    656     else:
    657       m_values_shape = merge_var.values.get_shape()
    658       m_indices_shape = merge_var.indices.get_shape()
    659       m_shape_shape = merge_var.dense_shape.get_shape()
    660       n_values_shape = next_var.values.get_shape()
    661       n_indices_shape = next_var.indices.get_shape()
    662       n_shape_shape = next_var.dense_shape.get_shape()
    663       if (not _ShapeLessThanOrEqual(n_values_shape, m_values_shape) or
    664           not _ShapeLessThanOrEqual(n_indices_shape, m_indices_shape) or
    665           not _ShapeLessThanOrEqual(n_shape_shape, m_shape_shape)):
    666         raise ValueError(
    667             "The shape for %s is not an invariant for the loop. It enters "
    668             "the loop with shape (%s, %s, %s), but has shape (%s, %s, %s) "
    669             "after one iteration. Provide shape invariants using either "
    670             "the `shape_invariants` argument of tf.while_loop or set_shape() "
    671             "on the loop variables." %
    672             (merge_var.name, m_values_shape, m_indices_shape, m_shape_shape,
    673              n_values_shape, n_indices_shape, n_shape_shape))
    674 
    675 
    676 def _AddNextAndBackEdge(m, v, enforce_shape_invariant=True):
    677   """Add NextIteration and back edge from v to m."""
    678   if isinstance(m, ops.Tensor):
    679     v = ops.convert_to_tensor(v)
    680     v = _NextIteration(v)
    681     if enforce_shape_invariant:
    682       # Make sure the shapes of loop outputs are correct. We do this before
    683       # calling _update_input, which will raise a less-helpful error message if
    684       # the types don't match.
    685       # TODO(skyewm): call this for other cases below (needs testing)
    686       _EnforceShapeInvariant(m, v)
    687     m.op._update_input(1, v)  # pylint: disable=protected-access
    688   elif isinstance(m, ops.IndexedSlices):
    689     # pylint: disable=protected-access
    690     v = math_ops._as_indexed_slices(v, optimize=False)
    691     v = _NextIteration(v)
    692     m.values.op._update_input(1, v.values)
    693     m.indices.op._update_input(1, v.indices)
    694     # pylint: enable=protected-access
    695     if m.dense_shape is not None:
    696       if v.dense_shape is None:
    697         raise ValueError("Must have dense shape: %s" % v.name)
    698       m.dense_shape.op._update_input(1, v.dense_shape)
    699   elif isinstance(m, sparse_tensor.SparseTensor):
    700     if not isinstance(v, sparse_tensor.SparseTensor):
    701       raise ValueError("Must be a sparse tensor: %s" % v.name)
    702     v = _NextIteration(v)
    703     # pylint: disable=protected-access
    704     m.values.op._update_input(1, v.values)
    705     m.indices.op._update_input(1, v.indices)
    706     m.dense_shape.op._update_input(1, v.dense_shape)
    707     # pylint: enable=protected-access
    708   else:
    709     raise TypeError("Type %s not supported" % type(m))
    710   return v
    711 
    712 
    713 def GetMaxSizeFromNestedMaximumIterations(value, while_ctxt):
    714   """Calculate a max_size for use by stack ops inside an XLA while_loop.
    715 
    716   Args:
    717     value: The value inside the while_loop forward context.  Used for printing
    718       error messages.
    719     while_ctxt: The forward context inside which value resides.  This does
    720       not always match the value's immediate context, as `value` may be
    721       inside e.g. a cond context inside the while_loop.
    722 
    723   Returns:
    724     A tensor containing the `max_size` to feed to a Stack initializer.
    725 
    726   Raises:
    727     ValueError: If `value` is nested inside a `while_loop` that either
    728       lacks a `maximum_iterations` parameter, or the `maximum_iterations`
    729       parameter:
    730 
    731         - is inside a `while_loop` that is a parent of the calling context, and
    732         - cannot be evaluated at graph build time to a constant.
    733   """
    734   value_name = value.name
    735   # curr_ctxt is the context that tf.gradients was called in.
    736   curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
    737 
    738   curr_ctxt_name = curr_ctxt.name if curr_ctxt is not None else ""
    739   max_size = constant_op.constant(1)
    740 
    741   # Loop through all containing while contexts between value and the
    742   # current context, multiplying together each context's
    743   # max_iterations to get the maximum stack size.
    744   while while_ctxt not in (None, curr_ctxt):
    745     max_iter = while_ctxt.maximum_iterations
    746     if max_iter is None:
    747       raise ValueError(
    748           "Cannot create a gradient accumulator for tensor '%s' inside "
    749           "XLA while_loop because maximum_iterations was not passed to "
    750           "the tf.while_loop call ('%s')." % (value_name, while_ctxt.name))
    751 
    752     # pylint: disable=protected-access
    753     max_iter_ctxt = max_iter.op._get_control_flow_context()
    754     # pylint: enable=protected-access
    755 
    756     # If max_iter_ctxt (non-strictly) contains curr_ctxt, then it's OK to use.
    757     if util.IsContainingContext(curr_ctxt, max_iter_ctxt):
    758       max_size *= max_iter
    759     else:
    760       # We cannot use max_iter because it's defined in a nested while
    761       # or cond context, so will fail if we try to use it as input to
    762       # any ops in curr_ctxt (e.g. max_size or the final accumulator
    763       # stack). Attempt to get a constant value out to use instead.
    764       const_max_iter = tensor_util.constant_value(max_iter)
    765       if const_max_iter is None:
    766         raise ValueError(
    767             "Cannot create a gradient accumulator for tensor '%s' inside XLA "
    768             "while_loop. maximum_iterations tensor '%s' for while_loop context "
    769             "'%s' must be statically known (e.g. a constant value or known "
    770             "shape dimension), or be defined at or outside the while loop "
    771             "context '%s' (currently defined in '%s')." %
    772             (value_name, max_iter.name, while_ctxt.name, curr_ctxt_name,
    773              max_iter_ctxt.name))
    774       max_size *= const_max_iter
    775 
    776     # Find the next outer WhileContext (or stop if we reach the
    777     # tf.gradient's context).
    778     while_ctxt = util.GetContainingWhileContext(
    779         while_ctxt.outer_context, stop_ctxt=curr_ctxt)
    780 
    781   return max_size
    782 
    783 
    784 class GradLoopState(object):
    785   """The state used for constructing the gradient graph for a while loop.
    786 
    787   We create a GradLoopState for each while loop in forward and its
    788   corresponding while loop in backprop. This gives us access to both
    789   the forward and the backprop WhileContexts.
    790 
    791   During the construction of gradient graph, any time when we detect
    792   a forward value that is needed for backprop, we create a history
    793   accumulator and add it to `history_map`. Any time when we backprop
    794   a loop switch op (in _SwitchGrad), we add the grad merge op in
    795   `switch_map`.
    796   """
    797 
    798   def __init__(self, forward_ctxt, outer_grad_state):
    799     # The grad loop state for the outer while loop.
    800     self._outer_grad_state = None
    801 
    802     # The while loop context for forward.
    803     self._forward_context = None
    804 
    805     # The loop counter added by AddForwardLoopCounter. It is the value
    806     # of the loop counter for the next iteration.
    807     self._forward_index = None
    808 
    809     # A sync op for forward.
    810     self._forward_sync = None
    811 
    812     # The while loop context for backprop.
    813     self._grad_context = None
    814 
    815     # The loop counter added by AddBackpropLoopCounter. It is the value
    816     # of the loop counter for the current iteration.
    817     self._grad_index = None
    818 
    819     # A sync op for backprop.
    820     self._grad_sync = None
    821 
    822     # Information needed by backprop.
    823     self._history_map = {}
    824     self._switch_map = {}
    825     self._unused_exits = []
    826     self._deferred_exits = []
    827     self._forward_loop_exits = list(forward_ctxt.loop_exits)
    828     self._pending_exits_count = len(forward_ctxt.loop_exits)
    829 
    830     self._outer_grad_state = outer_grad_state
    831     if outer_grad_state:
    832       outer_forward_ctxt = outer_grad_state.forward_context
    833     else:
    834       outer_forward_ctxt = forward_ctxt.outer_context
    835 
    836     # Add the forward loop counter.
    837     if outer_forward_ctxt:
    838       outer_forward_ctxt.Enter()
    839     cnt, forward_index = forward_ctxt.AddForwardLoopCounter(outer_grad_state)
    840     if outer_forward_ctxt:
    841       outer_forward_ctxt.Exit()
    842     self._forward_context = forward_ctxt
    843     self._forward_index = forward_index
    844 
    845     # Add the backprop WhileContext, and the backprop loop counter.
    846     if outer_grad_state:
    847       # This is a nested loop. Remember the iteration counts for each
    848       # execution of this inner loop.
    849       outer_forward_ctxt.AddName(cnt.name)
    850       history_cnt = outer_grad_state.AddForwardAccumulator(cnt)
    851 
    852       outer_grad_ctxt = outer_grad_state.grad_context
    853       outer_grad_ctxt.Enter()
    854       self._grad_context = WhileContext(
    855           maximum_iterations=forward_ctxt.maximum_iterations,
    856           parallel_iterations=forward_ctxt.parallel_iterations,
    857           back_prop=forward_ctxt.back_prop,
    858           swap_memory=forward_ctxt.swap_memory,
    859           name=forward_ctxt.name,
    860           grad_state=self)
    861       real_cnt = outer_grad_state.AddBackpropAccumulatedValue(history_cnt, cnt)
    862       self._grad_index = self._grad_context.AddBackpropLoopCounter(
    863           real_cnt, outer_grad_state)
    864       outer_grad_ctxt.Exit()
    865     else:
    866       if outer_forward_ctxt:
    867         outer_forward_ctxt.Enter()
    868       self._grad_context = WhileContext(
    869           maximum_iterations=forward_ctxt.maximum_iterations,
    870           parallel_iterations=forward_ctxt.parallel_iterations,
    871           back_prop=forward_ctxt.back_prop,
    872           swap_memory=forward_ctxt.swap_memory,
    873           name=forward_ctxt.name,
    874           grad_state=self)
    875       self._grad_index = self._grad_context.AddBackpropLoopCounter(
    876           cnt, outer_grad_state)
    877       if outer_forward_ctxt:
    878         outer_forward_ctxt.Exit()
    879 
    880   @property
    881   def outer_grad_state(self):
    882     """The grad loop state for outer loop."""
    883     return self._outer_grad_state
    884 
    885   @property
    886   def forward_context(self):
    887     """The while loop context for forward."""
    888     return self._forward_context
    889 
    890   @property
    891   def forward_index(self):
    892     """The loop index of forward loop."""
    893     return self._forward_index
    894 
    895   @property
    896   def forward_sync(self):
    897     """A control trigger node for synchronization in the forward loop.
    898 
    899     One main use is to keep the push ops of a stack executed in the
    900     iteration order.
    901     """
    902     if self._forward_sync is None:
    903       with ops.control_dependencies(None):
    904         self._forward_sync = control_trigger(name="f_sync")
    905       self._forward_sync._set_control_flow_context(self._forward_context)
    906       self._forward_index.op._add_control_input(self._forward_sync)
    907     return self._forward_sync
    908 
    909   @property
    910   def grad_context(self):
    911     """The corresponding WhileContext for gradient."""
    912     return self._grad_context
    913 
    914   @property
    915   def grad_index(self):
    916     """The loop index of backprop loop."""
    917     return self._grad_index
    918 
    919   @property
    920   def grad_sync(self):
    921     """A control trigger node for synchronization in the grad loop.
    922 
    923     One main use is to keep the pop ops of a stack executed in the
    924     iteration order.
    925     """
    926     if self._grad_sync is None:
    927       with ops.control_dependencies(None):
    928         self._grad_sync = control_trigger(name="b_sync")
    929       self._grad_sync._set_control_flow_context(self._grad_context)
    930       self._grad_index.op._add_control_input(self._grad_sync)
    931       if self._grad_context.outer_context:
    932         self._grad_context.outer_context.AddInnerOp(self._grad_sync)
    933     return self._grad_sync
    934 
    935   @property
    936   def history_map(self):
    937     """The map that records all the tensors needed for backprop."""
    938     return self._history_map
    939 
    940   @property
    941   def switch_map(self):
    942     """The map that records all the Switch ops for the while loop."""
    943     return self._switch_map
    944 
    945   @property
    946   def unused_exits(self):
    947     """The list of "unused" exits."""
    948     return self._unused_exits
    949 
    950   @property
    951   def deferred_exits(self):
    952     """The list of "deferred" exits."""
    953     return self._deferred_exits
    954 
    955   @property
    956   def forward_loop_exits(self):
    957     """The list of exits of the forward loop."""
    958     return self._forward_loop_exits
    959 
    960   @property
    961   def pending_exits_count(self):
    962     """The number of exits we expect to see but haven't."""
    963     return self._pending_exits_count
    964 
    965   @pending_exits_count.setter
    966   def pending_exits_count(self, cnt):
    967     """Set the pending count to cnt."""
    968     self._pending_exits_count = cnt
    969 
    970   def AddForwardAccumulator(self, value, dead_branch=False):
    971     """Add an accumulator for each forward tensor that is needed in backprop.
    972 
    973     This is added to the forward loop at the first time when a tensor
    974     in the forward loop is used by backprop gradient computation loop.
    975     We create an accumulator that accumulates the value of tensor at each
    976     iteration. Called in the control flow context where gradients() is called.
    977 
    978     The pseudocode is:
    979     ```
    980       acc = stack();
    981       while (_pivot) {
    982         acc = stack_push(acc, value);
    983       }
    984     ```
    985 
    986     We make sure that the stack push op in one iteration is executed before
    987     next iteration. This is achieved by adding a control edge from
    988     `forward_index.op.inputs[0].op` to the push op, and another control
    989     edge from the push op to either `forward_index.op` or `forward_sync`.
    990 
    991     Args:
    992       value: The source tensor in forward that is to be accumulated.
    993       dead_branch: True iff the tensor is on a dead branch of a cond.
    994 
    995     Returns:
    996       The stack that contains the accumulated history of the tensor.
    997 
    998     Raises:
    999       TypeError: For internal errors involving the value condition context.
   1000       ValueError: If `value` is inside a XLA scope and a valid max size
   1001         for the stack can't be found.
   1002     """
   1003     # curr_ctxt is the context that tf.gradients was called in.
   1004     curr_ctxt = ops.get_default_graph()._get_control_flow_context()  # pylint: disable=protected-access
   1005     with ops.control_dependencies(None):
   1006       if curr_ctxt:
   1007         curr_ctxt.Enter()
   1008       with ops.colocate_with(value):
   1009         # We only need to pass maximum_iterations to the stack if
   1010         # we're inside an XLA context.
   1011         if not util.IsInXLAContext(value.op):
   1012           max_size = constant_op.constant(-1, dtypes.int32)
   1013         else:
   1014           max_size = GetMaxSizeFromNestedMaximumIterations(
   1015               value, self.forward_context)
   1016         # pylint: disable=protected-access
   1017         acc = gen_data_flow_ops._stack_v2(
   1018             max_size=max_size, elem_type=value.dtype.base_dtype, name="f_acc")
   1019         # pylint: enable=protected-access
   1020       if curr_ctxt:
   1021         curr_ctxt.Exit()
   1022 
   1023       # Make acc available in the forward context.
   1024       enter_acc = self.forward_context.AddValue(acc)
   1025 
   1026       # Add the stack_push op in the context of value.op.
   1027       swap_enabled = self.forward_context.swap_memory
   1028       value_ctxt = util.GetOutputContext(value.op)
   1029       if value_ctxt == self.forward_context:
   1030         # value is not nested in the forward context.
   1031         self.forward_context.Enter()
   1032         # pylint: disable=protected-access
   1033         push = gen_data_flow_ops._stack_push_v2(
   1034             enter_acc, value, swap_memory=swap_enabled)
   1035         # pylint: enable=protected-access
   1036         self.forward_context.Exit()
   1037         # Protect stack push and order it before forward_index.
   1038         self.forward_index.op._add_control_input(push.op)
   1039       else:
   1040         # value is in a cond context within the forward context.
   1041         if not isinstance(value_ctxt, CondContext):
   1042           raise TypeError("value_ctxt is not a CondContext: %s" % value_ctxt)
   1043         if dead_branch:
   1044           # The special case for creating a zero tensor for a dead
   1045           # branch of a switch. See ControlFlowState.ZerosLike().
   1046           value_ctxt.outer_context.Enter()
   1047           # pylint: disable=protected-access
   1048           push = gen_data_flow_ops._stack_push_v2(
   1049               enter_acc, value, swap_memory=swap_enabled)
   1050           # pylint: enable=protected-access
   1051           value_ctxt.outer_context.Exit()
   1052           push.op._set_control_flow_context(value_ctxt)
   1053         else:
   1054           value_ctxt.Enter()
   1055           # pylint: disable=protected-access
   1056           push = gen_data_flow_ops._stack_push_v2(
   1057               enter_acc, value, swap_memory=swap_enabled)
   1058           # pylint: enable=protected-access
   1059           value_ctxt.Exit()
   1060         # Protect stack push and order it before forward_sync.
   1061         self.forward_sync._add_control_input(push.op)
   1062       # Order stack push after the successor of forward_index
   1063       add_op = self.forward_index.op.inputs[0].op
   1064       push.op._add_control_input(add_op)
   1065       return acc
   1066 
   1067   def AddBackpropAccumulatedValue(self, history_value, value,
   1068                                   dead_branch=False):
   1069     """Add the getter for an accumulated value in the grad context.
   1070 
   1071     This is added to the backprop loop. Called in the grad context to
   1072     get the value of an accumulated value. The stack pop op must be guarded
   1073     by the pred of the controlling cond.
   1074 
   1075     Args:
   1076       history_value: The history (a stack) of a value.
   1077       value: The value that is pushed onto the stack.
   1078       dead_branch: True iff the tensor is on a dead branch of a cond.
   1079 
   1080     Returns:
   1081       The current value (the top of the stack).
   1082     """
   1083     history_ctxt = history_value.op._get_control_flow_context()
   1084     # Find the cond context that controls history_value if any.
   1085     cond_ctxt = None
   1086     value_ctxt = value.op._get_control_flow_context()
   1087     while value_ctxt and value_ctxt != history_ctxt:
   1088       if isinstance(value_ctxt, CondContext):
   1089         cond_ctxt = value_ctxt
   1090         break
   1091       value_ctxt = value_ctxt.outer_context
   1092     with ops.control_dependencies(None):
   1093       self.grad_context.Enter()
   1094       if cond_ctxt:
   1095         # Guard stack pop with a switch if it is controlled by a cond.
   1096         grad_state = self
   1097         pred = None
   1098         while pred is None and grad_state:
   1099           pred = grad_state.history_map.get(cond_ctxt.pred.name)
   1100           grad_state = grad_state.outer_grad_state
   1101         if pred is None:
   1102           pred = cond_ctxt.pred
   1103         branch = (1 - cond_ctxt.branch) if dead_branch else cond_ctxt.branch
   1104         history_value = _SwitchRefOrTensor(history_value, pred)[branch]
   1105       # pylint: disable=protected-access
   1106       pop = gen_data_flow_ops._stack_pop_v2(history_value,
   1107                                             value.dtype.base_dtype)
   1108       # pylint: enable=protected-access
   1109       pop.set_shape(value.get_shape())
   1110       self.grad_context.Exit()
   1111     parallel_iterations = self.grad_context.parallel_iterations
   1112     if parallel_iterations > 1:
   1113       # All pops are ordered after pivot_for_body and before grad_sync.
   1114       self.grad_sync._add_control_input(pop.op)
   1115     return pop
   1116 
   1117   def GetRealValue(self, value):
   1118     """Get the real value of `value`.
   1119 
   1120     If backprop "uses" a value produced by forward inference, an accumulator
   1121     is added in the forward loop to accumulate its values.  We use the
   1122     accumulated value. This method must be called in the grad loop context.
   1123     `value` must be in forward and needed for backprop.
   1124 
   1125     Args:
   1126       value: A tensor to be captured.
   1127 
   1128     Returns:
   1129       The same tensor obtained from the saved history.
   1130     """
   1131     assert value.op.type not in ["Variable", "VariableV2"]
   1132     real_value = self._history_map.get(value.name)
   1133     if real_value is None:
   1134       cur_value = value
   1135       cur_grad_state = self
   1136       while True:
   1137         enter_op = util.GetLoopConstantEnter(cur_value)
   1138         if enter_op:
   1139           # Special case: cur_value comes from a constant Enter node.
   1140           cur_value = enter_op.inputs[0]
   1141           cur_grad_state = cur_grad_state.outer_grad_state
   1142           if cur_grad_state is None:
   1143             # We are now outside all nested loops for this gradient(),
   1144             # so `value` is a loop invariant and there is no need to
   1145             # save the history of value. Just make cur_value to enter
   1146             # the right control flow context.
   1147             real_value = self._grad_context.AddValue(cur_value)
   1148             break
   1149         elif constant_op.is_constant(cur_value):
   1150           # If the value to be forwarded is a constant, clone the constant in
   1151           # the gradient loop rather than using a stack.
   1152           # TODO(phawkins): consider hoisting the constant out of the loop
   1153           # instead.
   1154           real_value = constant_op.constant(
   1155               tensor_util.constant_value(cur_value), dtype=cur_value.dtype)
   1156           break
   1157         else:
   1158           # Record the history of this value in forward_ctxt.
   1159           self._grad_context.Exit()
   1160           history_value = cur_grad_state.AddForwardAccumulator(cur_value)
   1161           self._grad_context.Enter()
   1162           break
   1163 
   1164       if real_value is None:
   1165         # Add the stack pop op in the grad context.
   1166         real_value = cur_grad_state.AddBackpropAccumulatedValue(
   1167             history_value, cur_value)
   1168         if cur_grad_state != self:
   1169           real_value = self._grad_context.AddValue(real_value)
   1170       self._history_map[value.name] = real_value
   1171     return real_value
   1172 
   1173 
   1174 def _GetWhileContext(op):
   1175   """Get the WhileContext to which this op belongs."""
   1176   ctxt = op._get_control_flow_context()
   1177   if ctxt:
   1178     ctxt = ctxt.GetWhileContext()
   1179   return ctxt
   1180 
   1181 
   1182 class ControlFlowState(object):
   1183   """Maintain the mapping from the loops to their grad states."""
   1184 
   1185   def __init__(self):
   1186     self._map = {}  # maps forward loop context to GradLoopState
   1187 
   1188   def GetGradState(self, op, before):
   1189     """Return the grad state for this op if it's in a forward loop context."""
   1190     if before and util.IsLoopExit(op):
   1191       forward_ctxt = op._get_control_flow_context()
   1192       forward_ctxt = forward_ctxt.outer_context
   1193       if forward_ctxt:
   1194         forward_ctxt = forward_ctxt.GetWhileContext()
   1195     else:
   1196       forward_ctxt = _GetWhileContext(op)
   1197     if forward_ctxt:
   1198       return self._map.get(forward_ctxt)
   1199     return None
   1200 
   1201   def ProcessUnusedLoopExits(self, pending_count, to_ops_set):
   1202     """Process all the "unused" loop exits.
   1203 
   1204     The "unused" exits of the loops are added to `unused_exits`. An exit is
   1205     unused if its pending_count is 0. If there is an exit with real gradient,
   1206     all these deferred exits will enter the backprop loop with zero gradient.
   1207     Otherwise, they will enter the backprop loop with None. As an example,
   1208     people often write:
   1209 
   1210     ```python
   1211     v1, _ = tf.while_loop(p, b, [x1, x2])
   1212     result = gradients(v1, x1)
   1213     ```
   1214 
   1215     The exit node for x2 is not included by the betweenness analysis. But we
   1216     need to backprop x2 if x2 is involved in computing v1.
   1217 
   1218     Args:
   1219       pending_count: The number of backprop inputs for every op.
   1220       to_ops_set: The set of ops for ys in gradients(ys, xs)
   1221 
   1222     Returns:
   1223       The set of unused loop exits that we know at this point we need
   1224       to backprop.
   1225     """
   1226     loop_exits = []
   1227     for _, grad_state in self._map.items():
   1228       # pylint: disable=protected-access
   1229       for y in grad_state.forward_loop_exits:
   1230         if pending_count[y.op._id] == 0:
   1231           grad_state.pending_exits_count -= 1
   1232           if y.op._id not in to_ops_set:
   1233             grad_state.unused_exits.append(y)
   1234           if grad_state.pending_exits_count == 0:
   1235             loop_exits.extend(grad_state.unused_exits)
   1236       # Need to include Enters in backprop for higher-order gradients.
   1237       for y in grad_state.forward_context.loop_enters:
   1238         if pending_count[y.op._id] == 0:
   1239           pending_count[y.op._id] = 1
   1240       # pylint: enable=protected-access
   1241     return loop_exits
   1242 
   1243   def EnterGradWhileContext(self, op, before):
   1244     """Enter the WhileContext for gradient computation."""
   1245     grad_state = self.GetGradState(op, before)
   1246     if grad_state:
   1247       grad_state.grad_context.Enter()
   1248 
   1249   def ExitGradWhileContext(self, op, before):
   1250     """Exit the WhileContext for gradient computation."""
   1251     grad_state = self.GetGradState(op, before)
   1252     if grad_state:
   1253       grad_state.grad_context.Exit()
   1254 
   1255   def AddWhileContext(self, op, between_op_list, between_ops):
   1256     """Add the grad state for the while loop that op belongs to.
   1257 
   1258     Note that op is an Exit, and this method must be called in
   1259     the control flow context where gradients() is called.
   1260 
   1261     Note that this method modifies `between_op_list` and `between_ops`.
   1262     """
   1263     forward_ctxt = _GetWhileContext(op)
   1264     grad_state = self._map.get(forward_ctxt)
   1265     if grad_state is None:
   1266       # This is a new while loop so create a grad state for it.
   1267       outer_forward_ctxt = forward_ctxt.outer_context
   1268       if outer_forward_ctxt:
   1269         outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
   1270       outer_grad_state = None
   1271       if outer_forward_ctxt:
   1272         outer_grad_state = self._map.get(outer_forward_ctxt)
   1273       grad_state = GradLoopState(forward_ctxt, outer_grad_state)
   1274       self._map[forward_ctxt] = grad_state
   1275 
   1276       # We need to include all exits of a loop for backprop.
   1277       for loop_exit in grad_state.forward_loop_exits:
   1278         if not between_ops[loop_exit.op._id]:
   1279           between_ops[loop_exit.op._id] = True
   1280           between_op_list.append(loop_exit.op)
   1281 
   1282   def ZerosLikeForExit(self, val):
   1283     """Create zeros_like gradient for a loop exit.
   1284 
   1285     If the result of a loop variable is not used but is involved in
   1286     computing the result of some needed loop variable, we create a
   1287     zero-valued tensor that is fed as gradient for the Exit node of that
   1288     loop variable. Note that val.op is an Exit, and this method must be
   1289     called in the control flow context where gradients() is called.
   1290 
   1291     Args:
   1292       val: The output tensor of an Exit op.
   1293 
   1294     Returns:
   1295       A zero tensor of the same shape of val.
   1296     """
   1297     val_shape = val.get_shape()
   1298     forward_ctxt = val.op._get_control_flow_context()
   1299     outer_forward_ctxt = forward_ctxt.outer_context
   1300     if outer_forward_ctxt:
   1301       outer_forward_ctxt = outer_forward_ctxt.GetWhileContext()
   1302     outer_grad_state = None
   1303     if outer_forward_ctxt:
   1304       outer_grad_state = self._map.get(outer_forward_ctxt)
   1305     if outer_grad_state:
   1306       # This is a nested loop.
   1307       if val_shape.is_fully_defined():
   1308         # If the shape is known statically, just create a zero tensor
   1309         # with the right shape in the right context.
   1310         outer_grad_state.grad_context.Enter()
   1311         result = array_ops.zeros(val_shape.dims, val.dtype)
   1312         outer_grad_state.grad_context.Exit()
   1313       else:
   1314         # Only the shape of value is needed for backprop.
   1315         forward_ctxt.outer_context.Enter()
   1316         shape = array_ops.shape_internal(val, optimize=False)
   1317         forward_ctxt.outer_context.Exit()
   1318         # Save the shape to a stack.
   1319         history_shape = outer_grad_state.AddForwardAccumulator(shape)
   1320         # Get the shape back from the stack.
   1321         outer_grad_ctxt = outer_grad_state.grad_context
   1322         outer_grad_ctxt.Enter()
   1323         real_shape = outer_grad_state.AddBackpropAccumulatedValue(
   1324             history_shape, shape)
   1325         result = array_ops.zeros(real_shape, val.dtype)
   1326         outer_grad_ctxt.Exit()
   1327     else:
   1328       # This is not a nested loop.
   1329       if val_shape.is_fully_defined():
   1330         # If the shape is known statically, just create a zero tensor
   1331         # with the right shape.
   1332         result = array_ops.zeros(val_shape.dims, val.dtype)
   1333       else:
   1334         result = array_ops.zeros_like(val, optimize=False)
   1335     return result
   1336 
   1337   def ZerosLike(self, op, index):
   1338     """Create zeros_like for the specified output of an op.
   1339 
   1340     If op is in a while loop that is part of gradients(), this method
   1341     must be called in its grad loop context.
   1342 
   1343     Args:
   1344       op: A tensorflow operation.
   1345       index: the index for a specific output of the op.
   1346 
   1347     Returns:
   1348       A zero tensor of the same shape of op.outputs[index].
   1349     """
   1350     if util.IsLoopSwitch(op):
   1351       return None
   1352     dead_branch = util.IsSwitch(op)
   1353     forward_ctxt = _GetWhileContext(op)
   1354     grad_state = self._map.get(forward_ctxt)
   1355     if grad_state is None:
   1356       # op is not in a while loop that is part of gradients().
   1357       return ZerosLikeOutsideLoop(op, index)
   1358     op_ctxt = op._get_control_flow_context()
   1359     val = ops.convert_to_tensor(op.outputs[index], name="tensor")
   1360     shape = val.get_shape()
   1361     if shape.is_fully_defined():
   1362       # If the shape is known statically, just create a zero tensor with
   1363       # the right shape in the grad loop context.
   1364       result = constant_op.constant(0, shape=shape.dims, dtype=val.dtype)
   1365       if dead_branch:
   1366         # op is a cond switch. Guard the zero tensor with a switch.
   1367         pred = grad_state.history_map.get(op_ctxt.pred.name)
   1368         branch = op_ctxt.branch
   1369         result = _SwitchRefOrTensor(result, pred)[1 - branch]
   1370     else:
   1371       # Unknown shape so keep a history of the shape at runtime.
   1372       if dead_branch:
   1373         # Need to add a special switch to guard the value.
   1374         pred = op_ctxt.pred
   1375         branch = op_ctxt.branch
   1376         op_ctxt.outer_context.Enter()
   1377         val = _SwitchRefOrTensor(op.inputs[0], pred)[1 - branch]
   1378         zeros_shape = array_ops.shape_internal(val, optimize=False)
   1379         op_ctxt.outer_context.Exit()
   1380         val.op._set_control_flow_context(op_ctxt)
   1381         zeros_shape.op._set_control_flow_context(op_ctxt)
   1382       else:
   1383         op_ctxt.Enter()
   1384         zeros_shape = array_ops.shape_internal(val, optimize=False)
   1385         op_ctxt.Exit()
   1386 
   1387       # Add forward accumulator for shape.
   1388       grad_state.grad_context.Exit()
   1389       history_zeros_shape = grad_state.AddForwardAccumulator(
   1390           zeros_shape, dead_branch=dead_branch)
   1391       grad_state.grad_context.Enter()
   1392 
   1393       # Create a zero tensor with the right shape.
   1394       shape = grad_state.AddBackpropAccumulatedValue(history_zeros_shape,
   1395                                                      zeros_shape, dead_branch)
   1396       result = array_ops.zeros(shape, val.dtype)
   1397     return result
   1398 
   1399   def PostProcessing(self):
   1400     """Perform postprocessing at the end of gradients().
   1401 
   1402     We have created the gradient graph at this point. So this function
   1403     can be used to perform any postprocessing on the gradient graph.
   1404     We currently perform the following postprocessing:
   1405       1. Patch the gradient graph if the output of a loop variable
   1406          doesn't depend on its input.
   1407     """
   1408     for _, grad_state in self._map.items():
   1409       for _, b_merge in grad_state.switch_map.items():
   1410         if b_merge.op.inputs[0] == b_merge.op.inputs[1]:
   1411           # The value of this loop variable at iteration i+1 doesn't
   1412           # depend on its value at iteration i. So use zeros as the
   1413           # gradients for all iterations > 0.
   1414           dtype = b_merge.op.inputs[0].dtype
   1415           shape = b_merge.op.inputs[0].get_shape()
   1416           # pylint: disable=protected-access
   1417           if shape.is_fully_defined():
   1418             grad_state.grad_context.Enter()
   1419             # Create a zeros and use it for iterations > 0.
   1420             grad_val = constant_op.constant(0, dtype=dtype, shape=shape)
   1421             next_grad_val = _NextIteration(grad_val)
   1422             grad_state.grad_context.Exit()
   1423           else:
   1424             # Create a zeros in the outer grad context.
   1425             outer_grad_ctxt = grad_state.grad_context.outer_context
   1426             if outer_grad_ctxt:
   1427               outer_grad_ctxt.Enter()
   1428             enter_grad_op = b_merge.op.inputs[0].op
   1429             enter_grad = enter_grad_op.inputs[0]
   1430             grad_shape = array_ops.shape_internal(enter_grad, optimize=False)
   1431             grad_val = array_ops.zeros(grad_shape)
   1432             if outer_grad_ctxt:
   1433               outer_grad_ctxt.Exit()
   1434             # Use the zeros for iterations > 0.
   1435             grad_state.grad_context.Enter()
   1436             next_grad_val = _NextIteration(grad_val)
   1437             grad_state.grad_context.Exit()
   1438           b_merge.op._update_input(1, next_grad_val)
   1439           # pylint: enable=protected-access
   1440 
   1441 
   1442 def MaybeCreateControlFlowState(between_op_list, between_ops,
   1443                                 colocate_gradients_with_ops):
   1444   """Create the state for all the while loops involved in one gradients().
   1445 
   1446   We create a ControlFlowState when there are while loops involved in
   1447   gradients(). In gradients(), control flow logic is only invoked when
   1448   the ControlFlowState is not None.
   1449 
   1450   Note that this method modifies `between_op_list` and `between_ops`.
   1451   """
   1452   loop_state = None
   1453   for op in between_op_list:
   1454     if util.IsLoopExit(op):
   1455       if loop_state is None:
   1456         loop_state = ControlFlowState()
   1457       if colocate_gradients_with_ops:
   1458         with ops.colocate_with(op):
   1459           loop_state.AddWhileContext(op, between_op_list, between_ops)
   1460       else:
   1461         loop_state.AddWhileContext(op, between_op_list, between_ops)
   1462   return loop_state
   1463 
   1464 
   1465 def ZerosLikeOutsideLoop(op, index):
   1466   """Create zeros_like for the specified output of an op."""
   1467   val = op.outputs[index]
   1468   if not util.IsSwitch(op):
   1469     return array_ops.zeros_like(val, optimize=False)
   1470   else:
   1471     op_ctxt = op._get_control_flow_context()
   1472     if op_ctxt:
   1473       # We are in a cond context. Use a switch to create zeros only when needed.
   1474       pred = op_ctxt.pred
   1475       branch = op_ctxt.branch
   1476       switch_val = switch(op.inputs[0], pred)[1 - branch]
   1477       zeros_shape = array_ops.shape_internal(switch_val, optimize=False)
   1478       return array_ops.zeros(zeros_shape, dtype=val.dtype)
   1479     else:
   1480       return array_ops.zeros_like(val, optimize=False)
   1481 
   1482 
   1483 class ControlFlowContext(object):
   1484   """The base class for control flow context.
   1485 
   1486   The usage pattern is a sequence of (Enter, Exit) followed by a final
   1487   ExitResult.
   1488 
   1489   We maintain the following state for control flow contexts during graph
   1490   construction:
   1491    1. graph has _control_flow_context: the current context used to
   1492       construct new nodes. Changed by ctxt.Enter() and ctxt.Exit()
   1493    2. op has _control_flow_context: the context to which the op belongs.
   1494       Set at the time the op is created. Immutable.
   1495    3. A ControlFlowContext has _outer_context: the context in which this
   1496       context is created. Set at the time a context is created. Immutable.
   1497    4. A ControlFlowContext has _context_stack.
   1498       Pushed and popped by ctxt.Enter() and ctxt.Exit()
   1499   """
   1500 
   1501   def __init__(self, values_def=None, import_scope=None):
   1502     self._nested_contexts = []
   1503     self._outer_context = ops.get_default_graph()._get_control_flow_context()
   1504     if self._outer_context:
   1505       self._outer_context._nested_contexts.append(self)  # pylint: disable=protected-access
   1506     self._context_stack = []
   1507     if values_def:
   1508       self._init_values_from_proto(values_def, import_scope=import_scope)
   1509     else:
   1510       # Values that have been already seen in this context.
   1511       self._values = set()
   1512       # Values referenced by but external to this context.
   1513       self._external_values = {}
   1514 
   1515   def _init_values_from_proto(self, values_def, import_scope=None):
   1516     """Initializes values and external_values from `ValuesDef` protocol buffer.
   1517 
   1518     Args:
   1519       values_def: `ValuesDef` protocol buffer.
   1520       import_scope: Optional `string`. Name scope to add.
   1521     """
   1522     assert isinstance(values_def, control_flow_pb2.ValuesDef)
   1523     self._values = set(
   1524         ops.prepend_name_scope(value, import_scope)
   1525         for value in values_def.values)
   1526     g = ops.get_default_graph()
   1527     self._external_values = {}
   1528     for k, v in values_def.external_values.items():
   1529       k = ops.prepend_name_scope(k, import_scope)
   1530       self._external_values[k] = g.as_graph_element(
   1531           ops.prepend_name_scope(v, import_scope))
   1532     op_names = set([
   1533         op.split(":")[0]
   1534         for op in self._values - set(self._external_values.keys())
   1535     ])
   1536     for op in op_names:
   1537       # pylint: disable=protected-access
   1538       g.as_graph_element(op)._set_control_flow_context(self)
   1539       # pylint: enable=protected-access
   1540 
   1541   @property
   1542   def name(self):
   1543     return self._name
   1544 
   1545   @property
   1546   def outer_context(self):
   1547     """Return the context containing this context."""
   1548     return self._outer_context
   1549 
   1550   @property
   1551   def grad_state(self):
   1552     raise NotImplementedError("Abstract method")
   1553 
   1554   @property
   1555   def back_prop(self):
   1556     raise NotImplementedError("Abstract method")
   1557 
   1558   @abc.abstractmethod
   1559   def to_control_flow_context_def(self, context_def, export_scope=None):
   1560     """Serializes this into `context_def`.
   1561 
   1562     Args:
   1563       context_def: a `ControlFlowContextDef` protocol buffer.
   1564       export_scope: Optional `string`. Name scope to remove.
   1565     """
   1566     raise NotImplementedError("Abstract method")
   1567 
   1568   def _to_values_def(self, export_scope=None):
   1569     """Converts the values to a `ValuesDef` protocol buffer.
   1570 
   1571     Args:
   1572       export_scope: Optional `string`. Name scope to remove.
   1573 
   1574     Returns:
   1575       A `ValuesDef` protocol buffer.
   1576     """
   1577     values_def = control_flow_pb2.ValuesDef()
   1578     values_def.values.extend(
   1579         [ops.strip_name_scope(v, export_scope) for v in sorted(self._values)])
   1580     for k, v in self._external_values.items():
   1581       k = ops.strip_name_scope(k, export_scope)
   1582       values_def.external_values[k] = ops.strip_name_scope(v.name, export_scope)
   1583     return values_def
   1584 
   1585   def AddName(self, name):
   1586     self._values.add(name)
   1587 
   1588   # pylint: disable=protected-access
   1589   def Enter(self):
   1590     """Enter this control flow context."""
   1591     graph = ops.get_default_graph()
   1592     self._context_stack.append(graph._get_control_flow_context())
   1593     graph._set_control_flow_context(self)
   1594 
   1595   def Exit(self):
   1596     """Exit this control flow context."""
   1597     graph = ops.get_default_graph()
   1598     last_context = self._context_stack.pop()
   1599     graph._set_control_flow_context(last_context)
   1600 
   1601   def ExitResult(self, result):
   1602     """Make a list of tensors available in the outer context."""
   1603     if self._outer_context:
   1604       nest.map_structure(lambda x: self._outer_context.AddName(x.name), result)
   1605 
   1606   def GetWhileContext(self):
   1607     """Return the while context containing this context."""
   1608     if self._outer_context:
   1609       return self._outer_context.GetWhileContext()
   1610     return None
   1611 
   1612   def _IsInOuterContext(self, op):
   1613     op_ctxt = util.GetOutputContext(op)
   1614     outer_ctxt = self.outer_context
   1615     while outer_ctxt != op_ctxt:
   1616       if outer_ctxt is None:
   1617         return False
   1618       outer_ctxt = outer_ctxt.outer_context
   1619     return True
   1620 
   1621   def _RemoveExternalControlEdges(self, op):
   1622     """Remove any external control dependency on this op."""
   1623     while_ctxt = self.GetWhileContext()
   1624     # A control input of `op` is internal if it is in the same while
   1625     # loop context as the enclosing while loop context of self.
   1626     if while_ctxt is None:
   1627       internal_control_inputs = op.control_inputs
   1628     else:
   1629       internal_control_inputs = []
   1630       for x in op.control_inputs:
   1631         ctxt = util.GetOutputContext(x)
   1632         if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
   1633           internal_control_inputs.append(x)
   1634     external_control_inputs = []
   1635     if len(internal_control_inputs) != len(op.control_inputs):
   1636       external_control_inputs = list(set(op.control_inputs)
   1637                                      - set(internal_control_inputs))
   1638       op._remove_all_control_inputs()
   1639       op._add_control_inputs(internal_control_inputs)
   1640     return internal_control_inputs, external_control_inputs
   1641 
   1642   # pylint: enable=protected-access
   1643 
   1644   def AddInnerOp(self, op):
   1645     """Notifies a scope about an operator added to an inner scope."""
   1646     if self._outer_context:
   1647       self._outer_context.AddInnerOp(op)
   1648 
   1649   def GetControlPivot(self):
   1650     """Returns the pivot node for this context, or None."""
   1651     return None
   1652 
   1653   def IsWhileContext(self):
   1654     return False
   1655 
   1656   def IsCondContext(self):
   1657     return False
   1658 
   1659   def IsXLAContext(self):
   1660     return False
   1661 
   1662   def __str__(self):
   1663     return self.name
   1664 
   1665 
   1666 class CondContext(ControlFlowContext):
   1667   """The context for the conditional construct."""
   1668 
   1669   def __init__(self,
   1670                pred=None,
   1671                pivot=None,
   1672                branch=None,
   1673                name="cond_text",
   1674                context_def=None,
   1675                import_scope=None):
   1676     """Creates a `CondContext`.
   1677 
   1678     Args:
   1679       pred: The `boolean` tensor for the conditional predicate.
   1680       pivot: The predicate tensor in this branch.
   1681       branch: 0 or 1 representing this branch.
   1682       name: Name of the `CondContext` python object.
   1683       context_def: Optional `ContextDef` protocol buffer to initialize the
   1684         `CondContext` object from.
   1685       import_scope: Optional `string`. Name scope to add. Only used when
   1686         initialing from protocol buffer.
   1687     """
   1688     self._name = ops.get_default_graph().unique_name(name)
   1689 
   1690     if context_def:
   1691       self._init_from_proto(context_def, import_scope=import_scope)
   1692     else:
   1693       # Initializes the default fields.
   1694       ControlFlowContext.__init__(self)
   1695       self._pred = pred  # The boolean tensor for the cond predicate
   1696       self._pivot = pivot  # The predicate tensor in this branch
   1697       self._branch = branch  # 0 or 1 representing this branch
   1698 
   1699       # Values considered to have been already seen in this context.
   1700       self._values.add(pred.name)
   1701       self._values.add(pivot.name)
   1702 
   1703   def _init_from_proto(self, context_def, import_scope=None):
   1704     """Creates a new `CondContext` from protocol buffer.
   1705 
   1706     Args:
   1707       context_def: `CondContextDef` protocol buffer.
   1708       import_scope: Optional `string`. Name scope to add.
   1709     """
   1710     assert isinstance(context_def, control_flow_pb2.CondContextDef)
   1711     # Create from context_def.
   1712     g = ops.get_default_graph()
   1713     self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
   1714     self._pred = g.as_graph_element(
   1715         ops.prepend_name_scope(context_def.pred_name, import_scope))
   1716     self._pivot = g.as_graph_element(
   1717         ops.prepend_name_scope(context_def.pivot_name, import_scope))
   1718     self._branch = context_def.branch
   1719     super(CondContext, self).__init__(
   1720         values_def=context_def.values_def, import_scope=import_scope)
   1721 
   1722   @property
   1723   def pred(self):
   1724     return self._pred
   1725 
   1726   @property
   1727   def pivot(self):
   1728     return self._pivot
   1729 
   1730   @property
   1731   def branch(self):
   1732     return self._branch
   1733 
   1734   @property
   1735   def grad_state(self):
   1736     if self.GetWhileContext():
   1737       return self.GetWhileContext().grad_state
   1738     return None
   1739 
   1740   @property
   1741   def back_prop(self):
   1742     if self.GetWhileContext():
   1743       self.GetWhileContext().back_prop
   1744     return False
   1745 
   1746   def GetControlPivot(self):
   1747     return self._pivot
   1748 
   1749   def to_proto(self, export_scope=None):
   1750     """Converts a `CondContext` to a `CondContextDef` protocol buffer.
   1751 
   1752     Args:
   1753       export_scope: Optional `string`. Name scope to remove.
   1754 
   1755     Returns:
   1756       A `CondContextDef` protocol buffer.
   1757     """
   1758     if (export_scope is None or self.name.startswith(export_scope)):
   1759       context_def = control_flow_pb2.CondContextDef()
   1760       context_def.context_name = ops.strip_name_scope(self.name, export_scope)
   1761       context_def.pred_name = ops.strip_name_scope(self._pred.name,
   1762                                                    export_scope)
   1763       context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
   1764                                                     export_scope)
   1765       context_def.branch = self._branch
   1766       context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def(
   1767           export_scope))
   1768       # TODO(b/72868227): enable this once the corresponding control_flow.proto
   1769       # changes have been checked in (they aren't checked in and this is
   1770       # disabled for now to ensure forwards compatibility).
   1771       if False:  # pylint: disable=using-constant-test
   1772         for nested in self._nested_contexts:
   1773           nested_def = context_def.nested_contexts.add()
   1774           nested.to_control_flow_context_def(nested_def)
   1775 
   1776       return context_def
   1777     else:
   1778       return None
   1779 
   1780   @staticmethod
   1781   def from_proto(context_def, import_scope=None):
   1782     """Returns a `CondContext` object created from `context_def`."""
   1783     ret = CondContext(context_def=context_def,
   1784                       import_scope=import_scope)
   1785 
   1786     # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
   1787     # control_flow.proto changes have been checked in (they aren't checked in
   1788     # and this is here for now to ensure forwards compatibility).
   1789     if hasattr(context_def, "nested_contexts"):
   1790       ret.Enter()
   1791       for nested_def in context_def.nested_contexts:
   1792         from_control_flow_context_def(nested_def)
   1793       ret.Exit()
   1794     return ret
   1795 
   1796   def to_control_flow_context_def(self, context_def, export_scope=None):
   1797     context_def.cond_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
   1798 
   1799   def AddValue(self, val):
   1800     """Add `val` to the current context and its outer context recursively."""
   1801     if val.name in self._values:
   1802       # Use the real value if it comes from outer context. This is needed in
   1803       # particular for nested conds.
   1804       result = self._external_values.get(val.name)
   1805       result = val if result is None else result
   1806     else:
   1807       result = val
   1808       self._values.add(val.name)
   1809       if self._outer_context:
   1810         result = self._outer_context.AddValue(val)
   1811         self._values.add(result.name)
   1812       with ops.control_dependencies(None):
   1813         result = _SwitchRefOrTensor(result, self._pred)[self._branch]
   1814         if self._outer_context:
   1815           self._outer_context.AddInnerOp(result.op)
   1816 
   1817       result.op.graph.prevent_fetching(result.op)
   1818       # pylint: disable=protected-access
   1819       result.op._set_control_flow_context(self)
   1820       # pylint: enable=protected-access
   1821 
   1822       self._values.add(result.name)
   1823       self._external_values[val.name] = result
   1824     return result
   1825 
   1826   def AddOp(self, op):
   1827     self._AddOpInternal(op)
   1828 
   1829   def _AddOpInternal(self, op):
   1830     """Add `op` to the current context."""
   1831     if not op.inputs:
   1832       # Remove any external control dependency on this op
   1833       self._RemoveExternalControlEdges(op)
   1834       # pylint: disable=protected-access
   1835       op._add_control_input(self._pivot.op)
   1836       # pylint: enable=protected-access
   1837       for x in op.outputs:
   1838         self._values.add(x.name)
   1839     else:
   1840       for index in range(len(op.inputs)):
   1841         x = op.inputs[index]
   1842         real_x = self.AddValue(x)
   1843         if real_x != x:
   1844           # pylint: disable=protected-access
   1845           op._update_input(index, real_x)
   1846           # pylint: enable=protected-access
   1847       # Remove any external control dependency on this op.
   1848       self._RemoveExternalControlEdges(op)
   1849       for x in op.outputs:
   1850         self._values.add(x.name)
   1851       # pylint: disable=protected-access
   1852       if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
   1853         op._add_control_input(self._pivot.op)
   1854       # pylint: enable=protected-access
   1855 
   1856     if self._outer_context or not util.IsLoopExit(op):
   1857       op.graph.prevent_fetching(op)
   1858 
   1859     if self._outer_context:
   1860       self._outer_context.AddInnerOp(op)
   1861 
   1862   def _ProcessOutputTensor(self, val):
   1863     """Process an output tensor of a conditional branch."""
   1864     real_val = val
   1865     if val.name not in self._values:
   1866       # Handle the special case of lambda: x
   1867       self._values.add(val.name)
   1868       if self._outer_context:
   1869         real_val = self._outer_context.AddValue(val)
   1870         self._values.add(real_val.name)
   1871       real_val = _SwitchRefOrTensor(real_val, self._pred)[self._branch]
   1872       self._external_values[val.name] = real_val
   1873     else:
   1874       external_val = self._external_values.get(val.name)
   1875       if external_val is not None:
   1876         real_val = external_val
   1877     return real_val
   1878 
   1879   def _BuildCondTensor(self, v):
   1880     if isinstance(v, ops.Operation):
   1881       # Use pivot as the proxy for this op.
   1882       return with_dependencies([v], self._pivot)
   1883     elif isinstance(v, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
   1884       values = self._ProcessOutputTensor(v.values)
   1885       indices = self._ProcessOutputTensor(v.indices)
   1886       if isinstance(v, ops.IndexedSlices):
   1887         dense_shape = v.dense_shape
   1888         if dense_shape is not None:
   1889           dense_shape = self._ProcessOutputTensor(dense_shape)
   1890         return ops.IndexedSlices(values, indices, dense_shape)
   1891       else:
   1892         dense_shape = self._ProcessOutputTensor(v.dense_shape)
   1893         return sparse_tensor.SparseTensor(indices, values, dense_shape)
   1894     else:
   1895       v = nest.map_structure(_convert_tensorarray_to_flow, v)
   1896       return self._ProcessOutputTensor(ops.convert_to_tensor(v))
   1897 
   1898   def BuildCondBranch(self, fn):
   1899     """Add the subgraph defined by fn() to the graph."""
   1900     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1901     original_result = fn()
   1902     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1903     if len(post_summaries) > len(pre_summaries):
   1904       new_summaries = post_summaries[len(pre_summaries):]
   1905       summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   1906       summary_ref[:] = pre_summaries
   1907       with ops.control_dependencies(new_summaries):
   1908         if original_result is None:
   1909           return no_op(), None
   1910         else:
   1911           original_result = nest.map_structure(array_ops.identity,
   1912                                                original_result)
   1913     if original_result is None:
   1914       return None, None
   1915 
   1916     result = nest.map_structure(self._BuildCondTensor, original_result)
   1917     if not isinstance(result, (list, _basetuple)):
   1918       result = [result]
   1919     return original_result, result
   1920 
   1921   def IsCondContext(self):
   1922     return True
   1923 
   1924 
   1925 def _UnpackIfSingleton(res):
   1926   if isinstance(res, (list, _basetuple)) and len(res) == 1:
   1927     return res[0]
   1928   else:
   1929     return res
   1930 
   1931 
   1932 # pylint: disable=redefined-outer-name
   1933 # pylint: disable=g-doc-args
   1934 @tf_export("cond")
   1935 @deprecation.deprecated_args(
   1936     None, "fn1/fn2 are deprecated in favor of the true_fn/false_fn arguments.",
   1937     "fn1", "fn2")
   1938 def cond(pred,
   1939          true_fn=None,
   1940          false_fn=None,
   1941          strict=False,
   1942          name=None,
   1943          fn1=None,
   1944          fn2=None):
   1945   """Return `true_fn()` if the predicate `pred` is true else `false_fn()`.
   1946 
   1947   `true_fn` and `false_fn` both return lists of output tensors. `true_fn` and
   1948   `false_fn` must have the same non-zero number and type of outputs.
   1949 
   1950   Note that the conditional execution applies only to the operations defined in
   1951   `true_fn` and `false_fn`. Consider the following simple program:
   1952 
   1953   ```python
   1954   z = tf.multiply(a, b)
   1955   result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
   1956   ```
   1957 
   1958   If `x < y`, the `tf.add` operation will be executed and `tf.square`
   1959   operation will not be executed. Since `z` is needed for at least one
   1960   branch of the `cond`, the `tf.multiply` operation is always executed,
   1961   unconditionally.
   1962   Although this behavior is consistent with the dataflow model of TensorFlow,
   1963   it has occasionally surprised some users who expected a lazier semantics.
   1964 
   1965   Note that `cond` calls `true_fn` and `false_fn` *exactly once* (inside the
   1966   call to `cond`, and not at all during `Session.run()`). `cond`
   1967   stitches together the graph fragments created during the `true_fn` and
   1968   `false_fn` calls with some additional graph nodes to ensure that the right
   1969   branch gets executed depending on the value of `pred`.
   1970 
   1971   `tf.cond` supports nested structures as implemented in
   1972   `tensorflow.python.util.nest`. Both `true_fn` and `false_fn` must return the
   1973   same (possibly nested) value structure of lists, tuples, and/or named tuples.
   1974   Singleton lists and tuples form the only exceptions to this: when returned by
   1975   `true_fn` and/or `false_fn`, they are implicitly unpacked to single values.
   1976   This behavior is disabled by passing `strict=True`.
   1977 
   1978   Args:
   1979     pred: A scalar determining whether to return the result of `true_fn` or
   1980       `false_fn`.
   1981     true_fn: The callable to be performed if pred is true.
   1982     false_fn: The callable to be performed if pred is false.
   1983     strict: A boolean that enables/disables 'strict' mode; see above.
   1984     name: Optional name prefix for the returned tensors.
   1985 
   1986   Returns:
   1987     Tensors returned by the call to either `true_fn` or `false_fn`. If the
   1988     callables return a singleton list, the element is extracted from the list.
   1989 
   1990   Raises:
   1991     TypeError: if `true_fn` or `false_fn` is not callable.
   1992     ValueError: if `true_fn` and `false_fn` do not return the same number of
   1993       tensors, or return tensors of different types.
   1994 
   1995   Example:
   1996 
   1997   ```python
   1998   x = tf.constant(2)
   1999   y = tf.constant(5)
   2000   def f1(): return tf.multiply(x, 17)
   2001   def f2(): return tf.add(y, 23)
   2002   r = tf.cond(tf.less(x, y), f1, f2)
   2003   # r is set to f1().
   2004   # Operations in f2 (e.g., tf.add) are not executed.
   2005   ```
   2006 
   2007   """
   2008   # We needed to make true_fn/false_fn keyword arguments for
   2009   # backwards-compatibility. This check exists so that we can convert back to
   2010   # having them be positional arguments.
   2011   # TODO(josh11b): Make `true_fn` and `false_fn` positional arguments after
   2012   # `fn1` and `fn2` are deleted.
   2013   if fn1 is not None:
   2014     if true_fn is not None:
   2015       raise TypeError("cond(): true_fn and fn1 may not be set simultaneously.")
   2016     true_fn = fn1
   2017   elif true_fn is None:
   2018     raise TypeError("cond(): true_fn argument required")
   2019   if fn2 is not None:
   2020     if false_fn is not None:
   2021       raise TypeError("cond(): false_fn and fn2 may not be set simultaneously.")
   2022     false_fn = fn2
   2023   elif false_fn is None:
   2024     raise TypeError("cond(): false_fn argument required")
   2025 
   2026   if not callable(true_fn):
   2027     raise TypeError("true_fn must be callable.")
   2028   if not callable(false_fn):
   2029     raise TypeError("false_fn must be callable.")
   2030 
   2031   with ops.name_scope(name, "cond", [pred]):
   2032     if context.in_eager_mode():
   2033       if pred:
   2034         return _UnpackIfSingleton(true_fn())
   2035       return _UnpackIfSingleton(false_fn())
   2036 
   2037     # Add the Switch to the graph.
   2038     if isinstance(pred, bool):
   2039       raise TypeError("pred must not be a Python bool")
   2040     p_2, p_1 = switch(pred, pred)
   2041     pivot_1 = array_ops.identity(p_1, name="switch_t")
   2042     pivot_2 = array_ops.identity(p_2, name="switch_f")
   2043     pred = array_ops.identity(pred, name="pred_id")
   2044     # Disable the fetching of tensors that are only on one branch of cond.
   2045     for tensor in [p_1, p_2, pivot_1, pivot_2, pred]:
   2046       tensor.op.graph.prevent_fetching(tensor.op)
   2047 
   2048     # Build the graph for the true branch in a new context.
   2049     context_t = CondContext(pred, pivot_1, branch=1)
   2050     context_t.Enter()
   2051     orig_res_t, res_t = context_t.BuildCondBranch(true_fn)
   2052     if orig_res_t is None:
   2053       raise ValueError("true_fn must have a return value.")
   2054     context_t.ExitResult(res_t)
   2055     context_t.Exit()
   2056 
   2057     # Build the graph for the false branch in a new context.
   2058     context_f = CondContext(pred, pivot_2, branch=0)
   2059     context_f.Enter()
   2060     orig_res_f, res_f = context_f.BuildCondBranch(false_fn)
   2061     if orig_res_f is None:
   2062       raise ValueError("false_fn must have a return value.")
   2063     context_f.ExitResult(res_f)
   2064     context_f.Exit()
   2065 
   2066     if not strict:
   2067       orig_res_t = _UnpackIfSingleton(orig_res_t)
   2068       orig_res_f = _UnpackIfSingleton(orig_res_f)
   2069 
   2070     # Check that the return values of the two branches have the same structure.
   2071     try:
   2072       nest.assert_same_structure(orig_res_t, orig_res_f)
   2073     except TypeError as e:
   2074       raise TypeError(
   2075           "Incompatible return types of true_fn and false_fn: {}".format(e))
   2076     except ValueError as e:
   2077       raise ValueError(
   2078           "Incompatible return values of true_fn and false_fn: {}".format(e))
   2079 
   2080     # Add the final merge to the graph.
   2081     if not res_t:
   2082       raise ValueError("true_fn and false_fn must return at least one result.")
   2083 
   2084     res_t_flat = nest.flatten(res_t)
   2085     res_f_flat = nest.flatten(res_f)
   2086 
   2087     for x, y in zip(res_t_flat, res_f_flat):
   2088       assert ((isinstance(x, ops.IndexedSlices) and
   2089                isinstance(y, ops.IndexedSlices)) or
   2090               (isinstance(x, sparse_tensor.SparseTensor) and
   2091                isinstance(y, sparse_tensor.SparseTensor)) or
   2092               (isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor)))
   2093       val_x = x if isinstance(x, ops.Tensor) else x.values
   2094       val_y = y if isinstance(y, ops.Tensor) else y.values
   2095       if val_x.dtype.base_dtype != val_y.dtype.base_dtype:
   2096         raise ValueError(
   2097             "Outputs of true_fn and false_fn must have the same type: %s, %s" %
   2098             (val_x.dtype.name, val_y.dtype.name))
   2099 
   2100     merges = [merge(pair)[0] for pair in zip(res_f_flat, res_t_flat)]
   2101     merges = _convert_flows_to_tensorarrays(nest.flatten(orig_res_t), merges)
   2102 
   2103     # Only add non-nested conds to the collection. Any nested control flow will
   2104     # be encapsulated in the root context.
   2105     assert context_t.outer_context == context_f.outer_context
   2106     # TODO(b/72868227): remove "if True..." once the corresponding
   2107     # control_flow.proto changes have been checked in (they aren't checked in
   2108     # and this is disabled for now to ensure forwards compatibility).
   2109     if True or context_t.outer_context is None:
   2110       ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t)
   2111       ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f)
   2112 
   2113     merges = nest.pack_sequence_as(structure=orig_res_t, flat_sequence=merges)
   2114 
   2115     # Singleton lists and tuples are automatically unpacked if strict == False.
   2116     if not strict:
   2117       merges = _UnpackIfSingleton(merges)
   2118     return merges
   2119 
   2120 
   2121 # pylint: enable=g-doc-args
   2122 # pylint: enable=redefined-outer-name
   2123 
   2124 
   2125 def smart_cond(pred, true_fn=None, false_fn=None, name=None):
   2126   """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
   2127 
   2128   If `pred` is a bool or has a constant value, we return either `true_fn()`
   2129   or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
   2130 
   2131   Arguments:
   2132     pred: A scalar determining whether to return the result of `true_fn` or
   2133       `false_fn`.
   2134     true_fn: The callable to be performed if pred is true.
   2135     false_fn: The callable to be performed if pred is false.
   2136     name: Optional name prefix when using `tf.cond`.
   2137 
   2138   Returns:
   2139     Tensors returned by the call to either `true_fn` or `false_fn`.
   2140 
   2141   Raises:
   2142     TypeError: If `true_fn` or `false_fn` is not callable.
   2143   """
   2144   if not callable(true_fn):
   2145     raise TypeError('`true_fn` must be callable.')
   2146   if not callable(false_fn):
   2147     raise TypeError('`false_fn` must be callable.')
   2148 
   2149   pred_value = smart_constant_value(pred)
   2150   if pred_value is not None:
   2151     if pred_value:
   2152       return true_fn()
   2153     else:
   2154       return false_fn()
   2155   else:
   2156     return cond(pred, true_fn=true_fn, false_fn=false_fn, name=name)
   2157 
   2158 
   2159 def smart_constant_value(pred):
   2160   """Return the bool value for `pred`, or None if `pred` had a dynamic value.
   2161 
   2162   Arguments:
   2163     pred: A scalar, either a Python bool or tensor.
   2164 
   2165   Returns:
   2166     True or False if `pred` has a constant boolean value, None otherwise.
   2167 
   2168   Raises:
   2169     TypeError: If `pred` is not a Tensor or bool.
   2170   """
   2171   if isinstance(pred, bool):
   2172     pred_value = pred
   2173   elif isinstance(pred, ops.Tensor):
   2174     pred_value = tensor_util.constant_value(pred)
   2175   else:
   2176     raise TypeError('`pred` must be a Tensor or a Python bool.')
   2177   return pred_value
   2178 
   2179 
   2180 def _resource_safe_shape(t):
   2181   """Returns the shape of t or the variable it points to."""
   2182   if t.dtype == dtypes.resource:
   2183     while t.op.inputs:
   2184       t = t.op.inputs[0]
   2185     return tensor_shape.TensorShape(t.op.get_attr("shape"))
   2186   return array_ops.shape_internal(t, optimize=False)
   2187 
   2188 
   2189 # TODO(yuanbyu): Consider having a unified notion of context for
   2190 # not only conditionals and loops but also control dependency and
   2191 # subgraphs.
   2192 class WhileContext(ControlFlowContext):
   2193   """The context for the loop construct."""
   2194 
   2195   def __init__(self,
   2196                maximum_iterations=None,
   2197                parallel_iterations=10,
   2198                back_prop=True,
   2199                swap_memory=False,
   2200                name="while_context",
   2201                grad_state=None,
   2202                context_def=None,
   2203                import_scope=None):
   2204     """"Creates a `WhileContext`.
   2205 
   2206     Args:
   2207       maximum_iterations: Optional upper bound on number of loop iterations.
   2208       parallel_iterations: The number of iterations allowed to run in parallel.
   2209       back_prop: Whether backprop is enabled for this while loop.
   2210       swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
   2211       name: Optional name prefix for the returned tensors.
   2212       grad_state: The gradient loop state.
   2213       context_def: Optional `WhileContextDef` protocol buffer to initialize
   2214         the `Whilecontext` python object from.
   2215       import_scope: Optional `string`. Name scope to add. Only used when
   2216         initialing from protocol buffer.
   2217     """
   2218     if context_def:
   2219       self._init_from_proto(context_def, import_scope=import_scope)
   2220     else:
   2221       ControlFlowContext.__init__(self)
   2222       self._init_from_args(maximum_iterations, parallel_iterations, back_prop,
   2223                            swap_memory, name)
   2224     # The gradient loop state.
   2225     self._grad_state = grad_state
   2226 
   2227   def _init_from_args(self, maximum_iterations, parallel_iterations, back_prop,
   2228                       swap_memory, name):
   2229     """Creates a new `WhileContext` from arguments.
   2230 
   2231     Args:
   2232       maximum_iterations: Optional upper bound on number of loop iterations.
   2233       parallel_iterations: The number of iterations allowed to run in parallel.
   2234       back_prop: Whether backprop is enabled for this while loop.
   2235       swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
   2236       name: Optional name prefix for the returned tensors.
   2237 
   2238     Raises:
   2239       ValueError: If `parallel_iterations` has invalid value.
   2240     """
   2241     if not isinstance(parallel_iterations, int) or (parallel_iterations <= 0):
   2242       raise ValueError("`parallel_iterations` must be a positive integer: "
   2243                        "%s" % parallel_iterations)
   2244     self._name = ops.get_default_graph().unique_name(name)
   2245     self._maximum_iterations = maximum_iterations
   2246     self._parallel_iterations = parallel_iterations
   2247     self._back_prop = back_prop
   2248     self._swap_memory = swap_memory
   2249     # We use this node to control constants created by the pred lambda.
   2250     self._pivot_for_pred = None
   2251     # We use this node to control constants created by the body lambda.
   2252     self._pivot_for_body = None
   2253     # The boolean tensor for loop termination condition. Used in code
   2254     # generation for gradient computation
   2255     self._pivot = None
   2256     # The list of exit tensors for loop variables.
   2257     self._loop_exits = []
   2258     # The list of enter tensors for loop variables.
   2259     self._loop_enters = []
   2260 
   2261   def _init_from_proto(self, context_def, import_scope=None):
   2262     """Creates a new `WhileContext` from protocol buffer.
   2263 
   2264     Args:
   2265       context_def: `WhileContextDef` protocol buffer.
   2266       import_scope: Optional `string`. Name scope to add.
   2267     """
   2268     assert isinstance(context_def, control_flow_pb2.WhileContextDef)
   2269     # Create from context_def.
   2270     g = ops.get_default_graph()
   2271     self._name = ops.prepend_name_scope(context_def.context_name, import_scope)
   2272     if context_def.maximum_iterations_name:
   2273       self._maximum_iterations = g.as_graph_element(
   2274           ops.prepend_name_scope(context_def.maximum_iterations_name,
   2275                                  import_scope))
   2276     else:
   2277       self._maximum_iterations = None
   2278     self._parallel_iterations = context_def.parallel_iterations
   2279     self._back_prop = context_def.back_prop
   2280     self._swap_memory = context_def.swap_memory
   2281     self._pivot_for_pred = g.as_graph_element(
   2282         ops.prepend_name_scope(context_def.pivot_for_pred_name, import_scope))
   2283     # We use this node to control constants created by the body lambda.
   2284     self._pivot_for_body = g.as_graph_element(
   2285         ops.prepend_name_scope(context_def.pivot_for_body_name, import_scope))
   2286     # The boolean tensor for loop termination condition. Used in code
   2287     # generation for gradient computation.
   2288     self._pivot = g.as_graph_element(
   2289         ops.prepend_name_scope(context_def.pivot_name, import_scope))
   2290     # The list of exit tensors for loop variables.
   2291     self._loop_exits = [
   2292         g.as_graph_element(ops.prepend_name_scope(exit_name, import_scope))
   2293         for exit_name in context_def.loop_exit_names
   2294     ]
   2295     # The list of enter tensors for loop variables.
   2296     self._loop_enters = [
   2297         g.as_graph_element(ops.prepend_name_scope(enter_name, import_scope))
   2298         for enter_name in context_def.loop_enter_names
   2299     ]
   2300     super(WhileContext, self).__init__(
   2301         values_def=context_def.values_def, import_scope=import_scope)
   2302 
   2303     # import_scope causes self.name to be different from the original serialized
   2304     # context's name. Rewrite "frame_name" attrs with the new name.
   2305     if import_scope:
   2306       for tensor_name in self._values:
   2307         op = g.as_graph_element(tensor_name).op
   2308         if util.IsLoopEnter(op):
   2309           # pylint: disable=protected-access
   2310           op._set_attr("frame_name",
   2311                        attr_value_pb2.AttrValue(s=compat.as_bytes(self.name)))
   2312           # pylint: enable=protected-access
   2313 
   2314   @property
   2315   def maximum_iterations(self):
   2316     """The maximum number of iterations that will be executed."""
   2317     return self._maximum_iterations
   2318 
   2319   @property
   2320   def parallel_iterations(self):
   2321     """The number of iterations allowed to run in parallel."""
   2322     return self._parallel_iterations
   2323 
   2324   @property
   2325   def back_prop(self):
   2326     """True iff backprop is enabled for this while loop."""
   2327     return self._back_prop
   2328 
   2329   @property
   2330   def swap_memory(self):
   2331     """True iff GPU-CPU memory swap is enabled for this while loop."""
   2332     return self._swap_memory
   2333 
   2334   @property
   2335   def pivot(self):
   2336     """The boolean tensor representing the loop termination condition."""
   2337     return self._pivot
   2338 
   2339   @property
   2340   def loop_enters(self):
   2341     """The list of enter tensors for loop variables."""
   2342     return self._loop_enters
   2343 
   2344   @property
   2345   def loop_exits(self):
   2346     """The list of exit tensors for loop variables."""
   2347     return self._loop_exits
   2348 
   2349   @property
   2350   def grad_state(self):
   2351     """The gradient loop state."""
   2352     return self._grad_state
   2353 
   2354   def to_proto(self, export_scope=None):
   2355     """Converts a `WhileContext` to a `WhileContextDef` protocol buffer.
   2356 
   2357     Args:
   2358       export_scope: Optional `string`. Name scope to remove.
   2359 
   2360     Returns:
   2361       A `WhileContextDef` protocol buffer.
   2362     """
   2363     if (export_scope is None or self.name.startswith(export_scope)):
   2364       context_def = control_flow_pb2.WhileContextDef()
   2365       context_def.context_name = ops.strip_name_scope(self.name, export_scope)
   2366       context_def.parallel_iterations = self._parallel_iterations
   2367       if self._maximum_iterations is not None:
   2368         context_def.maximum_iterations_name = ops.strip_name_scope(
   2369             self._maximum_iterations.name, export_scope)
   2370       context_def.back_prop = self._back_prop
   2371       context_def.swap_memory = self._swap_memory
   2372       context_def.pivot_for_pred_name = ops.strip_name_scope(
   2373           self._pivot_for_pred.name, export_scope)
   2374       context_def.pivot_for_body_name = ops.strip_name_scope(
   2375           self._pivot_for_body.name, export_scope)
   2376       context_def.pivot_name = ops.strip_name_scope(self._pivot.name,
   2377                                                     export_scope)
   2378       context_def.loop_exit_names.extend([
   2379           ops.strip_name_scope(l.name, export_scope) for l in self._loop_exits
   2380       ])
   2381       context_def.loop_enter_names.extend([
   2382           ops.strip_name_scope(l.name, export_scope) for l in self._loop_enters
   2383       ])
   2384       context_def.values_def.MergeFrom(
   2385           super(WhileContext, self)._to_values_def(
   2386               export_scope=export_scope))
   2387       # TODO(b/72868227): remove "if True..." once the corresponding
   2388       # control_flow.proto changes have been checked in (they aren't checked in
   2389       # and this is disabled for now to ensure forwards compatibility).
   2390       if False:  # pylint: disable=using-constant-test
   2391         for nested in self._nested_contexts:
   2392           nested_def = context_def.nested_contexts.add()
   2393           nested.to_control_flow_context_def(nested_def)
   2394 
   2395       return context_def
   2396     else:
   2397       return None
   2398 
   2399   def to_control_flow_context_def(self, context_def, export_scope=None):
   2400     context_def.while_ctxt.CopyFrom(self.to_proto(export_scope=export_scope))
   2401 
   2402   @staticmethod
   2403   def from_proto(context_def, import_scope=None):
   2404     """Returns a `WhileContext` object created from `context_def`.
   2405 
   2406     Args:
   2407       context_def: A `WhileContextDef` protocol buffer.
   2408       import_scope: Optional `string`. Name scope to add.
   2409 
   2410     Returns:
   2411       A `WhileContext` Python object.
   2412     """
   2413     ret = WhileContext(context_def=context_def,
   2414                        import_scope=import_scope)
   2415     # TODO(b/72868227): remove "if hasattr(...)" once the corresponding
   2416     # control_flow.proto changes have been checked in (they aren't checked in
   2417     # and this is disabled for now to ensure forwards compatibility).
   2418     if hasattr(context_def, "nested_contexts"):
   2419       ret.Enter()
   2420       for nested_def in context_def.nested_contexts:
   2421         from_control_flow_context_def(nested_def, import_scope=import_scope)
   2422       ret.Exit()
   2423     return ret
   2424 
   2425   def GetWhileContext(self):
   2426     return self
   2427 
   2428   def GetControlPivot(self):
   2429     if self._pivot_for_body is not None:
   2430       return self._pivot_for_body
   2431     return self._pivot_for_pred
   2432 
   2433   def AddValue(self, val):
   2434     """Add `val` to the current context and its outer context recursively."""
   2435     result = val
   2436     if val.name not in self._values:
   2437       self._values.add(val.name)
   2438 
   2439       # If we are in a grad context and val is from its forward context,
   2440       # use GetRealValue(), which adds the logic to save the history of
   2441       # val in forward.
   2442       grad_ctxt = ops.get_default_graph()._get_control_flow_context()
   2443       if grad_ctxt:
   2444         grad_ctxt = grad_ctxt.GetWhileContext()
   2445         if grad_ctxt.grad_state:
   2446           forward_ctxt = _GetWhileContext(val.op)
   2447           if util.IsLoopExit(val.op):
   2448             forward_ctxt = forward_ctxt.outer_context
   2449             if forward_ctxt:
   2450               forward_ctxt = forward_ctxt.GetWhileContext()
   2451           if forward_ctxt == grad_ctxt.grad_state.forward_context:
   2452             real_val = grad_ctxt.grad_state.GetRealValue(val)
   2453             self._external_values[val.name] = real_val
   2454             return real_val
   2455 
   2456       if self._outer_context is not None:
   2457         result = self._outer_context.AddValue(val)
   2458       # Create an Enter to make `result` known to this loop context.
   2459       with ops.control_dependencies(None):
   2460         enter = _Enter(
   2461             result,
   2462             self._name,
   2463             is_constant=True,
   2464             parallel_iterations=self._parallel_iterations)
   2465         enter.graph.prevent_feeding(enter)
   2466         if self._outer_context:
   2467           self._outer_context.AddInnerOp(enter.op)
   2468       # Fix the control inputs and control flow context of these enter ops.
   2469       self._FixControlInputsAndContext([enter])
   2470 
   2471       # Add `enter` in this context.
   2472       self._values.add(enter.name)
   2473       self._external_values[val.name] = enter
   2474       result = enter
   2475     else:
   2476       actual_val = self._external_values.get(val.name)
   2477       if actual_val is not None:
   2478         result = actual_val
   2479     return result
   2480 
   2481   def AddOp(self, op):
   2482     """Add `op` to the current context."""
   2483     # For a reduction op, if op is in a grad context and its input is from
   2484     # its forward context, moving op to the forward context means we would
   2485     # store the tensor after the reduction as opposed to the tensor before
   2486     # reduction, and therefore could significantly reduce memory consumption.
   2487     # For now, we do this only for a few ops.
   2488     if op.type in {"Shape", "Size", "Rank"}:
   2489       grad_ctxt = ops.get_default_graph()._get_control_flow_context()
   2490       if grad_ctxt:
   2491         grad_ctxt = grad_ctxt.GetWhileContext()
   2492         if grad_ctxt.grad_state:
   2493           op_input_forward_ctxt = _GetWhileContext(op.inputs[0].op)
   2494           if op_input_forward_ctxt == grad_ctxt.grad_state.forward_context:
   2495             op_input_ctxt = op.inputs[0].op._get_control_flow_context()
   2496             op._set_control_flow_context(op_input_ctxt)
   2497             op_input_ctxt._AddOpInternal(op)
   2498             return
   2499     self._AddOpInternal(op)
   2500 
   2501   def _AddOpInternal(self, op):
   2502     """Add `op` to the current context.
   2503 
   2504     We move any external control dependencies of the op to the loop pivot, to
   2505     ensure they get executed.
   2506     """
   2507     if not op.inputs:
   2508       # Remove any external control dependency on this op
   2509       control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
   2510       # Add a control edge from the control pivot to this op.
   2511       if not control_inputs:
   2512         # pylint: disable=protected-access
   2513         op._add_control_input(self.GetControlPivot().op)
   2514         # pylint: enable=protected-access
   2515       for x in op.outputs:
   2516         self._values.add(x.name)
   2517     else:
   2518       for index in range(len(op.inputs)):
   2519         x = op.inputs[index]
   2520         real_x = self.AddValue(x)
   2521         if real_x != x:
   2522           op._update_input(index, real_x)  # pylint: disable=protected-access
   2523       # Remove any external control dependency on this op.
   2524       _, external_inputs = self._RemoveExternalControlEdges(op)
   2525       # Add a control dependency to prevent loop invariants from
   2526       # enabling ops that should not be executed.
   2527       self._MaybeAddControlDependency(op)
   2528       for x in op.outputs:
   2529         self._values.add(x.name)
   2530     if external_inputs:
   2531       # Use an identity to pull control inputs as data inputs. Note that we
   2532       # ignore ops which don't have outputs. TODO(apassos): fix that
   2533       with ops.control_dependencies(None):
   2534         self.Enter()
   2535         external_inputs = [array_ops.identity(x.outputs[0]).op
   2536                            for x in external_inputs if x.outputs]
   2537         self.Exit()
   2538       op._add_control_inputs(external_inputs)  # pylint: disable=protected-access
   2539     if self._outer_context or not util.IsLoopExit(op):
   2540       op.graph.prevent_fetching(op)
   2541       for x in op.outputs:
   2542         op.graph.prevent_feeding(x)
   2543 
   2544     if self._outer_context:
   2545       self._outer_context.AddInnerOp(op)
   2546 
   2547   def _MaybeAddControlDependency(self, op):
   2548     """Add a control input to the op if it only depends on loop invariants."""
   2549 
   2550     def _IsOpFree(op):
   2551       """Determines if `op` needs a control dependency."""
   2552       if op.control_inputs:
   2553         return False
   2554       # pylint: disable=protected-access
   2555       if op.graph._is_function(op.type) or op.type == "SymbolicGradient":
   2556         return True
   2557       # pylint: enable=protected-access
   2558       for x in op.inputs:
   2559         if not util.IsLoopConstantEnter(x.op):
   2560           return False
   2561       return True
   2562 
   2563     if _IsOpFree(op):
   2564       # pylint: disable=protected-access
   2565       op._add_control_input(self.GetControlPivot().op)
   2566       # pylint: enable=protected-access
   2567 
   2568   def AddForwardLoopCounter(self, outer_grad_state):
   2569     """Adds a loop that counts the number of iterations.
   2570 
   2571     This is added to the forward loop at the time when we start to
   2572     create the loop for backprop gradient computation. Called in
   2573     the outer context of this forward context.
   2574 
   2575     The pseudocode is:
   2576       `n = 0; while (_pivot) { n++; }`
   2577 
   2578     Note that a control dependency is added to `n` to ensure the correct
   2579     execution order of stack push ops.
   2580 
   2581     Args:
   2582       outer_grad_state: The outer grad state. None if not nested.
   2583 
   2584     Returns:
   2585       The number of iterations taken by the forward loop and the loop index.
   2586     """
   2587     n = constant_op.constant(0, name="f_count")
   2588     if outer_grad_state is not None:
   2589       # Force the stack pushes of i-th execution of an inner loop to be ordered
   2590       # before the pushes of (i+1)-th execution of the same inner loop.
   2591       outer_add_op = outer_grad_state.forward_index.op.inputs[0].op
   2592       n.op._add_control_input(outer_add_op)  # pylint: disable=protected-access
   2593 
   2594     self.Enter()
   2595     self.AddName(n.name)
   2596     enter_n = _Enter(
   2597         n,
   2598         self._name,
   2599         is_constant=False,
   2600         parallel_iterations=self._parallel_iterations,
   2601         name="f_count")
   2602     self.loop_enters.append(enter_n)
   2603 
   2604     merge_n = merge([enter_n, enter_n])[0]
   2605     switch_n = switch(merge_n, self._pivot)
   2606 
   2607     index = math_ops.add(switch_n[1], 1)
   2608     next_n = _NextIteration(index)
   2609     merge_n.op._update_input(1, next_n)
   2610 
   2611     total_iterations = exit(switch_n[0], name="f_count")
   2612     self.loop_exits.append(total_iterations)
   2613     self.ExitResult([total_iterations])
   2614     self.Exit()
   2615     return total_iterations, next_n
   2616 
   2617   def AddBackpropLoopCounter(self, count, outer_grad_state):
   2618     """Add the backprop loop that controls the iterations.
   2619 
   2620     This is added to the backprop loop. It is used to control the loop
   2621     termination of the backprop loop. Called in the outer context of
   2622     this grad context.
   2623 
   2624     The pseudocode is:
   2625       `n = count; while (n >= 1) { n--; }`
   2626 
   2627     Note that a control dependency is added to `final_zero` to ensure the
   2628     correct execution order of stack pop ops.
   2629 
   2630     Args:
   2631       count: The number of iterations for backprop.
   2632       outer_grad_state: The outer grad state. None if not nested.
   2633 
   2634     Returns:
   2635       The loop index.
   2636     """
   2637     one = constant_op.constant(1, name="b_count")
   2638 
   2639     self.Enter()
   2640     self.AddName(count.name)
   2641     enter_count = _Enter(
   2642         count,
   2643         self._name,
   2644         is_constant=False,
   2645         parallel_iterations=self._parallel_iterations,
   2646         name="b_count")
   2647     self.loop_enters.append(enter_count)
   2648 
   2649     merge_count = merge([enter_count, enter_count])[0]
   2650     self._pivot_for_pred = merge_count
   2651 
   2652     pred = math_ops.greater_equal(merge_count, one)
   2653     self._pivot = loop_cond(pred, name="b_count")
   2654     switch_count = switch(merge_count, self._pivot)
   2655 
   2656     index = math_ops.subtract(switch_count[1], one)
   2657     self._pivot_for_body = index
   2658     next_count = _NextIteration(index)
   2659     merge_count.op._update_input(1, next_count)
   2660 
   2661     final_zero = exit(switch_count[0], name="b_count")
   2662     self.loop_exits.append(final_zero)
   2663     if outer_grad_state is not None:
   2664       # Force the stack pops of i-th execution of an inner loop to be ordered
   2665       # before the pops of (i+1)-th execution of the same inner loop.
   2666       # pylint: disable=protected-access
   2667       outer_grad_state.grad_sync._add_control_input(final_zero.op)
   2668       # pylint: enable=protected-access
   2669 
   2670     self.ExitResult([final_zero])
   2671     self.Exit()
   2672     return next_count
   2673 
   2674   def AddBackpropAccumulator(self, op, grad):
   2675     """Add an accumulation loop for every loop invariant.
   2676 
   2677     This is added to the backprop loop. It is used to accumulate partial
   2678     gradients within each loop iteration. Called when in the gradient while
   2679     context.
   2680 
   2681     The pseudocode is:
   2682       ```
   2683       acc = 0.0;
   2684       while (_pivot) {
   2685         acc += grad;
   2686       }
   2687       ```
   2688 
   2689     Args:
   2690       op: The Enter op for a loop invariant.
   2691       grad: The partial gradient of an iteration for a loop invariant.
   2692 
   2693     Returns:
   2694       The gradient for a loop invariant.
   2695     """
   2696     self.Exit()
   2697     # Create a zeros tensor with the right shape for acc. If we don't
   2698     # know the full shape statically, we will have to get the shape
   2699     # dynamically from the forward inference. Getting the shape right
   2700     # for the zeros is only needed for the base case when the loop exits
   2701     # without running any iterations.
   2702     shape = grad.get_shape()
   2703     if shape.is_fully_defined():
   2704       if self.outer_context:
   2705         self.outer_context.Enter()
   2706       acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
   2707       if self.outer_context:
   2708         self.outer_context.Exit()
   2709     else:
   2710       value = op.inputs[0]
   2711       if (isinstance(self.outer_context, WhileContext) and
   2712           self.outer_context.grad_state is not None):
   2713         # We are in a nested while loop.
   2714         forward_ctxt = self.grad_state.forward_context
   2715         forward_ctxt.outer_context.Enter()
   2716         zeros_shape = array_ops.shape_internal(value, optimize=False)
   2717         forward_ctxt.outer_context.Exit()
   2718         outer_grad_state = self.grad_state.outer_grad_state
   2719         history_zeros_shape = outer_grad_state.AddForwardAccumulator(
   2720             zeros_shape)
   2721         self.outer_context.Enter()
   2722         real_shape = outer_grad_state.AddBackpropAccumulatedValue(
   2723             history_zeros_shape, zeros_shape)
   2724         acc = array_ops.zeros(real_shape, grad.dtype)
   2725         self.outer_context.Exit()
   2726       else:
   2727         if self.outer_context:
   2728           self.outer_context.Enter()
   2729         zeros_shape = array_ops.shape_internal(value, optimize=False)
   2730         acc = array_ops.zeros(zeros_shape, grad.dtype)
   2731         if self.outer_context:
   2732           self.outer_context.Exit()
   2733 
   2734     self.Enter()
   2735     self.AddName(acc.name)
   2736     enter_acc = _Enter(
   2737         acc,
   2738         self._name,
   2739         is_constant=False,
   2740         parallel_iterations=self._parallel_iterations,
   2741         name="b_acc")
   2742     self.loop_enters.append(enter_acc)
   2743 
   2744     merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
   2745     switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
   2746 
   2747     add_acc = math_ops.add(switch_acc_true, grad)
   2748     next_acc = _NextIteration(add_acc)
   2749     merge_acc.op._update_input(1, next_acc)  # pylint: disable=protected-access
   2750 
   2751     result_acc = exit(switch_acc_false, name="b_acc")
   2752     self.loop_exits.append(result_acc)
   2753     self.ExitResult([result_acc])
   2754     return result_acc
   2755 
   2756   def AddBackpropIndexedSlicesAccumulator(self, op, grad):
   2757     """This is used for accumulating gradients that are IndexedSlices.
   2758 
   2759     This is essentially the equivalent of AddBackpropAccumulator but optimized
   2760     for things like updating embeddings from within a while loop.
   2761 
   2762     Args:
   2763       op: The Enter op for a loop invariant.
   2764       grad: The partial gradients represented as an IndexedSlices.
   2765 
   2766     Returns:
   2767       The accumulated IndexedSlices gradient of the loop invariant.
   2768     """
   2769     values = grad.values
   2770     indices = grad.indices
   2771     dense_shape = grad.dense_shape
   2772 
   2773     self.Exit()
   2774     if self.outer_context:
   2775       self.outer_context.Enter()
   2776     if values.get_shape().is_fully_defined():
   2777       values_shape = tensor_shape.TensorShape(
   2778           [tensor_shape.Dimension(1)] + values.get_shape().dims[1:])
   2779       if self.outer_context:
   2780         self.outer_context.Enter()
   2781       values_acc = constant_op.constant(
   2782           0, values.dtype, shape=values_shape, name="b_acc")
   2783       if self.outer_context:
   2784         self.outer_context.Exit()
   2785     else:
   2786       values_shape = _resource_safe_shape(op.inputs[0])[1:]
   2787       values_shape = array_ops.concat([[1], values_shape], 0)
   2788       values_acc = array_ops.zeros(values_shape, dtype=values.dtype)
   2789     indices_acc = constant_op.constant([0], indices.dtype)
   2790     shape_acc = None
   2791     if dense_shape is not None:
   2792       if dense_shape.get_shape().is_fully_defined():
   2793         if self.outer_context:
   2794           self.outer_context.Enter()
   2795         shape_acc = constant_op.constant(
   2796             0, dense_shape.dtype, shape=dense_shape.get_shape())
   2797         if self.outer_context:
   2798           self.outer_context.Exit()
   2799       else:
   2800         shape_acc = array_ops.zeros_like(
   2801             array_ops.shape_internal(op.inputs[0], optimize=False),
   2802             optimize=False)
   2803 
   2804     if self.outer_context:
   2805       self.outer_context.Exit()
   2806 
   2807     self.Enter()
   2808     self.AddName(values_acc.name)
   2809     self.AddName(indices_acc.name)
   2810     init_acc = [indices_acc, values_acc]
   2811     if shape_acc is not None:
   2812       self.AddName(shape_acc.name)
   2813       init_acc.append(shape_acc)
   2814 
   2815     # Set use_input_shape=False since the accumulator tensors will grow in
   2816     # size. If use_input_shape=True, the _update_input call below will result in
   2817     # incompatible shapes.
   2818     enter_acc = [
   2819         _Enter(
   2820             x,
   2821             self._name,
   2822             is_constant=False,
   2823             parallel_iterations=self._parallel_iterations,
   2824             use_input_shape=False,
   2825             name="b_acc") for x in init_acc
   2826     ]
   2827     # Manually set appropriate partial shapes.
   2828     enter_acc[0].set_shape([None])
   2829     if values_acc.shape.dims is not None:
   2830       enter_acc[1].set_shape([None] + values_acc.shape.as_list()[1:])
   2831     self.loop_enters.extend(enter_acc)
   2832 
   2833     merge_acc = [merge([x, x], name="b_acc")[0] for x in enter_acc]
   2834     switch_acc = [switch(x, self._pivot) for x in merge_acc]
   2835 
   2836     # The actual accumulation.
   2837     acc_indexed_slices = [
   2838         array_ops.concat([xa[1], xv], 0)
   2839         for xa, xv in zip(switch_acc[:2], [indices, values])
   2840     ]
   2841     if shape_acc is not None:
   2842       # For the shape we just keep the maximum
   2843       acc_indexed_slices.append(math_ops.maximum(dense_shape, switch_acc[2][1]))
   2844 
   2845     next_acc = [_NextIteration(x) for x in acc_indexed_slices]
   2846     for xm, xn in zip(merge_acc, next_acc):
   2847       xm.op._update_input(1, xn)  # pylint: disable=protected-access
   2848 
   2849     exit_acc = [exit(x[0], name="b_acc") for x in switch_acc]
   2850     self.loop_exits.extend(exit_acc)
   2851 
   2852     self.ExitResult(exit_acc)
   2853     return ops.IndexedSlices(
   2854         indices=exit_acc[0],
   2855         values=exit_acc[1],
   2856         dense_shape=exit_acc[2] if shape_acc is not None else None)
   2857 
   2858   def _InitializeValues(self, values):
   2859     """Makes the values known to this context."""
   2860     self._values = set()
   2861     for x in values:
   2862       if isinstance(x, ops.Tensor):
   2863         self._values.add(x.name)
   2864       else:
   2865         self._values.add(x.values.name)
   2866         self._values.add(x.indices.name)
   2867         if isinstance(x, ops.IndexedSlices):
   2868           dense_shape = x.dense_shape
   2869         elif isinstance(x, sparse_tensor.SparseTensor):
   2870           dense_shape = x.dense_shape
   2871         else:
   2872           raise TypeError("Type %s not supported" % type(x))
   2873         if dense_shape is not None:
   2874           self._values.add(dense_shape.name)
   2875 
   2876   def _BuildLoop(self, pred, body, original_loop_vars, loop_vars,
   2877                  shape_invariants):
   2878     """Core: Add the loop termination condition and body to the graph."""
   2879     flat_loop_vars = nest.flatten(original_loop_vars)
   2880 
   2881     # Let the context know the loop variables so the loop variables
   2882     # would be added in the outer contexts properly.
   2883     self._InitializeValues(loop_vars)
   2884     real_vars = loop_vars
   2885     if self._outer_context:
   2886       real_vars = [self._outer_context.AddValue(x) for x in loop_vars]
   2887     with ops.control_dependencies(None):
   2888       enter_vars = [
   2889           _Enter(
   2890               x,
   2891               self._name,
   2892               is_constant=False,
   2893               parallel_iterations=self._parallel_iterations,
   2894               use_input_shape=(shape_invariants is None)) for x in real_vars
   2895       ]
   2896       for x in enter_vars:
   2897         x.graph.prevent_feeding(x)
   2898         if self._outer_context:
   2899           self._outer_context.AddInnerOp(x.op)
   2900 
   2901     # Finds the closest enclosing non-None control pivot.
   2902     outer_context = self._outer_context
   2903     control_pivot = None
   2904     while outer_context is not None and control_pivot is None:
   2905       control_pivot = outer_context.GetControlPivot()
   2906       # pylint: disable=protected-access
   2907       outer_context = outer_context._outer_context
   2908       # pylint: enable=protected-access
   2909 
   2910     if control_pivot is not None:
   2911       for var in enter_vars:
   2912         if util.IsLoopConstantEnter(var.op.inputs[0].op):
   2913           # pylint: disable=protected-access
   2914           var.op._add_control_input(control_pivot.op)
   2915           # pylint: enable=protected-access
   2916     _SetShapeInvariants(real_vars, enter_vars, shape_invariants)
   2917 
   2918     # Fix the control inputs and control flow context of these enter ops.
   2919     self._FixControlInputsAndContext(enter_vars)
   2920     self._InitializeValues(enter_vars)
   2921     self._loop_enters = enter_vars
   2922 
   2923     merge_vars = [merge([x, x])[0] for x in enter_vars]
   2924     self._pivot_for_pred = merge_vars[0]
   2925 
   2926     # Build the graph for pred.
   2927     merge_vars_with_tensor_arrays = (
   2928         _convert_flows_to_tensorarrays(flat_loop_vars, merge_vars))
   2929     packed_vars = nest.pack_sequence_as(
   2930         structure=original_loop_vars,
   2931         flat_sequence=merge_vars_with_tensor_arrays)
   2932     c = ops.convert_to_tensor(pred(*packed_vars))
   2933     self._pivot = loop_cond(c, name="LoopCond")
   2934     switch_vars = [_SwitchRefOrTensor(x, self._pivot) for x in merge_vars]
   2935 
   2936     # Build the graph for body.
   2937     vars_for_body = [_Identity(x[1]) for x in switch_vars]
   2938     self._pivot_for_body = vars_for_body[0]
   2939     # Convert TensorArray flow variables inside the context back into
   2940     # their associated TensorArrays for calling the body.
   2941     vars_for_body_with_tensor_arrays = (
   2942         _convert_flows_to_tensorarrays(flat_loop_vars, vars_for_body))
   2943     packed_vars_for_body = nest.pack_sequence_as(
   2944         structure=original_loop_vars,
   2945         flat_sequence=vars_for_body_with_tensor_arrays)
   2946     pre_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   2947     body_result = body(*packed_vars_for_body)
   2948     post_summaries = ops.get_collection(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   2949     if not nest.is_sequence(body_result):
   2950       body_result = [body_result]
   2951     if len(post_summaries) > len(pre_summaries):
   2952       new_summaries = post_summaries[len(pre_summaries):]
   2953       summary_ref = ops.get_collection_ref(ops.GraphKeys._SUMMARY_COLLECTION)  # pylint: disable=protected-access
   2954       summary_ref[:] = pre_summaries
   2955       with ops.control_dependencies(new_summaries):
   2956 
   2957         def map_fn(x):
   2958           # TODO(apassos) figure out how to trigger with tensor arrays as well
   2959           if isinstance(x, tensor_array_ops.TensorArray):
   2960             return x
   2961           return array_ops.identity(x)
   2962 
   2963         body_result = nest.map_structure(map_fn, body_result)
   2964 
   2965     # Compare the structure types of input and output of body.
   2966     # For backwards compatibility, the first layer is forced to a list
   2967     # during this comparison, because inputs are typically lists and
   2968     # outputs of the body are typically tuples.
   2969     nest.assert_same_structure(list(packed_vars_for_body), list(body_result))
   2970 
   2971     # Store body_result to keep track of TensorArrays returned by body
   2972     original_body_result = body_result
   2973     # Convert TensorArrays returned by body into their flow variables
   2974     result = nest.map_structure(_convert_tensorarray_to_flow,
   2975                                 nest.flatten(body_result))
   2976     result = ops.convert_n_to_tensor_or_indexed_slices(result)
   2977 
   2978     # Add NextIteration and the back edges to complete the loop.
   2979     if len(merge_vars) != len(result):
   2980       raise ValueError("Number of inputs and outputs of body must match "
   2981                        "loop_vars: %d, %d" % (len(merge_vars), len(result)))
   2982     next_vars = []
   2983     for m, v in zip(merge_vars, result):
   2984       next_vars.append(_AddNextAndBackEdge(m, v))
   2985 
   2986     # Add the exit ops.
   2987     exit_vars = [exit(x[0]) for x in switch_vars]
   2988     self._loop_exits = exit_vars
   2989 
   2990     # Exit the loop.
   2991     self.ExitResult(exit_vars)
   2992 
   2993     return original_body_result, exit_vars
   2994 
   2995   def BuildLoop(self, pred, body, loop_vars, shape_invariants):
   2996     """Add the loop termination condition and body to the graph."""
   2997 
   2998     # Keep original_loop_vars to identify which are TensorArrays
   2999     original_loop_vars = loop_vars
   3000     # Convert TensorArrays to their flow variables
   3001     loop_vars = nest.map_structure(_convert_tensorarray_to_flow,
   3002                                    nest.flatten(loop_vars))
   3003     loop_vars = ops.convert_n_to_tensor_or_indexed_slices(loop_vars)
   3004     try:
   3005       self.Enter()
   3006       original_body_result, exit_vars = self._BuildLoop(
   3007           pred, body, original_loop_vars, loop_vars, shape_invariants)
   3008     finally:
   3009       self.Exit()
   3010 
   3011     flat_result = nest.flatten(original_body_result)
   3012     # Convert TensorArray flow variables outside the context back into
   3013     # their associated TensorArrays for returning to caller.
   3014     exit_vars_with_tensor_arrays = (
   3015         _convert_flows_to_tensorarrays(flat_result, exit_vars))
   3016     packed_exit_vars = nest.pack_sequence_as(
   3017         structure=original_body_result,
   3018         flat_sequence=exit_vars_with_tensor_arrays)
   3019     return (packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars)
   3020 
   3021   def _FixControlInputsAndContext(self, enters):
   3022     graph = ops.get_default_graph()
   3023     # pylint: disable=protected-access
   3024     for e in enters:
   3025       if isinstance(e, ops.Tensor):
   3026         xs = [e]
   3027       else:
   3028         if not isinstance(e, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
   3029           raise TypeError("Type %s not supported" % type(e))
   3030         xs = [e.values, e.indices]
   3031         shape = e.dense_shape
   3032         if shape is not None:
   3033           xs.append(shape)
   3034       for x in xs:
   3035         inp_op = x.op.inputs[0].op
   3036         control_inputs = graph._control_dependencies_for_inputs([inp_op])
   3037         outer_control_inputs = [
   3038             op for op in control_inputs if self._IsInOuterContext(op)
   3039         ]
   3040         x.op._set_control_flow_context(self)
   3041         x.op._add_control_inputs(outer_control_inputs)
   3042         graph._record_op_seen_by_control_dependencies(x.op)
   3043     # pylint: enable=protected-access
   3044 
   3045   def IsWhileContext(self):
   3046     return True
   3047 
   3048 
   3049 # pylint: disable=redefined-outer-name
   3050 @tf_export("while_loop")
   3051 def while_loop(cond,
   3052                body,
   3053                loop_vars,
   3054                shape_invariants=None,
   3055                parallel_iterations=10,
   3056                back_prop=True,
   3057                swap_memory=False,
   3058                name=None,
   3059                maximum_iterations=None):
   3060   """Repeat `body` while the condition `cond` is true.
   3061 
   3062   `cond` is a callable returning a boolean scalar tensor. `body` is a callable
   3063   returning a (possibly nested) tuple, namedtuple or list of tensors of the same
   3064   arity (length and structure) and types as `loop_vars`. `loop_vars` is a
   3065   (possibly nested) tuple, namedtuple or list of tensors that is passed to both
   3066   `cond` and `body`. `cond` and `body` both take as many arguments as there are
   3067   `loop_vars`.
   3068 
   3069   In addition to regular Tensors or IndexedSlices, the body may accept and
   3070   return TensorArray objects.  The flows of the TensorArray objects will
   3071   be appropriately forwarded between loops and during gradient calculations.
   3072 
   3073   Note that `while_loop` calls `cond` and `body` *exactly once* (inside the
   3074   call to `while_loop`, and not at all during `Session.run()`). `while_loop`
   3075   stitches together the graph fragments created during the `cond` and `body`
   3076   calls with some additional graph nodes to create the graph flow that
   3077   repeats `body` until `cond` returns false.
   3078 
   3079   For correctness, `tf.while_loop()` strictly enforces shape invariants for
   3080   the loop variables. A shape invariant is a (possibly partial) shape that
   3081   is unchanged across the iterations of the loop. An error will be raised
   3082   if the shape of a loop variable after an iteration is determined to be more
   3083   general than or incompatible with its shape invariant. For example, a shape
   3084   of [11, None] is more general than a shape of [11, 17], and [11, 21] is not
   3085   compatible with [11, 17]. By default (if the argument `shape_invariants` is
   3086   not specified), it is assumed that the initial shape of each tensor in
   3087   `loop_vars` is the same in every iteration. The `shape_invariants` argument
   3088   allows the caller to specify a less specific shape invariant for each loop
   3089   variable, which is needed if the shape varies between iterations. The
   3090   @{tf.Tensor.set_shape}
   3091   function may also be used in the `body` function to indicate that
   3092   the output loop variable has a particular shape. The shape invariant for
   3093   SparseTensor and IndexedSlices are treated specially as follows:
   3094 
   3095   a) If a loop variable is a SparseTensor, the shape invariant must be
   3096   TensorShape([r]) where r is the rank of the dense tensor represented
   3097   by the sparse tensor. It means the shapes of the three tensors of the
   3098   SparseTensor are ([None], [None, r], [r]). NOTE: The shape invariant here
   3099   is the shape of the SparseTensor.dense_shape property. It must be the shape of
   3100   a vector.
   3101 
   3102   b) If a loop variable is an IndexedSlices, the shape invariant must be
   3103   a shape invariant of the values tensor of the IndexedSlices. It means
   3104   the shapes of the three tensors of the IndexedSlices are (shape, [shape[0]],
   3105   [shape.ndims]).
   3106 
   3107   `while_loop` implements non-strict semantics, enabling multiple iterations
   3108   to run in parallel. The maximum number of parallel iterations can be
   3109   controlled by `parallel_iterations`, which gives users some control over
   3110   memory consumption and execution order. For correct programs, `while_loop`
   3111   should return the same result for any parallel_iterations > 0.
   3112 
   3113   For training, TensorFlow stores the tensors that are produced in the
   3114   forward inference and are needed in back propagation. These tensors are a
   3115   main source of memory consumption and often cause OOM errors when training
   3116   on GPUs. When the flag swap_memory is true, we swap out these tensors from
   3117   GPU to CPU. This for example allows us to train RNN models with very long
   3118   sequences and large batches.
   3119 
   3120   Args:
   3121     cond: A callable that represents the termination condition of the loop.
   3122     body: A callable that represents the loop body.
   3123     loop_vars: A (possibly nested) tuple, namedtuple or list of numpy array,
   3124       `Tensor`, and `TensorArray` objects.
   3125     shape_invariants: The shape invariants for the loop variables.
   3126     parallel_iterations: The number of iterations allowed to run in parallel.
   3127       It must be a positive integer.
   3128     back_prop: Whether backprop is enabled for this while loop.
   3129     swap_memory: Whether GPU-CPU memory swap is enabled for this loop.
   3130     name: Optional name prefix for the returned tensors.
   3131     maximum_iterations: Optional maximum number of iterations of the while loop
   3132       to run.  If provided, the `cond` output is AND-ed with an additional
   3133       condition ensuring the number of iterations executed is no greater than
   3134       `maximum_iterations`.
   3135 
   3136   Returns:
   3137     The output tensors for the loop variables after the loop. When the length
   3138     of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when
   3139     the length of `loop_vars` is greater than 1 it returns a list.
   3140 
   3141   Raises:
   3142     TypeError: if `cond` or `body` is not callable.
   3143     ValueError: if `loop_vars` is empty.
   3144 
   3145   Example:
   3146 
   3147   ```python
   3148   i = tf.constant(0)
   3149   c = lambda i: tf.less(i, 10)
   3150   b = lambda i: tf.add(i, 1)
   3151   r = tf.while_loop(c, b, [i])
   3152   ```
   3153 
   3154   Example with nesting and a namedtuple:
   3155 
   3156   ```python
   3157   import collections
   3158   Pair = collections.namedtuple('Pair', 'j, k')
   3159   ijk_0 = (tf.constant(0), Pair(tf.constant(1), tf.constant(2)))
   3160   c = lambda i, p: i < 10
   3161   b = lambda i, p: (i + 1, Pair((p.j + p.k), (p.j - p.k)))
   3162   ijk_final = tf.while_loop(c, b, ijk_0)
   3163   ```
   3164 
   3165   Example using shape_invariants:
   3166 
   3167   ```python
   3168   i0 = tf.constant(0)
   3169   m0 = tf.ones([2, 2])
   3170   c = lambda i, m: i < 10
   3171   b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
   3172   tf.while_loop(
   3173       c, b, loop_vars=[i0, m0],
   3174       shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
   3175   ```
   3176   
   3177   Example which demonstrates non-strict semantics: In the following 
   3178   example, the final value of the counter `i` does not depend on `x`. So 
   3179   the `while_loop` can increment the counter parallel to updates of `x`. 
   3180   However, because the loop counter at one loop iteration depends
   3181   on the value at the previous iteration, the loop counter itself cannot
   3182   be incremented in parallel. Hence if we just want the final value of the 
   3183   counter (which we print on the line `print(sess.run(i))`), then  
   3184   `x` will never be incremented, but the counter will be updated on a 
   3185   single thread. Conversely, if we want the value of the output (which we
   3186   print on the line `print(sess.run(out).shape)`), then the counter may be 
   3187   incremented on its own thread, while `x` can be incremented in 
   3188   parallel on a separate thread. In the extreme case, it is conceivable 
   3189   that the thread incrementing the counter runs until completion before 
   3190   `x` is incremented even a single time. The only thing that can never 
   3191   happen is that the thread updating `x` can never get ahead of the 
   3192   counter thread because the thread incrementing `x` depends on the value 
   3193   of the counter. 
   3194   ```python
   3195   import tensorflow as tf
   3196 
   3197   n = 10000
   3198   x = tf.constant(list(range(n)))
   3199   c = lambda i, x: i < n
   3200   b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:"))
   3201   i, out = tf.while_loop(c, b, (0, x))
   3202   with tf.Session() as sess:
   3203       print(sess.run(i))  # prints [0] ... [9999]
   3204 
   3205       # The following line may increment the counter and x in parallel. 
   3206       # The counter thread may get ahead of the other thread, but not the 
   3207       # other way around. So you may see things like 
   3208       # [9996] x:[9987]
   3209       # meaning that the counter thread is on iteration 9996, 
   3210       # while the other thread is on iteration 9987
   3211       print(sess.run(out).shape)  
   3212   ```
   3213 
   3214   """
   3215   with ops.name_scope(name, "while", loop_vars):
   3216     if not loop_vars:
   3217       raise ValueError("No loop variables provided")
   3218     if not callable(cond):
   3219       raise TypeError("cond must be callable.")
   3220     if not callable(body):
   3221       raise TypeError("body must be callable.")
   3222     if parallel_iterations < 1:
   3223       raise TypeError("parallel_iterations must be a positive integer.")
   3224 
   3225     if maximum_iterations is not None:
   3226       maximum_iterations = ops.convert_to_tensor(
   3227           maximum_iterations, name="maximum_iterations")
   3228       if maximum_iterations.shape.ndims != 0:
   3229         raise ValueError("maximum_iterations must be a scalar, saw shape: %s" %
   3230                          maximum_iterations.shape)
   3231 
   3232       counter = constant_op.constant(
   3233           0, dtype=maximum_iterations.dtype, name="iteration_counter")
   3234       orig_cond = cond
   3235       orig_body = body
   3236       if len(loop_vars) == 1:
   3237         loop_vars = (counter, loop_vars[0])
   3238         cond = lambda i, lv: (  # pylint: disable=g-long-lambda
   3239             math_ops.logical_and(i < maximum_iterations, orig_cond(lv)))
   3240         body = lambda i, lv: (i + 1, orig_body(lv))
   3241       else:
   3242         loop_vars = (counter, loop_vars)
   3243         cond = lambda i, lv: (  # pylint: disable=g-long-lambda
   3244             math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
   3245         body = lambda i, lv: (i + 1, orig_body(*lv))
   3246 
   3247     if context.in_eager_mode():
   3248       while cond(*loop_vars):
   3249         loop_vars = body(*loop_vars)
   3250       if maximum_iterations is not None:
   3251         return loop_vars[1]
   3252       else:
   3253         return loop_vars
   3254 
   3255     if shape_invariants is not None:
   3256       if maximum_iterations is not None:
   3257         shape_invariants = (tensor_shape.TensorShape([]), shape_invariants)
   3258       nest.assert_same_structure(loop_vars, shape_invariants)
   3259 
   3260     loop_context = WhileContext(
   3261         maximum_iterations=maximum_iterations,
   3262         parallel_iterations=parallel_iterations,
   3263         back_prop=back_prop,
   3264         swap_memory=swap_memory)
   3265     # Only add non-nested loops to the collection. Any nested control flow will
   3266     # be encapsulated in the root context.
   3267     # TODO(b/72868227): enable condition once the corresponding
   3268     # control_flow.proto changes have been checked in (they aren't checked in
   3269     # and this is disabled for now to ensure forwards compatibility).
   3270     if True or loop_context.outer_context is None:
   3271       ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
   3272     result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
   3273     if maximum_iterations is not None:
   3274       return result[1]
   3275     else:
   3276       return result
   3277 
   3278 
   3279 # pylint: enable=redefined-outer-name
   3280 
   3281 
   3282 def _AsTensorList(x, p):
   3283   """Return x as a list of Tensors or IndexedSlices.
   3284 
   3285   For entries of `x` that are Operations, this returns an Identity of `p`
   3286   with a dependency on the operation.
   3287 
   3288   Args:
   3289     x: A Tensor/IndexedSlices/Operation or a list or tuple of them.
   3290     p: A Tensor to return for entries in `x` that are Operations.
   3291 
   3292   Returns:
   3293     A list of Tensors or IndexedSlices.
   3294   """
   3295   if not isinstance(x, (list, _basetuple)):
   3296     x = [x]
   3297 
   3298   l = []
   3299   for v in x:
   3300     if isinstance(v, ops.Operation):
   3301       v = with_dependencies([v], p)
   3302     v = ops.convert_to_tensor_or_indexed_slices(v)
   3303     if isinstance(v, ops.Tensor):
   3304       l.append(array_ops.identity(v))
   3305     else:
   3306       l.append(
   3307           ops.IndexedSlices(
   3308               array_ops.identity(v.values), array_ops.identity(v.indices)))
   3309   return l
   3310 
   3311 
   3312 def _CheckResults(a, b):
   3313   assert len(a) == len(b), (
   3314       "Values returned by a() and b() must have the same length.")
   3315   for x, y in zip(a, b):
   3316     assert x.dtype == y.dtype, (
   3317         "Values returned by a() [%s] and b() [%s] must have "
   3318         "the same type: %s, %s." % (x.name, y.name, x.dtype.name, y.dtype.name))
   3319 
   3320 
   3321 def with_dependencies(dependencies, output_tensor, name=None):
   3322   """Produces the content of `output_tensor` only after `dependencies`.
   3323 
   3324   In some cases, a user may want the output of an operation to be
   3325   consumed externally only after some other dependencies have run
   3326   first. This function ensures returns `output_tensor`, but only after all
   3327   operations in `dependencies` have run. Note that this means that there is
   3328   no guarantee that `output_tensor` will be evaluated after any `dependencies`
   3329   have run.
   3330 
   3331   See also @{tf.tuple$tuple} and @{tf.group$group}.
   3332 
   3333   Args:
   3334     dependencies: Iterable of operations to run before this op finishes.
   3335     output_tensor: A `Tensor` or `IndexedSlices` that will be returned.
   3336     name: (Optional) A name for this operation.
   3337 
   3338   Returns:
   3339     Same as `output_tensor`.
   3340 
   3341   Raises:
   3342     TypeError: if `output_tensor` is not a `Tensor` or `IndexedSlices`.
   3343   """
   3344   if context.in_eager_mode():
   3345     return output_tensor
   3346   with ops.name_scope(name, "control_dependency",
   3347                       list(dependencies) + [output_tensor]) as name:
   3348     with ops.colocate_with(output_tensor):
   3349       with ops.control_dependencies(dependencies):
   3350         output_tensor = ops.convert_to_tensor_or_indexed_slices(output_tensor)
   3351         if isinstance(output_tensor, ops.Tensor):
   3352           return _Identity(output_tensor, name=name)
   3353         else:
   3354           return ops.IndexedSlices(
   3355               _Identity(output_tensor.values, name=name), output_tensor.indices,
   3356               output_tensor.dense_shape)
   3357 
   3358 
   3359 def _GroupControlDeps(dev, deps, name=None):
   3360   with ops.control_dependencies(deps):
   3361     if dev is None:
   3362       return no_op(name=name)
   3363     else:
   3364       with ops.device(dev):
   3365         return no_op(name=name)
   3366 
   3367 
   3368 # TODO(touts): Accept "inputs" as a list.
   3369 @tf_export("group")
   3370 def group(*inputs, **kwargs):
   3371   """Create an op that groups multiple operations.
   3372 
   3373   When this op finishes, all ops in `inputs` have finished. This op has no
   3374   output.
   3375 
   3376   See also @{tf.tuple$tuple} and
   3377   @{tf.control_dependencies$control_dependencies}.
   3378 
   3379   Args:
   3380     *inputs: Zero or more tensors to group.
   3381     name: A name for this operation (optional).
   3382 
   3383   Returns:
   3384     An Operation that executes all its inputs.
   3385 
   3386   Raises:
   3387     ValueError: If an unknown keyword argument is provided.
   3388   """
   3389   if context.in_eager_mode():
   3390     return None
   3391   name = kwargs.pop("name", None)
   3392   if kwargs:
   3393     raise ValueError("Unknown keyword arguments: " + ", ".join(kwargs.keys()))
   3394   with ops.name_scope(name, "group_deps", inputs) as name:
   3395     # Grouping no inputs means do nothing
   3396     if not inputs:
   3397       return no_op(name=name)
   3398 
   3399     # Sorts *inputs according to their devices.
   3400     ops_on_device = {}  # device -> operations specified on the device.
   3401     for inp in nest.flatten(inputs):
   3402       if not hasattr(inp, "device"):
   3403         raise TypeError("Expected tf.group() expected Tensor arguments not "
   3404                         "'%s' with type '%s'" % (inp, type(inp)))
   3405       if not hasattr(inp, "device"):
   3406         if isinstance(inp, list):
   3407           raise TypeError("To call tf.group() with a list, use "
   3408                           "tf.group(*[...]) not tf.group([...]).")
   3409         raise TypeError("Expected tf.group() expected Tensor arguments not "
   3410                         "'%s' with type '%s'" % (inp, type(inp)))
   3411       dev = inp.device
   3412       if dev in ops_on_device:
   3413         ops_on_device[dev].append(inp)
   3414       else:
   3415         ops_on_device[dev] = [inp]
   3416     if len(ops_on_device) == 1:
   3417       # 1-level tree. The root node is the returned NoOp node.
   3418       (dev, deps), = ops_on_device.items()
   3419       return _GroupControlDeps(dev, deps, name=name)
   3420 
   3421     # 2-level tree. The root node is the returned NoOp node.
   3422     # deps contains 1 NoOp node for each device.
   3423     deps = []
   3424 
   3425     def device_key(dev):
   3426       """A sort key that allows None to be compared to strings."""
   3427       return "" if dev is None else dev
   3428 
   3429     for dev in sorted(six.iterkeys(ops_on_device), key=device_key):
   3430       deps.append(_GroupControlDeps(dev, ops_on_device[dev]))
   3431 
   3432     with ops.control_dependencies(deps):
   3433       return no_op(name=name)
   3434 
   3435 
   3436 @tf_export("tuple")
   3437 def tuple(tensors, name=None, control_inputs=None):  # pylint: disable=redefined-builtin
   3438   """Group tensors together.
   3439 
   3440   This creates a tuple of tensors with the same values as the `tensors`
   3441   argument, except that the value of each tensor is only returned after the
   3442   values of all tensors have been computed.
   3443 
   3444   `control_inputs` contains additional ops that have to finish before this op
   3445   finishes, but whose outputs are not returned.
   3446 
   3447   This can be used as a "join" mechanism for parallel computations: all the
   3448   argument tensors can be computed in parallel, but the values of any tensor
   3449   returned by `tuple` are only available after all the parallel computations
   3450   are done.
   3451 
   3452   See also @{tf.group$group} and
   3453   @{tf.control_dependencies$control_dependencies}.
   3454 
   3455   Args:
   3456     tensors: A list of `Tensor`s or `IndexedSlices`, some entries can be `None`.
   3457     name: (optional) A name to use as a `name_scope` for the operation.
   3458     control_inputs: List of additional ops to finish before returning.
   3459 
   3460   Returns:
   3461     Same as `tensors`.
   3462 
   3463   Raises:
   3464     ValueError: If `tensors` does not contain any `Tensor` or `IndexedSlices`.
   3465     TypeError: If `control_inputs` is not a list of `Operation` or `Tensor`
   3466       objects.
   3467 
   3468   """
   3469   if context.in_eager_mode():
   3470     return tensors
   3471   with ops.name_scope(name, "tuple", tensors) as name:
   3472     gating_ops = [t.op for t in tensors if t is not None]
   3473     if control_inputs:
   3474       for c in control_inputs:
   3475         if isinstance(c, ops.Tensor):
   3476           c = c.op
   3477         elif not isinstance(c, ops.Operation):
   3478           raise TypeError("Control input must be Operation or Tensor: %s" % c)
   3479         gating_ops.append(c)
   3480     # Note that in order to ensure ordering in the pbtxt, we must take care to
   3481     # ensure the order here.
   3482     gating_ops = sorted(set(gating_ops), key=lambda op: op._id)  # Uniquify ops.
   3483     if not gating_ops:
   3484       raise ValueError("Must have at least one Tensor: %s" % tensors)
   3485     gate = group(*gating_ops)
   3486     tpl = []
   3487     for t in tensors:
   3488       if t is not None:
   3489         tpl.append(with_dependencies([gate], t))
   3490       else:
   3491         tpl.append(None)
   3492     return tpl
   3493 
   3494 
   3495 def _assert_at_most_n_true(predicates, n, msg):
   3496   """Returns an Assert op that checks that at most n predicates are True.
   3497 
   3498   Args:
   3499     predicates: list of bool scalar tensors.
   3500     n: maximum number of true predicates allowed.
   3501     msg: Error message.
   3502   """
   3503   preds_c = array_ops.stack(predicates, name="preds_c")
   3504   num_true_conditions = math_ops.reduce_sum(
   3505       math_ops.cast(preds_c, dtypes.int32), name="num_true_conds")
   3506   condition = math_ops.less_equal(num_true_conditions,
   3507                                   constant_op.constant(n, name="n_true_conds"))
   3508   preds_names = ", ".join(getattr(p, "name", "?") for p in predicates)
   3509   error_msg = [
   3510       "%s: more than %d conditions (%s) evaluated as True:" %
   3511       (msg, n, preds_names), preds_c
   3512   ]
   3513   return Assert(condition, data=error_msg, summarize=len(predicates))
   3514 
   3515 
   3516 def _case_create_default_action(predicates, actions):
   3517   """Creates default action for a list of actions and their predicates.
   3518 
   3519   It uses the input actions to select an arbitrary as default and makes sure
   3520   that corresponding predicates have valid values.
   3521 
   3522   Args:
   3523     predicates: a list of bool scalar tensors
   3524     actions: a list of callable objects which return tensors.
   3525 
   3526   Returns:
   3527     a callable
   3528   """
   3529   k = len(predicates) - 1  # could pick any
   3530   predicate, action = predicates[k], actions[k]
   3531   other_predicates, other_actions = predicates[:k], actions[:k]
   3532 
   3533   def default_action():
   3534     others_msg = ("Implementation error: "
   3535                   "selected default action #%d was called, but some of other "
   3536                   "predicates are True: " % k)
   3537     default_msg = ("Input error: "
   3538                    "None of conditions evaluated as True:",
   3539                    array_ops.stack(predicates, name="preds_c"))
   3540     with ops.control_dependencies([
   3541         _assert_at_most_n_true(other_predicates, n=0, msg=others_msg),
   3542         Assert(predicate, data=default_msg)
   3543     ]):
   3544       return action()
   3545 
   3546   return default_action, other_predicates, other_actions
   3547 
   3548 
   3549 def _case_verify_and_canonicalize_args(pred_fn_pairs, exclusive, name):
   3550   """Verifies input arguments for the case function.
   3551 
   3552   Args:
   3553     pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
   3554                    callable which returns a list of tensors.
   3555     exclusive: True iff at most one predicate is allowed to evaluate to `True`.
   3556     name: A name for the case operation.
   3557 
   3558   Raises:
   3559     TypeError: If `pred_fn_pairs` is not a list/dictionary.
   3560     TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
   3561     TypeError: If `fns[i]` is not callable for any i, or `default` is not
   3562                callable.
   3563 
   3564   Returns:
   3565     a tuple <list of scalar bool tensors, list of callables>.
   3566   """
   3567   if not isinstance(pred_fn_pairs, (list, _basetuple, dict)):
   3568     raise TypeError("fns must be a list, tuple, or dict")
   3569 
   3570   if isinstance(pred_fn_pairs, collections.OrderedDict):
   3571     pred_fn_pairs = pred_fn_pairs.items()
   3572   elif isinstance(pred_fn_pairs, dict):
   3573     pred_fn_pairs = sorted(pred_fn_pairs.items(), key=lambda item: item[0].name)
   3574     if not exclusive:
   3575       logging.warn("%s: An unordered dictionary of predicate/fn pairs was "
   3576                    "provided, but exclusive=False. The order of conditional "
   3577                    "tests is deterministic but not guaranteed.", name)
   3578   for pred_fn_pair in pred_fn_pairs:
   3579     if not isinstance(pred_fn_pair, _basetuple) or len(pred_fn_pair) != 2:
   3580       raise TypeError("Each entry in pred_fn_pairs must be a 2-tuple")
   3581     pred, fn = pred_fn_pair
   3582     if pred.dtype != dtypes.bool:
   3583       raise TypeError("pred must be of type bool: %s", pred.name)
   3584     if not callable(fn):
   3585       raise TypeError("fn for pred %s must be callable." % pred.name)
   3586   predicates, actions = zip(*pred_fn_pairs)
   3587   return predicates, actions
   3588 
   3589 
   3590 @tf_export("case")
   3591 def case(pred_fn_pairs,
   3592          default=None,
   3593          exclusive=False,
   3594          strict=False,
   3595          name="case"):
   3596   """Create a case operation.
   3597 
   3598   The `pred_fn_pairs` parameter is a dict or list of pairs of size N.
   3599   Each pair contains a boolean scalar tensor and a python callable that
   3600   creates the tensors to be returned if the boolean evaluates to True.
   3601   `default` is a callable generating a list of tensors. All the callables
   3602   in `pred_fn_pairs` as well as `default` (if provided) should return the same
   3603   number and types of tensors.
   3604 
   3605   If `exclusive==True`, all predicates are evaluated, and an exception is
   3606   thrown if more than one of the predicates evaluates to `True`.
   3607   If `exclusive==False`, execution stops at the first predicate which
   3608   evaluates to True, and the tensors generated by the corresponding function
   3609   are returned immediately. If none of the predicates evaluate to True, this
   3610   operation returns the tensors generated by `default`.
   3611 
   3612   `tf.case` supports nested structures as implemented in
   3613   `tensorflow.python.util.nest`. All of the callables must return the same
   3614   (possibly nested) value structure of lists, tuples, and/or named tuples.
   3615   Singleton lists and tuples form the only exceptions to this: when returned by
   3616   a callable, they are implicitly unpacked to single values. This
   3617   behavior is disabled by passing `strict=True`.
   3618 
   3619   If an unordered dictionary is used for `pred_fn_pairs`, the order of the
   3620   conditional tests is not guaranteed. However, the order is guaranteed to be
   3621   deterministic, so that variables created in conditional branches are created
   3622   in fixed order across runs.
   3623 
   3624   **Example 1:**
   3625 
   3626   Pseudocode:
   3627 
   3628   ```
   3629   if (x < y) return 17;
   3630   else return 23;
   3631   ```
   3632 
   3633   Expressions:
   3634 
   3635   ```python
   3636   f1 = lambda: tf.constant(17)
   3637   f2 = lambda: tf.constant(23)
   3638   r = case([(tf.less(x, y), f1)], default=f2)
   3639   ```
   3640 
   3641   **Example 2:**
   3642 
   3643   Pseudocode:
   3644 
   3645   ```
   3646   if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
   3647   if (x < y) return 17;
   3648   else if (x > z) return 23;
   3649   else return -1;
   3650   ```
   3651 
   3652   Expressions:
   3653 
   3654   ```python
   3655   def f1(): return tf.constant(17)
   3656   def f2(): return tf.constant(23)
   3657   def f3(): return tf.constant(-1)
   3658   r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
   3659            default=f3, exclusive=True)
   3660   ```
   3661 
   3662   Args:
   3663     pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
   3664                    callable which returns a list of tensors.
   3665     default: Optional callable that returns a list of tensors.
   3666     exclusive: True iff at most one predicate is allowed to evaluate to `True`.
   3667     strict: A boolean that enables/disables 'strict' mode; see above.
   3668     name: A name for this operation (optional).
   3669 
   3670   Returns:
   3671     The tensors returned by the first pair whose predicate evaluated to True, or
   3672     those returned by `default` if none does.
   3673 
   3674   Raises:
   3675     TypeError: If `pred_fn_pairs` is not a list/dictionary.
   3676     TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
   3677     TypeError: If `fns[i]` is not callable for any i, or `default` is not
   3678                callable.
   3679   """
   3680   predicates, actions = _case_verify_and_canonicalize_args(
   3681       pred_fn_pairs, exclusive, name)
   3682   with ops.name_scope(name, "case", [predicates]):
   3683     if default is None:
   3684       default, predicates, actions = _case_create_default_action(
   3685           predicates, actions)
   3686     fn = default
   3687     # To eval conditions in direct order we create nested conditions in reverse:
   3688     #   cond(c[0], true_fn=.., false_fn=cond(c[1], ...))
   3689     for predicate, action in reversed(list(zip(predicates, actions))):
   3690       fn = functools.partial(
   3691           cond, predicate, true_fn=action, false_fn=fn, strict=strict)
   3692     if exclusive:
   3693       with ops.control_dependencies([
   3694           _assert_at_most_n_true(
   3695               predicates, n=1, msg="Input error: exclusive=True")
   3696       ]):
   3697         return fn()
   3698     else:
   3699       return fn()
   3700 
   3701 
   3702 class XLAControlFlowContext(ControlFlowContext):
   3703   """Base class for XLA and TPU control flow contexts."""
   3704 
   3705   def __init__(self):
   3706     super(XLAControlFlowContext, self).__init__()
   3707     self._name = "XLAControlFlowContext"
   3708 
   3709   def IsXLAContext(self):
   3710     return True
   3711 
   3712   def AddOp(self, _):
   3713     pass
   3714 
   3715   def AddValue(self, x):
   3716     return x
   3717 
   3718 
   3719 def from_control_flow_context_def(context_def, import_scope=None):
   3720   """Deserializes `context_def` into the appropriate ControlFlowContext.
   3721 
   3722   Args:
   3723     context_def: ControlFlowContextDef proto
   3724     import_scope: Optional `string`. Name scope to add.
   3725 
   3726   Returns:
   3727     A ControlFlowContext subclass
   3728   """
   3729   if context_def.HasField("cond_ctxt"):
   3730     return CondContext.from_proto(context_def.cond_ctxt,
   3731                                   import_scope=import_scope)
   3732   if context_def.HasField("while_ctxt"):
   3733     return WhileContext.from_proto(context_def.while_ctxt,
   3734                                    import_scope=import_scope)
   3735   raise NotImplementedError("Unknown ControlFlowContextDef field: %s"
   3736                             % context_def.WhichOneof("ctxt"))
   3737 
   3738 
   3739 ops.register_proto_function(
   3740     ops.GraphKeys.COND_CONTEXT,
   3741     proto_type=control_flow_pb2.CondContextDef,
   3742     to_proto=CondContext.to_proto,
   3743     from_proto=CondContext.from_proto)
   3744 
   3745 ops.register_proto_function(
   3746     ops.GraphKeys.WHILE_CONTEXT,
   3747     proto_type=control_flow_pb2.WhileContextDef,
   3748     to_proto=WhileContext.to_proto,
   3749     from_proto=WhileContext.from_proto)
   3750