Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 
     16 # pylint: disable=g-long-lambda
     17 """Tests for tensorflow.ops.control_flow_ops."""
     18 
     19 from __future__ import absolute_import
     20 from __future__ import division
     21 from __future__ import print_function
     22 
     23 import collections
     24 import math
     25 import sys
     26 import time
     27 
     28 import numpy as np
     29 from six.moves import xrange  # pylint: disable=redefined-builtin
     30 
     31 from tensorflow.core.protobuf import config_pb2
     32 from tensorflow.python.client import device_lib
     33 from tensorflow.python.client import session
     34 from tensorflow.python.eager import context
     35 from tensorflow.python.eager import def_function
     36 from tensorflow.python.eager import function as eager_function
     37 from tensorflow.python.eager import wrap_function
     38 from tensorflow.python.framework import constant_op
     39 from tensorflow.python.framework import dtypes
     40 from tensorflow.python.framework import errors_impl
     41 from tensorflow.python.framework import function
     42 from tensorflow.python.framework import ops
     43 from tensorflow.python.framework import sparse_tensor
     44 from tensorflow.python.framework import tensor_shape
     45 from tensorflow.python.framework import test_util
     46 from tensorflow.python.ops import array_ops
     47 from tensorflow.python.ops import control_flow_ops
     48 from tensorflow.python.ops import control_flow_util
     49 from tensorflow.python.ops import data_flow_ops
     50 from tensorflow.python.ops import functional_ops
     51 from tensorflow.python.ops import gen_array_ops
     52 from tensorflow.python.ops import gen_control_flow_ops
     53 from tensorflow.python.ops import gen_data_flow_ops
     54 from tensorflow.python.ops import gen_logging_ops
     55 from tensorflow.python.ops import gen_state_ops
     56 from tensorflow.python.ops import gradient_checker_v2
     57 from tensorflow.python.ops import gradients_impl
     58 from tensorflow.python.ops import init_ops
     59 from tensorflow.python.ops import linalg_ops
     60 from tensorflow.python.ops import logging_ops
     61 from tensorflow.python.ops import map_fn
     62 from tensorflow.python.ops import math_ops
     63 from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
     64 from tensorflow.python.ops import nn_ops
     65 from tensorflow.python.ops import random_ops
     66 from tensorflow.python.ops import resource_variable_ops
     67 from tensorflow.python.ops import script_ops
     68 from tensorflow.python.ops import sparse_ops
     69 from tensorflow.python.ops import state_ops
     70 from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
     71 from tensorflow.python.ops import tensor_array_ops
     72 from tensorflow.python.ops import variable_scope
     73 from tensorflow.python.ops import variables
     74 from tensorflow.python.ops import while_v2  # pylint: disable=unused-import
     75 # pylint: disable=unused-import
     76 from tensorflow.python.ops.ragged import ragged_factory_ops
     77 import tensorflow.python.ops.tensor_array_grad
     78 # pylint: enable=unused-import
     79 from tensorflow.python.platform import test
     80 from tensorflow.python.training import adam
     81 from tensorflow.python.training import gradient_descent
     82 from tensorflow.python.util import nest
     83 
     84 
     85 def check_consumers(graph):
     86   """Sanity check on the consumer list of the tensors."""
     87 
     88   consumer_count = {}
     89   for op in graph.get_operations():
     90     for v in op.inputs:
     91       cnt = consumer_count.get(v, 0)
     92       consumer_count[v] = cnt + 1
     93   for k, v in consumer_count.items():
     94     if len(k.consumers()) != v:
     95       return False
     96   return True
     97 
     98 
     99 def all_fetchables():
    100   tensor_names = []
    101   graph = ops.get_default_graph()
    102   for op in graph.get_operations():
    103     for t in op.outputs:
    104       if graph.is_fetchable(t):
    105         tensor_names.append(t.name)
    106   return tensor_names
    107 
    108 
    109 def all_feedables():
    110   feedable_tensors = []
    111   graph = ops.get_default_graph()
    112   for op in graph.get_operations():
    113     for t in op.inputs:
    114       if graph.is_feedable(t):
    115         feedable_tensors.append(t)
    116   return feedable_tensors
    117 
    118 
    119 def opt_cfg():
    120   return config_pb2.ConfigProto(
    121       allow_soft_placement=True,
    122       graph_options=config_pb2.GraphOptions(
    123           optimizer_options=config_pb2.OptimizerOptions(
    124               opt_level=config_pb2.OptimizerOptions.L1,
    125               do_function_inlining=True,
    126               do_constant_folding=True)))
    127 
    128 
    129 def isum(s, maximum_iterations=None):
    130   i = constant_op.constant(0, name="i")
    131   c = lambda i, s: math_ops.less(i, 10)
    132   b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)]
    133   _, r_s = control_flow_ops.while_loop(
    134       c, b, [i, s], maximum_iterations=maximum_iterations)
    135   return r_s
    136 
    137 
    138 @test_util.with_control_flow_v2
    139 class ControlFlowTest(test.TestCase):
    140 
    141   @test_util.run_v1_only("b/120545219")
    142   def testRefIdentity(self):
    143     with self.cached_session():
    144       v = variables.VariableV1(7)
    145 
    146       v = control_flow_ops._Identity(v)
    147       op = state_ops.assign(v, 9)
    148       v2 = control_flow_ops.with_dependencies([op], v)
    149 
    150       self.assertTrue(isinstance(v2, ops.Tensor))
    151       self.evaluate(variables.global_variables_initializer())
    152       self.assertEqual(9, self.evaluate(v2))
    153 
    154   @test_util.run_v1_only("b/120545219")
    155   def testRefEnter(self):
    156     with self.cached_session():
    157       v = variables.VariableV1(7)
    158 
    159       enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
    160       nine = constant_op.constant(9)
    161       enter_nine = gen_control_flow_ops.enter(nine, "foo_1")
    162       op = state_ops.assign(enter_v, enter_nine)
    163       v2 = control_flow_ops.with_dependencies([op], enter_v)
    164       v3 = control_flow_ops.exit(v2)
    165       self.evaluate(variables.global_variables_initializer())
    166       self.assertEqual(9, self.evaluate(v3))
    167 
    168   @test_util.run_v1_only("b/120545219")
    169   def testRefSwitch(self):
    170     with self.cached_session():
    171       v = variables.VariableV1(7)
    172 
    173       p = constant_op.constant(True)
    174       v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p)  # pylint: disable=protected-access
    175       v2 = state_ops.assign(v1[1], 9)
    176       self.evaluate(variables.global_variables_initializer())
    177       self.assertEqual(9, self.evaluate(v2))
    178 
    179   def testEnterMulExit(self):
    180     with self.cached_session():
    181       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    182       enter_data = gen_control_flow_ops.enter(data, "foo_1", False)
    183       five = constant_op.constant(5)
    184       enter_five = gen_control_flow_ops.enter(five, "foo_1", False)
    185       mul_op = math_ops.multiply(enter_data, enter_five)
    186       exit_op = control_flow_ops.exit(mul_op)
    187 
    188       result = self.evaluate(exit_op)
    189     self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
    190 
    191   @test_util.run_deprecated_v1
    192   def testEnterShapePropagation(self):
    193     with self.cached_session():
    194       v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
    195 
    196       # If is_constant=True, the shape information should be propagated.
    197       enter_v_constant = gen_control_flow_ops.enter(
    198           v, "frame1", is_constant=True)
    199       self.assertEqual(enter_v_constant.shape, [2])
    200 
    201       # Otherwise, the shape should be unknown.
    202       enter_v_non_constant = gen_control_flow_ops.enter(
    203           v, "frame2", is_constant=False)
    204       self.assertEqual(enter_v_non_constant.shape, None)
    205 
    206   @test_util.run_v1_only("b/120545219")
    207   def testSwitchMergeIndexedSlices(self):
    208     with self.cached_session():
    209       values = constant_op.constant([1, 2, 3, 4, 5, 6])
    210       indices = constant_op.constant([0, 2, 4, 6, 8, 10])
    211       data = ops.IndexedSlices(values, indices)
    212       pred = ops.convert_to_tensor(True)
    213       switch_op = control_flow_ops.switch(data, pred)
    214       merge_op = control_flow_ops.merge(switch_op)[0]
    215 
    216       val = merge_op.values
    217       ind = merge_op.indices
    218     self.assertAllEqual(np.arange(1, 7), val)
    219     self.assertAllEqual(np.arange(0, 12, 2), ind)
    220 
    221   @test_util.run_v1_only("b/120545219")
    222   def testSwitchDeadBranch(self):
    223     with self.cached_session():
    224       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    225       ports = ops.convert_to_tensor(True, name="ports")
    226       switch_op = control_flow_ops.switch(data, ports)
    227       dead_branch = array_ops.identity(switch_op[0])
    228 
    229       with self.assertRaisesWithPredicateMatch(
    230           errors_impl.InvalidArgumentError,
    231           lambda e: "Retval[0] does not have value" in str(e)):
    232         self.evaluate(dead_branch)
    233 
    234   @test_util.run_v1_only("b/120545219")
    235   def testSwitchMergeLess(self):
    236     with self.cached_session():
    237       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    238       zero = ops.convert_to_tensor(0)
    239       one = ops.convert_to_tensor(1)
    240       less_op = math_ops.less(zero, one)
    241       switch_op = control_flow_ops.switch(data, less_op)
    242       merge_op = control_flow_ops.merge(switch_op)[0]
    243 
    244       result = self.evaluate(merge_op)
    245     self.assertAllEqual(np.arange(1, 7), result)
    246 
    247   @test_util.run_v1_only("b/120545219")
    248   def testSwitchMergeAddIdentity(self):
    249     with self.cached_session():
    250       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    251       ports = ops.convert_to_tensor(False, name="ports")
    252       switch_op = control_flow_ops.switch(data, ports)
    253       one = constant_op.constant(1)
    254       add_op = math_ops.add(switch_op[0], one)
    255       id_op = array_ops.identity(switch_op[1])
    256       merge_op = control_flow_ops.merge([add_op, id_op])[0]
    257 
    258       result = self.evaluate(merge_op)
    259     self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result)
    260 
    261   @test_util.run_v1_only("b/120545219")
    262   def testSwitchMergeAddMul(self):
    263     with self.cached_session():
    264       data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
    265       ports = ops.convert_to_tensor(True, name="ports")
    266       switch_op = control_flow_ops.switch(data, ports)
    267       one = constant_op.constant(1)
    268       add_op = math_ops.add(switch_op[0], one)
    269       five = constant_op.constant(5)
    270       mul_op = math_ops.multiply(switch_op[1], five)
    271       merge_op = control_flow_ops.merge([add_op, mul_op])[0]
    272 
    273       result = self.evaluate(merge_op)
    274     self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result)
    275 
    276   @test_util.run_v1_only("b/120545219")
    277   def testLoop_false(self):
    278     with self.cached_session():
    279       false = ops.convert_to_tensor(False)
    280       n = constant_op.constant(10)
    281 
    282       enter_false = gen_control_flow_ops.enter(false, "foo_1", False)
    283       enter_n = gen_control_flow_ops.enter(n, "foo_1", False)
    284 
    285       merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
    286       switch_n = control_flow_ops.switch(merge_n, enter_false)
    287       exit_n = control_flow_ops.exit(switch_n[0])
    288       next_n = control_flow_ops.next_iteration(switch_n[0])
    289       merge_n.op._update_input(1, next_n)
    290 
    291       result = self.evaluate(exit_n)
    292     self.assertAllEqual(10, result)
    293 
    294   @test_util.run_deprecated_v1
    295   def testLoop_1(self):
    296     with self.cached_session():
    297       zero = constant_op.constant(0)
    298       one = constant_op.constant(1)
    299       n = constant_op.constant(10)
    300 
    301       enter_i = gen_control_flow_ops.enter(zero, "foo", False)
    302       enter_one = gen_control_flow_ops.enter(one, "foo", True)
    303       enter_n = gen_control_flow_ops.enter(n, "foo", True)
    304 
    305       with ops.device(test.gpu_device_name()):
    306         merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
    307 
    308       less_op = math_ops.less(merge_i, enter_n)
    309       cond_op = control_flow_ops.loop_cond(less_op)
    310       switch_i = control_flow_ops.switch(merge_i, cond_op)
    311 
    312       add_i = math_ops.add(switch_i[1], enter_one)
    313 
    314       next_i = control_flow_ops.next_iteration(add_i)
    315       merge_i.op._update_input(1, next_i)
    316 
    317       exit_i = control_flow_ops.exit(switch_i[0])
    318       result = self.evaluate(exit_i)
    319     self.assertAllEqual(10, result)
    320 
    321   @test_util.run_v1_only("b/120545219")
    322   def testLoop_2(self):
    323     with self.cached_session():
    324       zero = constant_op.constant(0)
    325       one = constant_op.constant(1)
    326       n = constant_op.constant(10)
    327 
    328       enter_i = gen_control_flow_ops.enter(zero, "foo", False)
    329       enter_one = gen_control_flow_ops.enter(one, "foo", True)
    330       enter_n = gen_control_flow_ops.enter(n, "foo", True)
    331 
    332       merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
    333 
    334       less_op = math_ops.less(merge_i, enter_n)
    335       cond_op = control_flow_ops.loop_cond(less_op)
    336       switch_i = control_flow_ops.switch(merge_i, cond_op)
    337 
    338       add_i = math_ops.add(switch_i[1], enter_one)
    339 
    340       with ops.device(test.gpu_device_name()):
    341         next_i = control_flow_ops.next_iteration(add_i)
    342       merge_i.op._update_input(1, next_i)
    343 
    344       exit_i = control_flow_ops.exit(switch_i[0])
    345       result = self.evaluate(exit_i)
    346     self.assertAllEqual(10, result)
    347 
    348   @test_util.run_v1_only("b/120545219")
    349   def testDifferentFrame(self):
    350     with self.cached_session():
    351       data = array_ops.placeholder(dtypes.float32, shape=[])
    352       enter_1 = gen_control_flow_ops.enter(data, "foo_1", False)
    353       enter_2 = gen_control_flow_ops.enter(data, "foo_2", False)
    354       res = math_ops.add(enter_1, enter_2)
    355       with self.assertRaisesOpError("has inputs from different frames"):
    356         res.eval(feed_dict={data: 1.0})
    357 
    358   @test_util.run_deprecated_v1
    359   def testCondBool(self):
    360     values = constant_op.constant(10)
    361     fn1 = lambda: math_ops.add(values, 1)
    362     fn2 = lambda: math_ops.subtract(values, 1)
    363     with self.assertRaisesRegexp(TypeError, "must not be a Python bool"):
    364       _ = control_flow_ops.cond(False, fn1, fn2)
    365 
    366   @test_util.run_deprecated_v1
    367   def testCondInt(self):
    368     p = array_ops.placeholder(dtypes.bool, shape=[])
    369     v = constant_op.constant(10)
    370     fn1 = lambda: math_ops.add(v, 1)
    371     fn2 = lambda: math_ops.subtract(v, 1)
    372     y = control_flow_ops.cond(p, fn1, fn2)
    373     grad = gradients_impl.gradients(y, [v])
    374     self.assertAllEqual([None], grad)
    375 
    376   def testCondOutputShape(self):
    377     x = constant_op.constant(1.0)
    378     b = control_flow_ops.cond(
    379         constant_op.constant(True), lambda: math_ops.square(x),
    380         lambda: math_ops.subtract(x, 1.))
    381     self.assertEqual(b.shape, tensor_shape.scalar())
    382 
    383   @test_util.run_v1_only("b/120545219")
    384   def testFetchable(self):
    385     with self.cached_session() as sess:
    386       x = array_ops.placeholder(dtypes.float32)
    387       control_flow_ops.cond(
    388           constant_op.constant(True), lambda: x + 2, lambda: x + 0)
    389       graph = ops.get_default_graph()
    390       for op in graph.get_operations():
    391         for t in op.inputs:
    392           if graph.is_fetchable(t.op):
    393             sess.run(t, feed_dict={x: 3})
    394           else:
    395             with self.assertRaisesRegexp(ValueError,
    396                                          "has been marked as not fetchable"):
    397               sess.run(t, feed_dict={x: 3})
    398 
    399   @test_util.disable_control_flow_v2("Not relevant")
    400   @test_util.run_v1_only("b/120545219")
    401   def testFeedable(self):
    402     with self.cached_session() as sess:
    403       c = constant_op.constant(2)
    404       i0 = constant_op.constant(0)
    405       r = control_flow_ops.while_loop(lambda i: i < 1000,
    406                                       lambda i: math_ops.square(c) + i, [i0])
    407       self.assertEqual(1000, r.eval(feed_dict={i0: 0}))
    408       feedable_tensors = all_feedables()
    409       for t in feedable_tensors:
    410         sess.run(r, feed_dict={t: 3})
    411       graph = ops.get_default_graph()
    412       for op in graph.get_operations():
    413         for t in op.inputs:
    414           if t not in feedable_tensors and t.dtype is dtypes.int32:
    415             with self.assertRaisesRegexp(ValueError, "may not be fed"):
    416               sess.run(r, feed_dict={t: 3})
    417 
    418   @test_util.run_v1_only("b/120545219")
    419   def testCondIndexedSlices(self):
    420     with self.cached_session():
    421       values = constant_op.constant(10)
    422       indices = constant_op.constant(0)
    423       x = ops.IndexedSlices(values, indices)
    424       pred = math_ops.less(1, 2)
    425       fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices)
    426       fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), indices)
    427       r = control_flow_ops.cond(pred, fn1, fn2)
    428 
    429       val = r.values
    430       ind = r.indices
    431     self.assertAllEqual(11, val)
    432     self.assertAllEqual(0, ind)
    433 
    434   def testCondMismatchedIndexedSlices(self):
    435     @def_function.function
    436     def foo():
    437       values = constant_op.constant(10)
    438       indices = constant_op.constant(0)
    439       x = ops.IndexedSlices(values, indices)
    440       v1_msg = "The two structures don't have the same nested structure"
    441       v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
    442                 "number, type, and overall structure of return values.")
    443       with self.assertRaisesRegexp(
    444           TypeError,
    445           v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
    446         control_flow_ops.cond(
    447             constant_op.constant(True),
    448             lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices),
    449             lambda: math_ops.add(x.values, 1), indices)
    450     foo()
    451 
    452   def testCondSparseTensor(self):
    453     values = constant_op.constant([2.0, 4.0], name="values")
    454     indices = constant_op.constant([[0], [3]],
    455                                    dtype=dtypes.int64,
    456                                    name="indices")
    457     shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
    458     x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
    459     pred = math_ops.less(1, 2)
    460     fn1 = lambda: sparse_tensor.SparseTensor(
    461         indices + 1, x.values + 1, dense_shape=shape)
    462     fn2 = lambda: sparse_tensor.SparseTensor(
    463         indices, x.values - 1, dense_shape=shape)
    464     r = control_flow_ops.cond(pred, fn1, fn2)
    465     self.assertAllEqual([3.0, 5.0], r.values)
    466     self.assertAllEqual([[1], [4]], r.indices)
    467     self.assertAllEqual(r.values.get_shape(), (2,))
    468 
    469   def testCondRaggedTensor(self):
    470     rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
    471     pred = math_ops.less(1, 2)
    472     fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0)
    473     fn2 = lambda: rt[:2] - 2
    474     result = control_flow_ops.cond(pred, fn1, fn2)
    475     self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values)
    476     self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits)
    477 
    478   @test_util.run_v1_only("b/120545219")
    479   def testCondResource(self):
    480 
    481     with self.cached_session():
    482       rv = resource_variable_ops.ResourceVariable(True)
    483       self.evaluate(variables.global_variables_initializer())
    484       t = ops.convert_to_tensor(1.0)
    485 
    486       def case():
    487         assign = resource_variable_ops.assign_variable_op(rv.handle, False)
    488         with ops.control_dependencies([assign]):
    489           return array_ops.identity(t)
    490 
    491       self.assertEqual(
    492           1.0, self.evaluate(control_flow_ops.cond(rv, case, lambda: t)))
    493 
    494   @test_util.run_v1_only("b/120545219")
    495   def testCondWithTensorArrayGrad(self):
    496     with self.cached_session() as sess:
    497       with ops.device(test.gpu_device_name()):
    498         pred = array_ops.placeholder(dtypes.bool, [])
    499         x = constant_op.constant([1.0, 2.0, 3.0])
    500         y = control_flow_ops.cond(
    501             pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x),
    502             lambda: constant_op.constant([1.0, 1.0, 1.0]))
    503         g = gradients_impl.gradients(y, x)[0]
    504 
    505       self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0])
    506       self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0])
    507 
    508   @test_util.disable_control_flow_v2("b/113293074")
    509   @test_util.run_v1_only("b/120545219")
    510   def testCondIndexedSlicesDifferentTypes(self):
    511     with self.cached_session():
    512       values = constant_op.constant(10)
    513       i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
    514       i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64)
    515       x = ops.IndexedSlices(values, i_32)
    516       pred = math_ops.less(1, 2)
    517       fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32)
    518       fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), i_64)
    519       r = control_flow_ops.cond(pred, fn1, fn2)
    520 
    521       val = r.values
    522       ind = r.indices
    523     self.assertAllEqual(11, val)
    524     self.assertAllEqual(0, ind)
    525     self.assertTrue(ind.dtype == np.int64)
    526 
    527   @test_util.run_v1_only("b/120545219")
    528   def testCondColocation(self):
    529     with self.session(use_gpu=True):
    530       with ops.device("/cpu:0"):
    531         v = variables.Variable(7.0)
    532 
    533       x = constant_op.constant(10.0)
    534       pred = math_ops.less(1.0, 2.0)
    535       fn1 = lambda: math_ops.add(v, 1.0)
    536       fn2 = lambda: math_ops.subtract(x, 1.0)
    537       r = control_flow_ops.cond(pred, fn1, fn2)
    538 
    539       for op in x.graph.get_operations():
    540         if op.name == "cond/Add/Switch":
    541           self.assertDeviceEqual(op.device, "/cpu:0")
    542 
    543   def _testCond_1(self, use_gpu):
    544     with self.cached_session(use_gpu=use_gpu):
    545       x = constant_op.constant(10)
    546       pred = math_ops.less(1, 2)
    547       fn1 = lambda: math_ops.add(x, 1)
    548       fn2 = lambda: math_ops.subtract(x, 1)
    549       r = control_flow_ops.cond(pred, fn1, fn2)
    550 
    551       result = self.evaluate(r)
    552     self.assertAllEqual(11, result)
    553 
    554   def testCond_1(self):
    555 
    556     self._testCond_1(use_gpu=False)
    557     # TODO(b/116526896): Enable GPU tests.
    558     # self._testCond_1(use_gpu=True)
    559 
    560   def testCond_2(self):
    561 
    562     with self.cached_session():
    563       x = constant_op.constant(10)
    564       r = control_flow_ops.cond(
    565           math_ops.less(1, 0), lambda: math_ops.add(x, 1),
    566           lambda: math_ops.subtract(x, 1))
    567       result = self.evaluate(r)
    568     self.assertAllEqual(9, result)
    569 
    570   def testCond_3(self):
    571 
    572     with self.cached_session():
    573       x = constant_op.constant(10)
    574       pred = math_ops.less(1, 2)
    575       fn1 = lambda: math_ops.add(x, 1)
    576       fn2 = lambda: math_ops.subtract(x, 1)
    577       fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1)
    578       r = control_flow_ops.cond(pred, fn3, fn2)
    579 
    580       result = self.evaluate(r)
    581     self.assertAllEqual(12, result)
    582 
    583   @test_util.disable_xla("b/128638446")
    584   @test_util.run_in_graph_and_eager_modes
    585   def testCondPruning(self):
    586     v1 = variables.Variable(7)
    587     v2 = variables.Variable(7)
    588     v3 = variables.Variable(7)
    589 
    590     def f():
    591       age = constant_op.constant(3)
    592       max_age = constant_op.constant(2)
    593       pred = math_ops.greater(age, max_age)
    594       fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op]
    595       fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op]
    596       r = control_flow_ops.cond(pred, fn1, fn2)
    597       self.assertEqual(len(r), 2)
    598       return r[1]
    599 
    600     f_defun = eager_function.defun(f)
    601 
    602     if not context.executing_eagerly():
    603       with self.cached_session():
    604         self.evaluate(variables.global_variables_initializer())
    605         result = self.evaluate(f())
    606         self.assertEqual(True, result)
    607         # Only second cond result was fetched, so v1 assign shouldn't run.
    608         self.assertEqual(7, self.evaluate(v1))
    609         self.assertEqual(2, self.evaluate(v2))
    610         self.assertEqual(7, self.evaluate(v3))
    611 
    612     result = f_defun()
    613     self.assertEqual(True, self.evaluate(result))
    614     # Both v1 and v2 branch assignments should be run in defun.
    615     self.assertEqual(1, self.evaluate(v1))
    616     self.assertEqual(2, self.evaluate(v2))
    617     self.assertEqual(7, self.evaluate(v3))
    618 
    619   def testCond_5(self):
    620     with self.cached_session():
    621       alive = constant_op.constant(True, name="alive")
    622       count = constant_op.constant(0, name="count")
    623 
    624       def body(i):
    625         return control_flow_ops.cond(
    626             alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)],
    627             lambda: [alive, count])
    628 
    629       for i in range(10):
    630         alive, count = body(i)
    631       self.assertAllEqual(4, self.evaluate(count))
    632 
    633   @test_util.run_v1_only("b/120545219")
    634   def testCond_6(self):
    635     with self.cached_session():
    636       v1 = variables.Variable([7])
    637 
    638       age = constant_op.constant(3)
    639       pred = math_ops.greater(age, 4)
    640       fn1 = lambda: age
    641       fn2 = lambda: v1
    642       r = control_flow_ops.cond(pred, fn1, fn2)
    643 
    644       self.evaluate(variables.global_variables_initializer())
    645       result = self.evaluate(r)
    646       self.assertAllEqual(np.array([7]), result)
    647 
    648   def testCond_7(self):
    649     with self.cached_session() as sess:
    650       x = constant_op.constant(10)
    651       y = constant_op.constant(200)
    652       pred = math_ops.less(1, 2)
    653       fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)]
    654       fn2 = lambda: [y, y]
    655       r = control_flow_ops.cond(pred, fn1, fn2)
    656       self.assertAllEqual([11, 12], self.evaluate(r))
    657 
    658   @test_util.run_gpu_only
    659   @test_util.run_deprecated_v1
    660   def testCond_Device(self):
    661     x = constant_op.constant(-10.)
    662 
    663     # True branch function defined outside of device scope
    664     def true_fn():
    665       return math_ops.exp(x)
    666 
    667     with ops.device("CPU:0"):
    668       r = control_flow_ops.cond(
    669           constant_op.constant(True), true_fn, lambda: 0.)
    670       self.assertIn("cpu", r.device.lower())
    671 
    672     with session.Session() as sess:
    673       options = config_pb2.RunOptions(output_partition_graphs=True)
    674       run_metadata = config_pb2.RunMetadata()
    675       sess.run(r, options=options, run_metadata=run_metadata)
    676       # We expect that everything runs on CPU, even if GPU is available.
    677       self.assertEqual(len(run_metadata.partition_graphs), 1)
    678 
    679   def _count_matching_switch_nodes_on_device(self, run_metadata, device_str):
    680     # Returns the number of Switch nodes with type float32 placed on
    681     # `device_str`.
    682     device_graphs = [
    683         g for g in run_metadata.partition_graphs
    684         if device_str in g.node[0].device
    685     ]
    686     self.assertLen(device_graphs, 1)
    687     switch_nodes = [
    688         n for n in device_graphs[0].node if n.op == "Switch" and
    689         n.attr["T"].type == dtypes.float32.as_datatype_enum
    690     ]
    691     return len(switch_nodes)
    692 
    693   @test_util.run_gpu_only
    694   @test_util.run_deprecated_v1
    695   def testCondSwitchColocatedWithInputWhenInputOnCPU(self):
    696     x = array_ops.placeholder(dtypes.float32)
    697 
    698     # `arg` is used in the cond then branch so a Switch node is created for it.
    699     # We test that the Switch node gets placed on the same device as `arg`.
    700     # We force `arg` to be on CPU here.
    701     with ops.device("CPU:0"):
    702       arg = x + 10.
    703 
    704     def true_fn():
    705       with ops.device("CPU:0"):
    706         return arg + 1
    707 
    708     r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
    709 
    710     with session.Session() as sess:
    711       run_metadata = config_pb2.RunMetadata()
    712       options = config_pb2.RunOptions(output_partition_graphs=True)
    713       sess.run(
    714           r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
    715       self.assertEqual(len(run_metadata.partition_graphs), 2)
    716       # Check that the Switch for `arg` gets placed on CPU.
    717       self.assertEqual(
    718           self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 1)
    719       self.assertEqual(
    720           self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 0)
    721 
    722   @test_util.run_gpu_only
    723   @test_util.run_deprecated_v1
    724   def testCondSwitchColocatedWithInputWhenInputOnGPU(self):
    725     x = array_ops.placeholder(dtypes.float32)
    726 
    727     # `arg` is used in the cond then branch so a Switch node is created for it.
    728     # We test that the Switch node gets placed on the same device as `arg`.
    729     # Note: `arg` gets placed on GPU by default by the placer.
    730     arg = x + 10.
    731 
    732     def true_fn():
    733       with ops.device("CPU:0"):
    734         return arg + 1
    735 
    736     r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.)
    737 
    738     with session.Session() as sess:
    739       run_metadata = config_pb2.RunMetadata()
    740       options = config_pb2.RunOptions(output_partition_graphs=True)
    741       sess.run(
    742           r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata)
    743       self.assertEqual(len(run_metadata.partition_graphs), 2)
    744       # Check that the Switch for `arg` gets placed on GPU.
    745       self.assertEqual(
    746           self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 0)
    747       self.assertEqual(
    748           self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 1)
    749 
    750   def testCondListOutput(self):
    751     with self.cached_session() as sess:
    752       x = constant_op.constant(10)
    753       y = constant_op.constant(200)
    754       pred = math_ops.less(1, 2)
    755       fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)]
    756       fn2 = lambda: [y, y]
    757       r = control_flow_ops.cond(pred, fn1, fn2)
    758       test_result = self.evaluate(r)
    759       self.assertListEqual([210, 210], test_result)
    760 
    761   def testTupleOutput(self):
    762     with self.cached_session() as sess:
    763       x = constant_op.constant(10)
    764       y = constant_op.constant(200)
    765       pred = math_ops.less(1, 2)
    766       fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y))
    767       fn2 = lambda: (y, y)
    768       r = control_flow_ops.cond(pred, fn1, fn2)
    769       test_result = self.evaluate(r)
    770       self.assertTupleEqual((210, 210), test_result)
    771 
    772   def testDictOutput(self):
    773     with self.cached_session() as sess:
    774       x = constant_op.constant(10)
    775       y = constant_op.constant(200)
    776       pred = math_ops.less(1, 2)
    777       fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
    778       fn2 = lambda: {"a": y, "b": y}
    779       r = control_flow_ops.cond(pred, fn1, fn2)
    780       test_result = self.evaluate(r)
    781       self.assertDictEqual({"a": 210, "b": 210}, test_result)
    782 
    783   @test_util.run_deprecated_v1
    784   def testEmbeddedListOutput(self):
    785     with self.cached_session() as sess:
    786       x = constant_op.constant(10)
    787       y = constant_op.constant(200)
    788       pred = math_ops.less(1, 2)
    789       fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]]
    790       fn2 = lambda: [[y, y]]
    791       # Pass strict=True flag as cond_v2 allows for tensors to be
    792       # in nested output structures as singletons
    793       r = control_flow_ops.cond(pred, fn1, fn2, strict=True)
    794       test_result = self.evaluate(r)
    795       self.assertListEqual([[210, 210]], test_result)
    796 
    797   def testEmbeddedTupleOutput(self):
    798     with self.cached_session() as sess:
    799       x = constant_op.constant(10)
    800       y = constant_op.constant(200)
    801       pred = math_ops.less(1, 2)
    802       fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y)))
    803       fn2 = lambda: ((y, y))
    804       r = control_flow_ops.cond(pred, fn1, fn2)
    805       test_result = self.evaluate(r)
    806       self.assertTupleEqual(((210, 210)), test_result)
    807 
    808   def testEmbeddedDictOutput(self):
    809     with self.cached_session() as sess:
    810       x = constant_op.constant(10)
    811       y = constant_op.constant(200)
    812       pred = math_ops.less(1, 2)
    813       fn1 = lambda: {"a": {"c": math_ops.add(x, y)},
    814                      "b": {"d": math_ops.add(x, y)}}
    815       fn2 = lambda: {"a": {"c": y},
    816                      "b": {"d": y}}
    817       r = control_flow_ops.cond(pred, fn1, fn2)
    818       test_result = self.evaluate(r)
    819       self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result)
    820 
    821   @test_util.run_v1_only("b/120545219")
    822   def testCheckNestedOutputStruct(self):
    823     with self.cached_session() as sess:
    824       x = constant_op.constant(10)
    825       y = constant_op.constant(200)
    826       pred = math_ops.less(1, 2)
    827       fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)}
    828       fn2 = lambda: {"c": y, "d": y}
    829       v1_msg = "The two structures don't have the same nested structure"
    830       v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same "
    831                 "number, type, and overall structure of return values.")
    832       with self.assertRaisesRegexp(
    833           TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError,
    834           v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
    835         control_flow_ops.cond(pred, fn1, fn2)
    836 
    837   @test_util.run_deprecated_v1
    838   def testCondRef(self):
    839 
    840     with self.cached_session():
    841       x = gen_state_ops.variable(
    842           shape=[1],
    843           dtype=dtypes.float32,
    844           name="x",
    845           container="",
    846           shared_name="")
    847       true_fn = lambda: x
    848       false_fn = lambda: constant_op.constant([2.0])
    849       r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
    850       self.assertAllEqual([2.0], self.evaluate(r))
    851 
    852   @test_util.disable_control_flow_v2("b/79881896 (placeholder)")
    853   @test_util.run_v1_only("b/120545219")
    854   def testCondWithControl(self):
    855     with self.cached_session():
    856       control_holder = array_ops.placeholder(dtypes.float32, shape=())
    857       a = constant_op.constant(3)
    858 
    859       def true_branch():
    860         with ops.control_dependencies([control_holder]):
    861           _ = a + 1
    862         return a + 2
    863 
    864       r = control_flow_ops.cond(
    865           constant_op.constant(True), true_branch,
    866           lambda: constant_op.constant(1))
    867       self.assertEqual(5, self.evaluate(r))
    868 
    869   @test_util.run_v1_only("b/120545219")
    870   def testUninitializedRefIdentity(self):
    871     with self.cached_session() as sess:
    872       v = gen_state_ops.variable(
    873           shape=[1],
    874           dtype=dtypes.float32,
    875           name="v",
    876           container="",
    877           shared_name="")
    878       inited = state_ops.is_variable_initialized(v)
    879       v_f, v_t = control_flow_ops.ref_switch(v, inited)
    880       # Both v_f and v_t are uninitialized references. However, an actual use
    881       # of the reference in the 'true' branch in the 'tf.identity' op will
    882       # not 'fire' when v is uninitialized, so this is a valid construction.
    883       # This test tests that ref_identity allows uninitialized ref as input
    884       # so that this construction is allowed.
    885       v_f_op = gen_array_ops.ref_identity(v_f)
    886       v_t_op = gen_array_ops.ref_identity(v_t)
    887       with ops.control_dependencies([v_f_op]):
    888         assign_v = state_ops.assign(v, [1.0])
    889       with ops.control_dependencies([v_t_op]):
    890         orig_v = array_ops.identity(v)
    891       merged_op = control_flow_ops.merge([assign_v, orig_v])
    892       self.assertAllEqual([1.0], self.evaluate(merged_op.output))
    893 
    894   def testCondSwitchIdentity(self):
    895     # Make sure the recv identity is not removed by optimization.
    896     with session.Session(config=opt_cfg()) as sess:
    897       pred = constant_op.constant(True)
    898 
    899       def fn1():
    900         return control_flow_ops.no_op()
    901 
    902       def fn2():
    903         return control_flow_ops.Assert(False, ["Wrong branch!!!"])
    904 
    905       r = control_flow_ops.cond(pred, fn1, fn2)
    906       self.evaluate(r)
    907 
    908   def testCondRecvIdentity(self):
    909     # Make sure the switch identity is not removed by optimization.
    910     with session.Session(config=opt_cfg()) as sess:
    911       with ops.device(test.gpu_device_name()):
    912         pred = constant_op.constant(True)
    913 
    914       def fn1():
    915         return control_flow_ops.no_op()
    916 
    917       def fn2():
    918         with ops.device("/cpu:0"):
    919           return control_flow_ops.Assert(False, ["Wrong branch!!!"])
    920 
    921       r = control_flow_ops.cond(pred, fn1, fn2)
    922       self.evaluate(r)
    923 
    924   @test_util.run_v1_only("b/120545219")
    925   def testCondGrad_1(self):
    926     with self.cached_session():
    927       x = constant_op.constant(10.0, name="x")
    928       pred = math_ops.less(1, 2)
    929       fn1 = lambda: array_ops.identity(x)
    930       fn2 = lambda: array_ops.identity(x)
    931       r = control_flow_ops.cond(pred, fn1, fn2)
    932 
    933       grad = gradients_impl.gradients(r, [x])[0]
    934       self.assertAllEqual(1.0, self.evaluate(grad))
    935 
    936   @test_util.run_deprecated_v1
    937   def testCondGrad_2(self):
    938     with self.cached_session():
    939       c = array_ops.placeholder(dtypes.int32, shape=[])
    940       x = constant_op.constant(10.0)
    941       pred = math_ops.less(c, 2)
    942       fn1 = lambda: math_ops.multiply(x, 42.0)
    943       fn2 = lambda: math_ops.multiply(x, 3.0)
    944       r = control_flow_ops.cond(pred, fn1, fn2)
    945 
    946       grad = gradients_impl.gradients(r, [x])[0]
    947       self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
    948       self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
    949 
    950   @test_util.disable_control_flow_v2(
    951       "b/110550782 (gradient w.r.t external variable)")
    952   @test_util.run_deprecated_v1
    953   def testCondGrad_3(self):
    954     with self.cached_session():
    955       c = array_ops.placeholder(dtypes.int32, shape=[])
    956       ox = constant_op.constant(10.0)
    957       pred = math_ops.less(c, 2)
    958 
    959       def fn1(x):
    960         m = x * x
    961         return gradients_impl.gradients(m, [ox])[0]
    962 
    963       fn2 = lambda: math_ops.multiply(ox, 3.0)
    964       y = math_ops.multiply(7.0, ox)
    965       r = control_flow_ops.cond(pred, lambda: fn1(y), fn2)
    966 
    967       self.assertAllEqual(980.0, r.eval(feed_dict={c: 1}))
    968       self.assertAllEqual(30.0, r.eval(feed_dict={c: 3}))
    969 
    970   @test_util.run_deprecated_v1
    971   def testCondGradMultiDevice(self):
    972     config = config_pb2.ConfigProto(device_count={"CPU": 2},
    973                                     allow_soft_placement=True)
    974     with self.cached_session(use_gpu=True, config=config) as sess:
    975       pred = array_ops.placeholder(dtypes.bool, [])
    976       x = array_ops.placeholder(dtypes.float32)
    977       y = array_ops.placeholder(dtypes.float32)
    978 
    979       with ops.device("/cpu:0"):
    980         z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0)
    981 
    982       with ops.device("/cpu:1"):
    983         grad = gradients_impl.gradients(z, x)[0]
    984 
    985       with ops.device("/cpu:0"):
    986         grad_grad = gradients_impl.gradients(grad, x)[0]
    987 
    988       self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0)
    989       self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
    990 
    991       # v1 control flow gets None second derivative for some reason.
    992       if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
    993         self.assertIsNone(grad_grad)
    994         return
    995 
    996       self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0)
    997       self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
    998 
    999   @test_util.run_v1_only("b/120545219")
   1000   def testNestedCond_Simple(self):
   1001     with self.cached_session():
   1002       x = constant_op.constant(0., name="X")
   1003       y = control_flow_ops.cond(
   1004           constant_op.constant(True), lambda: x,
   1005           lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
   1006       result = gradients_impl.gradients(y, x)[0]
   1007       self.assertEqual(1.0, self.evaluate(result))
   1008 
   1009       z = control_flow_ops.cond(
   1010           constant_op.constant(False), lambda: x,
   1011           lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x))
   1012       result = gradients_impl.gradients(z, x)[0]
   1013       self.assertEqual(1.0, self.evaluate(result))
   1014 
   1015   @test_util.disable_control_flow_v2("b/113327884")
   1016   @test_util.run_v1_only("b/120545219")
   1017   def testCondGrad_Gather(self):
   1018     with self.cached_session() as sess:
   1019       v1 = variables.Variable([1.0, 42.0])
   1020       c = array_ops.placeholder(dtypes.int32, shape=[])
   1021       pred = math_ops.less(c, 2)
   1022       fn1 = lambda: array_ops.identity(v1)
   1023       fn2 = lambda: array_ops.gather(v1, [1, 1])
   1024       r = control_flow_ops.cond(pred, fn1, fn2)
   1025       grad = gradients_impl.gradients(r, [v1])[0]
   1026       self.evaluate(variables.global_variables_initializer())
   1027       # Should just be [1, 1], but possibly a sparse representation
   1028       gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 1})
   1029       dense_gv = [
   1030           sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2)
   1031       ]
   1032       self.assertAllEqual(dense_gv, [1.0, 1.0])
   1033       # Should be [0, 2], as the else forwards v1[1] twice
   1034       gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 3})
   1035       dense_gv = [
   1036           sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2)
   1037       ]
   1038       self.assertAllEqual(dense_gv, [0.0, 2.0])
   1039 
   1040   @test_util.run_deprecated_v1
   1041   def testCondGrad_ResourceVarSparseRead(self):
   1042     # NOTE(skyewm): this test is interesting because the
   1043     # ResourceVariable.sparse_read gradient function returns IndexedSlices.
   1044     var = resource_variable_ops.ResourceVariable(
   1045         np.ones((4, 2), dtype=np.float32))
   1046     x = constant_op.constant(1.0)
   1047     r = control_flow_ops.cond(
   1048         constant_op.constant(True),
   1049         lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])),
   1050         lambda: constant_op.constant(np.zeros((2, 3)),
   1051                                      dtype=dtypes.float32))
   1052     grad = gradients_impl.gradients(r, var)[0]
   1053 
   1054     self.evaluate(variables.global_variables_initializer())
   1055     grad_val = self.evaluate(grad)
   1056     self.assertIsInstance(grad_val, ops.IndexedSlicesValue)
   1057     self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.],
   1058                                                                   [1., 1.],
   1059                                                                   [1., 1.],
   1060                                                                   [0., 0.]])
   1061 
   1062   @test_util.disable_xla("b/128643464")
   1063   def testCondGrad_MultiGather(self):
   1064     # NOTE(skyewm): this test is interesting because the array_ops.gather and
   1065     # ResourceVariable.sparse_read gradient functions returns IndexedSlices.
   1066     var = resource_variable_ops.ResourceVariable(
   1067         np.ones((4, 2), dtype=np.float32))
   1068     x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32))
   1069     x2 = constant_op.constant(2.0)
   1070 
   1071     def true_fn():
   1072       y1 = var.sparse_read([1, 2])
   1073       y2 = array_ops.gather(x1, [2]) * x2
   1074       y3 = x2 * [1., 1., 1.]
   1075       return y1, y2, y3
   1076 
   1077     def false_fn():
   1078       y1 = np.zeros((2, 2), dtype=np.float32)
   1079       y2 = array_ops.gather(x1, [2]) * x2
   1080       y3 = array_ops.gather(x1, [2])
   1081       return y1, y2, y3
   1082 
   1083     @def_function.function
   1084     def foo():
   1085       r = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn)
   1086       return gradients_impl.gradients(r, [var, x1, x2])
   1087 
   1088     grad = foo()
   1089     self.evaluate(variables.global_variables_initializer())
   1090     var_grad, x1_grad, x2_grad = self.evaluate(grad)
   1091     self.assertIsInstance(var_grad, ops.IndexedSlicesValue)
   1092     self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.],
   1093                                                                   [1., 1.],
   1094                                                                   [1., 1.],
   1095                                                                   [0., 0]])
   1096     self.assertIsInstance(x1_grad, ops.IndexedSlicesValue)
   1097     self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.],
   1098                                                                  [0., 0., 0.],
   1099                                                                  [2., 2., 2.]])
   1100     self.assertIsInstance(x1_grad, ops.IndexedSlicesValue)
   1101     self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.)
   1102 
   1103   @test_util.run_v1_only("b/120545219")
   1104   def testCondPredicateTensor(self):
   1105     """Regression test for lowering predicate from non-first output of an op."""
   1106 
   1107     @eager_function.defun
   1108     def foo():
   1109       return constant_op.constant("foo"), constant_op.constant(True)
   1110 
   1111     r = control_flow_ops.cond(foo()[1], lambda: 1.0, lambda: 2.0)
   1112     self.assertEqual(self.evaluate(r), 1.0)
   1113 
   1114   @test_util.run_v1_only("Tests Session.run() pruning logic.")
   1115   def testCondFeedConstantPredicate(self):
   1116     with self.cached_session() as sess:
   1117       value = constant_op.constant(37.0)
   1118       predicate = constant_op.constant(True)
   1119       cond_output = control_flow_ops.cond(
   1120           predicate, lambda: constant_op.constant(0.0), lambda: value)
   1121       result = array_ops.identity(cond_output)
   1122       self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False}))
   1123       self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True}))
   1124       self.assertEqual(0.0, sess.run(result))
   1125 
   1126   @test_util.run_v1_only("Tests Session.run() pruning logic.")
   1127   def testCondFeedPlaceholderWithDefaultPredicate(self):
   1128     with self.cached_session() as sess:
   1129       value = constant_op.constant(37.0)
   1130       predicate = array_ops.placeholder_with_default(
   1131           constant_op.constant(True), [])
   1132       cond_output = control_flow_ops.cond(
   1133           predicate, lambda: constant_op.constant(0.0), lambda: value)
   1134       result = array_ops.identity(cond_output)
   1135       self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False}))
   1136       self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True}))
   1137       self.assertAllEqual(0.0, sess.run(result))
   1138 
   1139   @test_util.disable_xla("b/128644469 PrintV2")
   1140   @test_util.run_in_graph_and_eager_modes
   1141   def testCondAutoControlDeps(self):
   1142 
   1143     def branch_fn():
   1144       logging_ops.print_v2("A")
   1145       logging_ops.print_v2("B")
   1146       with ops.control_dependencies([logging_ops.print_v2("C")]):
   1147         return constant_op.constant(10)
   1148 
   1149     def build_cond():
   1150       return control_flow_ops.cond(
   1151           constant_op.constant(True), branch_fn, lambda: 0)
   1152 
   1153     def build_nested_cond():
   1154       return control_flow_ops.cond(
   1155           constant_op.constant(True), build_cond, lambda: 0)
   1156 
   1157     # In v1 graph mode, pruning should make only "C" print.
   1158     if not context.executing_eagerly():
   1159       with self.cached_session():
   1160         with self.captureWritesToStream(sys.stderr) as printed:
   1161           self.assertEqual(self.evaluate(build_cond()), 10)
   1162         self.assertEqual(printed.contents(), "C\n")
   1163 
   1164         with self.captureWritesToStream(sys.stderr) as printed:
   1165           self.assertEqual(self.evaluate(build_nested_cond()), 10)
   1166         self.assertEqual(printed.contents(), "C\n")
   1167 
   1168     # In defuns, all prints should execute in program order.
   1169     # This doesn't work with legacy control flow.
   1170     if control_flow_util.ENABLE_CONTROL_FLOW_V2:
   1171 
   1172       @eager_function.defun
   1173       def cond():
   1174         return build_cond()
   1175 
   1176       with self.captureWritesToStream(sys.stderr) as printed:
   1177         self.assertEqual(self.evaluate(cond()), 10)
   1178       self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
   1179                       printed.contents())
   1180 
   1181       @eager_function.defun
   1182       def nested_cond():
   1183         return build_nested_cond()
   1184 
   1185       with self.captureWritesToStream(sys.stderr) as printed:
   1186         self.assertEqual(self.evaluate(nested_cond()), 10)
   1187       self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
   1188                       printed.contents())
   1189 
   1190     # wrap_function should prune.
   1191     def pruned_cond():
   1192       return build_cond()
   1193     pruned_cond = wrap_function.wrap_function(pruned_cond, [])
   1194 
   1195     with self.captureWritesToStream(sys.stderr) as printed:
   1196       self.assertEqual(self.evaluate(pruned_cond()), 10)
   1197     self.assertEqual(printed.contents(), "C\n")
   1198 
   1199     def pruned_nested_cond():
   1200       return build_nested_cond()
   1201     pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, [])
   1202 
   1203     with self.captureWritesToStream(sys.stderr) as printed:
   1204       self.assertEqual(self.evaluate(pruned_nested_cond()), 10)
   1205     self.assertEqual(printed.contents(), "C\n")
   1206 
   1207   @test_util.disable_xla("b/128643646 PrintV2")
   1208   @test_util.run_in_graph_and_eager_modes
   1209   def testWhileAutoControlDeps(self):
   1210     # Legacy while_loop fails this test because it produces deprecation notices
   1211     # in stderr.
   1212     if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return
   1213 
   1214     def cond(i, unused_x):
   1215       logging_ops.print_v2("A")
   1216       return i < 2
   1217 
   1218     def body(i, x):
   1219       logging_ops.print_v2("B")
   1220       with ops.control_dependencies([logging_ops.print_v2("C")]):
   1221         x = array_ops.identity(x)
   1222       with ops.control_dependencies([logging_ops.print_v2("D")]):
   1223         return i + 1, x
   1224 
   1225     def build_while():
   1226       return control_flow_ops.while_loop(
   1227           cond, body, [constant_op.constant(0), constant_op.constant(0)])
   1228 
   1229     def build_nested_while():
   1230       return control_flow_ops.cond(
   1231           constant_op.constant(True), build_while, lambda: [0, 0])
   1232 
   1233     # In v1 graph mode, pruning should make only "D" print.
   1234     if not context.executing_eagerly():
   1235       with self.cached_session():
   1236         with self.captureWritesToStream(sys.stderr) as printed:
   1237           self.assertEqual(self.evaluate(build_while()[0]), 2)
   1238         self.assertTrue(printed.contents().endswith("D\nD\n"),
   1239                         printed.contents())
   1240 
   1241         with self.captureWritesToStream(sys.stderr) as printed:
   1242           self.assertEqual(self.evaluate(build_nested_while()[0]), 2)
   1243         self.assertTrue(printed.contents().endswith("D\nD\n"),
   1244                         printed.contents())
   1245 
   1246     # In defuns, all prints should execute in program order.
   1247     @eager_function.defun
   1248     def while_loop():
   1249       return build_while()[0]
   1250 
   1251     with self.captureWritesToStream(sys.stderr) as printed:
   1252       self.assertEqual(self.evaluate(while_loop()), 2)
   1253     self.assertTrue(printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
   1254                     printed.contents())
   1255 
   1256     @eager_function.defun
   1257     def nested_while_loop():
   1258       return build_nested_while()[0]
   1259 
   1260     # TODO(b/117840611): calling nested_while_loop fails in eager
   1261     if not context.executing_eagerly():
   1262       with self.captureWritesToStream(sys.stderr) as printed:
   1263         self.assertEqual(self.evaluate(nested_while_loop()), 2)
   1264       self.assertTrue(
   1265           printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
   1266           printed.contents())
   1267 
   1268     # wrap_function should prune.
   1269     def pruned_while():
   1270       return build_while()[0]
   1271     pruned_while = wrap_function.wrap_function(pruned_while, [])
   1272 
   1273     with self.captureWritesToStream(sys.stderr) as printed:
   1274       self.assertEqual(self.evaluate(pruned_while()), 2)
   1275     self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
   1276 
   1277     def pruned_nested_while():
   1278       return build_nested_while()[0]
   1279     pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, [])
   1280 
   1281     # TODO(b/117840611): calling nested_while_loop fails in eager
   1282     if not context.executing_eagerly():
   1283       with self.captureWritesToStream(sys.stderr) as printed:
   1284         self.assertEqual(self.evaluate(pruned_nested_while()), 2)
   1285       self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
   1286 
   1287   # Microbenchmark: 256,000 iterations/s.
   1288   def testWhile_1(self):
   1289     with self.cached_session():
   1290       n = constant_op.constant(0)
   1291       c = lambda x: math_ops.less(x, 10000)
   1292       b = lambda x: math_ops.add(x, 1)
   1293       r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
   1294       self.assertEqual(10000, self.evaluate(r))
   1295 
   1296   @test_util.run_v1_only("b/120545219")
   1297   def testWhileExternalControlDependencies(self):
   1298     with self.cached_session():
   1299       v = variables.Variable(0.0)
   1300       v.initializer.run()
   1301       increment = v.assign_add(1.0).read_value()
   1302 
   1303       def body_fn(i):
   1304         with ops.control_dependencies([increment]):
   1305           return i + 1
   1306 
   1307       result = control_flow_ops.while_loop(cond=lambda i: i < 2,
   1308                                            body=body_fn, loop_vars=[1])
   1309       self.assertAllEqual(result, 2)
   1310       self.assertAllEqual(v.read_value(), 1.0)
   1311 
   1312   @test_util.run_v1_only("b/120545219")
   1313   def testWhileExternalControlDependenciesNoInput(self):
   1314     with self.cached_session():
   1315       v = variables.Variable(0.0)
   1316       v.initializer.run()
   1317       # TODO(apassos): figure out why the reading is necessary here.
   1318       increment = v.assign_add(1.0).read_value()
   1319 
   1320       def body_fn(unused_i):
   1321         with ops.control_dependencies([increment]):
   1322           return constant_op.constant(5, name="five")
   1323 
   1324       result = control_flow_ops.while_loop(cond=lambda i: i < 5,
   1325                                            body=body_fn, loop_vars=[0])
   1326       self.evaluate(result)
   1327       self.assertAllEqual(self.evaluate(v), 1.0)
   1328 
   1329   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   1330   @test_util.run_v1_only("b/120545219")
   1331   def testWhileWithRefs_1(self):
   1332     with self.cached_session() as sess:
   1333       x = variables.VariableV1(0)._ref()  # pylint: disable=protected-access
   1334       i = constant_op.constant(0)
   1335       c = lambda i, x: math_ops.less(i, 100)
   1336 
   1337       self.assertEqual(x.dtype, dtypes.int32_ref)
   1338 
   1339       def b(i, x):
   1340         self.assertEqual(x.dtype, dtypes.int32_ref)
   1341         return (i + 1, gen_array_ops.ref_identity(x))
   1342 
   1343       r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5)
   1344 
   1345       self.evaluate(variables.global_variables_initializer())
   1346 
   1347       self.assertEqual(r[0].dtype, dtypes.int32)
   1348       self.assertEqual(r[1].dtype, dtypes.int32_ref)
   1349 
   1350       value_i, value_x = self.evaluate(r)
   1351 
   1352     self.assertEqual(100, value_i)
   1353     self.assertEqual(0, value_x)
   1354 
   1355   def testWhile_2(self):
   1356     with self.cached_session():
   1357       s = constant_op.constant(0)
   1358       r = isum(s)
   1359       self.assertAllEqual(45, self.evaluate(r))
   1360 
   1361   def testWhileWithMaximumIterations(self):
   1362     with self.cached_session():
   1363       s = constant_op.constant([1, 2, 3, 4, 5])
   1364       r = isum(s, maximum_iterations=3)
   1365       self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r))
   1366 
   1367   @test_util.run_v1_only("b/120545219")
   1368   def testWhileWithMaximumIterationsAndSingleArgument(self):
   1369     with self.cached_session():
   1370       r = control_flow_ops.while_loop(
   1371           lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
   1372       self.assertEqual(1, self.evaluate(r))
   1373 
   1374   @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
   1375   @test_util.run_v1_only("b/120545219")
   1376   def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
   1377     v = constant_op.constant(1.0)
   1378 
   1379     def training_loop_with_gradient(i):
   1380       out = control_flow_ops.while_loop(
   1381           lambda i_, _: i_ < 3,
   1382           lambda i_, j: [i_ + 1, j * v], [0, 1.0],
   1383           maximum_iterations=i)
   1384       g = gradients_impl.gradients(out, v)
   1385       with ops.control_dependencies(g):
   1386         return i + 1
   1387 
   1388     xla_context = control_flow_ops.XLAControlFlowContext()
   1389     xla_context.Enter()
   1390     # Create training loop, ensure we can call gradient() of
   1391     # while_loop inside the training loop.
   1392     loop = control_flow_ops.while_loop(lambda i: i < 3,
   1393                                        training_loop_with_gradient, [0])
   1394     xla_context.Exit()
   1395 
   1396     loop_execute = array_ops.identity(loop)  # Because loop is not fetchable.
   1397 
   1398     # Should execute without issue.
   1399     self.assertEqual(3, self.evaluate(loop_execute))
   1400 
   1401   @test_util.run_v1_only("b/120545219")
   1402   def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
   1403     if control_flow_util.ENABLE_CONTROL_FLOW_V2:
   1404       self.skipTest("WhileV2 does lazy evaluation of maximum_iterations")
   1405     v = constant_op.constant(1.0)
   1406 
   1407     def inner_body(i, x):
   1408       out = control_flow_ops.while_loop(
   1409           lambda i, _: i < 3,
   1410           lambda i, j: [i + 1, j * v], [0, x],
   1411           maximum_iterations=i)
   1412       return out
   1413 
   1414     def create_while_loop(maximum_iterations=None):
   1415       return control_flow_ops.while_loop(
   1416           lambda i, _: i < 3,
   1417           inner_body, [0, 1.0],
   1418           maximum_iterations=maximum_iterations)
   1419 
   1420     loop_no_xla = create_while_loop(maximum_iterations=5)
   1421     # maximum_iterations is fine outside of an XLA scope
   1422     gs = gradients_impl.gradients(loop_no_xla, v)
   1423     self.evaluate(gs)  # This should execute without error.
   1424 
   1425     xla_context = control_flow_ops.XLAControlFlowContext()
   1426     xla_context.Enter()
   1427     loop_no_maxiter = create_while_loop()
   1428     loop_with_maxiter = create_while_loop(maximum_iterations=2)
   1429     xla_context.Exit()
   1430 
   1431     with self.assertRaisesRegexp(
   1432         ValueError,
   1433         r"Cannot create a gradient accumulator for tensor '.+' inside "
   1434         r"XLA while_loop because maximum_iterations was not passed to "
   1435         r"the tf.while_loop call \('.+'\)."):
   1436       _ = gradients_impl.gradients(loop_no_maxiter, v)
   1437 
   1438     with self.assertRaisesRegexp(
   1439         ValueError,
   1440         r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
   1441         r"while_loop. maximum_iterations tensor '.+' for while_loop context "
   1442         r"'.+' must be statically known \(e.g. a constant value or known "
   1443         r"shape dimension\), or be defined at or outside the while loop "
   1444         r"context '.*' \(currently defined in '.*'\)"):
   1445       _ = gradients_impl.gradients(loop_with_maxiter, v)
   1446 
   1447   @test_util.run_v1_only("b/120545219")
   1448   def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
   1449     v = constant_op.constant(1.0)
   1450 
   1451     def create_while_loop():
   1452       max_iter_holder = []
   1453 
   1454       def create_mi():
   1455         max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=()))
   1456         return 1.0
   1457 
   1458       _ = control_flow_ops.cond(
   1459           constant_op.constant(True), create_mi, create_mi)
   1460 
   1461       return control_flow_ops.while_loop(
   1462           lambda i, _: i < 3,
   1463           lambda i, x: (i + 1, v * x), (0, 1.0),
   1464           maximum_iterations=max_iter_holder[0])
   1465 
   1466     if control_flow_util.ENABLE_CONTROL_FLOW_V2:
   1467       xla_context = control_flow_ops.XLAControlFlowContext()
   1468       xla_context.Enter()
   1469       with self.assertRaisesRegexp(
   1470           ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"):
   1471         loop = create_while_loop()
   1472       xla_context.Exit()
   1473     else:
   1474       xla_context = control_flow_ops.XLAControlFlowContext()
   1475       xla_context.Enter()
   1476       loop = create_while_loop()
   1477       xla_context.Exit()
   1478       with self.assertRaisesRegexp(
   1479           ValueError,
   1480           r"Cannot create a gradient accumulator for tensor '.+' inside XLA "
   1481           r"while_loop. maximum_iterations tensor '.*Placeholder:0' for "
   1482           r"while_loop context '.+' must be statically known \(e.g. a constant "
   1483           r"value or known shape dimension\), or be defined at or outside the "
   1484           r"while loop context '' \(currently defined in 'cond/.+'\)"):
   1485         _ = gradients_impl.gradients(loop, v)
   1486 
   1487   @test_util.run_v1_only("b/120545219")
   1488   def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
   1489     if test_util.is_gpu_available():
   1490       self.skipTest("b/128646372, b/128645947 fails in opensource build")
   1491 
   1492     v = constant_op.constant(1.0)
   1493 
   1494     p = array_ops.placeholder(dtype=dtypes.int32)
   1495 
   1496     def mid_body_builder(iterations):
   1497 
   1498       def mid_body(i, x):
   1499         r = control_flow_ops.while_loop(
   1500             lambda *_: True,
   1501             lambda i, x: (i + 1, v * x), (0, x),
   1502             maximum_iterations=iterations,
   1503             name="inner")
   1504         return (i + 1, gradients_impl.gradients(x + r[1], v)[0])
   1505 
   1506       return mid_body
   1507 
   1508     def outer_body(i, x):
   1509       iterations = array_ops.size(p, name="iterations")
   1510       return (i + 1, x + control_flow_ops.while_loop(
   1511           lambda *_: True,
   1512           mid_body_builder(iterations), (0, x),
   1513           maximum_iterations=iterations,
   1514           name="mid")[1])
   1515 
   1516     def create_while_loop():
   1517       with ops.device("/cpu:0"):
   1518         r = control_flow_ops.while_loop(
   1519             lambda *_: True,
   1520             outer_body, (0, 1.0),
   1521             maximum_iterations=5,
   1522             name="outer")
   1523         return array_ops.identity(r[1])
   1524 
   1525     xla_context = control_flow_ops.XLAControlFlowContext()
   1526     xla_context.Enter()
   1527     final_with_xla_context = create_while_loop()
   1528     xla_context.Exit()
   1529 
   1530     final_without_xla_context = create_while_loop()
   1531 
   1532     with self.session(use_gpu=False) as sess:
   1533       opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
   1534       run_metadata_without_xla_context = config_pb2.RunMetadata()
   1535       run_metadata = config_pb2.RunMetadata()
   1536 
   1537       final_value_without_xla_context = sess.run(
   1538           final_without_xla_context,
   1539           feed_dict={p: [0, 0, 0]},
   1540           options=opts,
   1541           run_metadata=run_metadata_without_xla_context)
   1542 
   1543       final_value_with_xla_context = sess.run(
   1544           final_with_xla_context,
   1545           feed_dict={p: [0, 0, 0]},
   1546           options=opts,
   1547           run_metadata=run_metadata)
   1548 
   1549       if control_flow_util.ENABLE_CONTROL_FLOW_V2:
   1550         # With while_v2 on xla, run_metadata only contains the unlowered While
   1551         # op so node_stats does not have statistics for the pushes. So as a
   1552         # loose check we check the pushes in the lowered version.
   1553         node_stats = run_metadata_without_xla_context.step_stats.dev_stats[
   1554             0].node_stats
   1555         stack_push_op = "TensorListPushBack"
   1556       else:
   1557         node_stats = run_metadata.step_stats.dev_stats[0].node_stats
   1558         stack_push_op = "StackPushV2"
   1559       stack_push_count = len(
   1560           [x for x in node_stats if x.node_name.endswith(stack_push_op)])
   1561       # Pushes to the stack = product of maximum_iterations values;
   1562       # the last two "3"s comes from size(p), when p == [0, 0, 0].
   1563       self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats))
   1564 
   1565       self.assertAllClose(final_value_with_xla_context,
   1566                           final_value_without_xla_context)
   1567 
   1568   # Have more than 10 parallel iterations and hence exercise k-bound
   1569   # most of the time.
   1570   @test_util.run_deprecated_v1
   1571   def testWhile_3(self):
   1572     with self.cached_session():
   1573 
   1574       def compute(i, m, c, o):
   1575         m, c = [math_ops.add(m, 1), math_ops.add(c, 1)]
   1576         o = math_ops.add(o, m)
   1577         o = math_ops.add(o, c)
   1578         i = math_ops.add(i, 1)
   1579         return [i, m, c, o]
   1580 
   1581       i = ops.convert_to_tensor(0)
   1582       m = ops.convert_to_tensor(0)
   1583       c = ops.convert_to_tensor(0)
   1584       o = ops.convert_to_tensor(0)
   1585       d = ops.convert_to_tensor(100)
   1586       r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d),
   1587                                       compute, [i, m, c, o])
   1588       result = r[3]
   1589     self.assertAllEqual(10100, result)
   1590 
   1591   @test_util.run_deprecated_v1
   1592   def testWhile_4(self):
   1593     with self.cached_session():
   1594 
   1595       def compute(i, m, c, o):
   1596         m, c = [array_ops.gather(x, i), array_ops.gather(x, i)]
   1597         o = math_ops.add(o, m)
   1598         o = math_ops.add(o, c)
   1599         i = math_ops.add(i, 1)
   1600         return [i, m, c, o]
   1601 
   1602       i = ops.convert_to_tensor(0)
   1603       m = ops.convert_to_tensor(0)
   1604       c = ops.convert_to_tensor(0)
   1605       o = ops.convert_to_tensor(0)
   1606       x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
   1607       s = array_ops.size(x)
   1608       r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s),
   1609                                       compute, [i, m, c, o])
   1610       result = r[3]
   1611     self.assertAllEqual(42, result)
   1612 
   1613   @test_util.run_v1_only("b/120545219")
   1614   def testWhile_5(self):
   1615     with self.cached_session():
   1616 
   1617       def compute(i, c, o):
   1618         c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0),
   1619                                     [1] + array_ops.expand_dims(i, 0))
   1620         o = array_ops.concat([o, c], 0)
   1621         i = math_ops.add(i, 1)
   1622         return [i, c, o]
   1623 
   1624       i = ops.convert_to_tensor(0)
   1625       c = ops.convert_to_tensor([0])
   1626       o = ops.convert_to_tensor([0])
   1627       x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6])
   1628       s = array_ops.size(x)
   1629       r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s),
   1630                                       compute, [i, c, o], [
   1631                                           i.get_shape(),
   1632                                           tensor_shape.unknown_shape(),
   1633                                           tensor_shape.unknown_shape()
   1634                                       ])
   1635       result = r[2]
   1636     self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
   1637 
   1638   @test_util.run_gpu_only
   1639   @test_util.run_deprecated_v1
   1640   def testWhile_Device(self):
   1641 
   1642     # Body function defined outside of device scope
   1643     def body(x):
   1644       return math_ops.exp(x)
   1645 
   1646     with ops.device("CPU:0"):
   1647       r = control_flow_ops.while_loop(
   1648           lambda x: x < 10, body, [constant_op.constant(-10.)])
   1649       self.assertIn("cpu", r.device.lower())
   1650 
   1651     with session.Session() as sess:
   1652       options = config_pb2.RunOptions(output_partition_graphs=True)
   1653       run_metadata = config_pb2.RunMetadata()
   1654       sess.run(r, options=options, run_metadata=run_metadata)
   1655       # We expect that everything runs on CPU, even if GPU is available.
   1656       self.assertEqual(len(run_metadata.partition_graphs), 1)
   1657 
   1658   @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
   1659   @test_util.run_v1_only("b/120545219")
   1660   def testBufferForwarding(self):
   1661     run_options = config_pb2.RunOptions(
   1662         trace_level=config_pb2.RunOptions.FULL_TRACE)
   1663     run_metadata = config_pb2.RunMetadata()
   1664 
   1665     with self.cached_session() as sess:
   1666       with ops.device("/cpu:0"):
   1667         c = constant_op.constant(2)
   1668         i0 = constant_op.constant(0)
   1669         r = control_flow_ops.while_loop(lambda i: i < 1000,
   1670                                         lambda i: math_ops.square(c) + i, [i0])
   1671       r_val = sess.run(r, options=run_options, run_metadata=run_metadata)
   1672       self.assertEqual(1000, r_val)
   1673       self.assertTrue(run_metadata.HasField("step_stats"))
   1674       unique_allocs = set()
   1675       for node_stat in run_metadata.step_stats.dev_stats[0].node_stats:
   1676         for output in node_stat.output:
   1677           unique_allocs.add(
   1678               output.tensor_description.allocation_description.ptr)
   1679       # Prior to cl/147536680, the number of unique allocations was about 1005.
   1680       self.assertLess(len(unique_allocs), 756)
   1681 
   1682   def _testWhile_Gpu_1(self, use_gpu):
   1683     with self.cached_session(use_gpu=use_gpu):
   1684       n = constant_op.constant(1.0)
   1685       c = lambda x: math_ops.less(x, 10.0)
   1686       b = lambda x: math_ops.add(x, 1.0)
   1687       r = control_flow_ops.while_loop(c, b, [n])
   1688       self.assertAllClose(10.0, self.evaluate(r))
   1689 
   1690   def testWhile_Gpu_1(self):
   1691     self._testWhile_Gpu_1(use_gpu=False)
   1692     self._testWhile_Gpu_1(use_gpu=True)
   1693 
   1694   def _testWhile_Gpu_2(self, use_gpu):
   1695     with self.cached_session(use_gpu=use_gpu):
   1696       n = constant_op.constant(1.0)
   1697       c = lambda x: math_ops.less(x, 10.0)
   1698 
   1699       def b(x):
   1700         with ops.device("/cpu:0"):
   1701           return math_ops.add(x, 1.0)
   1702 
   1703       r = control_flow_ops.while_loop(c, b, [n])
   1704       self.assertAllClose(10.0, self.evaluate(r))
   1705 
   1706   def testWhile_Gpu_2(self):
   1707     self._testWhile_Gpu_2(use_gpu=False)
   1708     self._testWhile_Gpu_2(use_gpu=True)
   1709 
   1710   def testWhileShape(self):
   1711     with self.cached_session():
   1712       i = constant_op.constant(0)
   1713       m = array_ops.ones([2, 2])
   1714       c = lambda i, j: math_ops.less(i, 2)
   1715 
   1716       def _b(i, j):
   1717         new_i = math_ops.add(i, 1)
   1718         new_j = array_ops.tile(j, [2, 2])
   1719         return [new_i, new_j]
   1720 
   1721       r = control_flow_ops.while_loop(
   1722           c, _b, [i, m],
   1723           [i.get_shape(), tensor_shape.unknown_shape()])
   1724       r = r[1] * array_ops.ones([8, 8])
   1725       self.assertAllEqual(np.ones((8, 8)), self.evaluate(r))
   1726 
   1727   @test_util.run_deprecated_v1
   1728   def testWhileWithNonTensorInput_Scalar(self):
   1729     with self.cached_session():
   1730       n = 0
   1731       c = lambda x: x < 10000
   1732       b = lambda x: x + 1
   1733       r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
   1734       self.assertEqual(10000, self.evaluate(r))
   1735 
   1736   def testWhileWithNonTensorInput_Vector(self):
   1737     with self.cached_session():
   1738       n = np.array([0])  # Note, [0] would not work here; that is a list
   1739       c = lambda x: x[0] < 10000
   1740       b = lambda x: array_ops.stack([x[0] + 1])
   1741       r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
   1742       self.assertEqual([10000], self.evaluate(r))
   1743 
   1744   @test_util.run_v1_only("b/120545219")
   1745   def testWhileShapeInference(self):
   1746     with self.cached_session():
   1747       i = constant_op.constant(0)
   1748       m = array_ops.ones([2, 2])
   1749       c = lambda i, j: math_ops.less(i, 2)
   1750 
   1751       def b(i, j):
   1752         new_i = math_ops.add(i, 1)
   1753         new_j = array_ops.concat([j, j], 0)
   1754         return [new_i, new_j]
   1755 
   1756       r = control_flow_ops.while_loop(
   1757           c, b, [i, m],
   1758           [i.get_shape(), tensor_shape.TensorShape([None, 2])])
   1759       self.assertIsNone(r[1].shape.dims[0].value)
   1760       self.assertEqual(r[1].shape.dims[1], tensor_shape.Dimension(2))
   1761 
   1762       with self.assertRaisesRegexp(
   1763           ValueError,
   1764           r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
   1765           r"shape \(4, 2\) after one iteration. To allow the shape to vary "
   1766           r"across iterations, use the `shape_invariants` argument of "
   1767           r"tf.while_loop to specify a less-specific shape."):
   1768         r = control_flow_ops.while_loop(c, b, [i, m])
   1769 
   1770   @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
   1771   @test_util.run_v1_only("b/120545219")
   1772   def testWhileShapeInferenceSparseTensor(self):
   1773     values = constant_op.constant([2.0, 4.0], name="values")
   1774     indices = constant_op.constant([[0], [3]],
   1775                                    dtype=dtypes.int64,
   1776                                    name="indices")
   1777     shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
   1778     i = constant_op.constant(0)
   1779     x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
   1780 
   1781     def c(i, _):
   1782       return i < 10
   1783 
   1784     def b1(i, x):  # modifies values.  (shape of components is not changed.)
   1785       return [
   1786           i + 1,
   1787           sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
   1788       ]
   1789 
   1790     def b2(i, x):  # adds new values.  (shape of components is changed.)
   1791       return [
   1792           i + 1,
   1793           sparse_ops.sparse_add(
   1794               x,
   1795               sparse_tensor.SparseTensor(
   1796                   indices=math_ops.cast(
   1797                       array_ops.fill([1, 1], i), dtypes.int64),
   1798                   values=array_ops.fill([1], 1.0),
   1799                   dense_shape=x.dense_shape))
   1800       ]
   1801 
   1802     def b3(i, x):  # modifies rank.  (shape of all components is changed.)
   1803       return [
   1804           i + 1,
   1805           sparse_tensor.SparseTensor(
   1806               array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0,
   1807               array_ops.concat([x.dense_shape, [10]], axis=0))
   1808       ]
   1809 
   1810     # Default shape invariant; b1 only modifies values.
   1811     _, r = control_flow_ops.while_loop(c, b1, [i, x])
   1812     self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
   1813     self.assertEqual(r.values.get_shape().as_list(), [None])
   1814     self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
   1815 
   1816     # Default shape invariant; b2 adds new values
   1817     _, r = control_flow_ops.while_loop(c, b2, [i, x])
   1818     self.assertEqual(r.indices.get_shape().as_list(), [None, 1])
   1819     self.assertEqual(r.values.get_shape().as_list(), [None])
   1820     self.assertEqual(r.dense_shape.get_shape().as_list(), [1])
   1821 
   1822     # Default shape invariant; b3 modifies rank (which is not allowed).
   1823     with self.assertRaises(ValueError):
   1824       _, r = control_flow_ops.while_loop(c, b3, [i, x])
   1825 
   1826     # Explicit shape invariant, allowing any rank; b1 only modifies values.
   1827     _, r = control_flow_ops.while_loop(
   1828         c, b1, [i, x],
   1829         [i.get_shape(), tensor_shape.TensorShape([None])])
   1830     self.assertEqual(r.indices.get_shape().as_list(), [None, None])
   1831     self.assertEqual(r.values.get_shape().as_list(), [None])
   1832     self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
   1833 
   1834     # Explicit shape invariant, allowing any rank; b3 modifies rank.
   1835     _, r = control_flow_ops.while_loop(
   1836         c, b3, [i, x],
   1837         [i.get_shape(), tensor_shape.TensorShape([None])])
   1838     self.assertEqual(r.indices.get_shape().as_list(), [None, None])
   1839     self.assertEqual(r.values.get_shape().as_list(), [None])
   1840     self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
   1841 
   1842     # Shape invariant with ndims=None.  Technically, this isn't supported
   1843     # according to the docs, but we support it for backwards compatibility.
   1844     _, r = control_flow_ops.while_loop(
   1845         c, b1, [i, x],
   1846         [i.get_shape(), tensor_shape.TensorShape(None)])
   1847     self.assertEqual(r.indices.get_shape().as_list(), [None, None])
   1848     self.assertEqual(r.values.get_shape().as_list(), [None])
   1849     self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
   1850     _, r = control_flow_ops.while_loop(
   1851         c, b3, [i, x],
   1852         [i.get_shape(), tensor_shape.TensorShape(None)])
   1853     self.assertEqual(r.indices.get_shape().as_list(), [None, None])
   1854     self.assertEqual(r.values.get_shape().as_list(), [None])
   1855     self.assertEqual(r.dense_shape.get_shape().as_list(), [None])
   1856 
   1857     # Explicit shape invariant, with a specific (incompatible) rank.
   1858     with self.assertRaisesRegexp(ValueError, "is not compatible with"):
   1859       _, r = control_flow_ops.while_loop(
   1860           c, b1, [i, x],
   1861           [i.get_shape(), tensor_shape.TensorShape([5])])
   1862 
   1863   @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
   1864   @test_util.run_v1_only("b/120545219")
   1865   def testWhileShapeInferenceIndexedSlices(self):
   1866     with self.cached_session():
   1867       values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
   1868       indices = constant_op.constant([0, 3], name="indices")
   1869       shape = constant_op.constant([10, 2], name="dense_shape")
   1870       i = constant_op.constant(0)
   1871       x = ops.IndexedSlices(values, indices, dense_shape=shape)
   1872 
   1873       def c(i, _):
   1874         return i < 10
   1875 
   1876       def b(i, x):
   1877         return [
   1878             i + 1,
   1879             ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
   1880         ]
   1881 
   1882       _, r = control_flow_ops.while_loop(c, b, [i, x])
   1883       self.assertEqual(r.dense_shape.get_shape()[0], 2)
   1884       self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2]))
   1885 
   1886       _, r = control_flow_ops.while_loop(
   1887           c, b, [i, x],
   1888           [i.get_shape(), tensor_shape.TensorShape([None, 2])])
   1889       self.assertEqual(r.dense_shape.get_shape()[0], 2)
   1890       self.assertEqual(r.values.get_shape().as_list(), [None, 2])
   1891 
   1892       with self.assertRaisesRegexp(ValueError, "is not compatible with"):
   1893         _, r = control_flow_ops.while_loop(
   1894             c, b, [i, x],
   1895             [i.get_shape(), tensor_shape.TensorShape([None, 5])])
   1896 
   1897   @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
   1898   def testWhileShapeInferenceRaggedTensor(self):
   1899     if context.executing_eagerly():
   1900       self.skipTest("b/116328420")
   1901     i = constant_op.constant(0)
   1902     x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]])
   1903     c = lambda i, _: i < 10
   1904 
   1905     def b1(i, x):  # Adds new values to rows (but doesn't create new rows)
   1906       return [
   1907           i + 1,
   1908           array_ops.concat([x, x], axis=1)
   1909       ]
   1910 
   1911     def b2(i, x):  # Adds new rows.
   1912       return [
   1913           i + 1,
   1914           array_ops.concat([x, x], axis=0)
   1915       ]
   1916 
   1917     # Default shape invariant; b1 adds new values to rows.
   1918     _, r = control_flow_ops.while_loop(c, b1, [i, x])
   1919     self.assertEqual(r.row_splits.shape.as_list(), [4])
   1920 
   1921     self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
   1922 
   1923     # Default shape invariant; b2 adds new rows (not allowed).
   1924     if not context.executing_eagerly():
   1925       with self.assertRaises(ValueError):
   1926         _, r = control_flow_ops.while_loop(c, b2, [i, x])
   1927 
   1928     # Explicit shape invariant; b1 adds new values to rows.
   1929     _, r = control_flow_ops.while_loop(
   1930         c, b1, [i, x],
   1931         [i.get_shape(), tensor_shape.TensorShape([None, None])])
   1932     self.assertTrue(r.row_splits.shape.as_list() in ([4], [None]))
   1933     self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
   1934 
   1935     # Explicit shape invariant; b2 adds new rows.
   1936     _, r = control_flow_ops.while_loop(
   1937         c, b2, [i, x],
   1938         [i.get_shape(), tensor_shape.TensorShape([None, None])])
   1939     self.assertTrue(r.row_splits.shape.as_list() in ([3 * 2**10 + 1], [None]))
   1940     self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None]))
   1941 
   1942   @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)")
   1943   def testWhileShapeInferenceRaggedTensorRaggedRank2(self):
   1944     if context.executing_eagerly():
   1945       self.skipTest("b/116328420")
   1946     i = constant_op.constant(0)
   1947     x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]],
   1948                                      [[], [8, 9, 10]]])
   1949     c = lambda i, _: i < 10
   1950     def b(i, x):
   1951       return [
   1952           i + 1,
   1953           array_ops.concat([x, x[..., i:i+1]], axis=-1)
   1954       ]
   1955     _, r = control_flow_ops.while_loop(c, b, [i, x])
   1956     self.assertEqual(r.row_splits.shape.as_list(), [3])
   1957     self.assertTrue(r.values.row_splits.shape.as_list() in ([6], [None]))
   1958     self.assertTrue(r.values.values.shape.as_list() in ([49], [None]))
   1959 
   1960   def _testNestedWhile_1(self, use_gpu):
   1961     with self.cached_session(use_gpu=use_gpu):
   1962       n = constant_op.constant(0)
   1963 
   1964       def cpu_sum(s):
   1965         c = lambda i, s: math_ops.less(i, 10)
   1966 
   1967         def b(i, s):
   1968           i1 = math_ops.add(i, 1)
   1969           with ops.device("/cpu:0"):
   1970             s1 = math_ops.add(i, s)
   1971           return i1, s1
   1972 
   1973         _, r_s = control_flow_ops.while_loop(c, b, [n, s])
   1974         return r_s
   1975 
   1976       c = lambda x: math_ops.less(x, 200)
   1977       b = lambda x: math_ops.add(x, cpu_sum(n))
   1978       r = control_flow_ops.while_loop(c, b, [n])
   1979       self.assertEqual(225, self.evaluate(r))
   1980 
   1981   def testNestedWhile_1(self):
   1982     self._testNestedWhile_1(use_gpu=False)
   1983     self._testNestedWhile_1(use_gpu=True)
   1984 
   1985   def _testNestedWhile_2(self, use_gpu):
   1986     # Test the cases that A -> Enter and Exit -> A are partitioned.
   1987     with self.cached_session(use_gpu=use_gpu):
   1988       s0 = constant_op.constant(2.0)
   1989 
   1990       def inner_loop(s):
   1991         c = lambda s: math_ops.less(s, 20.0)
   1992 
   1993         def b(s):
   1994           s1 = math_ops.add(s, s)
   1995           return s1
   1996 
   1997         r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1)
   1998         return r_s
   1999 
   2000       outer_c = lambda x: math_ops.less(x, 3000.0)
   2001 
   2002       def outer_b(x):
   2003         x = logging_ops.Print(x, [x])  # Edge "Print -> Enter" is partitioned
   2004         x = inner_loop(x)
   2005         with ops.device("/cpu:0"):
   2006           x = math_ops.square(x)  # Edge "Exit -> Square" is partitioned
   2007         return x
   2008 
   2009       r = control_flow_ops.while_loop(
   2010           outer_c, outer_b, [s0], parallel_iterations=1)
   2011       self.assertEqual(1048576.0, self.evaluate(r))
   2012 
   2013   def testNestedWhile_2(self):
   2014     self._testNestedWhile_2(use_gpu=False)
   2015     self._testNestedWhile_2(use_gpu=True)
   2016 
   2017   @test_util.run_v1_only("b/120545219")
   2018   def testWhileWithControl_1(self):
   2019     with self.cached_session():
   2020       n = constant_op.constant(0)
   2021       r = constant_op.constant(0)
   2022       condition = lambda n_, r_: math_ops.less(n_, 10)
   2023 
   2024       def body(n_, r_):
   2025         n_ = math_ops.add(n_, 1)
   2026         with r_.graph.control_dependencies([r_]):
   2027           r_ = constant_op.constant(12)
   2028         return [n_, r_]
   2029 
   2030       res = control_flow_ops.while_loop(
   2031           condition, body, [n, r], parallel_iterations=1)
   2032       self.assertAllEqual(12, res[1])
   2033 
   2034   @test_util.run_deprecated_v1
   2035   def testWhileWithControl_2(self):
   2036     with self.cached_session():
   2037       r = constant_op.constant(0)
   2038       condition = lambda r_: math_ops.less(r_, 10)
   2039 
   2040       def body(r_):
   2041         with r_.graph.control_dependencies([r_]):
   2042           r_ = constant_op.constant(12)
   2043         return [r_]
   2044 
   2045       res = control_flow_ops.while_loop(
   2046           condition, body, [r], parallel_iterations=1)
   2047       self.assertAllEqual(12, self.evaluate(res))
   2048 
   2049   @test_util.run_v1_only("b/120545219")
   2050   def testWhileWithControl_3(self):
   2051     with self.cached_session() as sess:
   2052       b = array_ops.placeholder(dtypes.bool)
   2053       c = constant_op.constant(1)
   2054       x0 = constant_op.constant(0)
   2055       with ops.control_dependencies([b]):
   2056         r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0])
   2057       self.assertEqual(10, sess.run(r, {b: True}))
   2058 
   2059   @test_util.run_v1_only("b/120545219")
   2060   def testWhileWithControl_4(self):
   2061     with self.cached_session() as sess:
   2062       b = array_ops.placeholder(dtypes.bool)
   2063       c = constant_op.constant(1)
   2064       x0 = constant_op.constant(0)
   2065       with ops.control_dependencies([b]):
   2066         r = control_flow_ops.while_loop(
   2067             lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
   2068       self.assertEqual(10, sess.run(r, {b: True}))
   2069 
   2070   @test_util.run_v1_only("b/120545219")
   2071   def testWhileWithControl_5(self):
   2072     with self.cached_session() as sess:
   2073       b = array_ops.placeholder(dtypes.bool)
   2074       c = constant_op.constant(1)
   2075       x0 = constant_op.constant(0)
   2076 
   2077       def body(x):
   2078         with ops.control_dependencies([b]):
   2079           return x + c
   2080 
   2081       r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
   2082       self.assertEqual(10, sess.run(r, {b: True}))
   2083 
   2084   def testWhileCondWithControl(self):
   2085     # Ensure that no control edges by an outer control dependency context are
   2086     # added to nodes inside cond/while contexts.
   2087     with self.cached_session() as sess:
   2088       const_true = lambda: constant_op.constant(True)
   2089       const_false = lambda: constant_op.constant(False)
   2090       cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false)
   2091       body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i)
   2092 
   2093       with ops.control_dependencies([control_flow_ops.no_op()]):
   2094         loop = control_flow_ops.while_loop(cond, body,
   2095                                            (constant_op.constant(5),))
   2096       self.assertEqual(0, self.evaluate(loop))
   2097 
   2098   @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
   2099   @test_util.run_v1_only("b/120545219")
   2100   def testWhileCondWithControl_1(self):
   2101     with self.cached_session():
   2102       v = variable_scope.get_variable(
   2103           "v", [], initializer=init_ops.constant_initializer(2))
   2104       i0 = constant_op.constant(0)
   2105       with ops.control_dependencies([i0]):
   2106 
   2107         def loop_condition(i):
   2108           return i < 4
   2109 
   2110         def loop_body(i):
   2111           some_cond = control_flow_ops.cond(
   2112               constant_op.constant(True),
   2113               lambda: state_ops.assign(v, math_ops.square(v)), lambda: v)
   2114           with ops.control_dependencies([some_cond]):
   2115             return i + 1
   2116 
   2117       r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,))
   2118       self.evaluate(variables.global_variables_initializer())
   2119       self.assertEqual(4, self.evaluate(r))
   2120       self.assertAllClose(65536.0, self.evaluate(v))
   2121 
   2122   @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
   2123   @test_util.run_v1_only("b/120545219")
   2124   def testWhileCondExitControl(self):
   2125 
   2126     with self.cached_session():
   2127       v = variables.Variable(1)
   2128 
   2129       def false_branch():
   2130         cond = lambda i: i < 100
   2131 
   2132         def body(i):
   2133           x = state_ops.assign(v, i)
   2134           return x + 1
   2135 
   2136         loop = control_flow_ops.while_loop(cond, body, [0])
   2137         # Make sure to handle correctly control edge from Exit to a node.
   2138         with ops.control_dependencies([loop]):
   2139           return constant_op.constant(6.0)
   2140 
   2141       r = control_flow_ops.cond(
   2142           constant_op.constant(False), lambda: constant_op.constant(1.0),
   2143           false_branch)
   2144       self.evaluate(variables.global_variables_initializer())
   2145       self.assertEqual(6.0, self.evaluate(r))
   2146       self.assertEqual(99, self.evaluate(v))
   2147 
   2148   def testCondWhile_1(self):
   2149 
   2150     with self.cached_session():
   2151       n = ops.convert_to_tensor(0, name="n")
   2152       c = lambda x: math_ops.less(x, 10)
   2153       b = lambda x: math_ops.add(x, 1)
   2154       r = control_flow_ops.cond(
   2155           math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]),
   2156           lambda: n)
   2157       self.assertAllEqual(10, self.evaluate(r))
   2158 
   2159   def testCondWhile_2(self):
   2160 
   2161     with self.cached_session():
   2162       n = ops.convert_to_tensor(0)
   2163       c = lambda x: math_ops.less(x, 10)
   2164       b = lambda x: math_ops.add(x, 1)
   2165       r = control_flow_ops.cond(
   2166           math_ops.less(1, 0), lambda: math_ops.add(n, 1),
   2167           lambda: control_flow_ops.while_loop(c, b, [n]))
   2168       self.assertAllEqual(10, self.evaluate(r))
   2169 
   2170   def _testCondWhile_3(self, use_gpu):
   2171     with self.cached_session(use_gpu=use_gpu) as sess:
   2172       p = array_ops.placeholder(dtypes.bool)
   2173       n = constant_op.constant(0.0)
   2174 
   2175       def c(x):
   2176         return math_ops.less(x, 10.0)
   2177 
   2178       def b(x):
   2179         with ops.device("/cpu:0"):
   2180           x1 = math_ops.add(x, 1.0)
   2181         return x1
   2182 
   2183       r = control_flow_ops.cond(p,
   2184                                 lambda: control_flow_ops.while_loop(c, b, [n]),
   2185                                 lambda: math_ops.multiply(n, 2.0))
   2186       r1 = gradients_impl.gradients(r, [n])
   2187       self.assertEqual(10., sess.run(r, {p: True}))
   2188       self.assertEqual([1.0], sess.run(r1, {p: True}))
   2189       self.assertEqual(0.0, sess.run(r, {p: False}))
   2190       self.assertEqual([2.0], sess.run(r1, {p: False}))
   2191 
   2192   @test_util.run_deprecated_v1
   2193   def testCondWhile_3(self):
   2194     self._testCondWhile_3(use_gpu=False)
   2195     self._testCondWhile_3(use_gpu=True)
   2196 
   2197   def testWhileCond_1(self):
   2198 
   2199     with self.cached_session():
   2200       i = ops.convert_to_tensor(0, name="i")
   2201       n = ops.convert_to_tensor(10, name="n")
   2202       one = ops.convert_to_tensor(1, name="one")
   2203       c = lambda x: math_ops.less(x, n)
   2204       # pylint: disable=undefined-variable
   2205       # for OSS build
   2206       b = lambda x: control_flow_ops.cond(
   2207           constant_op.constant(True),
   2208           lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one))
   2209       # pylint: enable=undefined-variable
   2210       r = control_flow_ops.while_loop(c, b, [i])
   2211       self.assertAllEqual(10, self.evaluate(r))
   2212 
   2213   def testWhileCond_2(self):
   2214 
   2215     with self.cached_session():
   2216       n = ops.convert_to_tensor(0, name="n")
   2217       c = lambda x: math_ops.less(x, 10)
   2218       b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n)
   2219       r = control_flow_ops.while_loop(c, b, [n])
   2220       self.assertAllEqual(10, self.evaluate(r))
   2221 
   2222   def testWhileCond_3(self):
   2223 
   2224     with self.cached_session():
   2225       n = ops.convert_to_tensor(0)
   2226       c = lambda x: math_ops.less(x, 10)
   2227       # pylint: disable=undefined-variable
   2228       # for OSS build
   2229       b = lambda x: control_flow_ops.cond(math_ops.less(0, 1),
   2230                                           lambda: math_ops.add(x, 1),
   2231                                           lambda: math_ops.subtract(x, 1))
   2232       # pylint: enable=undefined-variable
   2233       r = control_flow_ops.while_loop(c, b, [n])
   2234       self.assertAllEqual(10, self.evaluate(r))
   2235 
   2236   @test_util.run_deprecated_v1
   2237   def testWhileCondGradMultiDevice(self):
   2238     config = config_pb2.ConfigProto(device_count={"CPU": 2},
   2239                                     allow_soft_placement=True)
   2240     with self.cached_session(use_gpu=True, config=config) as sess:
   2241       pred = array_ops.placeholder(dtypes.bool, [])
   2242       x_init = constant_op.constant(1.0)
   2243 
   2244       with ops.device("/cpu:0"):
   2245         z = control_flow_ops.while_loop(
   2246             lambda i, _: i < 3,
   2247             lambda i, x: (i + 1, control_flow_ops.cond(
   2248                 pred, lambda: x * 2.0, lambda: 10.0)),
   2249             [0, x_init])
   2250 
   2251       with ops.device("/cpu:1"):
   2252         grad = gradients_impl.gradients(z, x_init)[0]
   2253 
   2254       with ops.device("/cpu:0"):
   2255         grad_grad = gradients_impl.gradients(grad, x_init)[0]
   2256 
   2257       self.assertEqual(sess.run(grad, {pred: True}), 8.0)
   2258       self.assertEqual(sess.run(grad, {pred: False}), 0.0)
   2259 
   2260       if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
   2261         return
   2262 
   2263       self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
   2264       self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0)
   2265 
   2266   # NOTE: It is ok to have parallel_iterations > 1
   2267   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   2268   @test_util.run_deprecated_v1
   2269   def testWhileUpdateVariable_1(self):
   2270     with self.cached_session():
   2271       select = variables.Variable([3.0, 4.0, 5.0])
   2272       n = constant_op.constant(0)
   2273 
   2274       def loop_iterator(j):
   2275         return math_ops.less(j, 3)
   2276 
   2277       def loop_body(j):
   2278         ns = state_ops.scatter_update(select, j, 10.0)
   2279         nj = math_ops.add(j, 1)
   2280         op = control_flow_ops.group(ns)
   2281         nj = control_flow_ops.with_dependencies([op], nj)
   2282         return [nj]
   2283 
   2284       r = control_flow_ops.while_loop(
   2285           loop_iterator, loop_body, [n], parallel_iterations=1)
   2286       self.evaluate(variables.global_variables_initializer())
   2287       self.assertEqual(3, self.evaluate(r))
   2288       result = self.evaluate(select)
   2289       self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
   2290 
   2291   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   2292   @test_util.run_v1_only("b/120545219")
   2293   def testWhileUpdateVariable_2(self):
   2294     with self.cached_session():
   2295       select1 = variables.Variable([3.0, 4.0, 5.0])
   2296       select2 = variables.Variable([3.0, 4.0, 5.0])
   2297       n = constant_op.constant(0)
   2298 
   2299       def loop_iterator(j):
   2300         return math_ops.less(j, 3)
   2301 
   2302       def loop_body(j):
   2303         ns1 = state_ops.scatter_update(select1, j, 10.0)
   2304         ns2 = state_ops.scatter_update(select2, j, 10.0)
   2305         nj = math_ops.add(j, 1)
   2306         op = control_flow_ops.group(ns1, ns2)
   2307         nj = control_flow_ops.with_dependencies([op], nj)
   2308         return [nj]
   2309 
   2310       r = control_flow_ops.while_loop(
   2311           loop_iterator, loop_body, [n], parallel_iterations=1)
   2312       self.evaluate(variables.global_variables_initializer())
   2313       self.assertEqual(3, self.evaluate(r))
   2314       result1 = self.evaluate(select1)
   2315       self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1)
   2316       result2 = self.evaluate(select2)
   2317       self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
   2318 
   2319   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   2320   @test_util.run_v1_only("b/120545219")
   2321   def testWhileUpdateVariable_3(self):
   2322     with self.cached_session():
   2323       select = variables.Variable([3.0, 4.0, 5.0])
   2324       n = constant_op.constant(0)
   2325 
   2326       def loop_iterator(j, _):
   2327         return math_ops.less(j, 3)
   2328 
   2329       def loop_body(j, _):
   2330         ns = state_ops.scatter_update(select, j, 10.0)
   2331         nj = math_ops.add(j, 1)
   2332         return [nj, ns]
   2333 
   2334       r = control_flow_ops.while_loop(
   2335           loop_iterator,
   2336           loop_body, [n, array_ops.identity(select)],
   2337           parallel_iterations=1)
   2338       self.evaluate(variables.global_variables_initializer())
   2339       result = r[1]
   2340     self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
   2341 
   2342   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   2343   @test_util.run_v1_only("b/120545219")
   2344   def testWhileUpdateVariable_4(self):
   2345     with self.cached_session():
   2346       var_a = variables.Variable(0, name="a")
   2347       var_b = variables.Variable(0, name="b")
   2348       self.evaluate(variables.global_variables_initializer())
   2349 
   2350       c = constant_op.constant(0, name="c")
   2351       asn1 = state_ops.assign_add(var_a, 1, name="a_add")
   2352 
   2353       # Loop condition
   2354       def pred(i):
   2355         return math_ops.less(i, 10)
   2356 
   2357       # Loop body
   2358       def loop_body(i):
   2359         asn2 = state_ops.assign_add(var_b, asn1, name="b_add")
   2360         with ops.control_dependencies([asn2]):
   2361           ni = math_ops.add(i, 1, name="i_add")
   2362         return ni
   2363 
   2364       lpa = control_flow_ops.while_loop(
   2365           pred, loop_body, [c], parallel_iterations=1)
   2366 
   2367       self.assertEqual(0, self.evaluate(var_b))
   2368       self.evaluate(lpa)  # Run the loop
   2369       self.assertEqual(10, self.evaluate(var_b))
   2370 
   2371   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   2372   @test_util.run_v1_only("b/120545219")
   2373   def testWhileUpdateVariable_5(self):
   2374     with self.cached_session():
   2375       # Create some variables.
   2376       var_a = variables.Variable(0, name="a")
   2377       var_b = variables.Variable(0, name="b")
   2378       self.evaluate(variables.global_variables_initializer())
   2379 
   2380       # Change condition to check var_b
   2381       def pred(_):
   2382         return math_ops.less(var_b, 10)
   2383 
   2384       # Change body to increment var_b
   2385       def loop_body(i):
   2386         asn1 = state_ops.assign_add(
   2387             var_a, constant_op.constant(1), name="a_add")
   2388         asn2 = state_ops.assign_add(
   2389             var_b, constant_op.constant(1), name="b_add")
   2390         with ops.control_dependencies([asn1, asn2]):
   2391           inc_b = array_ops.identity(var_b)
   2392         return inc_b
   2393 
   2394       lpa = control_flow_ops.while_loop(
   2395           pred, loop_body, [var_b], parallel_iterations=1, name="loop")
   2396 
   2397       self.assertEqual(0, self.evaluate(var_b))
   2398       self.evaluate(lpa)  # Run the loop
   2399       self.assertEqual(10, self.evaluate(var_a))
   2400       self.assertEqual(10, self.evaluate(var_b))
   2401 
   2402   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   2403   @test_util.run_v1_only("b/120545219")
   2404   def testWhileUpdateVariable_6(self):
   2405     with self.cached_session():
   2406       # Create some variables.
   2407       var_a = variables.Variable(0, name="a")
   2408       var_b = variables.Variable(0, name="b")
   2409       c = constant_op.constant(0)
   2410       self.evaluate(variables.global_variables_initializer())
   2411 
   2412       # Loop condition
   2413       def pred(i):
   2414         return math_ops.less(i, 10)
   2415 
   2416       # Loop body
   2417       def loop_body(i):
   2418         asn1 = state_ops.assign_add(var_a, 1, name="a_add")
   2419         with ops.control_dependencies([asn1]):
   2420           asn2 = state_ops.assign_add(var_b, var_a, name="b_add")
   2421         with ops.control_dependencies([asn2]):
   2422           ni = math_ops.add(i, 1, name="i_add")
   2423           return ni
   2424 
   2425       lpa = control_flow_ops.while_loop(
   2426           pred, loop_body, [c], parallel_iterations=1, name="loop")
   2427 
   2428       self.assertEqual(0, self.evaluate(var_b))
   2429       self.evaluate(lpa)  # Run the loop
   2430       self.assertEqual(55, self.evaluate(var_b))
   2431       self.assertEqual(10, self.evaluate(var_a))
   2432 
   2433   @test_util.run_v1_only("b/120545219")
   2434   def testWhileQueue_1(self):
   2435     with self.cached_session():
   2436       q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
   2437       i = constant_op.constant(0)
   2438 
   2439       def c(i):
   2440         return math_ops.less(i, 10)
   2441 
   2442       def b(i):
   2443         ni = math_ops.add(i, 1)
   2444         ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni)
   2445         return ni
   2446 
   2447       r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
   2448       self.assertEqual([10], self.evaluate(r))
   2449       for i in xrange(10):
   2450         self.assertEqual([i], self.evaluate(q.dequeue()))
   2451 
   2452   @test_util.run_v1_only("b/120545219")
   2453   def testWhileTimeOut(self):
   2454     run_options = config_pb2.RunOptions(timeout_in_ms=1)
   2455     with self.cached_session() as sess:
   2456       n = constant_op.constant(0)
   2457       c = lambda x: True
   2458       b = lambda x: math_ops.add(x, 1)
   2459       r = control_flow_ops.while_loop(c, b, [n])
   2460       with self.assertRaises(errors_impl.DeadlineExceededError):
   2461         sess.run(r, options=run_options)
   2462 
   2463   @test_util.disable_control_flow_v2("b/117119329 (stack)")
   2464   @test_util.run_v1_only("b/120545219")
   2465   def testWhileStack_1(self):
   2466     with self.cached_session():
   2467       s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
   2468       i = constant_op.constant(0)
   2469 
   2470       def c(i):
   2471         return math_ops.less(i, 10)
   2472 
   2473       def b(i):
   2474         ni = math_ops.add(i, 1)
   2475         ni = control_flow_ops.with_dependencies(
   2476             [gen_data_flow_ops.stack_push_v2(s, i)], ni)
   2477         return ni
   2478 
   2479       r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1)
   2480 
   2481       x = constant_op.constant(0)
   2482 
   2483       def c1(i, _):
   2484         return math_ops.greater(i, 0)
   2485 
   2486       def b1(i, x):
   2487         ni = math_ops.subtract(i, 1)
   2488         nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32)
   2489         return [ni, nx]
   2490 
   2491       _, rx = control_flow_ops.while_loop(
   2492           c1,
   2493           b1, [r, x],
   2494           [r.get_shape(), tensor_shape.unknown_shape()],
   2495           parallel_iterations=1)
   2496       self.assertEqual(45, self.evaluate(rx))
   2497 
   2498   def _testWhileGrad_ColocateGradients(self, colocate):
   2499     gpu_dev_name = test.gpu_device_name() if test.is_gpu_available(
   2500     ) else "/device:CPU:0"
   2501 
   2502     graph = ops.Graph()
   2503     with graph.as_default():
   2504       v = constant_op.constant(2.0, name="v")
   2505       c = lambda v: math_ops.less(v, 100.0)
   2506 
   2507       def b(x):
   2508         with ops.device(gpu_dev_name):
   2509           return math_ops.square(x)
   2510 
   2511       loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
   2512       r = gradients_impl.gradients(
   2513           loop, v, colocate_gradients_with_ops=colocate)[0]
   2514 
   2515     r_ops = graph.get_operations()
   2516     r_devices = [(op.name, op.device) for op in r_ops]
   2517 
   2518     self.assertTrue(any("Square" in op.name for op in r_ops))
   2519 
   2520     for (name, dev) in r_devices:
   2521       if not colocate and name.endswith("Square"):
   2522         # Only forward graph contain gpu in Square device
   2523         self.assertTrue(gpu_dev_name in dev)
   2524       elif colocate and "Square" in name:
   2525         # Forward and backward graphs contain gpu in Square/Square_grad devices
   2526         self.assertTrue(gpu_dev_name in dev)
   2527       else:
   2528         self.assertFalse(gpu_dev_name in dev)
   2529 
   2530     with self.session(graph=graph) as sess:
   2531       self.assertAllClose(1024.0, self.evaluate(r))
   2532 
   2533   @test_util.disable_control_flow_v2("b/116351701 (colocation)")
   2534   @test_util.run_v1_only("b/120545219")
   2535   def testWhileGrad_ColocateGradients(self):
   2536     self._testWhileGrad_ColocateGradients(colocate=False)
   2537     self._testWhileGrad_ColocateGradients(colocate=True)
   2538 
   2539   @test_util.run_v1_only("b/120545219")
   2540   def testWhileGrad_Square(self):
   2541     with self.cached_session():
   2542       v = constant_op.constant(2.0, name="v")
   2543       c = lambda v: math_ops.less(v, 100.0)
   2544       b = math_ops.square
   2545       r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
   2546       r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v)
   2547 
   2548       r = gradients_impl.gradients(r, v)[0]
   2549       self.assertAllClose(1024.0, self.evaluate(r))
   2550 
   2551   @test_util.run_v1_only("b/120545219")
   2552   def testWhileGrad_Shape(self):
   2553     with self.cached_session():
   2554       x = array_ops.placeholder(dtypes.float32, shape=[None])
   2555       v = constant_op.constant([2.0], name="v")
   2556       n = constant_op.constant(0, name="n")
   2557       c = lambda i, v: math_ops.less(i, 5)
   2558       b = lambda i, v: [i + 1, math_ops.multiply(x, v)]
   2559       r = control_flow_ops.while_loop(
   2560           c,
   2561           b, [n, v],
   2562           [n.get_shape(), tensor_shape.unknown_shape()],
   2563           parallel_iterations=1)
   2564 
   2565       r = gradients_impl.gradients(r[1], x)[0]
   2566       self.assertEqual([None], r.get_shape().as_list())
   2567       self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
   2568 
   2569   @test_util.run_deprecated_v1
   2570   def testWhileGrad_BaseShape(self):
   2571     with self.cached_session() as sess:
   2572       x = array_ops.placeholder(dtypes.float32, [None])
   2573       v0 = constant_op.constant([2.0, 2.0], name="v")
   2574       c = lambda v: constant_op.constant(False)
   2575       b = lambda v: math_ops.multiply(v, x)
   2576       r = control_flow_ops.while_loop(c, b, [v0])
   2577       y = math_ops.square(x)
   2578 
   2579       r = gradients_impl.gradients([r, y], x)[0]
   2580       self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
   2581 
   2582   @test_util.run_v1_only("b/120545219")
   2583   def testWhileGrad_MultipleUses(self):
   2584     with self.cached_session():
   2585       v = constant_op.constant(2.0, name="v")
   2586       c = lambda v: math_ops.less(v, 100.0)
   2587       b = math_ops.square
   2588       r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
   2589       r = math_ops.multiply(r, r)
   2590 
   2591       r = gradients_impl.gradients(r, v)[0]
   2592       self.assertEqual(524288.0, self.evaluate(r))
   2593 
   2594   @test_util.run_v1_only("b/120545219")
   2595   def testWhileGrad_LoopAdd(self):
   2596     with self.cached_session():
   2597       v = constant_op.constant(2.0, name="v")
   2598       c = lambda v: math_ops.less(v, 100.0)
   2599       b = math_ops.square
   2600       r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
   2601       r = math_ops.add(r, r)
   2602 
   2603       r = gradients_impl.gradients(r, v)[0]
   2604       self.assertAllClose(2048.0, self.evaluate(r))
   2605 
   2606   def _testWhileGrad_Mul(self, use_gpu, p_iters):
   2607     with self.cached_session(use_gpu=use_gpu) as sess:
   2608       a = constant_op.constant(3.0, name="a")
   2609       v = constant_op.constant(2.0, name="v")
   2610       c = lambda v: math_ops.less(v, 100.0)
   2611       b = lambda v: math_ops.multiply(v, a)
   2612       r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters)
   2613 
   2614       grad_a, grad_v = gradients_impl.gradients(r, [a, v])
   2615       grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v])
   2616       self.assertAllClose(216.0, grad_a_val)
   2617       self.assertAllClose(81.0, grad_v_val)
   2618 
   2619   @test_util.run_deprecated_v1
   2620   def testWhileGrad_Mul(self):
   2621     self._testWhileGrad_Mul(use_gpu=False, p_iters=1)
   2622     self._testWhileGrad_Mul(use_gpu=False, p_iters=10)
   2623     self._testWhileGrad_Mul(use_gpu=True, p_iters=1)
   2624     self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
   2625 
   2626   def _testNestedWhileCondWhileGrad(self, use_gpu):
   2627 
   2628     with self.cached_session(use_gpu=use_gpu):
   2629       v = constant_op.constant(1.0)
   2630 
   2631       def inner_loop(s):
   2632         z = constant_op.constant(0)
   2633         c = lambda i, x: math_ops.less(i, 4)
   2634         b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
   2635         return control_flow_ops.while_loop(c, b, [z, s])
   2636 
   2637       c = lambda x: math_ops.less(x, 128.0)
   2638 
   2639       def b(x):
   2640         return control_flow_ops.cond(
   2641             constant_op.constant(True),
   2642             lambda: math_ops.square(inner_loop(x)[1]),
   2643             lambda: math_ops.multiply(x, 2.0))
   2644 
   2645       r = control_flow_ops.while_loop(c, b, [v])
   2646       r = gradients_impl.gradients(r, v)[0]
   2647       self.assertAllClose(512.0, self.evaluate(r))
   2648 
   2649   @test_util.run_deprecated_v1
   2650   def testNestedWhileCondWhileGrad(self):
   2651     self._testNestedWhileCondWhileGrad(use_gpu=False)
   2652 
   2653   @test_util.run_deprecated_v1
   2654   def testNestedWhileCondWhileGradGpu(self):
   2655     self._testNestedWhileCondWhileGrad(use_gpu=True)
   2656 
   2657   @test_util.run_v1_only("b/120545219")
   2658   def testWhileGrad_Variable(self):
   2659     with self.cached_session():
   2660       a = variables.Variable(3.0)
   2661       v = constant_op.constant(2.0, name="v")
   2662       c = lambda v: math_ops.less(v, 100.0)
   2663       b = lambda v: math_ops.multiply(v, a)
   2664       r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
   2665 
   2666       r = gradients_impl.gradients(r, a)
   2667       self.evaluate(variables.global_variables_initializer())
   2668       self.assertAllClose(216.0, r[0])
   2669 
   2670   @test_util.run_deprecated_v1
   2671   def testWhileGrad_ResourceVariable(self):
   2672     with self.cached_session():
   2673       a = resource_variable_ops.ResourceVariable(3.0)
   2674       v = constant_op.constant(2.0, name="v")
   2675       c = lambda v: math_ops.less(v, 100.0)
   2676       b = lambda v: math_ops.multiply(v, a)
   2677       r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1)
   2678 
   2679       g = gradients_impl.gradients(r, a)
   2680       self.evaluate(variables.global_variables_initializer())
   2681       self.assertAllClose(216.0, g[0])
   2682 
   2683   def testWhileGrad_EagerResourceVariable(self):
   2684     with context.eager_mode():
   2685       a = resource_variable_ops.ResourceVariable(
   2686           np.ones([2, 2], dtype=np.float32))
   2687       v = constant_op.constant(1.0)
   2688 
   2689       @eager_function.defun
   2690       def fn():
   2691         r = control_flow_ops.while_loop(
   2692             lambda i, _: i < 2,
   2693             lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v),
   2694             [0, 1.0])[1]
   2695         return gradients_impl.gradients(r, [v])[0]
   2696 
   2697       self.assertEqual(self.evaluate(fn()), 32.)
   2698 
   2699   @test_util.disable_xla("b/128643381")
   2700   def testWhileGrad_ResourceVarInFunctionCall(self):
   2701 
   2702     @def_function.function
   2703     def foo(x, var):
   2704       return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
   2705 
   2706     @def_function.function
   2707     def bar(var):
   2708       r = control_flow_ops.while_loop(
   2709           lambda i, _: i < 2,
   2710           lambda i, x: (i + 1, foo(x, var)),
   2711           [0, 0.0])[1]
   2712       return gradients_impl.gradients(r, var)[0]
   2713 
   2714     var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.])
   2715     self.evaluate(variables.global_variables_initializer())
   2716     grad = self.evaluate(bar(var))
   2717     self.assertIsInstance(grad, ops.IndexedSlicesValue)
   2718     self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
   2719 
   2720   @test_util.disable_xla("b/128643461")
   2721   def testWhileGrad_ResourceVarInNestedFunctionCall(self):
   2722 
   2723     @def_function.function
   2724     def foo(x, var):
   2725       return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
   2726 
   2727     @def_function.function
   2728     def foo2(x, var):
   2729       return foo(x, var)
   2730 
   2731     @def_function.function
   2732     def bar(var):
   2733       r = control_flow_ops.while_loop(
   2734           lambda i, _: i < 2,
   2735           lambda i, x: (i + 1, foo2(x, var)),
   2736           [0, 0.0])[1]
   2737       return gradients_impl.gradients(r, var)[0]
   2738 
   2739     var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
   2740     self.evaluate(variables.global_variables_initializer())
   2741     grad = self.evaluate(bar(var))
   2742     self.assertIsInstance(grad, ops.IndexedSlicesValue)
   2743     self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
   2744 
   2745   def testWhileGrad_ResourceVarInLoopInFunctionCall(self):
   2746     if test.is_gpu_available():
   2747       self.skipTest("b/128635252")
   2748 
   2749     @def_function.function
   2750     def foo(x, var):
   2751       return control_flow_ops.while_loop(
   2752           lambda j, _: j < 3,
   2753           lambda j, y: (j + 1,
   2754                         y + math_ops.reduce_sum(var.sparse_read([1, 2]))),
   2755           [0, x])[1]
   2756 
   2757     @def_function.function
   2758     def bar(var):
   2759       r = control_flow_ops.while_loop(
   2760           lambda i, _: i < 2,
   2761           lambda i, x: (i + 1, foo(x, var)),
   2762           [0, 0.0])[1]
   2763       return gradients_impl.gradients(r, var)[0]
   2764 
   2765     var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
   2766     self.evaluate(variables.global_variables_initializer())
   2767     grad = self.evaluate(bar(var))
   2768     self.assertIsInstance(grad, ops.IndexedSlicesValue)
   2769     self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.])
   2770 
   2771   @test_util.disable_xla("b/128639858")
   2772   def testWhileCondGrad_ResourceVarInFunctionCall(self):
   2773 
   2774     @def_function.function
   2775     def foo(x, var):
   2776       return x + var.sparse_read([1])[0]
   2777 
   2778     def body(i, x):
   2779       return (i + 1, control_flow_ops.cond(
   2780           math_ops.equal(i % 2, 0),
   2781           lambda: foo(x, var1),
   2782           lambda: foo(x, var2)))
   2783 
   2784     @def_function.function
   2785     def bar(var1, var2):
   2786       r = control_flow_ops.while_loop(
   2787           lambda i, _: i < 4, body, [0, 0.0])
   2788       return gradients_impl.gradients(r, [var1, var2])
   2789 
   2790     var1 = resource_variable_ops.ResourceVariable([1., 2., 3.])
   2791     var2 = resource_variable_ops.ResourceVariable([4., 5.])
   2792     self.evaluate(variables.global_variables_initializer())
   2793     grads = self.evaluate(bar(var1, var2))
   2794     self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.])
   2795     self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.])
   2796 
   2797   @test_util.run_deprecated_v1
   2798   def testWhileGrad_ResourceVarSparseRead(self):
   2799     # NOTE(skyewm): this test is interesting because the
   2800     # ResourceVariable.sparse_read gradient function returns an IndexedSlices.
   2801     var = resource_variable_ops.ResourceVariable(np.ones(5),
   2802                                                  dtype=dtypes.float32)
   2803     r = control_flow_ops.while_loop(
   2804         lambda i, _: i < 3,
   2805         lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))),
   2806         [0, constant_op.constant(1.0)])[1]
   2807     grad = gradients_impl.gradients(r, var)[0]
   2808 
   2809     self.evaluate(variables.global_variables_initializer())
   2810     grad_val = self.evaluate(grad)
   2811     self.assertIsInstance(grad_val, ops.IndexedSlicesValue)
   2812     arr = gradient_checker_v2._to_numpy(grad_val)
   2813     self.assertAllEqual(arr, [0., 12., 0., 12., 0.])
   2814 
   2815   @test_util.run_deprecated_v1
   2816   def testWhileGrad_MultiResourceVarSparseRead(self):
   2817     # NOTE(skyewm): this test is interesting because the
   2818     # ResourceVariable.sparse_read gradient function returns an IndexedSlices.
   2819     var1 = resource_variable_ops.ResourceVariable(np.ones(5),
   2820                                                   dtype=dtypes.float32)
   2821     var2 = resource_variable_ops.ResourceVariable(np.ones(3),
   2822                                                   dtype=dtypes.float32)
   2823     x1_init = constant_op.constant([0., 0.])
   2824     x2_init = constant_op.constant(1.)
   2825     x3_init = constant_op.constant(1.)
   2826 
   2827     def body(i, unused_x1, x2, x3):
   2828       y1 = var1.sparse_read([1, 3])
   2829       y2 = x2 * 2
   2830       y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0]))
   2831       return i + 1, y1, y2, y3
   2832 
   2833     r = control_flow_ops.while_loop(
   2834         lambda i, x1, x2, x3: i < 3, body,
   2835         [0, x1_init, x2_init, x3_init])[1:]
   2836     var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2])
   2837 
   2838     self.evaluate(variables.global_variables_initializer())
   2839     var1_grad_val = self.evaluate(var1_grad)
   2840     var2_grad_val = self.evaluate(var2_grad)
   2841     self.assertIsInstance(var1_grad_val, ops.IndexedSlicesValue)
   2842     self.assertIsInstance(var2_grad_val, ops.IndexedSlicesValue)
   2843     self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val),
   2844                         [0., 1., 0., 1., 0.])
   2845     self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
   2846                         [3., 0., 0.])
   2847 
   2848   @test_util.run_deprecated_v1
   2849   def testWhileGrad_Gather(self):
   2850     # NOTE(skyewm): this test is interesting because the gather gradient
   2851     # function returns an IndexedSlices.
   2852     x = constant_op.constant([1., 1., 1., 1., 1.])
   2853     y = control_flow_ops.while_loop(
   2854         lambda i, _: i < 3,
   2855         lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
   2856         [0, x[:1]])[1]
   2857     z = y * 3.0
   2858     grad = gradients_impl.gradients(z, x)[0]
   2859     self.assertEqual(self.evaluate(y), 8.)
   2860     self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
   2861 
   2862   @test_util.run_deprecated_v1
   2863   def testWhileGrad_GatherNoFanOut(self):
   2864     # NOTE(skyewm): this test is interesting because the gather gradient
   2865     # function returns an IndexedSlices.
   2866     x = constant_op.constant([1., 1., 1., 1., 1.])
   2867     y = control_flow_ops.while_loop(
   2868         lambda i, _: i < 3,
   2869         lambda i, x: (i + 1, array_ops.gather(x, [0])),
   2870         [0, x[:1]])[1]
   2871     z = y * 3.0
   2872     grad = gradients_impl.gradients(z, x)[0]
   2873     self.assertEqual(self.evaluate(y), 1.)
   2874     self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
   2875 
   2876   @test_util.run_v1_only("b/120545219")
   2877   def testWhileGradInCond(self):
   2878 
   2879     with self.cached_session():
   2880       n = ops.convert_to_tensor(1.0, name="n")
   2881       x = array_ops.placeholder(dtypes.float32, shape=None)
   2882       c = lambda n: math_ops.less(n, 10.0)
   2883       b = lambda n: math_ops.add(n, x)
   2884 
   2885       def fn1():
   2886         r = control_flow_ops.while_loop(c, b, [n],
   2887                                         [tensor_shape.unknown_shape()])
   2888         return gradients_impl.gradients(r, x)[0]
   2889 
   2890       r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
   2891       self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
   2892 
   2893   @test_util.disable_control_flow_v2("b/116340060")
   2894   @test_util.run_v1_only("b/120545219")
   2895   def testGradInWhileWrtInitialLoopVal(self):
   2896     with self.cached_session():
   2897       x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
   2898       y = x + 1
   2899 
   2900       def body(i, v):
   2901         z = v * 2
   2902         return i + 1, gradients_impl.gradients(z, x)[0]
   2903 
   2904       with self.assertRaisesRegexp(
   2905           ValueError,
   2906           "Cannot compute gradient inside while loop with respect to op 'x'. "
   2907           "We do not support taking the gradient wrt or through the initial "
   2908           "value of a loop variable. Gradients can be computed through "
   2909           "loop invariants or wrt the input parameters to the loop body."):
   2910         control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
   2911 
   2912   @test_util.run_v1_only("b/120545219")
   2913   def testWhileGradInWhile(self):
   2914     with self.cached_session():
   2915       n = ops.convert_to_tensor(1.0, name="n")
   2916       x = array_ops.placeholder(dtypes.float32, shape=None)
   2917       c = lambda n: math_ops.less(n, 10.0)
   2918       b = lambda n: math_ops.add(n, x)
   2919 
   2920       def b1(n):
   2921         r = control_flow_ops.while_loop(c, b, [n],
   2922                                         [tensor_shape.unknown_shape()])
   2923         return gradients_impl.gradients(r, x)
   2924 
   2925       r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n],
   2926                                       [tensor_shape.unknown_shape()])
   2927       self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
   2928 
   2929   @test_util.run_v1_only("b/120545219")
   2930   def testCondGradInNestedWhiles(self):
   2931 
   2932     def outer_body(i, x):
   2933       _, x = control_flow_ops.while_loop(
   2934           lambda j, x: j < 3, inner_body, [0, 0.0])
   2935       return i + 1, x
   2936 
   2937     def inner_body(j, x):
   2938       y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x)
   2939       return j + 1, gradients_impl.gradients(y, x)[0]
   2940 
   2941     i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0])
   2942 
   2943     with self.cached_session() as sess:
   2944       i_val, x_val = self.evaluate([i, x])
   2945       self.assertEqual(i_val, 3)
   2946       self.assertAllClose(x_val, 1.0)
   2947 
   2948   @test_util.run_gpu_only
   2949   def testGpuResourceAccess(self):
   2950     with ops.device(test.gpu_device_name()):
   2951       var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
   2952 
   2953     @def_function.function
   2954     def foo():
   2955       return control_flow_ops.while_loop(
   2956           lambda i, _: i < 3,
   2957           lambda i, x: (i + 1, control_flow_ops.cond(
   2958               constant_op.constant(True),
   2959               lambda: x + var,
   2960               lambda: x)),
   2961           [0, 0.0])[1]
   2962 
   2963     self.evaluate(variables.global_variables_initializer())
   2964     self.assertEqual(self.evaluate(foo()), 9.0)
   2965 
   2966   @test_util.disable_xla("b/128643398")
   2967   def testNestedResourceAccess(self):
   2968     var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0))
   2969 
   2970     @eager_function.defun
   2971     def test_fn():
   2972       x = constant_op.constant(0.0)
   2973       r = control_flow_ops.while_loop(
   2974           # Outer loop condition
   2975           lambda i, y: i < 2,
   2976           # Outer loop body
   2977           lambda i, y: (i + 1, y + control_flow_ops.cond(
   2978               constant_op.constant(True),
   2979               # True branch
   2980               lambda: control_flow_ops.while_loop(
   2981                   # Inner loop condition
   2982                   lambda j, z: j < 3,
   2983                   # Inner loop body
   2984                   lambda j, z: (j + 1, z + math_ops.square(var)),
   2985                   # Inner initial loop value
   2986                   [0, y])[1],
   2987               # False branch
   2988               lambda: (0.0))),
   2989           # Outer initial loop value
   2990           [0, x])[1]
   2991 
   2992       grad = gradients_impl.gradients(r, x)[0]
   2993       return r, grad
   2994 
   2995     self.evaluate(variables.global_variables_initializer())
   2996     r, grad = self.evaluate(test_fn())
   2997     # 2 * 3 * 3^2
   2998     self.assertEqual(r, 81.0)
   2999     # v1 control flow gets the wrong answer!!!
   3000     # Gradient computation:
   3001     #   f(x) = x + 3^2
   3002     #   inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27
   3003     #   g(x) = x + inner_loop(x) = 2x + 27
   3004     #   outer_loop(x) = g(g(x)) = 4x + 81
   3005     #   outer_loop'(x) = 4
   3006     # Note that v1 control flow gets 4.0 as well if the cond is removed.
   3007     if control_flow_util.ENABLE_CONTROL_FLOW_V2:
   3008       self.assertEqual(grad, 4.0)
   3009 
   3010   def testWhile_NestedInput(self):
   3011     with self.cached_session() as sess:
   3012       named = collections.namedtuple("named", ("a", "b"))
   3013       loop_vars = [
   3014           named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
   3015           (constant_op.constant(2.0), constant_op.constant(3.0)),
   3016           constant_op.constant(4.0)
   3017       ]
   3018       c = lambda lv0, _1, _2: lv0.a < 100.0
   3019 
   3020       def b(lv0, lv1, lv2):
   3021         lv0 = named(a=lv0.a + 1, b=lv0.b)
   3022         lv1 = (lv1[0] + 1, lv1[1])
   3023         lv2 += 2
   3024         return [lv0, lv1, lv2]
   3025 
   3026       r = control_flow_ops.while_loop(c, b, loop_vars)
   3027 
   3028       self.assertTrue(isinstance(r, list))
   3029       self.assertTrue(isinstance(r[0], named))
   3030       self.assertTrue(isinstance(r[1], tuple))
   3031       self.assertTrue(isinstance(r[2], ops.Tensor))
   3032 
   3033       r_flattened = nest.flatten(r)
   3034       self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
   3035                        self.evaluate(r_flattened))
   3036 
   3037   @test_util.run_v1_only("b/120545219")
   3038   def testWhile_NestedBadArityFails(self):
   3039     with self.cached_session():
   3040       named = collections.namedtuple("named", ("a", "b"))
   3041       loop_vars = [
   3042           named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)),
   3043           (constant_op.constant(2.0), constant_op.constant(3.0)),
   3044           constant_op.constant(4.0)
   3045       ]
   3046       c = lambda lv0, _1, _2: lv0.a < 100.0
   3047 
   3048       def b(lv0, lv1, _):
   3049         return [lv0, lv1]
   3050 
   3051       with self.assertRaisesRegexp(ValueError, "the same number of elements"):
   3052         control_flow_ops.while_loop(c, b, loop_vars)
   3053 
   3054   @test_util.run_v1_only("b/120545219")
   3055   def testWhileGrad_ys_xs(self):
   3056     with self.cached_session():
   3057       x = constant_op.constant(3.0, name="x")
   3058       y = constant_op.constant(2.0, name="y")
   3059 
   3060       c = lambda x, y: math_ops.less(x, 100.0)
   3061 
   3062       def b(x, y):
   3063         y1 = math_ops.add(x, y)
   3064         x1 = math_ops.multiply(x, y1)
   3065         return x1, y1
   3066 
   3067       rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1)
   3068 
   3069       r = gradients_impl.gradients([rx, ry], x)
   3070       self.assertAllClose(304.0, r[0])
   3071       r = gradients_impl.gradients([rx, ry], y)
   3072       self.assertAllClose(124.0, r[0])
   3073       r = gradients_impl.gradients([rx], x)
   3074       self.assertAllClose(295.0, r[0])
   3075       r = gradients_impl.gradients([rx], y)
   3076       self.assertAllClose(120.0, r[0])
   3077 
   3078   @test_util.run_deprecated_v1
   3079   def testWhileGrad_Dependency(self):
   3080     with self.cached_session():
   3081       i = constant_op.constant(0, name="i")
   3082       x = constant_op.constant(2.0, name="x")
   3083 
   3084       c = lambda i, x: math_ops.less(i, 10)
   3085 
   3086       def b(i, x):
   3087         x = math_ops.multiply(x, 2.0)
   3088         i = math_ops.add(i, 1)
   3089         return i, x
   3090 
   3091       ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
   3092 
   3093       r = gradients_impl.gradients([ri, rx], x)
   3094       self.assertAllClose(1024.0, r[0])
   3095       r = gradients_impl.gradients([rx], x)
   3096       self.assertAllClose(1024.0, r[0])
   3097 
   3098   @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
   3099   @test_util.run_v1_only("b/120545219")
   3100   def testWhileGrad_NoGradient(self):
   3101     with self.cached_session():
   3102       v = constant_op.constant(2.0, name="v")
   3103       c = lambda v: math_ops.less(v, 100.0)
   3104       b = math_ops.square
   3105       r = control_flow_ops.while_loop(c, b, [v], back_prop=False)
   3106       r = math_ops.add(r, v)
   3107       r = gradients_impl.gradients(r, v)
   3108       self.assertAllClose(1.0, r[0])
   3109 
   3110   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   3111   @test_util.run_v1_only("b/120545219")
   3112   def testWhileGrad_NoDependency(self):
   3113     with self.cached_session() as sess:
   3114       variable = variables.Variable(array_ops.ones([2, 3]))
   3115       duration = array_ops.zeros([], dtype=dtypes.int32)
   3116 
   3117       def cond(duration, tensor, _):
   3118         del tensor
   3119         return duration < 10
   3120 
   3121       def body(duration, tensor, _):
   3122         return (duration + 1, tensor, tensor)
   3123 
   3124       loop_vars = [duration, variable, variable]
   3125       tensors = control_flow_ops.while_loop(
   3126           cond=cond, body=body, loop_vars=loop_vars)
   3127       cost = math_ops.reduce_sum(tensors[2])
   3128       grad = gradients_impl.gradients(cost, [variable])
   3129       self.evaluate(variables.global_variables_initializer())
   3130       self.assertAllClose(np.ones([2, 3]), sess.run(grad[0]))
   3131 
   3132   @test_util.run_deprecated_v1
   3133   def testWhileGrad_Const(self):
   3134     with self.cached_session() as sess:
   3135       c0 = constant_op.constant(0.0, name="c0")
   3136       c1 = constant_op.constant(1.0, name="c1")
   3137       duration = constant_op.constant(0, name="t")
   3138 
   3139       def cond(duration, _):
   3140         return duration < 1
   3141 
   3142       def body(duration, _):
   3143         return duration + 1, c1
   3144 
   3145       loop_vars = [duration, c0]
   3146       tensors = control_flow_ops.while_loop(
   3147           cond=cond, body=body, loop_vars=loop_vars)
   3148       cost = math_ops.reduce_sum(tensors[1])
   3149       grad = gradients_impl.gradients(cost, [c0])
   3150       self.assertAllClose(0.0, sess.run(grad[0]))
   3151 
   3152   @test_util.run_v1_only("b/120545219")
   3153   def testWhileGrad_SerialTwoLoops(self):
   3154     with self.cached_session():
   3155       i = constant_op.constant(0, name="i")
   3156       x = constant_op.constant(2.0, name="x")
   3157 
   3158       c = lambda i, x: math_ops.less(i, 5)
   3159 
   3160       def b(i, x):
   3161         x = math_ops.multiply(x, 2.0)
   3162         i = math_ops.add(i, 1)
   3163         return i, x
   3164 
   3165       _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
   3166       _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1)
   3167 
   3168       r = gradients_impl.gradients([rx], x)
   3169       self.assertAllClose(1024.0, r[0])
   3170 
   3171   @test_util.run_v1_only("b/120545219")
   3172   def testWhileGrad_ParallelTwoLoops(self):
   3173     with self.cached_session():
   3174       i = constant_op.constant(0, name="i")
   3175       x = constant_op.constant(2.0, name="x")
   3176 
   3177       c = lambda i, x: math_ops.less(i, 5)
   3178 
   3179       def b(i, x):
   3180         x = math_ops.multiply(x, 2.0)
   3181         i = math_ops.add(i, 1)
   3182         return i, x
   3183 
   3184       _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
   3185       _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1)
   3186       rx = math_ops.add(r1, r2)
   3187 
   3188       r = gradients_impl.gradients([rx], x)
   3189       self.assertAllClose(64.0, r[0])
   3190 
   3191   @test_util.run_v1_only("b/120545219")
   3192   def testWhileGrad_OneOutputWithControlDependencyOnSecond(self):
   3193     with self.cached_session():
   3194       i = constant_op.constant(0, name="i")
   3195       x = constant_op.constant(1.0, name="x")
   3196       y = constant_op.constant(1.0, name="y")
   3197       c = lambda i, *_: math_ops.less(i, 1, name="cond_less")
   3198 
   3199       def b(i, xi, yi):
   3200         # return (i + 1, xi, xi + yi)
   3201         return (math_ops.add(i, 1, name="inc"), array_ops.identity(
   3202             xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi"))
   3203 
   3204       _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y])
   3205       with ops.control_dependencies([x_f]):
   3206         y_f_d = array_ops.identity(y_f, name="y_f_d")
   3207 
   3208       self.assertAllClose(2.0, self.evaluate(y_f_d))  # y_f_d = 1.0 + 1.0
   3209       g = gradients_impl.gradients([y_f_d], [x])[0]
   3210       self.assertTrue(g is not None)
   3211       self.assertAllClose(1.0,
   3212                           self.evaluate(g))  # y_f_d = x + 1.0, dy_f_d/dx = 1.0
   3213 
   3214   def _testNestedWhileGrad_Simple(self, use_gpu):
   3215     with self.cached_session(use_gpu=use_gpu):
   3216       v = constant_op.constant(1.0)
   3217 
   3218       def inner_loop(s):
   3219         c = lambda x: math_ops.less(x, 4.0)
   3220         b = lambda x: math_ops.multiply(x, 2.0)
   3221         return control_flow_ops.while_loop(c, b, [s])
   3222 
   3223       c = lambda x: math_ops.less(x, 2.0)
   3224       b = lambda x: math_ops.multiply(inner_loop(x), 2.0)
   3225       r = control_flow_ops.while_loop(c, b, [v])
   3226 
   3227       r = gradients_impl.gradients(r, v)[0]
   3228       self.assertAllClose(8.0, self.evaluate(r))
   3229 
   3230   @test_util.run_deprecated_v1
   3231   def testNestedWhileGrad_Simple(self):
   3232     self._testNestedWhileGrad_Simple(use_gpu=False)
   3233     self._testNestedWhileGrad_Simple(use_gpu=True)
   3234 
   3235   @test_util.run_v1_only("b/120545219")
   3236   def testNestedWhileGrad_SerialInner(self):
   3237     with self.cached_session():
   3238       v = constant_op.constant(1.0)
   3239 
   3240       def inner_loop1(s):
   3241         z = constant_op.constant(0)
   3242         c = lambda i, x: math_ops.less(i, 4)
   3243         b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
   3244         return control_flow_ops.while_loop(c, b, [z, s])
   3245 
   3246       def inner_loop2(s):
   3247         z = constant_op.constant(0)
   3248         c = lambda i, x: math_ops.less(i, 4)
   3249         b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
   3250         return control_flow_ops.while_loop(c, b, [z, s])
   3251 
   3252       c = lambda x: math_ops.less(x, 128.0)
   3253       b = lambda x: inner_loop2(inner_loop1(x)[1])[1]
   3254       r = control_flow_ops.while_loop(c, b, [v])
   3255 
   3256       r = gradients_impl.gradients(r, v)[0]
   3257       self.assertAllClose(256.0, self.evaluate(r))
   3258 
   3259   @test_util.run_deprecated_v1
   3260   def testNestedWhileGrad_ParallelInner(self):
   3261     with self.cached_session():
   3262       v = constant_op.constant(1.0)
   3263 
   3264       def inner_loop1(s):
   3265         z = constant_op.constant(0)
   3266         c = lambda i, x: math_ops.less(i, 4)
   3267         b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
   3268         return control_flow_ops.while_loop(c, b, [z, s])
   3269 
   3270       def inner_loop2(s):
   3271         z = constant_op.constant(0)
   3272         c = lambda i, x: math_ops.less(i, 4)
   3273         b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)]
   3274         return control_flow_ops.while_loop(c, b, [z, s])
   3275 
   3276       c = lambda x: math_ops.less(x, 128.0)
   3277       b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1])
   3278       r = control_flow_ops.while_loop(c, b, [v])
   3279 
   3280       r = gradients_impl.gradients(r, v)[0]
   3281       self.assertAllClose(512.0, self.evaluate(r))
   3282 
   3283   @test_util.run_v1_only("b/120545219")
   3284   def testNestedWhileGrad_ParallelIterations(self):
   3285     # Make sure the stack pushes and pops of an inner loop are executed in
   3286     # the sequential order of the iterations of its outer loop.
   3287     with self.cached_session() as sess:
   3288 
   3289       def inner_loop(t):
   3290         fn = lambda n: n + math_ops.square(var)
   3291         return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10)
   3292 
   3293       def outer_loop(inp):
   3294         return map_fn.map_fn(
   3295             fn=inner_loop, elems=inp, parallel_iterations=10)
   3296 
   3297       var = variables.Variable(constant_op.constant(3.0))
   3298       inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
   3299       res = outer_loop(inp)
   3300       optimizer = adam.AdamOptimizer(learning_rate=0.001)
   3301       train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res)))
   3302       self.evaluate(variables.global_variables_initializer())
   3303       self.evaluate(train_op)
   3304       self.assertAllClose(2.999, var.read_value())
   3305 
   3306   def _testWhileCondGrad_Simple(self, use_gpu):
   3307     with self.cached_session(use_gpu=use_gpu):
   3308       v = ops.convert_to_tensor(2.0, name="v")
   3309       n = ops.convert_to_tensor(100.0, name="n")
   3310       one = ops.convert_to_tensor(1.0, name="one")
   3311       c = lambda x: math_ops.less(x, n)
   3312       # pylint: disable=undefined-variable
   3313       # for OSS build
   3314       b = lambda x: control_flow_ops.cond(constant_op.constant(True),
   3315                                           lambda: math_ops.square(x),
   3316                                           lambda: math_ops.subtract(x, one))
   3317       # pylint: enable=undefined-variable
   3318       r = control_flow_ops.while_loop(c, b, [v])
   3319       r = gradients_impl.gradients(r, v)[0]
   3320       self.assertAllClose(1024.0, self.evaluate(r))
   3321 
   3322   @test_util.run_deprecated_v1
   3323   def testWhileCondGrad_Simple(self):
   3324     self._testWhileCondGrad_Simple(use_gpu=False)
   3325     self._testWhileCondGrad_Simple(use_gpu=True)
   3326 
   3327   @test_util.run_deprecated_v1
   3328   def testWhileCondGrad_UnknownShape(self):
   3329     with self.cached_session() as sess:
   3330       v = array_ops.placeholder(dtypes.float32)
   3331       n = ops.convert_to_tensor(100.0, name="n")
   3332       one = ops.convert_to_tensor(1.0, name="one")
   3333       c = lambda x: math_ops.less(x, n)
   3334       # pylint: disable=undefined-variable
   3335       # for OSS build
   3336       b = lambda x: control_flow_ops.cond(constant_op.constant(True),
   3337                                           lambda: math_ops.square(x),
   3338                                           lambda: math_ops.subtract(x, one))
   3339       # pylint: enable=undefined-variable
   3340       r = control_flow_ops.while_loop(c, b, [v])
   3341       r = gradients_impl.gradients(r, v)[0]
   3342       r = sess.run(r, feed_dict={v: 2.0})
   3343       self.assertAllClose(1024.0, r)
   3344 
   3345   @test_util.run_deprecated_v1
   3346   def testWhileGrad_Concat(self):
   3347     with self.cached_session() as sess:
   3348       x = variable_scope.get_variable("x", initializer=[[1., 2.]])
   3349       i0 = constant_op.constant(0)
   3350       h0 = array_ops.zeros([0, 2])
   3351 
   3352       def condition(i, _):
   3353         return i < 2
   3354 
   3355       def body(i, h):
   3356         return i + 1, array_ops.concat([h, x], 0)
   3357 
   3358       _, h = control_flow_ops.while_loop(
   3359           condition, body, [i0, h0],
   3360           [i0.get_shape(), tensor_shape.TensorShape([None, 2])])
   3361       s = math_ops.reduce_sum(h)
   3362 
   3363       optimizer = gradient_descent.GradientDescentOptimizer(0.01)
   3364       op = optimizer.minimize(s)
   3365 
   3366       self.evaluate(variables.global_variables_initializer())
   3367       self.evaluate(op)
   3368       self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x))
   3369 
   3370   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   3371   @test_util.run_v1_only("b/120545219")
   3372   def testWhileWithRefsWithGradients_1(self):
   3373     with self.cached_session() as sess:
   3374       x = variables.VariableV1(0.)._ref()  # pylint: disable=protected-access
   3375       i = constant_op.constant(0)
   3376       c = lambda i, x: math_ops.less(i, 10)
   3377 
   3378       self.assertEqual(x.dtype, dtypes.float32_ref)
   3379 
   3380       def body(i, x):
   3381         self.assertEqual(x.dtype, dtypes.float32_ref)
   3382         return [i + 1, gen_array_ops.ref_identity(x)]
   3383 
   3384       r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
   3385 
   3386       grad_ys = [variables.VariableV1(73)._ref()]  # pylint: disable=protected-access
   3387       grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
   3388 
   3389       self.evaluate(variables.global_variables_initializer())
   3390 
   3391       self.assertEqual(r[0].dtype, dtypes.int32)
   3392       self.assertEqual(r[1].dtype, dtypes.float32_ref)
   3393 
   3394       value_i, value_x, value_x_grad = sess.run(r + grad)
   3395 
   3396     self.assertEqual(10, value_i)
   3397     self.assertEqual(0, value_x)
   3398     self.assertEqual(73, value_x_grad)
   3399 
   3400   @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
   3401   @test_util.run_v1_only("b/120545219")
   3402   def testWhileGrad_IndexedSlices(self):
   3403     with self.cached_session():
   3404       values = constant_op.constant([2.0, 4.0], name="values")
   3405       indices = constant_op.constant([0, 3], name="indices")
   3406       shape = constant_op.constant([10], name="dense_shape")
   3407       i = constant_op.constant(0)
   3408       x = ops.IndexedSlices(values, indices, dense_shape=shape)
   3409 
   3410       def c(i, _):
   3411         return i < 10
   3412 
   3413       def b(i, x):
   3414         return [
   3415             i + 1,
   3416             ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape)
   3417         ]
   3418 
   3419       _, r = control_flow_ops.while_loop(c, b, [i, x])
   3420       r = gradients_impl.gradients(r.values, values)[0]
   3421       self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
   3422 
   3423   @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
   3424   @test_util.run_v1_only("b/120545219")
   3425   def testWhileGrad_SparseTensor(self):
   3426     with self.cached_session():
   3427       values = constant_op.constant([2.0, 4.0], name="values")
   3428       indices = constant_op.constant(
   3429           [[0], [3]], dtype=dtypes.int64, name="indices")
   3430       shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape")
   3431       i = constant_op.constant(0)
   3432       x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape)
   3433 
   3434       def c(i, _):
   3435         return i < 10
   3436 
   3437       def b(i, x):
   3438         return [
   3439             i + 1,
   3440             sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape)
   3441         ]
   3442 
   3443       _, r = control_flow_ops.while_loop(c, b, [i, x])
   3444       r = gradients_impl.gradients(r.values, values)[0]
   3445       self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r))
   3446 
   3447   @test_util.run_v1_only("b/120545219")
   3448   def testCallGradInLoop(self):
   3449     with self.cached_session() as sess:
   3450       i0 = constant_op.constant(0)
   3451       params = constant_op.constant(5.0)
   3452       params_1 = math_ops.square(params)
   3453 
   3454       def c(i, _):
   3455         return i < 10
   3456 
   3457       def b(i, x):
   3458         data = constant_op.constant([1.0, 2.0, 3.0])
   3459         data = math_ops.multiply(data, params_1)
   3460         x1 = x + gradients_impl.gradients(data, params)[0]
   3461         return i + 1, x1
   3462 
   3463       output_grad = control_flow_ops.while_loop(
   3464           c, b, [i0, constant_op.constant(0.0)])
   3465       self.assertAllClose(600.0, self.evaluate(output_grad)[1])
   3466 
   3467   @test_util.run_deprecated_v1
   3468   def testWhileAndTensorArray(self):
   3469     with self.cached_session() as sess:
   3470       param = constant_op.constant(2.0)
   3471       n0 = constant_op.constant(0)
   3472       y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems")
   3473 
   3474       def c(i, _):
   3475         return i < 10
   3476 
   3477       def b(i, y):
   3478         return [
   3479             i + 1,
   3480             map_fn.map_fn(lambda x: math_ops.multiply(x, param), y)
   3481         ]
   3482 
   3483       r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1)
   3484       r = gradients_impl.gradients(r, param)[0]
   3485       self.assertAllClose(107520.0, self.evaluate(r))
   3486 
   3487   @test_util.run_deprecated_v1
   3488   def testNestedWhileAndTensorArray(self):
   3489     n = constant_op.constant(3.0)
   3490 
   3491     def Body(row, ta):
   3492 
   3493       def InnerBody(row, col, ta):
   3494         # Note: row and col are 1-based.
   3495         ta = ta.write(
   3496             math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col)
   3497         return row, col + 1., ta
   3498 
   3499       ta = control_flow_ops.while_loop(
   3500           lambda _, col, _1: col <= n,
   3501           InnerBody, [row, constant_op.constant(1.), ta],
   3502           return_same_structure=False)[2]
   3503       return row + 1., ta
   3504 
   3505     ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9)
   3506     ta = control_flow_ops.while_loop(
   3507         lambda row, _: row <= n,
   3508         Body, [constant_op.constant(1.), ta],
   3509         return_same_structure=False)[1]
   3510 
   3511     output = array_ops.reshape(ta.stack(), [3, 3])
   3512     self.assertAllEqual(
   3513         self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]])
   3514     # TODO(b/117675481): This does not work with current TA. Enable with new TA.
   3515     # grad = gradients_impl.gradients(output, [n])
   3516     # self.assertEqual(self.evaluate(grad), 3.5)
   3517 
   3518   @test_util.run_deprecated_v1
   3519   def testWhileGrad_StopGrad(self):
   3520     with self.cached_session():
   3521       x = constant_op.constant(3.0, name="x")
   3522       y = constant_op.constant(2.0, name="y")
   3523 
   3524       c = lambda x, y: math_ops.less(x, 100.0)
   3525 
   3526       def b(x, y):
   3527         y1 = math_ops.square(y)
   3528         x1 = math_ops.add(math_ops.square(x), y1)
   3529         return x1, y1
   3530 
   3531       rx, ry = control_flow_ops.while_loop(c, b, [x, y])
   3532 
   3533       r = gradients_impl.gradients(rx, y)[0]
   3534       self.assertEqual(136.0, self.evaluate(r))
   3535       r = gradients_impl.gradients(ry, y)[0]
   3536       self.assertEqual(32.0, self.evaluate(r))
   3537 
   3538       r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0]
   3539       self.assertEqual(r, None)
   3540       r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0]
   3541       self.assertEqual(r, None)
   3542 
   3543       r = gradients_impl.gradients(
   3544           array_ops.stop_gradient(math_ops.square(rx)), y)[0]
   3545       self.assertEqual(r, None)
   3546       r = gradients_impl.gradients(
   3547           array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0]
   3548       self.assertEqual(r, None)
   3549       r = gradients_impl.gradients(
   3550           array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0]
   3551       self.assertEqual(r, None)
   3552 
   3553       r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0]
   3554       self.assertEqual(168.0, self.evaluate(r))
   3555       r = gradients_impl.gradients(
   3556           math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0]
   3557       self.assertEqual(136.0, self.evaluate(r))
   3558       r = gradients_impl.gradients(
   3559           math_ops.add(array_ops.stop_gradient(rx), ry), y)[0]
   3560       self.assertEqual(32.0, self.evaluate(r))
   3561 
   3562   @test_util.run_deprecated_v1
   3563   @test_util.disable_control_flow_v2("b/118712257")
   3564   def testWhileGrad_StopGradInside(self):
   3565     with self.cached_session():
   3566       x = constant_op.constant(3.0, name="x")
   3567       y = constant_op.constant(2.0, name="y")
   3568 
   3569       c = lambda x, y: math_ops.less(x, 100.0)
   3570 
   3571       def b(x, y):
   3572         y1 = array_ops.stop_gradient(math_ops.square(y))
   3573         x1 = math_ops.add(math_ops.square(x), y1)
   3574         return x1, y1
   3575 
   3576       rx, _ = control_flow_ops.while_loop(c, b, [x, y])
   3577 
   3578       r = gradients_impl.gradients(rx, y)[0]
   3579       self.assertAllClose(0.0, self.evaluate(r))
   3580       r = gradients_impl.gradients(rx, x)[0]
   3581       self.assertAllClose(156.0, self.evaluate(r))
   3582 
   3583   @test_util.run_deprecated_v1
   3584   @test_util.disable_control_flow_v2("b/118712257")
   3585   def testWhileGrad_StopGradInsideNoShape(self):
   3586     with self.cached_session() as sess:
   3587       x = array_ops.placeholder(dtypes.float32)
   3588       y = array_ops.placeholder(dtypes.float32)
   3589 
   3590       c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0)
   3591 
   3592       def b(x, y):
   3593         y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped"))
   3594         x1 = math_ops.add(math_ops.square(x), y1)
   3595         return x1, y1
   3596 
   3597       rx, _ = control_flow_ops.while_loop(c, b, [x, y])
   3598 
   3599       r = gradients_impl.gradients(rx, y)[0]
   3600       feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]}
   3601       self.assertAllClose([0.0, 0.0], sess.run(r, feed_dict=feed_dict))
   3602       r = gradients_impl.gradients(rx, x)[0]
   3603       self.assertAllClose([156.0, 400.0], sess.run(r, feed_dict=feed_dict))
   3604       name = "gradients/while/stopped_grad"
   3605       all_ops = x.graph.get_operations()
   3606       self.assertFalse(any(name in op.name for op in all_ops))
   3607 
   3608   @test_util.run_deprecated_v1
   3609   def testWhileGradGradFail(self):
   3610     theta = variables.Variable(initial_value=1.)
   3611 
   3612     def fn(prev, x):
   3613       return prev + x * theta
   3614 
   3615     result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
   3616     grad_theta = gradients_impl.gradients(result, theta)
   3617     if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
   3618       with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
   3619         gradients_impl.gradients(grad_theta, theta)
   3620     grad_theta_stopped = array_ops.stop_gradient(grad_theta)
   3621     gradients_impl.gradients(grad_theta_stopped, theta)
   3622 
   3623   @test_util.run_deprecated_v1
   3624   def testStopGradOnWhileGrad(self):
   3625     with self.cached_session():
   3626       x = constant_op.constant(2.0, name="x")
   3627       y = constant_op.constant(2.0, name="y")
   3628 
   3629       c = lambda x: math_ops.less(x, 100.0)
   3630       b = lambda x: math_ops.multiply(x, y)
   3631       rx = control_flow_ops.while_loop(c, b, [x])
   3632 
   3633       rg = gradients_impl.gradients(rx, y)[0]
   3634       rg = array_ops.stop_gradient(rg)
   3635       r = math_ops.add(math_ops.square(y), rx)
   3636       r = math_ops.add(r, rg)
   3637       r = gradients_impl.gradients(r, y)[0]
   3638       self.assertEqual(388.0, self.evaluate(r))
   3639 
   3640   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   3641   @test_util.run_deprecated_v1
   3642   def testWhileGradientWithNontrainablePath1(self):
   3643     q = variables.Variable([7., 8.])
   3644 
   3645     def cond(_, y):
   3646       del y
   3647       return False
   3648 
   3649     def body(x, _):
   3650       return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
   3651 
   3652     _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
   3653     dy_dq, = gradients_impl.gradients(y, q)
   3654     self.assertIsNotNone(dy_dq)
   3655     with self.cached_session() as sess:
   3656       self.evaluate(q.initializer)
   3657       self.assertAllClose([0., 0.], self.evaluate(dy_dq))
   3658 
   3659   @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
   3660   @test_util.run_v1_only("b/120545219")
   3661   def testWhileGradientWithNontrainablePath2(self):
   3662     q = variables.Variable([7., 8.])
   3663 
   3664     def cond(_, y):
   3665       return math_ops.equal(y, 0.)
   3666 
   3667     def body(x, _):
   3668       zero = constant_op.constant(0, dtype=dtypes.int64)
   3669       return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q)
   3670 
   3671     _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.))
   3672     dy_dq, = gradients_impl.gradients(y, q)
   3673     self.assertIsNotNone(dy_dq)
   3674     with self.cached_session() as sess:
   3675       self.evaluate(q.initializer)
   3676       self.assertAllClose([1., 1.], self.evaluate(dy_dq))
   3677 
   3678   @test_util.run_v1_only("b/120545219")
   3679   def testIssue16504(self):
   3680     c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
   3681     w = variables.Variable(
   3682         initial_value=np.ones(100), dtype=dtypes.float32) / 100
   3683     k = variables.Variable(0, dtype=dtypes.int32)
   3684     chg_w = constant_op.constant(np.inf, dtype=dtypes.float32)
   3685 
   3686     def cond(k, _, chg_w):
   3687       return math_ops.logical_and(k < 10, chg_w > 1e-3)
   3688 
   3689     def body(k, w, chg_w):
   3690       grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w)
   3691       w_n = w * math_ops.exp(-0.1 * grad)
   3692       w_n /= math_ops.reduce_sum(w_n)
   3693       chg_w = (
   3694           math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum(
   3695               math_ops.abs(w)))
   3696       return k + 1, w_n, chg_w
   3697 
   3698     _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w])
   3699     grad, = gradients_impl.gradients(w, c)
   3700     self.assertIsNotNone(grad)
   3701 
   3702   @test_util.run_v1_only("b/120545219")
   3703   def testStopGradMultiFlows(self):
   3704     with self.cached_session():
   3705 
   3706       def body(i, y, r):
   3707         x = variable_scope.get_variable(
   3708             "x",
   3709             shape=(),
   3710             dtype=dtypes.float32,
   3711             initializer=init_ops.ones_initializer())
   3712         y *= x
   3713         return [i + 1, y, r + math_ops.reduce_sum(y)]
   3714 
   3715       i0 = constant_op.constant(0)
   3716       y0 = array_ops.ones(5)
   3717       r0 = constant_op.constant(0.0)
   3718       cond = lambda i, y, r: i < 1
   3719       _, _, r = control_flow_ops.while_loop(
   3720           cond, body, [i0, y0, r0], back_prop=True)
   3721 
   3722       vars_ = variables.global_variables()
   3723       grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0])
   3724       z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads)))
   3725       result = gradients_impl.gradients(z, vars_)[0]
   3726       self.evaluate(variables.global_variables_initializer())
   3727       self.assertEqual(5.0, self.evaluate(result))
   3728 
   3729   @test_util.run_v1_only("b/120545219")
   3730   def testOneValueCond(self):
   3731 
   3732     with self.cached_session():
   3733       c = array_ops.placeholder(dtypes.int32, shape=[])
   3734       one = ops.convert_to_tensor(1, name="one")
   3735       two = ops.convert_to_tensor(2, name="two")
   3736       p = math_ops.greater_equal(c, 1)
   3737       i = control_flow_ops.cond(p, lambda: one, lambda: two)
   3738       self.assertTrue(isinstance(i, ops.Tensor))
   3739 
   3740       # True case: c = 2 is >= 1
   3741       self.assertEqual([1], i.eval(feed_dict={c: 2}))
   3742 
   3743       # False case: c = 0 is not >= 1
   3744       self.assertEqual([2], i.eval(feed_dict={c: 0}))
   3745 
   3746   @test_util.run_deprecated_v1
   3747   def testExampleCond(self):
   3748 
   3749     with self.cached_session():
   3750       x = ops.convert_to_tensor([-2.0, 2.0], name="x")
   3751       d = array_ops.placeholder(dtypes.int32, shape=[])
   3752 
   3753       def l2():
   3754         return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x)))
   3755 
   3756       def l1():
   3757         return math_ops.reduce_sum(math_ops.abs(x))
   3758 
   3759       i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1)
   3760       self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
   3761       self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
   3762 
   3763   @test_util.run_v1_only("b/120545219")
   3764   def testCase(self):
   3765     with self.cached_session():
   3766       x = constant_op.constant(1)
   3767       y = constant_op.constant(2)
   3768       z = constant_op.constant(3)
   3769       f1 = lambda: constant_op.constant(17)
   3770       f2 = lambda: constant_op.constant(23)
   3771       f3 = lambda: constant_op.constant(-1)
   3772 
   3773       r1 = control_flow_ops.case(
   3774           {
   3775               x < y: f1,
   3776               x > z: f2
   3777           }, default=f3, exclusive=True)
   3778       self.assertAllEqual(r1, 17)
   3779 
   3780       r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3)
   3781       self.assertAllEqual(r2, 23)
   3782 
   3783       # Duplicate events can happen, first one is selected
   3784       r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3)
   3785       self.assertAllEqual(r3, 17)
   3786 
   3787       # Duplicate events cause an error if exclusive = True
   3788       r4 = control_flow_ops.case(
   3789           [(x < y, f1), (x < y, f2)], default=f3, exclusive=True)
   3790       with self.assertRaisesOpError("Input error:"):
   3791         self.evaluate(r4)
   3792 
   3793       # Check that the default is called if none of the others are
   3794       r5 = control_flow_ops.case({x > y: f1}, default=f3)
   3795       self.assertAllEqual(r5, -1)
   3796 
   3797       ran_once = [False, False, False]
   3798 
   3799       def break_run_twice(ix):
   3800 
   3801         def _break():
   3802           ran_once[ix] = True
   3803           return constant_op.constant(ix)
   3804 
   3805         return _break
   3806 
   3807       # Should not fail - each conditional gets called exactly once
   3808       # except default.  Default gets called twice: once to create an
   3809       # empty output and once for the actual cond switch.
   3810       r6 = control_flow_ops.case(
   3811           [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))],
   3812           default=lambda: constant_op.constant(2))
   3813 
   3814       self.assertAllEqual(r6, 0)
   3815 
   3816   @test_util.run_v1_only("b/120545219")
   3817   def testCaseSideEffects(self):
   3818     with self.cached_session() as sess:
   3819       v0 = variables.Variable(-1)
   3820       v1 = variables.Variable(-1)
   3821       v2 = variables.Variable(-1)
   3822 
   3823       a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0)
   3824       b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1)
   3825       c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2)
   3826 
   3827       x = constant_op.constant(1)
   3828       y = constant_op.constant(2)
   3829 
   3830       r0 = control_flow_ops.case(
   3831           ((x < y, a), (x > y, b)), default=c, exclusive=True)
   3832       r1 = control_flow_ops.case(
   3833           ((x > y, a), (x < y, b)), default=c, exclusive=True)
   3834       r2 = control_flow_ops.case(
   3835           ((x > y, a), (x > y, b)), default=c, exclusive=True)
   3836 
   3837       self.evaluate(variables.global_variables_initializer())
   3838       self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
   3839       self.assertEqual(2, self.evaluate(r2))
   3840       self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2])
   3841 
   3842       self.evaluate(variables.global_variables_initializer())
   3843       self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
   3844       self.assertEqual(1, self.evaluate(r1))
   3845       self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1])
   3846 
   3847       self.evaluate(variables.global_variables_initializer())
   3848       self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3)
   3849       self.assertEqual(0, self.evaluate(r0))
   3850       self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1])
   3851 
   3852   @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
   3853   @test_util.run_v1_only("b/120545219")
   3854   def testOneOpCond(self):
   3855     with self.cached_session():
   3856       v = variables.Variable(0)
   3857       c = ops.convert_to_tensor(0)
   3858       one = ops.convert_to_tensor(1)
   3859       two = ops.convert_to_tensor(2)
   3860       p = math_ops.greater_equal(c, 1)
   3861 
   3862       def a():
   3863         return state_ops.assign(v, one)
   3864 
   3865       def b():
   3866         return state_ops.assign(v, two)
   3867 
   3868       i = control_flow_ops.cond(p, a, b)
   3869       self.assertTrue(isinstance(i, ops.Tensor))
   3870       self.evaluate(variables.global_variables_initializer())
   3871 
   3872       self.assertEqual(0, self.evaluate(v))
   3873 
   3874       # True case: c = 2 is >= 1, v is set to 1.
   3875       self.assertEqual(1, i.eval(feed_dict={c.name: 2}))
   3876       self.assertEqual(1, self.evaluate(v))
   3877 
   3878       # False case: c = 0 is not >= 1, v is set to 2.
   3879       self.assertEqual(2, i.eval(feed_dict={c.name: 0}))
   3880       self.assertEqual(2, self.evaluate(v))
   3881 
   3882   @test_util.run_v1_only("b/120545219")
   3883   def testWithOpsDependencies(self):
   3884     with self.cached_session() as sess:
   3885       v = variables.VariableV1(0.0)
   3886       c = constant_op.constant(10)
   3887 
   3888       # Fetching v directly will result in an uninitialized error
   3889       with self.assertRaisesOpError("Attempting to use uninitialized value"):
   3890         self.evaluate([c, v])
   3891 
   3892       # Use a control dependency to ensure init_variable is run
   3893       # while asking for c
   3894       real_v = control_flow_ops.with_dependencies(
   3895           name="real_tensor",
   3896           output_tensor=v._ref(),  # pylint: disable=protected-access
   3897           dependencies=[v.initializer])
   3898       c_val, real_v_val = self.evaluate([c, real_v])
   3899 
   3900     # Ensure the result of 'real_c' is the same as 'c'
   3901     self.assertAllEqual(10, c_val)
   3902 
   3903     # Ensure that 'v' is initialized
   3904     self.assertAllClose(0.0, real_v_val)
   3905 
   3906   @test_util.run_v1_only("b/120545219")
   3907   def testWithTensorDependencies(self):
   3908     with self.cached_session():
   3909       v = variables.VariableV1(0.0)
   3910       c1 = constant_op.constant(10)
   3911       c2 = constant_op.constant(20)
   3912 
   3913       # c1_with_init_v depends on the init op for v
   3914       c1_with_init_v = control_flow_ops.with_dependencies(
   3915           name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer])
   3916       # c2_with_c1 depends on the value of c1_with_init_v
   3917       c2_with_c1_dep = control_flow_ops.with_dependencies(
   3918           name="c2_with_c1_dep",
   3919           output_tensor=c2,
   3920           dependencies=[c1_with_init_v])
   3921 
   3922       # Fetching v directly will result in an uninitialized error
   3923       with self.assertRaisesOpError("Attempting to use uninitialized value"):
   3924         self.evaluate(v)
   3925 
   3926       # Get the value of 'c2_with_c1_dep', which should cause 'v'
   3927       # to be initialized.
   3928       self.assertAllEqual(20, self.evaluate(c2_with_c1_dep))
   3929 
   3930       # Ensure that 'v' is initialized
   3931       self.assertAllClose(0.0, self.evaluate(v))
   3932 
   3933   @test_util.run_v1_only("b/120545219")
   3934   def testWithIndexedSlicesDependencies(self):
   3935     with self.cached_session():
   3936       v = variables.VariableV1(
   3937           np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
   3938       v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
   3939       gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
   3940       v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer],
   3941                                                              v_at_1)
   3942       gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values,
   3943                                                   v_at_1_after_init.indices)
   3944 
   3945       # Fetching gather_v_at_1 will result in an uninitialized error
   3946       with self.assertRaisesOpError("Attempting to use uninitialized value"):
   3947         self.evaluate(gather_v_at_1)
   3948 
   3949       # Getting gather_v_at_1_after_init will work, and initialize v.
   3950       self.assertAllEqual([[10.0, 11.0]],
   3951                           self.evaluate(gather_v_at_1_after_init))
   3952 
   3953       # Double check that 'v' is initialized
   3954       self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
   3955                           self.evaluate(v))
   3956 
   3957   def testDependenciesDevice(self):
   3958     with ops.Graph().as_default():
   3959       # device set on tensor => same device on dep.
   3960       with ops.device("/job:ps"):
   3961         vd = variables.VariableV1([0.0])
   3962       with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
   3963       self.assertTrue("/job:ps" in with_vd_dep.device)
   3964 
   3965       # No device set on tensor => no device on dep.
   3966       vnod = variables.VariableV1([0.0])
   3967       with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
   3968                                                          vnod)
   3969       self.assertDeviceEqual(None, with_vnod_dep.device)
   3970 
   3971       # device set on tensor, default device on graph => default device on dep.
   3972       vdef = variables.VariableV1([0.0], name="vdef")
   3973       with ops.device("/job:worker/device:GPU:1"):
   3974         with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
   3975                                                            vdef)
   3976         # The device is empty, but the colocation constraint is set.
   3977         self.assertDeviceEqual("", with_vdef_dep.device)
   3978         self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups())
   3979 
   3980   @test_util.run_v1_only("b/120545219")
   3981   def testGroup(self):
   3982     with self.cached_session() as sess:
   3983       v1 = variables.VariableV1([0.0])
   3984       v2 = variables.VariableV1([1.0])
   3985 
   3986       # Group init1 and init2 and run.
   3987       init = control_flow_ops.group(v1.initializer, v2.initializer)
   3988       # Fetching v1 directly will result in an uninitialized error
   3989       with self.assertRaisesOpError("Attempting to use uninitialized value"):
   3990         self.evaluate(v1)
   3991 
   3992       # Runs "init" before fetching v1 and v2.
   3993       init.run()
   3994       v1_val, v2_val = self.evaluate([v1, v2])
   3995 
   3996     # Ensure that v1 and v2 are initialized
   3997     self.assertAllClose([0.0], v1_val)
   3998     self.assertAllClose([1.0], v2_val)
   3999 
   4000   @test_util.run_v1_only("b/120545219")
   4001   def testGroupEmpty(self):
   4002     op = control_flow_ops.group()
   4003     self.assertEqual(op.type, "NoOp")
   4004     self.assertEqual(op.control_inputs, [])
   4005 
   4006   @test_util.run_deprecated_v1
   4007   def testMergeShapes(self):
   4008     # All inputs unknown.
   4009     p1 = array_ops.placeholder(dtypes.float32)
   4010     p2 = array_ops.placeholder(dtypes.float32)
   4011     p3 = array_ops.placeholder(dtypes.float32)
   4012     m, index = control_flow_ops.merge([p1, p2, p3])
   4013     self.assertIs(None, m.get_shape().ndims)
   4014     self.assertEqual([], index.get_shape())
   4015 
   4016     # All inputs known with different ranks.
   4017     p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
   4018     p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3])
   4019     m, index = control_flow_ops.merge([p1, p2])
   4020     self.assertIs(None, m.get_shape().ndims)
   4021     self.assertEqual([], index.get_shape())
   4022 
   4023     # All inputs known with some dimensions different.
   4024     p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
   4025     p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1])
   4026     m, index = control_flow_ops.merge([p1, p2])
   4027     self.assertEqual([None, None], m.get_shape().as_list())
   4028     self.assertEqual([], index.get_shape())
   4029 
   4030     p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
   4031     p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
   4032     m, index = control_flow_ops.merge([p1, p2])
   4033     self.assertEqual([None, 2], m.get_shape().as_list())
   4034     self.assertEqual([], index.get_shape())
   4035 
   4036     p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
   4037     p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2])
   4038     m, index = control_flow_ops.merge([p1, p2])
   4039     self.assertEqual([None, 2], m.get_shape().as_list())
   4040     self.assertEqual([], index.get_shape())
   4041 
   4042     # All inputs known with same dimensions.
   4043     p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
   4044     p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2])
   4045     m, index = control_flow_ops.merge([p1, p2])
   4046     self.assertEqual([1, 2], m.get_shape().as_list())
   4047     self.assertEqual([], index.get_shape())
   4048 
   4049     p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
   4050     p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
   4051     m, index = control_flow_ops.merge([p1, p2])
   4052     self.assertEqual([None, 2], m.get_shape().as_list())
   4053     self.assertEqual([], index.get_shape())
   4054 
   4055     p1 = array_ops.placeholder(dtypes.float32, shape=[None, None])
   4056     p2 = array_ops.placeholder(dtypes.float32, shape=[None, None])
   4057     m, index = control_flow_ops.merge([p1, p2])
   4058     self.assertEqual([None, None], m.get_shape().as_list())
   4059     self.assertEqual([], index.get_shape())
   4060 
   4061   @test_util.run_v1_only("b/120545219")
   4062   def testRefSelect(self):
   4063     index = array_ops.placeholder(dtypes.int32)
   4064 
   4065     # All inputs unknown.
   4066     p1 = array_ops.placeholder(dtypes.float32)
   4067     p2 = array_ops.placeholder(dtypes.float32)
   4068     p3 = array_ops.placeholder(dtypes.float32)
   4069     v1 = variables.VariableV1(p1, validate_shape=False)
   4070     v2 = variables.VariableV1(p2, validate_shape=False)
   4071     v3 = variables.VariableV1(p3, validate_shape=False)
   4072     self.assertIs(None, v1.get_shape().ndims)
   4073     s = control_flow_ops.ref_select(index, [v1, v2, v3])
   4074     self.assertIs(None, s.get_shape().ndims)
   4075 
   4076     # All inputs known but different.
   4077     v1 = variables.VariableV1([[1, 2]])
   4078     v2 = variables.VariableV1([[2], [1]])
   4079     s = control_flow_ops.ref_select(index, [v1, v2])
   4080     self.assertIs(None, s.get_shape().ndims)
   4081 
   4082     # All inputs known and same.
   4083     v1 = variables.VariableV1([[1, 2]])
   4084     v2 = variables.VariableV1([[1, 2]])
   4085     s = control_flow_ops.ref_select(index, [v1, v2])
   4086     self.assertEqual([1, 2], s.get_shape())
   4087 
   4088     # Possibly the same but not guaranteed.
   4089     v1 = variables.VariableV1([[1., 2.]])
   4090     p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
   4091     v2 = variables.VariableV1(p2, validate_shape=False)
   4092     s = control_flow_ops.ref_select(index, [v1, v2])
   4093     self.assertEqual(None, s.get_shape())
   4094 
   4095   @test_util.run_deprecated_v1
   4096   def testRunLoopTensor(self):
   4097     with self.cached_session() as sess:
   4098       tensor_list = []
   4099 
   4100       def condition(t):
   4101         return t < constant_op.constant(5)
   4102 
   4103       def body(_):
   4104         tensor_list.append(constant_op.constant(5))
   4105         return constant_op.constant(10)
   4106 
   4107       result = control_flow_ops.while_loop(condition, body,
   4108                                            [constant_op.constant(4)])
   4109       self.assertEqual(10, self.evaluate(result))
   4110 
   4111       # Ensure that we cannot run a tensor that escapes the loop body
   4112       # accidentally.
   4113       with self.assertRaises(ValueError):
   4114         sess.run(tensor_list[0])
   4115 
   4116   @test_util.run_v1_only("b/120545219")
   4117   def testWhilePyFuncBasic(self):
   4118 
   4119     def func(x):
   4120       return np.square(x)
   4121 
   4122     with self.cached_session():
   4123       r = control_flow_ops.while_loop(
   4124           lambda i, v: i < 4,
   4125           lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]],
   4126           [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)],
   4127           [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()])
   4128       self.assertEqual(self.evaluate(r[1]), 65536.0)
   4129 
   4130   @test_util.run_v1_only("b/120545219")
   4131   def testWhileFuncBasic(self):
   4132 
   4133     @function.Defun(dtypes.float32)
   4134     def func(x):
   4135       return math_ops.square(math_ops.square(x))
   4136 
   4137     with self.cached_session():
   4138       x = constant_op.constant(2.0, dtypes.float32)
   4139       r = control_flow_ops.while_loop(
   4140           lambda i, v: i < 2, lambda i, v: [i + 1, func(v)],
   4141           [constant_op.constant(0), x],
   4142           [tensor_shape.unknown_shape(),
   4143            tensor_shape.unknown_shape()])
   4144       grad = gradients_impl.gradients(r, x)[0]
   4145       self.assertEqual(self.evaluate(r[1]), 65536.0)
   4146       self.assertEqual(self.evaluate(grad), 524288.0)
   4147       # while_v2 does not have stacks.
   4148       if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
   4149         self.assertEqual(
   4150             len([op for op in x.graph.get_operations() if op.type == "StackV2"
   4151                 ]), 1)
   4152 
   4153 
   4154   @test_util.run_v1_only("b/120545219")
   4155   def testQIntSwitchMerge(self):
   4156     with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
   4157       constant_qint = constant_op.constant(np.array([42]), dtypes.qint8)
   4158       cond = constant_op.constant(True, dtypes.bool)
   4159       v_f, v_t = control_flow_ops.switch(constant_qint, cond)
   4160       result = control_flow_ops.merge([v_f, v_t])
   4161       self.evaluate(result)
   4162 
   4163   @test_util.run_v1_only("b/120545219")
   4164   def testQIntRefSwitchMerge(self):
   4165     with self.cached_session(use_gpu=test.is_gpu_available()) as sess:
   4166       var_qint = gen_state_ops.variable(
   4167           shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="")
   4168       assign_op = state_ops.assign(
   4169           var_qint, constant_op.constant(np.array([42]), dtypes.qint8))
   4170       self.evaluate(assign_op)
   4171 
   4172       cond = constant_op.constant(True, dtypes.bool)
   4173       v_f, v_t = control_flow_ops.ref_switch(var_qint, cond)
   4174       result = control_flow_ops.ref_merge([v_f, v_t])
   4175       self.evaluate(result)
   4176 
   4177   @test_util.run_v1_only("b/120545219")
   4178   def testUInt64SwitchMerge(self):
   4179     with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
   4180       constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64)
   4181       cond = constant_op.constant(True, dtypes.bool)
   4182       v_f, v_t = control_flow_ops.switch(constant_uint64, cond)
   4183       result = control_flow_ops.merge([v_f, v_t])
   4184       self.evaluate(result)
   4185 
   4186   @test_util.run_deprecated_v1
   4187   def testQIntArgAndRet(self):
   4188 
   4189     @function.Defun(dtypes.qint8)
   4190     def func(x):
   4191       return x
   4192 
   4193     with self.cached_session(force_gpu=test.is_gpu_available()) as sess:
   4194       qint = constant_op.constant(np.array([42]), dtypes.qint8)
   4195       result = func(qint)
   4196       self.evaluate(result)
   4197 
   4198   def testSparseIdentity(self):
   4199     st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
   4200     st2 = control_flow_ops._Identity(st1)
   4201     self.assertAllEqual(st1.indices, st2.indices)
   4202     self.assertAllEqual(st1.values, st2.values)
   4203     self.assertAllEqual(st1.dense_shape, st2.dense_shape)
   4204 
   4205   def testSparseEnterExit(self):
   4206     st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10])
   4207     st2 = control_flow_ops._Enter(st1, "foo_1")
   4208     st3 = control_flow_ops.exit(st2)
   4209     self.assertAllEqual(st1.indices, st3.indices)
   4210     self.assertAllEqual(st1.values, st3.values)
   4211     self.assertAllEqual(st1.dense_shape, st3.dense_shape)
   4212 
   4213 
   4214 class ControlFlowContextCheckTest(test.TestCase):
   4215 
   4216   def _getWhileTensor(self):
   4217     """Creates and returns a tensor from a while context."""
   4218     tensor = []
   4219 
   4220     def body(i):
   4221       if not tensor:
   4222         tensor.append(constant_op.constant(1))
   4223       return i + tensor[0]
   4224 
   4225     control_flow_ops.while_loop(lambda i: i < 10, body, [0])
   4226     return tensor[0]
   4227 
   4228   def _getCondTensor(self):
   4229     cond_tensor = []
   4230 
   4231     def true_fn():
   4232       if not cond_tensor:
   4233         cond_tensor.append(constant_op.constant(1))
   4234       return cond_tensor[0]
   4235 
   4236     control_flow_ops.cond(
   4237         math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
   4238     return cond_tensor[0]
   4239 
   4240   @test_util.run_v1_only("b/120545219")
   4241   def testInvalidContext(self):
   4242     # Accessing a while loop tensor outside of control flow is illegal.
   4243     while_tensor = self._getWhileTensor()
   4244     with self.assertRaisesRegexp(
   4245         ValueError,
   4246         "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' "
   4247         "is in a while loop. See info log for more details."):
   4248       math_ops.add(1, while_tensor)
   4249 
   4250   @test_util.run_v1_only("b/120545219")
   4251   def testInvalidContextInCond(self):
   4252     # Accessing a while loop tensor in cond is illegal.
   4253     while_tensor = self._getWhileTensor()
   4254     with self.assertRaisesRegexp(
   4255         ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because "
   4256         "'while/Const_1' is in a while loop. See info log for more details."):
   4257       # TODO(skyewm): this passes if we return while_tensor directly instead
   4258       # of using it as input to another op.
   4259       control_flow_ops.cond(
   4260           math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor),
   4261           lambda: constant_op.constant(0))
   4262 
   4263   @test_util.run_v1_only("b/120545219")
   4264   def testInvalidContextInWhile(self):
   4265     # Accessing a while loop tensor in a different while loop is illegal.
   4266     while_tensor = self._getWhileTensor()
   4267     with self.assertRaisesRegexp(
   4268         ValueError,
   4269         "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are "
   4270         "in different while loops. See info log for more details."):
   4271       control_flow_ops.while_loop(lambda i: i < 10,
   4272                                   lambda x: math_ops.add(1, while_tensor), [0])
   4273 
   4274     with self.assertRaisesRegexp(
   4275         ValueError,
   4276         "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' "
   4277         "because they are in different while loops. See info log for more "
   4278         "details."):
   4279       control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0])
   4280 
   4281   def testValidCondContext(self):
   4282     # Accessing a tensor from a cond context is OK (although dangerous).
   4283     cond_tensor = self._getCondTensor()
   4284     math_ops.add(1, cond_tensor)
   4285 
   4286   def testValidCondContextBranches(self):
   4287     # Accessing a tensor from a cond context from the other branch's cond
   4288     # context is OK (although dangerous).
   4289     cond_tensor = []
   4290 
   4291     def branch_fn():
   4292       if not cond_tensor:
   4293         cond_tensor.append(constant_op.constant(1))
   4294       return cond_tensor[0]
   4295 
   4296     control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn)
   4297 
   4298   @test_util.run_v1_only("b/120545219")
   4299   def testValidWhileContext(self):
   4300     # Accessing a tensor in a nested while is OK.
   4301     def body(_):
   4302       c = constant_op.constant(1)
   4303       return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0])
   4304 
   4305     control_flow_ops.while_loop(lambda i: i < 5, body, [0])
   4306 
   4307   @test_util.run_v1_only("b/120545219")
   4308   def testValidNestedContexts(self):
   4309     # Accessing a tensor from a cond context in a while context, all inside an
   4310     # outer while context, is OK.
   4311     def body(_):
   4312       cond_tensor = self._getCondTensor()
   4313       # Create another cond containing the while loop for good measure
   4314       return control_flow_ops.cond(
   4315           math_ops.less(1, 2),
   4316           lambda: control_flow_ops.while_loop(lambda i: i < 3,
   4317                                               lambda i: i + cond_tensor, [0]),
   4318           lambda: constant_op.constant(0))
   4319 
   4320     control_flow_ops.while_loop(lambda i: i < 5, body, [0])
   4321 
   4322   @test_util.run_v1_only("b/120545219")
   4323   def testInvalidNestedContexts(self):
   4324     # Accessing a tensor from a while context in a different while context, all
   4325     # inside a cond context, is illegal.
   4326     def true_fn():
   4327       while_tensor = self._getWhileTensor()
   4328       return control_flow_ops.while_loop(lambda i: i < 3,
   4329                                          lambda i: i + while_tensor, [0])
   4330 
   4331     with self.assertRaisesRegexp(
   4332         ValueError,
   4333         "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because"
   4334         " they are in different while loops. See info log for more details."):
   4335       control_flow_ops.cond(
   4336           math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0))
   4337 
   4338 
   4339 class TupleTest(test.TestCase):
   4340 
   4341   @test_util.run_v1_only("b/120545219")
   4342   def testTensors(self):
   4343     for v1_first in [True, False]:
   4344       with self.cached_session():
   4345         v1 = variables.VariableV1([1.0])
   4346         add1 = math_ops.add(
   4347             control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
   4348             2.0)
   4349         v2 = variables.VariableV1([10.0])
   4350         add2 = math_ops.add(
   4351             control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
   4352             20.0)
   4353         t1, _, t2 = control_flow_ops.tuple([add1, None, add2])
   4354 
   4355         # v1 is not initialized.
   4356         with self.assertRaisesOpError("Attempting to use uninitialized value"):
   4357           self.evaluate(v1)
   4358 
   4359         # v2 is not initialized.
   4360         with self.assertRaisesOpError("Attempting to use uninitialized value"):
   4361           self.evaluate(v2)
   4362 
   4363         if v1_first:
   4364           # Getting t1 initializes v2.
   4365           self.assertAllClose([3.0], self.evaluate(t1))
   4366           self.assertAllClose([10.0], self.evaluate(v2))
   4367         else:
   4368           # Getting t2 initializes v1.
   4369           self.assertAllClose([30.0], self.evaluate(t2))
   4370           self.assertAllClose([1.0], self.evaluate(v1))
   4371 
   4372   @test_util.run_v1_only("b/120545219")
   4373   def testIndexedSlices(self):
   4374     for v1_first in [True, False]:
   4375       with self.cached_session():
   4376         v1 = variables.VariableV1(
   4377             np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
   4378                 np.float32))
   4379         v1_at_1 = ops.IndexedSlices(
   4380             control_flow_ops.with_dependencies([v1.initializer], v1._ref()),  # pylint: disable=protected-access
   4381             constant_op.constant([1]))
   4382 
   4383         v2 = variables.VariableV1(
   4384             np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
   4385                 np.float32))
   4386         v2_at_1 = ops.IndexedSlices(
   4387             control_flow_ops.with_dependencies([v2.initializer], v2._ref()),  # pylint: disable=protected-access
   4388             constant_op.constant([1]))
   4389 
   4390         st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1])
   4391         g1 = array_ops.gather(st1.values, st1.indices)
   4392         g2 = array_ops.gather(st2.values, st2.indices)
   4393 
   4394         # v1 is not initialized.
   4395         with self.assertRaisesOpError("Attempting to use uninitialized value"):
   4396           self.evaluate(v1)
   4397 
   4398         # v2 is not initialized.
   4399         with self.assertRaisesOpError("Attempting to use uninitialized value"):
   4400           self.evaluate(v2)
   4401 
   4402         if v1_first:
   4403           # Getting g1 initializes v2.
   4404           self.assertAllClose([[10.0, 11.0]], self.evaluate(g1))
   4405           self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]],
   4406                               self.evaluate(v2))
   4407         else:
   4408           # Getting g2 initializes v1.
   4409           self.assertAllClose([[10.1, 11.1]], self.evaluate(g2))
   4410           self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]],
   4411                               self.evaluate(v1))
   4412 
   4413   def testAcceptTensorsAsControlInputs(self):
   4414     with self.cached_session():
   4415       var = variables.VariableV1(0)
   4416       assign = state_ops.assign(var, 1)
   4417       t, = control_flow_ops.tuple(
   4418           [constant_op.constant(0)], control_inputs=[assign])
   4419 
   4420       # Should trigger the assign.
   4421       self.evaluate(t)
   4422 
   4423       self.assertEquals(1, self.evaluate(var))
   4424 
   4425 
   4426 class AssertTest(test.TestCase):
   4427 
   4428   @test_util.run_deprecated_v1
   4429   def testGuardedAssertDoesNotCopyWhenTrue(self):
   4430     if test_util.is_gpu_available():
   4431       self.skipTest("b/128646478 fails in opensource")
   4432 
   4433     with self.session(use_gpu=True) as sess:
   4434       with ops.device(test.gpu_device_name()):
   4435         value = constant_op.constant(1.0)
   4436       with ops.device("/cpu:0"):
   4437         true = constant_op.constant(True)
   4438         guarded_assert = control_flow_ops.Assert(true, [value], name="guarded")
   4439         unguarded_assert = gen_logging_ops._assert(
   4440             true, [value], name="unguarded")
   4441       opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE)
   4442       guarded_metadata = config_pb2.RunMetadata()
   4443       sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata)
   4444       unguarded_metadata = config_pb2.RunMetadata()
   4445       sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata)
   4446       guarded_nodestat_names = [
   4447           n.node_name
   4448           for d in guarded_metadata.step_stats.dev_stats
   4449           for n in d.node_stats
   4450       ]
   4451       unguarded_nodestat_names = [
   4452           n.node_name
   4453           for d in unguarded_metadata.step_stats.dev_stats
   4454           for n in d.node_stats
   4455       ]
   4456       guarded_memcpy_nodestat_names = [
   4457           n for n in guarded_nodestat_names if "MEMCPYDtoH" in n
   4458       ]
   4459       unguarded_memcpy_nodestat_names = [
   4460           n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n
   4461       ]
   4462       if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
   4463         # A copy was performed for the unguarded assert
   4464         self.assertLess(0, len(unguarded_memcpy_nodestat_names),
   4465                         str(unguarded_nodestat_names))
   4466       # No copy was performed for the guarded assert
   4467       self.assertEqual([], guarded_memcpy_nodestat_names)
   4468 
   4469 
   4470 class WhileOpBenchmark(test.Benchmark):
   4471   """Evaluate the performance of while_loop op."""
   4472 
   4473   def _getInitVariables(self):
   4474     batch_size = 10
   4475     image_size = 256
   4476     kernel_size = 3
   4477     depth = 16
   4478 
   4479     init_step = constant_op.constant(-1)
   4480     image = variable_scope.get_variable(
   4481         "image",
   4482         initializer=random_ops.random_normal(
   4483             [batch_size, image_size, image_size, depth],
   4484             dtype=dtypes.float32,
   4485             stddev=1e-1))
   4486     kernel = variable_scope.get_variable(
   4487         "weights",
   4488         initializer=random_ops.truncated_normal(
   4489             [kernel_size, kernel_size, depth, depth],
   4490             dtype=dtypes.float32,
   4491             stddev=1e-1))
   4492     return init_step, image, kernel
   4493 
   4494   def _runOneBenchmark(self,
   4495                        default_device,
   4496                        num_iters=10,
   4497                        static_unroll=False,
   4498                        steps=10):
   4499     """Evaluate the while loop performance.
   4500 
   4501     Args:
   4502       default_device: The default device to run all ops except the loop_body.
   4503         loop_body is always run on GPU.
   4504       num_iters: Number of iterations to run.
   4505       static_unroll: If true, run unrolled version; otherwise, run while_loop.
   4506       steps: Total number of repeated steps to run the loop.
   4507 
   4508     Returns:
   4509       The duration of the run in seconds.
   4510     """
   4511 
   4512     def loop_body(i, x):
   4513       with ops.device("/gpu:0"):
   4514         # Always put loop body on GPU.
   4515         nx = nn_ops.conv2d(
   4516             input=x,
   4517             filter=kernel,
   4518             strides=[1, 1, 1, 1],
   4519             padding="SAME",
   4520             data_format="NHWC",
   4521             name="conv2d")
   4522         ni = math_ops.add(i, 1)
   4523         return ni, nx
   4524 
   4525     ops.reset_default_graph()
   4526     with session.Session() as sess, ops.device(default_device):
   4527       # Get the initial id i, input x, and kernel.
   4528       i, x, kernel = self._getInitVariables()
   4529       variables.global_variables_initializer().run()
   4530 
   4531       if static_unroll:
   4532         for _ in xrange(steps):
   4533           i, x = loop_body(i, x)
   4534       else:
   4535         i, x = control_flow_ops.while_loop(
   4536             lambda i, _: i < steps,
   4537             loop_body, [i, x],
   4538             parallel_iterations=steps,
   4539             swap_memory=True)
   4540 
   4541       r = math_ops.reduce_sum(x)
   4542       dx, dk = gradients_impl.gradients(r, [x, kernel])
   4543       # Use group to avoid fetching back results.
   4544       r = control_flow_ops.group(dx, dk)
   4545 
   4546       for _ in xrange(3):
   4547         # exclude warm up time
   4548         self.evaluate(r)
   4549 
   4550       start_time = time.time()
   4551       for _ in xrange(num_iters):
   4552         self.evaluate(r)
   4553       return (time.time() - start_time) / num_iters
   4554 
   4555   def benchmarkWhileOpCrossDevicePlacement(self):
   4556     iters = 10
   4557     # Run loop body on GPU, but other ops on CPU.
   4558     duration = self._runOneBenchmark("cpu", iters, static_unroll=False)
   4559     self.report_benchmark(
   4560         name="while_op_cross_device", iters=iters, wall_time=duration)
   4561 
   4562   def benchmarkWhileOpSameDevicePlacement(self):
   4563     iters = 10
   4564     # Run all ops on the same GPU device.
   4565     duration = self._runOneBenchmark("gpu", iters, static_unroll=False)
   4566     self.report_benchmark(
   4567         name="while_op_same_device", iters=iters, wall_time=duration)
   4568 
   4569   def benchmarkWhileOpUnrollCrossDevicePlacement(self):
   4570     iters = 10
   4571     # Run loop body on GPU, but other ops on CPU.
   4572     duration = self._runOneBenchmark("cpu", iters, static_unroll=True)
   4573     self.report_benchmark(
   4574         name="unroll_cross_device_cpu", iters=iters, wall_time=duration)
   4575 
   4576   def benchmarkWhileOpUnrollSameDevicePlacement(self):
   4577     iters = 10
   4578     # Run all ops on GPU.
   4579     duration = self._runOneBenchmark("gpu", iters, static_unroll=True)
   4580     self.report_benchmark(
   4581         name="unroll_same_device", iters=iters, wall_time=duration)
   4582 
   4583 
   4584 @test_util.with_control_flow_v2
   4585 class EagerTest(test.TestCase):
   4586 
   4587   def testCond(self):
   4588     with context.eager_mode():
   4589       pred = math_ops.less(1, 2)
   4590       fn1 = lambda: [constant_op.constant(10)]
   4591       fn2 = lambda: [constant_op.constant(20)]
   4592       r = control_flow_ops.cond(pred, fn1, fn2)
   4593 
   4594       self.assertAllEqual(r.numpy(), 10)
   4595       self.assertFalse(isinstance(r, list))
   4596 
   4597   # TODO(b/117279927): Re-enable once msan failure is fixed.
   4598   def DISABLED_testCondInDefun(self):
   4599     with context.eager_mode():
   4600 
   4601       @eager_function.defun
   4602       def foo(pred):
   4603         # TODO(b/111124878): this only needs to output one element.
   4604         fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
   4605         fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
   4606         return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
   4607 
   4608       r = foo(True)
   4609       self.assertAllEqual(r[0].numpy(), 10)
   4610       self.assertNotIsInstance(r, list)
   4611 
   4612       r = foo(False)
   4613       self.assertAllEqual(r[0].numpy(), 20)
   4614       self.assertFalse(isinstance(r, list))
   4615 
   4616   def testWhileLoop(self):
   4617     with context.eager_mode():
   4618       tensor = constant_op.constant([1, 2, 3, 4, 5])
   4619       self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50])
   4620 
   4621   def testWhileLoopWithMaxIterations(self):
   4622     with context.eager_mode():
   4623       tensor = constant_op.constant([1, 2, 3, 4, 5])
   4624       self.assertAllEqual(
   4625           isum(tensor, maximum_iterations=3).numpy(),
   4626           [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3])
   4627 
   4628   @test_util.run_v1_only("b/120545219")
   4629   def testWhileWithMaximumIterationsAndSingleArgument(self):
   4630     with context.eager_mode():
   4631       tensor = constant_op.constant(0)
   4632       r = control_flow_ops.while_loop(
   4633           lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1)
   4634       self.assertEqual(1, r.numpy())
   4635 
   4636   def testWithDependencies(self):
   4637     with context.eager_mode():
   4638       t1 = constant_op.constant(1)
   4639       t2 = constant_op.constant(2)
   4640       t3 = control_flow_ops.with_dependencies(t1, t2)
   4641       self.assertAllEqual(t2.numpy(), t3.numpy())
   4642 
   4643   def testTuple(self):
   4644     with context.eager_mode():
   4645       t1 = constant_op.constant(1)
   4646       t2 = constant_op.constant(2)
   4647       tup1, tup2 = control_flow_ops.tuple([t1, t2])
   4648       self.assertAllEqual(t1.numpy(), tup1.numpy())
   4649       self.assertAllEqual(t2.numpy(), tup2.numpy())
   4650 
   4651   @test_util.run_v1_only("b/120545219")
   4652   def testCase(self):
   4653     with context.eager_mode():
   4654       x = constant_op.constant(1)
   4655       y = constant_op.constant(2)
   4656       z = constant_op.constant(3)
   4657       f1 = lambda: constant_op.constant(17)
   4658       f2 = lambda: constant_op.constant(23)
   4659       f3 = lambda: constant_op.constant(-1)
   4660 
   4661       r1 = control_flow_ops.case(
   4662           [(x < y, f1), (x > z, f2)], default=f3, exclusive=True)
   4663       self.assertAllEqual(r1.numpy(), 17)
   4664 
   4665 
   4666 if __name__ == "__main__":
   4667   test.main()
   4668