1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OiR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 16 # pylint: disable=g-long-lambda 17 """Tests for tensorflow.ops.control_flow_ops.""" 18 19 from __future__ import absolute_import 20 from __future__ import division 21 from __future__ import print_function 22 23 import collections 24 import math 25 import sys 26 import time 27 28 import numpy as np 29 from six.moves import xrange # pylint: disable=redefined-builtin 30 31 from tensorflow.core.protobuf import config_pb2 32 from tensorflow.python.client import device_lib 33 from tensorflow.python.client import session 34 from tensorflow.python.eager import context 35 from tensorflow.python.eager import def_function 36 from tensorflow.python.eager import function as eager_function 37 from tensorflow.python.eager import wrap_function 38 from tensorflow.python.framework import constant_op 39 from tensorflow.python.framework import dtypes 40 from tensorflow.python.framework import errors_impl 41 from tensorflow.python.framework import function 42 from tensorflow.python.framework import ops 43 from tensorflow.python.framework import sparse_tensor 44 from tensorflow.python.framework import tensor_shape 45 from tensorflow.python.framework import test_util 46 from tensorflow.python.ops import array_ops 47 from tensorflow.python.ops import control_flow_ops 48 from tensorflow.python.ops import control_flow_util 49 from tensorflow.python.ops import data_flow_ops 50 from tensorflow.python.ops import functional_ops 51 from tensorflow.python.ops import gen_array_ops 52 from tensorflow.python.ops import gen_control_flow_ops 53 from tensorflow.python.ops import gen_data_flow_ops 54 from tensorflow.python.ops import gen_logging_ops 55 from tensorflow.python.ops import gen_state_ops 56 from tensorflow.python.ops import gradient_checker_v2 57 from tensorflow.python.ops import gradients_impl 58 from tensorflow.python.ops import init_ops 59 from tensorflow.python.ops import linalg_ops 60 from tensorflow.python.ops import logging_ops 61 from tensorflow.python.ops import map_fn 62 from tensorflow.python.ops import math_ops 63 from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 64 from tensorflow.python.ops import nn_ops 65 from tensorflow.python.ops import random_ops 66 from tensorflow.python.ops import resource_variable_ops 67 from tensorflow.python.ops import script_ops 68 from tensorflow.python.ops import sparse_ops 69 from tensorflow.python.ops import state_ops 70 from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 71 from tensorflow.python.ops import tensor_array_ops 72 from tensorflow.python.ops import variable_scope 73 from tensorflow.python.ops import variables 74 from tensorflow.python.ops import while_v2 # pylint: disable=unused-import 75 # pylint: disable=unused-import 76 from tensorflow.python.ops.ragged import ragged_factory_ops 77 import tensorflow.python.ops.tensor_array_grad 78 # pylint: enable=unused-import 79 from tensorflow.python.platform import test 80 from tensorflow.python.training import adam 81 from tensorflow.python.training import gradient_descent 82 from tensorflow.python.util import nest 83 84 85 def check_consumers(graph): 86 """Sanity check on the consumer list of the tensors.""" 87 88 consumer_count = {} 89 for op in graph.get_operations(): 90 for v in op.inputs: 91 cnt = consumer_count.get(v, 0) 92 consumer_count[v] = cnt + 1 93 for k, v in consumer_count.items(): 94 if len(k.consumers()) != v: 95 return False 96 return True 97 98 99 def all_fetchables(): 100 tensor_names = [] 101 graph = ops.get_default_graph() 102 for op in graph.get_operations(): 103 for t in op.outputs: 104 if graph.is_fetchable(t): 105 tensor_names.append(t.name) 106 return tensor_names 107 108 109 def all_feedables(): 110 feedable_tensors = [] 111 graph = ops.get_default_graph() 112 for op in graph.get_operations(): 113 for t in op.inputs: 114 if graph.is_feedable(t): 115 feedable_tensors.append(t) 116 return feedable_tensors 117 118 119 def opt_cfg(): 120 return config_pb2.ConfigProto( 121 allow_soft_placement=True, 122 graph_options=config_pb2.GraphOptions( 123 optimizer_options=config_pb2.OptimizerOptions( 124 opt_level=config_pb2.OptimizerOptions.L1, 125 do_function_inlining=True, 126 do_constant_folding=True))) 127 128 129 def isum(s, maximum_iterations=None): 130 i = constant_op.constant(0, name="i") 131 c = lambda i, s: math_ops.less(i, 10) 132 b = lambda i, s: [math_ops.add(i, 1), math_ops.add(i, s)] 133 _, r_s = control_flow_ops.while_loop( 134 c, b, [i, s], maximum_iterations=maximum_iterations) 135 return r_s 136 137 138 @test_util.with_control_flow_v2 139 class ControlFlowTest(test.TestCase): 140 141 @test_util.run_v1_only("b/120545219") 142 def testRefIdentity(self): 143 with self.cached_session(): 144 v = variables.VariableV1(7) 145 146 v = control_flow_ops._Identity(v) 147 op = state_ops.assign(v, 9) 148 v2 = control_flow_ops.with_dependencies([op], v) 149 150 self.assertTrue(isinstance(v2, ops.Tensor)) 151 self.evaluate(variables.global_variables_initializer()) 152 self.assertEqual(9, self.evaluate(v2)) 153 154 @test_util.run_v1_only("b/120545219") 155 def testRefEnter(self): 156 with self.cached_session(): 157 v = variables.VariableV1(7) 158 159 enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) 160 nine = constant_op.constant(9) 161 enter_nine = gen_control_flow_ops.enter(nine, "foo_1") 162 op = state_ops.assign(enter_v, enter_nine) 163 v2 = control_flow_ops.with_dependencies([op], enter_v) 164 v3 = control_flow_ops.exit(v2) 165 self.evaluate(variables.global_variables_initializer()) 166 self.assertEqual(9, self.evaluate(v3)) 167 168 @test_util.run_v1_only("b/120545219") 169 def testRefSwitch(self): 170 with self.cached_session(): 171 v = variables.VariableV1(7) 172 173 p = constant_op.constant(True) 174 v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p) # pylint: disable=protected-access 175 v2 = state_ops.assign(v1[1], 9) 176 self.evaluate(variables.global_variables_initializer()) 177 self.assertEqual(9, self.evaluate(v2)) 178 179 def testEnterMulExit(self): 180 with self.cached_session(): 181 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 182 enter_data = gen_control_flow_ops.enter(data, "foo_1", False) 183 five = constant_op.constant(5) 184 enter_five = gen_control_flow_ops.enter(five, "foo_1", False) 185 mul_op = math_ops.multiply(enter_data, enter_five) 186 exit_op = control_flow_ops.exit(mul_op) 187 188 result = self.evaluate(exit_op) 189 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 190 191 @test_util.run_deprecated_v1 192 def testEnterShapePropagation(self): 193 with self.cached_session(): 194 v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) 195 196 # If is_constant=True, the shape information should be propagated. 197 enter_v_constant = gen_control_flow_ops.enter( 198 v, "frame1", is_constant=True) 199 self.assertEqual(enter_v_constant.shape, [2]) 200 201 # Otherwise, the shape should be unknown. 202 enter_v_non_constant = gen_control_flow_ops.enter( 203 v, "frame2", is_constant=False) 204 self.assertEqual(enter_v_non_constant.shape, None) 205 206 @test_util.run_v1_only("b/120545219") 207 def testSwitchMergeIndexedSlices(self): 208 with self.cached_session(): 209 values = constant_op.constant([1, 2, 3, 4, 5, 6]) 210 indices = constant_op.constant([0, 2, 4, 6, 8, 10]) 211 data = ops.IndexedSlices(values, indices) 212 pred = ops.convert_to_tensor(True) 213 switch_op = control_flow_ops.switch(data, pred) 214 merge_op = control_flow_ops.merge(switch_op)[0] 215 216 val = merge_op.values 217 ind = merge_op.indices 218 self.assertAllEqual(np.arange(1, 7), val) 219 self.assertAllEqual(np.arange(0, 12, 2), ind) 220 221 @test_util.run_v1_only("b/120545219") 222 def testSwitchDeadBranch(self): 223 with self.cached_session(): 224 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 225 ports = ops.convert_to_tensor(True, name="ports") 226 switch_op = control_flow_ops.switch(data, ports) 227 dead_branch = array_ops.identity(switch_op[0]) 228 229 with self.assertRaisesWithPredicateMatch( 230 errors_impl.InvalidArgumentError, 231 lambda e: "Retval[0] does not have value" in str(e)): 232 self.evaluate(dead_branch) 233 234 @test_util.run_v1_only("b/120545219") 235 def testSwitchMergeLess(self): 236 with self.cached_session(): 237 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 238 zero = ops.convert_to_tensor(0) 239 one = ops.convert_to_tensor(1) 240 less_op = math_ops.less(zero, one) 241 switch_op = control_flow_ops.switch(data, less_op) 242 merge_op = control_flow_ops.merge(switch_op)[0] 243 244 result = self.evaluate(merge_op) 245 self.assertAllEqual(np.arange(1, 7), result) 246 247 @test_util.run_v1_only("b/120545219") 248 def testSwitchMergeAddIdentity(self): 249 with self.cached_session(): 250 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 251 ports = ops.convert_to_tensor(False, name="ports") 252 switch_op = control_flow_ops.switch(data, ports) 253 one = constant_op.constant(1) 254 add_op = math_ops.add(switch_op[0], one) 255 id_op = array_ops.identity(switch_op[1]) 256 merge_op = control_flow_ops.merge([add_op, id_op])[0] 257 258 result = self.evaluate(merge_op) 259 self.assertAllEqual(np.array([x + 1 for x in [1, 2, 3, 4, 5, 6]]), result) 260 261 @test_util.run_v1_only("b/120545219") 262 def testSwitchMergeAddMul(self): 263 with self.cached_session(): 264 data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") 265 ports = ops.convert_to_tensor(True, name="ports") 266 switch_op = control_flow_ops.switch(data, ports) 267 one = constant_op.constant(1) 268 add_op = math_ops.add(switch_op[0], one) 269 five = constant_op.constant(5) 270 mul_op = math_ops.multiply(switch_op[1], five) 271 merge_op = control_flow_ops.merge([add_op, mul_op])[0] 272 273 result = self.evaluate(merge_op) 274 self.assertAllEqual(np.array([x * 5 for x in [1, 2, 3, 4, 5, 6]]), result) 275 276 @test_util.run_v1_only("b/120545219") 277 def testLoop_false(self): 278 with self.cached_session(): 279 false = ops.convert_to_tensor(False) 280 n = constant_op.constant(10) 281 282 enter_false = gen_control_flow_ops.enter(false, "foo_1", False) 283 enter_n = gen_control_flow_ops.enter(n, "foo_1", False) 284 285 merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] 286 switch_n = control_flow_ops.switch(merge_n, enter_false) 287 exit_n = control_flow_ops.exit(switch_n[0]) 288 next_n = control_flow_ops.next_iteration(switch_n[0]) 289 merge_n.op._update_input(1, next_n) 290 291 result = self.evaluate(exit_n) 292 self.assertAllEqual(10, result) 293 294 @test_util.run_deprecated_v1 295 def testLoop_1(self): 296 with self.cached_session(): 297 zero = constant_op.constant(0) 298 one = constant_op.constant(1) 299 n = constant_op.constant(10) 300 301 enter_i = gen_control_flow_ops.enter(zero, "foo", False) 302 enter_one = gen_control_flow_ops.enter(one, "foo", True) 303 enter_n = gen_control_flow_ops.enter(n, "foo", True) 304 305 with ops.device(test.gpu_device_name()): 306 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 307 308 less_op = math_ops.less(merge_i, enter_n) 309 cond_op = control_flow_ops.loop_cond(less_op) 310 switch_i = control_flow_ops.switch(merge_i, cond_op) 311 312 add_i = math_ops.add(switch_i[1], enter_one) 313 314 next_i = control_flow_ops.next_iteration(add_i) 315 merge_i.op._update_input(1, next_i) 316 317 exit_i = control_flow_ops.exit(switch_i[0]) 318 result = self.evaluate(exit_i) 319 self.assertAllEqual(10, result) 320 321 @test_util.run_v1_only("b/120545219") 322 def testLoop_2(self): 323 with self.cached_session(): 324 zero = constant_op.constant(0) 325 one = constant_op.constant(1) 326 n = constant_op.constant(10) 327 328 enter_i = gen_control_flow_ops.enter(zero, "foo", False) 329 enter_one = gen_control_flow_ops.enter(one, "foo", True) 330 enter_n = gen_control_flow_ops.enter(n, "foo", True) 331 332 merge_i = control_flow_ops.merge([enter_i, enter_i])[0] 333 334 less_op = math_ops.less(merge_i, enter_n) 335 cond_op = control_flow_ops.loop_cond(less_op) 336 switch_i = control_flow_ops.switch(merge_i, cond_op) 337 338 add_i = math_ops.add(switch_i[1], enter_one) 339 340 with ops.device(test.gpu_device_name()): 341 next_i = control_flow_ops.next_iteration(add_i) 342 merge_i.op._update_input(1, next_i) 343 344 exit_i = control_flow_ops.exit(switch_i[0]) 345 result = self.evaluate(exit_i) 346 self.assertAllEqual(10, result) 347 348 @test_util.run_v1_only("b/120545219") 349 def testDifferentFrame(self): 350 with self.cached_session(): 351 data = array_ops.placeholder(dtypes.float32, shape=[]) 352 enter_1 = gen_control_flow_ops.enter(data, "foo_1", False) 353 enter_2 = gen_control_flow_ops.enter(data, "foo_2", False) 354 res = math_ops.add(enter_1, enter_2) 355 with self.assertRaisesOpError("has inputs from different frames"): 356 res.eval(feed_dict={data: 1.0}) 357 358 @test_util.run_deprecated_v1 359 def testCondBool(self): 360 values = constant_op.constant(10) 361 fn1 = lambda: math_ops.add(values, 1) 362 fn2 = lambda: math_ops.subtract(values, 1) 363 with self.assertRaisesRegexp(TypeError, "must not be a Python bool"): 364 _ = control_flow_ops.cond(False, fn1, fn2) 365 366 @test_util.run_deprecated_v1 367 def testCondInt(self): 368 p = array_ops.placeholder(dtypes.bool, shape=[]) 369 v = constant_op.constant(10) 370 fn1 = lambda: math_ops.add(v, 1) 371 fn2 = lambda: math_ops.subtract(v, 1) 372 y = control_flow_ops.cond(p, fn1, fn2) 373 grad = gradients_impl.gradients(y, [v]) 374 self.assertAllEqual([None], grad) 375 376 def testCondOutputShape(self): 377 x = constant_op.constant(1.0) 378 b = control_flow_ops.cond( 379 constant_op.constant(True), lambda: math_ops.square(x), 380 lambda: math_ops.subtract(x, 1.)) 381 self.assertEqual(b.shape, tensor_shape.scalar()) 382 383 @test_util.run_v1_only("b/120545219") 384 def testFetchable(self): 385 with self.cached_session() as sess: 386 x = array_ops.placeholder(dtypes.float32) 387 control_flow_ops.cond( 388 constant_op.constant(True), lambda: x + 2, lambda: x + 0) 389 graph = ops.get_default_graph() 390 for op in graph.get_operations(): 391 for t in op.inputs: 392 if graph.is_fetchable(t.op): 393 sess.run(t, feed_dict={x: 3}) 394 else: 395 with self.assertRaisesRegexp(ValueError, 396 "has been marked as not fetchable"): 397 sess.run(t, feed_dict={x: 3}) 398 399 @test_util.disable_control_flow_v2("Not relevant") 400 @test_util.run_v1_only("b/120545219") 401 def testFeedable(self): 402 with self.cached_session() as sess: 403 c = constant_op.constant(2) 404 i0 = constant_op.constant(0) 405 r = control_flow_ops.while_loop(lambda i: i < 1000, 406 lambda i: math_ops.square(c) + i, [i0]) 407 self.assertEqual(1000, r.eval(feed_dict={i0: 0})) 408 feedable_tensors = all_feedables() 409 for t in feedable_tensors: 410 sess.run(r, feed_dict={t: 3}) 411 graph = ops.get_default_graph() 412 for op in graph.get_operations(): 413 for t in op.inputs: 414 if t not in feedable_tensors and t.dtype is dtypes.int32: 415 with self.assertRaisesRegexp(ValueError, "may not be fed"): 416 sess.run(r, feed_dict={t: 3}) 417 418 @test_util.run_v1_only("b/120545219") 419 def testCondIndexedSlices(self): 420 with self.cached_session(): 421 values = constant_op.constant(10) 422 indices = constant_op.constant(0) 423 x = ops.IndexedSlices(values, indices) 424 pred = math_ops.less(1, 2) 425 fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices) 426 fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), indices) 427 r = control_flow_ops.cond(pred, fn1, fn2) 428 429 val = r.values 430 ind = r.indices 431 self.assertAllEqual(11, val) 432 self.assertAllEqual(0, ind) 433 434 def testCondMismatchedIndexedSlices(self): 435 @def_function.function 436 def foo(): 437 values = constant_op.constant(10) 438 indices = constant_op.constant(0) 439 x = ops.IndexedSlices(values, indices) 440 v1_msg = "The two structures don't have the same nested structure" 441 v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same " 442 "number, type, and overall structure of return values.") 443 with self.assertRaisesRegexp( 444 TypeError, 445 v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg): 446 control_flow_ops.cond( 447 constant_op.constant(True), 448 lambda: ops.IndexedSlices(math_ops.add(x.values, 1), indices), 449 lambda: math_ops.add(x.values, 1), indices) 450 foo() 451 452 def testCondSparseTensor(self): 453 values = constant_op.constant([2.0, 4.0], name="values") 454 indices = constant_op.constant([[0], [3]], 455 dtype=dtypes.int64, 456 name="indices") 457 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 458 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 459 pred = math_ops.less(1, 2) 460 fn1 = lambda: sparse_tensor.SparseTensor( 461 indices + 1, x.values + 1, dense_shape=shape) 462 fn2 = lambda: sparse_tensor.SparseTensor( 463 indices, x.values - 1, dense_shape=shape) 464 r = control_flow_ops.cond(pred, fn1, fn2) 465 self.assertAllEqual([3.0, 5.0], r.values) 466 self.assertAllEqual([[1], [4]], r.indices) 467 self.assertAllEqual(r.values.get_shape(), (2,)) 468 469 def testCondRaggedTensor(self): 470 rt = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) 471 pred = math_ops.less(1, 2) 472 fn1 = lambda: array_ops.concat([rt + 2, [[100]]], axis=0) 473 fn2 = lambda: rt[:2] - 2 474 result = control_flow_ops.cond(pred, fn1, fn2) 475 self.assertAllEqual([3, 4, 5, 6, 7, 8, 100], result.values) 476 self.assertAllEqual([0, 2, 3, 6, 7], result.row_splits) 477 478 @test_util.run_v1_only("b/120545219") 479 def testCondResource(self): 480 481 with self.cached_session(): 482 rv = resource_variable_ops.ResourceVariable(True) 483 self.evaluate(variables.global_variables_initializer()) 484 t = ops.convert_to_tensor(1.0) 485 486 def case(): 487 assign = resource_variable_ops.assign_variable_op(rv.handle, False) 488 with ops.control_dependencies([assign]): 489 return array_ops.identity(t) 490 491 self.assertEqual( 492 1.0, self.evaluate(control_flow_ops.cond(rv, case, lambda: t))) 493 494 @test_util.run_v1_only("b/120545219") 495 def testCondWithTensorArrayGrad(self): 496 with self.cached_session() as sess: 497 with ops.device(test.gpu_device_name()): 498 pred = array_ops.placeholder(dtypes.bool, []) 499 x = constant_op.constant([1.0, 2.0, 3.0]) 500 y = control_flow_ops.cond( 501 pred, lambda: map_fn.map_fn(lambda z: z * 2.0, x), 502 lambda: constant_op.constant([1.0, 1.0, 1.0])) 503 g = gradients_impl.gradients(y, x)[0] 504 505 self.assertAllEqual(sess.run(g, {pred: True}), [2.0, 2.0, 2.0]) 506 self.assertAllEqual(sess.run(g, {pred: False}), [0.0, 0.0, 0.0]) 507 508 @test_util.disable_control_flow_v2("b/113293074") 509 @test_util.run_v1_only("b/120545219") 510 def testCondIndexedSlicesDifferentTypes(self): 511 with self.cached_session(): 512 values = constant_op.constant(10) 513 i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32) 514 i_64 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int64) 515 x = ops.IndexedSlices(values, i_32) 516 pred = math_ops.less(1, 2) 517 fn1 = lambda: ops.IndexedSlices(math_ops.add(x.values, 1), i_32) 518 fn2 = lambda: ops.IndexedSlices(math_ops.subtract(x.values, 1), i_64) 519 r = control_flow_ops.cond(pred, fn1, fn2) 520 521 val = r.values 522 ind = r.indices 523 self.assertAllEqual(11, val) 524 self.assertAllEqual(0, ind) 525 self.assertTrue(ind.dtype == np.int64) 526 527 @test_util.run_v1_only("b/120545219") 528 def testCondColocation(self): 529 with self.session(use_gpu=True): 530 with ops.device("/cpu:0"): 531 v = variables.Variable(7.0) 532 533 x = constant_op.constant(10.0) 534 pred = math_ops.less(1.0, 2.0) 535 fn1 = lambda: math_ops.add(v, 1.0) 536 fn2 = lambda: math_ops.subtract(x, 1.0) 537 r = control_flow_ops.cond(pred, fn1, fn2) 538 539 for op in x.graph.get_operations(): 540 if op.name == "cond/Add/Switch": 541 self.assertDeviceEqual(op.device, "/cpu:0") 542 543 def _testCond_1(self, use_gpu): 544 with self.cached_session(use_gpu=use_gpu): 545 x = constant_op.constant(10) 546 pred = math_ops.less(1, 2) 547 fn1 = lambda: math_ops.add(x, 1) 548 fn2 = lambda: math_ops.subtract(x, 1) 549 r = control_flow_ops.cond(pred, fn1, fn2) 550 551 result = self.evaluate(r) 552 self.assertAllEqual(11, result) 553 554 def testCond_1(self): 555 556 self._testCond_1(use_gpu=False) 557 # TODO(b/116526896): Enable GPU tests. 558 # self._testCond_1(use_gpu=True) 559 560 def testCond_2(self): 561 562 with self.cached_session(): 563 x = constant_op.constant(10) 564 r = control_flow_ops.cond( 565 math_ops.less(1, 0), lambda: math_ops.add(x, 1), 566 lambda: math_ops.subtract(x, 1)) 567 result = self.evaluate(r) 568 self.assertAllEqual(9, result) 569 570 def testCond_3(self): 571 572 with self.cached_session(): 573 x = constant_op.constant(10) 574 pred = math_ops.less(1, 2) 575 fn1 = lambda: math_ops.add(x, 1) 576 fn2 = lambda: math_ops.subtract(x, 1) 577 fn3 = lambda: math_ops.add(control_flow_ops.cond(pred, fn1, fn2), 1) 578 r = control_flow_ops.cond(pred, fn3, fn2) 579 580 result = self.evaluate(r) 581 self.assertAllEqual(12, result) 582 583 @test_util.disable_xla("b/128638446") 584 @test_util.run_in_graph_and_eager_modes 585 def testCondPruning(self): 586 v1 = variables.Variable(7) 587 v2 = variables.Variable(7) 588 v3 = variables.Variable(7) 589 590 def f(): 591 age = constant_op.constant(3) 592 max_age = constant_op.constant(2) 593 pred = math_ops.greater(age, max_age) 594 fn1 = lambda: [state_ops.assign(v1, 1).op, state_ops.assign(v2, 2).op] 595 fn2 = lambda: [state_ops.assign(v3, 3).op, constant_op.constant(10).op] 596 r = control_flow_ops.cond(pred, fn1, fn2) 597 self.assertEqual(len(r), 2) 598 return r[1] 599 600 f_defun = eager_function.defun(f) 601 602 if not context.executing_eagerly(): 603 with self.cached_session(): 604 self.evaluate(variables.global_variables_initializer()) 605 result = self.evaluate(f()) 606 self.assertEqual(True, result) 607 # Only second cond result was fetched, so v1 assign shouldn't run. 608 self.assertEqual(7, self.evaluate(v1)) 609 self.assertEqual(2, self.evaluate(v2)) 610 self.assertEqual(7, self.evaluate(v3)) 611 612 result = f_defun() 613 self.assertEqual(True, self.evaluate(result)) 614 # Both v1 and v2 branch assignments should be run in defun. 615 self.assertEqual(1, self.evaluate(v1)) 616 self.assertEqual(2, self.evaluate(v2)) 617 self.assertEqual(7, self.evaluate(v3)) 618 619 def testCond_5(self): 620 with self.cached_session(): 621 alive = constant_op.constant(True, name="alive") 622 count = constant_op.constant(0, name="count") 623 624 def body(i): 625 return control_flow_ops.cond( 626 alive, lambda: [math_ops.less(i, 3), math_ops.add(count, 1)], 627 lambda: [alive, count]) 628 629 for i in range(10): 630 alive, count = body(i) 631 self.assertAllEqual(4, self.evaluate(count)) 632 633 @test_util.run_v1_only("b/120545219") 634 def testCond_6(self): 635 with self.cached_session(): 636 v1 = variables.Variable([7]) 637 638 age = constant_op.constant(3) 639 pred = math_ops.greater(age, 4) 640 fn1 = lambda: age 641 fn2 = lambda: v1 642 r = control_flow_ops.cond(pred, fn1, fn2) 643 644 self.evaluate(variables.global_variables_initializer()) 645 result = self.evaluate(r) 646 self.assertAllEqual(np.array([7]), result) 647 648 def testCond_7(self): 649 with self.cached_session() as sess: 650 x = constant_op.constant(10) 651 y = constant_op.constant(200) 652 pred = math_ops.less(1, 2) 653 fn1 = lambda: [math_ops.add(x, 1), math_ops.add(x, 2)] 654 fn2 = lambda: [y, y] 655 r = control_flow_ops.cond(pred, fn1, fn2) 656 self.assertAllEqual([11, 12], self.evaluate(r)) 657 658 @test_util.run_gpu_only 659 @test_util.run_deprecated_v1 660 def testCond_Device(self): 661 x = constant_op.constant(-10.) 662 663 # True branch function defined outside of device scope 664 def true_fn(): 665 return math_ops.exp(x) 666 667 with ops.device("CPU:0"): 668 r = control_flow_ops.cond( 669 constant_op.constant(True), true_fn, lambda: 0.) 670 self.assertIn("cpu", r.device.lower()) 671 672 with session.Session() as sess: 673 options = config_pb2.RunOptions(output_partition_graphs=True) 674 run_metadata = config_pb2.RunMetadata() 675 sess.run(r, options=options, run_metadata=run_metadata) 676 # We expect that everything runs on CPU, even if GPU is available. 677 self.assertEqual(len(run_metadata.partition_graphs), 1) 678 679 def _count_matching_switch_nodes_on_device(self, run_metadata, device_str): 680 # Returns the number of Switch nodes with type float32 placed on 681 # `device_str`. 682 device_graphs = [ 683 g for g in run_metadata.partition_graphs 684 if device_str in g.node[0].device 685 ] 686 self.assertLen(device_graphs, 1) 687 switch_nodes = [ 688 n for n in device_graphs[0].node if n.op == "Switch" and 689 n.attr["T"].type == dtypes.float32.as_datatype_enum 690 ] 691 return len(switch_nodes) 692 693 @test_util.run_gpu_only 694 @test_util.run_deprecated_v1 695 def testCondSwitchColocatedWithInputWhenInputOnCPU(self): 696 x = array_ops.placeholder(dtypes.float32) 697 698 # `arg` is used in the cond then branch so a Switch node is created for it. 699 # We test that the Switch node gets placed on the same device as `arg`. 700 # We force `arg` to be on CPU here. 701 with ops.device("CPU:0"): 702 arg = x + 10. 703 704 def true_fn(): 705 with ops.device("CPU:0"): 706 return arg + 1 707 708 r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.) 709 710 with session.Session() as sess: 711 run_metadata = config_pb2.RunMetadata() 712 options = config_pb2.RunOptions(output_partition_graphs=True) 713 sess.run( 714 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 715 self.assertEqual(len(run_metadata.partition_graphs), 2) 716 # Check that the Switch for `arg` gets placed on CPU. 717 self.assertEqual( 718 self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 1) 719 self.assertEqual( 720 self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 0) 721 722 @test_util.run_gpu_only 723 @test_util.run_deprecated_v1 724 def testCondSwitchColocatedWithInputWhenInputOnGPU(self): 725 x = array_ops.placeholder(dtypes.float32) 726 727 # `arg` is used in the cond then branch so a Switch node is created for it. 728 # We test that the Switch node gets placed on the same device as `arg`. 729 # Note: `arg` gets placed on GPU by default by the placer. 730 arg = x + 10. 731 732 def true_fn(): 733 with ops.device("CPU:0"): 734 return arg + 1 735 736 r = control_flow_ops.cond(constant_op.constant(True), true_fn, lambda: 0.) 737 738 with session.Session() as sess: 739 run_metadata = config_pb2.RunMetadata() 740 options = config_pb2.RunOptions(output_partition_graphs=True) 741 sess.run( 742 r, feed_dict={x: -10.}, options=options, run_metadata=run_metadata) 743 self.assertEqual(len(run_metadata.partition_graphs), 2) 744 # Check that the Switch for `arg` gets placed on GPU. 745 self.assertEqual( 746 self._count_matching_switch_nodes_on_device(run_metadata, "CPU"), 0) 747 self.assertEqual( 748 self._count_matching_switch_nodes_on_device(run_metadata, "GPU"), 1) 749 750 def testCondListOutput(self): 751 with self.cached_session() as sess: 752 x = constant_op.constant(10) 753 y = constant_op.constant(200) 754 pred = math_ops.less(1, 2) 755 fn1 = lambda: [math_ops.add(x, y), math_ops.add(x, y)] 756 fn2 = lambda: [y, y] 757 r = control_flow_ops.cond(pred, fn1, fn2) 758 test_result = self.evaluate(r) 759 self.assertListEqual([210, 210], test_result) 760 761 def testTupleOutput(self): 762 with self.cached_session() as sess: 763 x = constant_op.constant(10) 764 y = constant_op.constant(200) 765 pred = math_ops.less(1, 2) 766 fn1 = lambda: (math_ops.add(x, y), math_ops.add(x, y)) 767 fn2 = lambda: (y, y) 768 r = control_flow_ops.cond(pred, fn1, fn2) 769 test_result = self.evaluate(r) 770 self.assertTupleEqual((210, 210), test_result) 771 772 def testDictOutput(self): 773 with self.cached_session() as sess: 774 x = constant_op.constant(10) 775 y = constant_op.constant(200) 776 pred = math_ops.less(1, 2) 777 fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)} 778 fn2 = lambda: {"a": y, "b": y} 779 r = control_flow_ops.cond(pred, fn1, fn2) 780 test_result = self.evaluate(r) 781 self.assertDictEqual({"a": 210, "b": 210}, test_result) 782 783 @test_util.run_deprecated_v1 784 def testEmbeddedListOutput(self): 785 with self.cached_session() as sess: 786 x = constant_op.constant(10) 787 y = constant_op.constant(200) 788 pred = math_ops.less(1, 2) 789 fn1 = lambda: [[math_ops.add(x, y), math_ops.add(x, y)]] 790 fn2 = lambda: [[y, y]] 791 # Pass strict=True flag as cond_v2 allows for tensors to be 792 # in nested output structures as singletons 793 r = control_flow_ops.cond(pred, fn1, fn2, strict=True) 794 test_result = self.evaluate(r) 795 self.assertListEqual([[210, 210]], test_result) 796 797 def testEmbeddedTupleOutput(self): 798 with self.cached_session() as sess: 799 x = constant_op.constant(10) 800 y = constant_op.constant(200) 801 pred = math_ops.less(1, 2) 802 fn1 = lambda: ((math_ops.add(x, y), math_ops.add(x, y))) 803 fn2 = lambda: ((y, y)) 804 r = control_flow_ops.cond(pred, fn1, fn2) 805 test_result = self.evaluate(r) 806 self.assertTupleEqual(((210, 210)), test_result) 807 808 def testEmbeddedDictOutput(self): 809 with self.cached_session() as sess: 810 x = constant_op.constant(10) 811 y = constant_op.constant(200) 812 pred = math_ops.less(1, 2) 813 fn1 = lambda: {"a": {"c": math_ops.add(x, y)}, 814 "b": {"d": math_ops.add(x, y)}} 815 fn2 = lambda: {"a": {"c": y}, 816 "b": {"d": y}} 817 r = control_flow_ops.cond(pred, fn1, fn2) 818 test_result = self.evaluate(r) 819 self.assertDictEqual({"a": {"c": 210}, "b": {"d": 210}}, test_result) 820 821 @test_util.run_v1_only("b/120545219") 822 def testCheckNestedOutputStruct(self): 823 with self.cached_session() as sess: 824 x = constant_op.constant(10) 825 y = constant_op.constant(200) 826 pred = math_ops.less(1, 2) 827 fn1 = lambda: {"a": math_ops.add(x, y), "b": math_ops.add(x, y)} 828 fn2 = lambda: {"c": y, "d": y} 829 v1_msg = "The two structures don't have the same nested structure" 830 v2_msg = ("true_fn and false_fn arguments to tf.cond must have the same " 831 "number, type, and overall structure of return values.") 832 with self.assertRaisesRegexp( 833 TypeError if control_flow_util.ENABLE_CONTROL_FLOW_V2 else ValueError, 834 v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg): 835 control_flow_ops.cond(pred, fn1, fn2) 836 837 @test_util.run_deprecated_v1 838 def testCondRef(self): 839 840 with self.cached_session(): 841 x = gen_state_ops.variable( 842 shape=[1], 843 dtype=dtypes.float32, 844 name="x", 845 container="", 846 shared_name="") 847 true_fn = lambda: x 848 false_fn = lambda: constant_op.constant([2.0]) 849 r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn) 850 self.assertAllEqual([2.0], self.evaluate(r)) 851 852 @test_util.disable_control_flow_v2("b/79881896 (placeholder)") 853 @test_util.run_v1_only("b/120545219") 854 def testCondWithControl(self): 855 with self.cached_session(): 856 control_holder = array_ops.placeholder(dtypes.float32, shape=()) 857 a = constant_op.constant(3) 858 859 def true_branch(): 860 with ops.control_dependencies([control_holder]): 861 _ = a + 1 862 return a + 2 863 864 r = control_flow_ops.cond( 865 constant_op.constant(True), true_branch, 866 lambda: constant_op.constant(1)) 867 self.assertEqual(5, self.evaluate(r)) 868 869 @test_util.run_v1_only("b/120545219") 870 def testUninitializedRefIdentity(self): 871 with self.cached_session() as sess: 872 v = gen_state_ops.variable( 873 shape=[1], 874 dtype=dtypes.float32, 875 name="v", 876 container="", 877 shared_name="") 878 inited = state_ops.is_variable_initialized(v) 879 v_f, v_t = control_flow_ops.ref_switch(v, inited) 880 # Both v_f and v_t are uninitialized references. However, an actual use 881 # of the reference in the 'true' branch in the 'tf.identity' op will 882 # not 'fire' when v is uninitialized, so this is a valid construction. 883 # This test tests that ref_identity allows uninitialized ref as input 884 # so that this construction is allowed. 885 v_f_op = gen_array_ops.ref_identity(v_f) 886 v_t_op = gen_array_ops.ref_identity(v_t) 887 with ops.control_dependencies([v_f_op]): 888 assign_v = state_ops.assign(v, [1.0]) 889 with ops.control_dependencies([v_t_op]): 890 orig_v = array_ops.identity(v) 891 merged_op = control_flow_ops.merge([assign_v, orig_v]) 892 self.assertAllEqual([1.0], self.evaluate(merged_op.output)) 893 894 def testCondSwitchIdentity(self): 895 # Make sure the recv identity is not removed by optimization. 896 with session.Session(config=opt_cfg()) as sess: 897 pred = constant_op.constant(True) 898 899 def fn1(): 900 return control_flow_ops.no_op() 901 902 def fn2(): 903 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 904 905 r = control_flow_ops.cond(pred, fn1, fn2) 906 self.evaluate(r) 907 908 def testCondRecvIdentity(self): 909 # Make sure the switch identity is not removed by optimization. 910 with session.Session(config=opt_cfg()) as sess: 911 with ops.device(test.gpu_device_name()): 912 pred = constant_op.constant(True) 913 914 def fn1(): 915 return control_flow_ops.no_op() 916 917 def fn2(): 918 with ops.device("/cpu:0"): 919 return control_flow_ops.Assert(False, ["Wrong branch!!!"]) 920 921 r = control_flow_ops.cond(pred, fn1, fn2) 922 self.evaluate(r) 923 924 @test_util.run_v1_only("b/120545219") 925 def testCondGrad_1(self): 926 with self.cached_session(): 927 x = constant_op.constant(10.0, name="x") 928 pred = math_ops.less(1, 2) 929 fn1 = lambda: array_ops.identity(x) 930 fn2 = lambda: array_ops.identity(x) 931 r = control_flow_ops.cond(pred, fn1, fn2) 932 933 grad = gradients_impl.gradients(r, [x])[0] 934 self.assertAllEqual(1.0, self.evaluate(grad)) 935 936 @test_util.run_deprecated_v1 937 def testCondGrad_2(self): 938 with self.cached_session(): 939 c = array_ops.placeholder(dtypes.int32, shape=[]) 940 x = constant_op.constant(10.0) 941 pred = math_ops.less(c, 2) 942 fn1 = lambda: math_ops.multiply(x, 42.0) 943 fn2 = lambda: math_ops.multiply(x, 3.0) 944 r = control_flow_ops.cond(pred, fn1, fn2) 945 946 grad = gradients_impl.gradients(r, [x])[0] 947 self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1})) 948 self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3})) 949 950 @test_util.disable_control_flow_v2( 951 "b/110550782 (gradient w.r.t external variable)") 952 @test_util.run_deprecated_v1 953 def testCondGrad_3(self): 954 with self.cached_session(): 955 c = array_ops.placeholder(dtypes.int32, shape=[]) 956 ox = constant_op.constant(10.0) 957 pred = math_ops.less(c, 2) 958 959 def fn1(x): 960 m = x * x 961 return gradients_impl.gradients(m, [ox])[0] 962 963 fn2 = lambda: math_ops.multiply(ox, 3.0) 964 y = math_ops.multiply(7.0, ox) 965 r = control_flow_ops.cond(pred, lambda: fn1(y), fn2) 966 967 self.assertAllEqual(980.0, r.eval(feed_dict={c: 1})) 968 self.assertAllEqual(30.0, r.eval(feed_dict={c: 3})) 969 970 @test_util.run_deprecated_v1 971 def testCondGradMultiDevice(self): 972 config = config_pb2.ConfigProto(device_count={"CPU": 2}, 973 allow_soft_placement=True) 974 with self.cached_session(use_gpu=True, config=config) as sess: 975 pred = array_ops.placeholder(dtypes.bool, []) 976 x = array_ops.placeholder(dtypes.float32) 977 y = array_ops.placeholder(dtypes.float32) 978 979 with ops.device("/cpu:0"): 980 z = control_flow_ops.cond(pred, lambda: x * y * 2.0, lambda: 2.0) 981 982 with ops.device("/cpu:1"): 983 grad = gradients_impl.gradients(z, x)[0] 984 985 with ops.device("/cpu:0"): 986 grad_grad = gradients_impl.gradients(grad, x)[0] 987 988 self.assertEqual(sess.run(grad, {pred: True, x: 1.0, y: 2.0}), 4.0) 989 self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0) 990 991 # v1 control flow gets None second derivative for some reason. 992 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 993 self.assertIsNone(grad_grad) 994 return 995 996 self.assertEqual(sess.run(grad_grad, {pred: True, x: 1.0, y: 2.0}), 0.0) 997 self.assertEqual(sess.run(grad_grad, {pred: False, x: 1.0, y: 2.0}), 0.0) 998 999 @test_util.run_v1_only("b/120545219") 1000 def testNestedCond_Simple(self): 1001 with self.cached_session(): 1002 x = constant_op.constant(0., name="X") 1003 y = control_flow_ops.cond( 1004 constant_op.constant(True), lambda: x, 1005 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 1006 result = gradients_impl.gradients(y, x)[0] 1007 self.assertEqual(1.0, self.evaluate(result)) 1008 1009 z = control_flow_ops.cond( 1010 constant_op.constant(False), lambda: x, 1011 lambda: control_flow_ops.cond(x < 1., lambda: x, lambda: x)) 1012 result = gradients_impl.gradients(z, x)[0] 1013 self.assertEqual(1.0, self.evaluate(result)) 1014 1015 @test_util.disable_control_flow_v2("b/113327884") 1016 @test_util.run_v1_only("b/120545219") 1017 def testCondGrad_Gather(self): 1018 with self.cached_session() as sess: 1019 v1 = variables.Variable([1.0, 42.0]) 1020 c = array_ops.placeholder(dtypes.int32, shape=[]) 1021 pred = math_ops.less(c, 2) 1022 fn1 = lambda: array_ops.identity(v1) 1023 fn2 = lambda: array_ops.gather(v1, [1, 1]) 1024 r = control_flow_ops.cond(pred, fn1, fn2) 1025 grad = gradients_impl.gradients(r, [v1])[0] 1026 self.evaluate(variables.global_variables_initializer()) 1027 # Should just be [1, 1], but possibly a sparse representation 1028 gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 1}) 1029 dense_gv = [ 1030 sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2) 1031 ] 1032 self.assertAllEqual(dense_gv, [1.0, 1.0]) 1033 # Should be [0, 2], as the else forwards v1[1] twice 1034 gv, gi = sess.run([grad.values, grad.indices], feed_dict={c: 3}) 1035 dense_gv = [ 1036 sum(y for (x, y) in zip(gi, gv) if x == i) for i in range(2) 1037 ] 1038 self.assertAllEqual(dense_gv, [0.0, 2.0]) 1039 1040 @test_util.run_deprecated_v1 1041 def testCondGrad_ResourceVarSparseRead(self): 1042 # NOTE(skyewm): this test is interesting because the 1043 # ResourceVariable.sparse_read gradient function returns IndexedSlices. 1044 var = resource_variable_ops.ResourceVariable( 1045 np.ones((4, 2), dtype=np.float32)) 1046 x = constant_op.constant(1.0) 1047 r = control_flow_ops.cond( 1048 constant_op.constant(True), 1049 lambda: x * math_ops.reduce_sum(var.sparse_read([1, 2])), 1050 lambda: constant_op.constant(np.zeros((2, 3)), 1051 dtype=dtypes.float32)) 1052 grad = gradients_impl.gradients(r, var)[0] 1053 1054 self.evaluate(variables.global_variables_initializer()) 1055 grad_val = self.evaluate(grad) 1056 self.assertIsInstance(grad_val, ops.IndexedSlicesValue) 1057 self.assertAllEqual(gradient_checker_v2._to_numpy(grad_val), [[0., 0.], 1058 [1., 1.], 1059 [1., 1.], 1060 [0., 0.]]) 1061 1062 @test_util.disable_xla("b/128643464") 1063 def testCondGrad_MultiGather(self): 1064 # NOTE(skyewm): this test is interesting because the array_ops.gather and 1065 # ResourceVariable.sparse_read gradient functions returns IndexedSlices. 1066 var = resource_variable_ops.ResourceVariable( 1067 np.ones((4, 2), dtype=np.float32)) 1068 x1 = constant_op.constant(np.ones((3, 3), dtype=np.float32)) 1069 x2 = constant_op.constant(2.0) 1070 1071 def true_fn(): 1072 y1 = var.sparse_read([1, 2]) 1073 y2 = array_ops.gather(x1, [2]) * x2 1074 y3 = x2 * [1., 1., 1.] 1075 return y1, y2, y3 1076 1077 def false_fn(): 1078 y1 = np.zeros((2, 2), dtype=np.float32) 1079 y2 = array_ops.gather(x1, [2]) * x2 1080 y3 = array_ops.gather(x1, [2]) 1081 return y1, y2, y3 1082 1083 @def_function.function 1084 def foo(): 1085 r = control_flow_ops.cond(constant_op.constant(True), true_fn, false_fn) 1086 return gradients_impl.gradients(r, [var, x1, x2]) 1087 1088 grad = foo() 1089 self.evaluate(variables.global_variables_initializer()) 1090 var_grad, x1_grad, x2_grad = self.evaluate(grad) 1091 self.assertIsInstance(var_grad, ops.IndexedSlicesValue) 1092 self.assertAllEqual(gradient_checker_v2._to_numpy(var_grad), [[0., 0.], 1093 [1., 1.], 1094 [1., 1.], 1095 [0., 0]]) 1096 self.assertIsInstance(x1_grad, ops.IndexedSlicesValue) 1097 self.assertAllEqual(gradient_checker_v2._to_numpy(x1_grad), [[0., 0., 0.], 1098 [0., 0., 0.], 1099 [2., 2., 2.]]) 1100 self.assertIsInstance(x1_grad, ops.IndexedSlicesValue) 1101 self.assertEqual(gradient_checker_v2._to_numpy(x2_grad), 6.) 1102 1103 @test_util.run_v1_only("b/120545219") 1104 def testCondPredicateTensor(self): 1105 """Regression test for lowering predicate from non-first output of an op.""" 1106 1107 @eager_function.defun 1108 def foo(): 1109 return constant_op.constant("foo"), constant_op.constant(True) 1110 1111 r = control_flow_ops.cond(foo()[1], lambda: 1.0, lambda: 2.0) 1112 self.assertEqual(self.evaluate(r), 1.0) 1113 1114 @test_util.run_v1_only("Tests Session.run() pruning logic.") 1115 def testCondFeedConstantPredicate(self): 1116 with self.cached_session() as sess: 1117 value = constant_op.constant(37.0) 1118 predicate = constant_op.constant(True) 1119 cond_output = control_flow_ops.cond( 1120 predicate, lambda: constant_op.constant(0.0), lambda: value) 1121 result = array_ops.identity(cond_output) 1122 self.assertEqual(37.0, sess.run(result, feed_dict={predicate: False})) 1123 self.assertEqual(0.0, sess.run(result, feed_dict={predicate: True})) 1124 self.assertEqual(0.0, sess.run(result)) 1125 1126 @test_util.run_v1_only("Tests Session.run() pruning logic.") 1127 def testCondFeedPlaceholderWithDefaultPredicate(self): 1128 with self.cached_session() as sess: 1129 value = constant_op.constant(37.0) 1130 predicate = array_ops.placeholder_with_default( 1131 constant_op.constant(True), []) 1132 cond_output = control_flow_ops.cond( 1133 predicate, lambda: constant_op.constant(0.0), lambda: value) 1134 result = array_ops.identity(cond_output) 1135 self.assertAllEqual(37.0, sess.run(result, feed_dict={predicate: False})) 1136 self.assertAllEqual(0.0, sess.run(result, feed_dict={predicate: True})) 1137 self.assertAllEqual(0.0, sess.run(result)) 1138 1139 @test_util.disable_xla("b/128644469 PrintV2") 1140 @test_util.run_in_graph_and_eager_modes 1141 def testCondAutoControlDeps(self): 1142 1143 def branch_fn(): 1144 logging_ops.print_v2("A") 1145 logging_ops.print_v2("B") 1146 with ops.control_dependencies([logging_ops.print_v2("C")]): 1147 return constant_op.constant(10) 1148 1149 def build_cond(): 1150 return control_flow_ops.cond( 1151 constant_op.constant(True), branch_fn, lambda: 0) 1152 1153 def build_nested_cond(): 1154 return control_flow_ops.cond( 1155 constant_op.constant(True), build_cond, lambda: 0) 1156 1157 # In v1 graph mode, pruning should make only "C" print. 1158 if not context.executing_eagerly(): 1159 with self.cached_session(): 1160 with self.captureWritesToStream(sys.stderr) as printed: 1161 self.assertEqual(self.evaluate(build_cond()), 10) 1162 self.assertEqual(printed.contents(), "C\n") 1163 1164 with self.captureWritesToStream(sys.stderr) as printed: 1165 self.assertEqual(self.evaluate(build_nested_cond()), 10) 1166 self.assertEqual(printed.contents(), "C\n") 1167 1168 # In defuns, all prints should execute in program order. 1169 # This doesn't work with legacy control flow. 1170 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1171 1172 @eager_function.defun 1173 def cond(): 1174 return build_cond() 1175 1176 with self.captureWritesToStream(sys.stderr) as printed: 1177 self.assertEqual(self.evaluate(cond()), 10) 1178 self.assertTrue(printed.contents().endswith("A\nB\nC\n"), 1179 printed.contents()) 1180 1181 @eager_function.defun 1182 def nested_cond(): 1183 return build_nested_cond() 1184 1185 with self.captureWritesToStream(sys.stderr) as printed: 1186 self.assertEqual(self.evaluate(nested_cond()), 10) 1187 self.assertTrue(printed.contents().endswith("A\nB\nC\n"), 1188 printed.contents()) 1189 1190 # wrap_function should prune. 1191 def pruned_cond(): 1192 return build_cond() 1193 pruned_cond = wrap_function.wrap_function(pruned_cond, []) 1194 1195 with self.captureWritesToStream(sys.stderr) as printed: 1196 self.assertEqual(self.evaluate(pruned_cond()), 10) 1197 self.assertEqual(printed.contents(), "C\n") 1198 1199 def pruned_nested_cond(): 1200 return build_nested_cond() 1201 pruned_nested_cond = wrap_function.wrap_function(pruned_nested_cond, []) 1202 1203 with self.captureWritesToStream(sys.stderr) as printed: 1204 self.assertEqual(self.evaluate(pruned_nested_cond()), 10) 1205 self.assertEqual(printed.contents(), "C\n") 1206 1207 @test_util.disable_xla("b/128643646 PrintV2") 1208 @test_util.run_in_graph_and_eager_modes 1209 def testWhileAutoControlDeps(self): 1210 # Legacy while_loop fails this test because it produces deprecation notices 1211 # in stderr. 1212 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return 1213 1214 def cond(i, unused_x): 1215 logging_ops.print_v2("A") 1216 return i < 2 1217 1218 def body(i, x): 1219 logging_ops.print_v2("B") 1220 with ops.control_dependencies([logging_ops.print_v2("C")]): 1221 x = array_ops.identity(x) 1222 with ops.control_dependencies([logging_ops.print_v2("D")]): 1223 return i + 1, x 1224 1225 def build_while(): 1226 return control_flow_ops.while_loop( 1227 cond, body, [constant_op.constant(0), constant_op.constant(0)]) 1228 1229 def build_nested_while(): 1230 return control_flow_ops.cond( 1231 constant_op.constant(True), build_while, lambda: [0, 0]) 1232 1233 # In v1 graph mode, pruning should make only "D" print. 1234 if not context.executing_eagerly(): 1235 with self.cached_session(): 1236 with self.captureWritesToStream(sys.stderr) as printed: 1237 self.assertEqual(self.evaluate(build_while()[0]), 2) 1238 self.assertTrue(printed.contents().endswith("D\nD\n"), 1239 printed.contents()) 1240 1241 with self.captureWritesToStream(sys.stderr) as printed: 1242 self.assertEqual(self.evaluate(build_nested_while()[0]), 2) 1243 self.assertTrue(printed.contents().endswith("D\nD\n"), 1244 printed.contents()) 1245 1246 # In defuns, all prints should execute in program order. 1247 @eager_function.defun 1248 def while_loop(): 1249 return build_while()[0] 1250 1251 with self.captureWritesToStream(sys.stderr) as printed: 1252 self.assertEqual(self.evaluate(while_loop()), 2) 1253 self.assertTrue(printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"), 1254 printed.contents()) 1255 1256 @eager_function.defun 1257 def nested_while_loop(): 1258 return build_nested_while()[0] 1259 1260 # TODO(b/117840611): calling nested_while_loop fails in eager 1261 if not context.executing_eagerly(): 1262 with self.captureWritesToStream(sys.stderr) as printed: 1263 self.assertEqual(self.evaluate(nested_while_loop()), 2) 1264 self.assertTrue( 1265 printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"), 1266 printed.contents()) 1267 1268 # wrap_function should prune. 1269 def pruned_while(): 1270 return build_while()[0] 1271 pruned_while = wrap_function.wrap_function(pruned_while, []) 1272 1273 with self.captureWritesToStream(sys.stderr) as printed: 1274 self.assertEqual(self.evaluate(pruned_while()), 2) 1275 self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents()) 1276 1277 def pruned_nested_while(): 1278 return build_nested_while()[0] 1279 pruned_nested_while = wrap_function.wrap_function(pruned_nested_while, []) 1280 1281 # TODO(b/117840611): calling nested_while_loop fails in eager 1282 if not context.executing_eagerly(): 1283 with self.captureWritesToStream(sys.stderr) as printed: 1284 self.assertEqual(self.evaluate(pruned_nested_while()), 2) 1285 self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents()) 1286 1287 # Microbenchmark: 256,000 iterations/s. 1288 def testWhile_1(self): 1289 with self.cached_session(): 1290 n = constant_op.constant(0) 1291 c = lambda x: math_ops.less(x, 10000) 1292 b = lambda x: math_ops.add(x, 1) 1293 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 1294 self.assertEqual(10000, self.evaluate(r)) 1295 1296 @test_util.run_v1_only("b/120545219") 1297 def testWhileExternalControlDependencies(self): 1298 with self.cached_session(): 1299 v = variables.Variable(0.0) 1300 v.initializer.run() 1301 increment = v.assign_add(1.0).read_value() 1302 1303 def body_fn(i): 1304 with ops.control_dependencies([increment]): 1305 return i + 1 1306 1307 result = control_flow_ops.while_loop(cond=lambda i: i < 2, 1308 body=body_fn, loop_vars=[1]) 1309 self.assertAllEqual(result, 2) 1310 self.assertAllEqual(v.read_value(), 1.0) 1311 1312 @test_util.run_v1_only("b/120545219") 1313 def testWhileExternalControlDependenciesNoInput(self): 1314 with self.cached_session(): 1315 v = variables.Variable(0.0) 1316 v.initializer.run() 1317 # TODO(apassos): figure out why the reading is necessary here. 1318 increment = v.assign_add(1.0).read_value() 1319 1320 def body_fn(unused_i): 1321 with ops.control_dependencies([increment]): 1322 return constant_op.constant(5, name="five") 1323 1324 result = control_flow_ops.while_loop(cond=lambda i: i < 5, 1325 body=body_fn, loop_vars=[0]) 1326 self.evaluate(result) 1327 self.assertAllEqual(self.evaluate(v), 1.0) 1328 1329 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 1330 @test_util.run_v1_only("b/120545219") 1331 def testWhileWithRefs_1(self): 1332 with self.cached_session() as sess: 1333 x = variables.VariableV1(0)._ref() # pylint: disable=protected-access 1334 i = constant_op.constant(0) 1335 c = lambda i, x: math_ops.less(i, 100) 1336 1337 self.assertEqual(x.dtype, dtypes.int32_ref) 1338 1339 def b(i, x): 1340 self.assertEqual(x.dtype, dtypes.int32_ref) 1341 return (i + 1, gen_array_ops.ref_identity(x)) 1342 1343 r = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=5) 1344 1345 self.evaluate(variables.global_variables_initializer()) 1346 1347 self.assertEqual(r[0].dtype, dtypes.int32) 1348 self.assertEqual(r[1].dtype, dtypes.int32_ref) 1349 1350 value_i, value_x = self.evaluate(r) 1351 1352 self.assertEqual(100, value_i) 1353 self.assertEqual(0, value_x) 1354 1355 def testWhile_2(self): 1356 with self.cached_session(): 1357 s = constant_op.constant(0) 1358 r = isum(s) 1359 self.assertAllEqual(45, self.evaluate(r)) 1360 1361 def testWhileWithMaximumIterations(self): 1362 with self.cached_session(): 1363 s = constant_op.constant([1, 2, 3, 4, 5]) 1364 r = isum(s, maximum_iterations=3) 1365 self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], self.evaluate(r)) 1366 1367 @test_util.run_v1_only("b/120545219") 1368 def testWhileWithMaximumIterationsAndSingleArgument(self): 1369 with self.cached_session(): 1370 r = control_flow_ops.while_loop( 1371 lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1) 1372 self.assertEqual(1, self.evaluate(r)) 1373 1374 @test_util.disable_control_flow_v2("b/115776323 (max_iters)") 1375 @test_util.run_v1_only("b/120545219") 1376 def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self): 1377 v = constant_op.constant(1.0) 1378 1379 def training_loop_with_gradient(i): 1380 out = control_flow_ops.while_loop( 1381 lambda i_, _: i_ < 3, 1382 lambda i_, j: [i_ + 1, j * v], [0, 1.0], 1383 maximum_iterations=i) 1384 g = gradients_impl.gradients(out, v) 1385 with ops.control_dependencies(g): 1386 return i + 1 1387 1388 xla_context = control_flow_ops.XLAControlFlowContext() 1389 xla_context.Enter() 1390 # Create training loop, ensure we can call gradient() of 1391 # while_loop inside the training loop. 1392 loop = control_flow_ops.while_loop(lambda i: i < 3, 1393 training_loop_with_gradient, [0]) 1394 xla_context.Exit() 1395 1396 loop_execute = array_ops.identity(loop) # Because loop is not fetchable. 1397 1398 # Should execute without issue. 1399 self.assertEqual(3, self.evaluate(loop_execute)) 1400 1401 @test_util.run_v1_only("b/120545219") 1402 def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self): 1403 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1404 self.skipTest("WhileV2 does lazy evaluation of maximum_iterations") 1405 v = constant_op.constant(1.0) 1406 1407 def inner_body(i, x): 1408 out = control_flow_ops.while_loop( 1409 lambda i, _: i < 3, 1410 lambda i, j: [i + 1, j * v], [0, x], 1411 maximum_iterations=i) 1412 return out 1413 1414 def create_while_loop(maximum_iterations=None): 1415 return control_flow_ops.while_loop( 1416 lambda i, _: i < 3, 1417 inner_body, [0, 1.0], 1418 maximum_iterations=maximum_iterations) 1419 1420 loop_no_xla = create_while_loop(maximum_iterations=5) 1421 # maximum_iterations is fine outside of an XLA scope 1422 gs = gradients_impl.gradients(loop_no_xla, v) 1423 self.evaluate(gs) # This should execute without error. 1424 1425 xla_context = control_flow_ops.XLAControlFlowContext() 1426 xla_context.Enter() 1427 loop_no_maxiter = create_while_loop() 1428 loop_with_maxiter = create_while_loop(maximum_iterations=2) 1429 xla_context.Exit() 1430 1431 with self.assertRaisesRegexp( 1432 ValueError, 1433 r"Cannot create a gradient accumulator for tensor '.+' inside " 1434 r"XLA while_loop because maximum_iterations was not passed to " 1435 r"the tf.while_loop call \('.+'\)."): 1436 _ = gradients_impl.gradients(loop_no_maxiter, v) 1437 1438 with self.assertRaisesRegexp( 1439 ValueError, 1440 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 1441 r"while_loop. maximum_iterations tensor '.+' for while_loop context " 1442 r"'.+' must be statically known \(e.g. a constant value or known " 1443 r"shape dimension\), or be defined at or outside the while loop " 1444 r"context '.*' \(currently defined in '.*'\)"): 1445 _ = gradients_impl.gradients(loop_with_maxiter, v) 1446 1447 @test_util.run_v1_only("b/120545219") 1448 def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self): 1449 v = constant_op.constant(1.0) 1450 1451 def create_while_loop(): 1452 max_iter_holder = [] 1453 1454 def create_mi(): 1455 max_iter_holder.append(array_ops.placeholder(dtypes.int32, shape=())) 1456 return 1.0 1457 1458 _ = control_flow_ops.cond( 1459 constant_op.constant(True), create_mi, create_mi) 1460 1461 return control_flow_ops.while_loop( 1462 lambda i, _: i < 3, 1463 lambda i, x: (i + 1, v * x), (0, 1.0), 1464 maximum_iterations=max_iter_holder[0]) 1465 1466 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1467 xla_context = control_flow_ops.XLAControlFlowContext() 1468 xla_context.Enter() 1469 with self.assertRaisesRegexp( 1470 ValueError, r"Tensor.*Placeholder:0.* must be from the same graph.*"): 1471 loop = create_while_loop() 1472 xla_context.Exit() 1473 else: 1474 xla_context = control_flow_ops.XLAControlFlowContext() 1475 xla_context.Enter() 1476 loop = create_while_loop() 1477 xla_context.Exit() 1478 with self.assertRaisesRegexp( 1479 ValueError, 1480 r"Cannot create a gradient accumulator for tensor '.+' inside XLA " 1481 r"while_loop. maximum_iterations tensor '.*Placeholder:0' for " 1482 r"while_loop context '.+' must be statically known \(e.g. a constant " 1483 r"value or known shape dimension\), or be defined at or outside the " 1484 r"while loop context '' \(currently defined in 'cond/.+'\)"): 1485 _ = gradients_impl.gradients(loop, v) 1486 1487 @test_util.run_v1_only("b/120545219") 1488 def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self): 1489 if test_util.is_gpu_available(): 1490 self.skipTest("b/128646372, b/128645947 fails in opensource build") 1491 1492 v = constant_op.constant(1.0) 1493 1494 p = array_ops.placeholder(dtype=dtypes.int32) 1495 1496 def mid_body_builder(iterations): 1497 1498 def mid_body(i, x): 1499 r = control_flow_ops.while_loop( 1500 lambda *_: True, 1501 lambda i, x: (i + 1, v * x), (0, x), 1502 maximum_iterations=iterations, 1503 name="inner") 1504 return (i + 1, gradients_impl.gradients(x + r[1], v)[0]) 1505 1506 return mid_body 1507 1508 def outer_body(i, x): 1509 iterations = array_ops.size(p, name="iterations") 1510 return (i + 1, x + control_flow_ops.while_loop( 1511 lambda *_: True, 1512 mid_body_builder(iterations), (0, x), 1513 maximum_iterations=iterations, 1514 name="mid")[1]) 1515 1516 def create_while_loop(): 1517 with ops.device("/cpu:0"): 1518 r = control_flow_ops.while_loop( 1519 lambda *_: True, 1520 outer_body, (0, 1.0), 1521 maximum_iterations=5, 1522 name="outer") 1523 return array_ops.identity(r[1]) 1524 1525 xla_context = control_flow_ops.XLAControlFlowContext() 1526 xla_context.Enter() 1527 final_with_xla_context = create_while_loop() 1528 xla_context.Exit() 1529 1530 final_without_xla_context = create_while_loop() 1531 1532 with self.session(use_gpu=False) as sess: 1533 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 1534 run_metadata_without_xla_context = config_pb2.RunMetadata() 1535 run_metadata = config_pb2.RunMetadata() 1536 1537 final_value_without_xla_context = sess.run( 1538 final_without_xla_context, 1539 feed_dict={p: [0, 0, 0]}, 1540 options=opts, 1541 run_metadata=run_metadata_without_xla_context) 1542 1543 final_value_with_xla_context = sess.run( 1544 final_with_xla_context, 1545 feed_dict={p: [0, 0, 0]}, 1546 options=opts, 1547 run_metadata=run_metadata) 1548 1549 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 1550 # With while_v2 on xla, run_metadata only contains the unlowered While 1551 # op so node_stats does not have statistics for the pushes. So as a 1552 # loose check we check the pushes in the lowered version. 1553 node_stats = run_metadata_without_xla_context.step_stats.dev_stats[ 1554 0].node_stats 1555 stack_push_op = "TensorListPushBack" 1556 else: 1557 node_stats = run_metadata.step_stats.dev_stats[0].node_stats 1558 stack_push_op = "StackPushV2" 1559 stack_push_count = len( 1560 [x for x in node_stats if x.node_name.endswith(stack_push_op)]) 1561 # Pushes to the stack = product of maximum_iterations values; 1562 # the last two "3"s comes from size(p), when p == [0, 0, 0]. 1563 self.assertEqual(stack_push_count, 5 * 3 * 3, str(node_stats)) 1564 1565 self.assertAllClose(final_value_with_xla_context, 1566 final_value_without_xla_context) 1567 1568 # Have more than 10 parallel iterations and hence exercise k-bound 1569 # most of the time. 1570 @test_util.run_deprecated_v1 1571 def testWhile_3(self): 1572 with self.cached_session(): 1573 1574 def compute(i, m, c, o): 1575 m, c = [math_ops.add(m, 1), math_ops.add(c, 1)] 1576 o = math_ops.add(o, m) 1577 o = math_ops.add(o, c) 1578 i = math_ops.add(i, 1) 1579 return [i, m, c, o] 1580 1581 i = ops.convert_to_tensor(0) 1582 m = ops.convert_to_tensor(0) 1583 c = ops.convert_to_tensor(0) 1584 o = ops.convert_to_tensor(0) 1585 d = ops.convert_to_tensor(100) 1586 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, d), 1587 compute, [i, m, c, o]) 1588 result = r[3] 1589 self.assertAllEqual(10100, result) 1590 1591 @test_util.run_deprecated_v1 1592 def testWhile_4(self): 1593 with self.cached_session(): 1594 1595 def compute(i, m, c, o): 1596 m, c = [array_ops.gather(x, i), array_ops.gather(x, i)] 1597 o = math_ops.add(o, m) 1598 o = math_ops.add(o, c) 1599 i = math_ops.add(i, 1) 1600 return [i, m, c, o] 1601 1602 i = ops.convert_to_tensor(0) 1603 m = ops.convert_to_tensor(0) 1604 c = ops.convert_to_tensor(0) 1605 o = ops.convert_to_tensor(0) 1606 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 1607 s = array_ops.size(x) 1608 r = control_flow_ops.while_loop(lambda i, m, c, o: math_ops.less(i, s), 1609 compute, [i, m, c, o]) 1610 result = r[3] 1611 self.assertAllEqual(42, result) 1612 1613 @test_util.run_v1_only("b/120545219") 1614 def testWhile_5(self): 1615 with self.cached_session(): 1616 1617 def compute(i, c, o): 1618 c = array_ops.strided_slice(x, array_ops.expand_dims(i, 0), 1619 [1] + array_ops.expand_dims(i, 0)) 1620 o = array_ops.concat([o, c], 0) 1621 i = math_ops.add(i, 1) 1622 return [i, c, o] 1623 1624 i = ops.convert_to_tensor(0) 1625 c = ops.convert_to_tensor([0]) 1626 o = ops.convert_to_tensor([0]) 1627 x = ops.convert_to_tensor([1, 2, 3, 4, 5, 6]) 1628 s = array_ops.size(x) 1629 r = control_flow_ops.while_loop(lambda i, c, o: math_ops.less(i, s), 1630 compute, [i, c, o], [ 1631 i.get_shape(), 1632 tensor_shape.unknown_shape(), 1633 tensor_shape.unknown_shape() 1634 ]) 1635 result = r[2] 1636 self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result) 1637 1638 @test_util.run_gpu_only 1639 @test_util.run_deprecated_v1 1640 def testWhile_Device(self): 1641 1642 # Body function defined outside of device scope 1643 def body(x): 1644 return math_ops.exp(x) 1645 1646 with ops.device("CPU:0"): 1647 r = control_flow_ops.while_loop( 1648 lambda x: x < 10, body, [constant_op.constant(-10.)]) 1649 self.assertIn("cpu", r.device.lower()) 1650 1651 with session.Session() as sess: 1652 options = config_pb2.RunOptions(output_partition_graphs=True) 1653 run_metadata = config_pb2.RunMetadata() 1654 sess.run(r, options=options, run_metadata=run_metadata) 1655 # We expect that everything runs on CPU, even if GPU is available. 1656 self.assertEqual(len(run_metadata.partition_graphs), 1) 1657 1658 @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)") 1659 @test_util.run_v1_only("b/120545219") 1660 def testBufferForwarding(self): 1661 run_options = config_pb2.RunOptions( 1662 trace_level=config_pb2.RunOptions.FULL_TRACE) 1663 run_metadata = config_pb2.RunMetadata() 1664 1665 with self.cached_session() as sess: 1666 with ops.device("/cpu:0"): 1667 c = constant_op.constant(2) 1668 i0 = constant_op.constant(0) 1669 r = control_flow_ops.while_loop(lambda i: i < 1000, 1670 lambda i: math_ops.square(c) + i, [i0]) 1671 r_val = sess.run(r, options=run_options, run_metadata=run_metadata) 1672 self.assertEqual(1000, r_val) 1673 self.assertTrue(run_metadata.HasField("step_stats")) 1674 unique_allocs = set() 1675 for node_stat in run_metadata.step_stats.dev_stats[0].node_stats: 1676 for output in node_stat.output: 1677 unique_allocs.add( 1678 output.tensor_description.allocation_description.ptr) 1679 # Prior to cl/147536680, the number of unique allocations was about 1005. 1680 self.assertLess(len(unique_allocs), 756) 1681 1682 def _testWhile_Gpu_1(self, use_gpu): 1683 with self.cached_session(use_gpu=use_gpu): 1684 n = constant_op.constant(1.0) 1685 c = lambda x: math_ops.less(x, 10.0) 1686 b = lambda x: math_ops.add(x, 1.0) 1687 r = control_flow_ops.while_loop(c, b, [n]) 1688 self.assertAllClose(10.0, self.evaluate(r)) 1689 1690 def testWhile_Gpu_1(self): 1691 self._testWhile_Gpu_1(use_gpu=False) 1692 self._testWhile_Gpu_1(use_gpu=True) 1693 1694 def _testWhile_Gpu_2(self, use_gpu): 1695 with self.cached_session(use_gpu=use_gpu): 1696 n = constant_op.constant(1.0) 1697 c = lambda x: math_ops.less(x, 10.0) 1698 1699 def b(x): 1700 with ops.device("/cpu:0"): 1701 return math_ops.add(x, 1.0) 1702 1703 r = control_flow_ops.while_loop(c, b, [n]) 1704 self.assertAllClose(10.0, self.evaluate(r)) 1705 1706 def testWhile_Gpu_2(self): 1707 self._testWhile_Gpu_2(use_gpu=False) 1708 self._testWhile_Gpu_2(use_gpu=True) 1709 1710 def testWhileShape(self): 1711 with self.cached_session(): 1712 i = constant_op.constant(0) 1713 m = array_ops.ones([2, 2]) 1714 c = lambda i, j: math_ops.less(i, 2) 1715 1716 def _b(i, j): 1717 new_i = math_ops.add(i, 1) 1718 new_j = array_ops.tile(j, [2, 2]) 1719 return [new_i, new_j] 1720 1721 r = control_flow_ops.while_loop( 1722 c, _b, [i, m], 1723 [i.get_shape(), tensor_shape.unknown_shape()]) 1724 r = r[1] * array_ops.ones([8, 8]) 1725 self.assertAllEqual(np.ones((8, 8)), self.evaluate(r)) 1726 1727 @test_util.run_deprecated_v1 1728 def testWhileWithNonTensorInput_Scalar(self): 1729 with self.cached_session(): 1730 n = 0 1731 c = lambda x: x < 10000 1732 b = lambda x: x + 1 1733 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 1734 self.assertEqual(10000, self.evaluate(r)) 1735 1736 def testWhileWithNonTensorInput_Vector(self): 1737 with self.cached_session(): 1738 n = np.array([0]) # Note, [0] would not work here; that is a list 1739 c = lambda x: x[0] < 10000 1740 b = lambda x: array_ops.stack([x[0] + 1]) 1741 r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) 1742 self.assertEqual([10000], self.evaluate(r)) 1743 1744 @test_util.run_v1_only("b/120545219") 1745 def testWhileShapeInference(self): 1746 with self.cached_session(): 1747 i = constant_op.constant(0) 1748 m = array_ops.ones([2, 2]) 1749 c = lambda i, j: math_ops.less(i, 2) 1750 1751 def b(i, j): 1752 new_i = math_ops.add(i, 1) 1753 new_j = array_ops.concat([j, j], 0) 1754 return [new_i, new_j] 1755 1756 r = control_flow_ops.while_loop( 1757 c, b, [i, m], 1758 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 1759 self.assertIsNone(r[1].shape.dims[0].value) 1760 self.assertEqual(r[1].shape.dims[1], tensor_shape.Dimension(2)) 1761 1762 with self.assertRaisesRegexp( 1763 ValueError, 1764 r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has " 1765 r"shape \(4, 2\) after one iteration. To allow the shape to vary " 1766 r"across iterations, use the `shape_invariants` argument of " 1767 r"tf.while_loop to specify a less-specific shape."): 1768 r = control_flow_ops.while_loop(c, b, [i, m]) 1769 1770 @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)") 1771 @test_util.run_v1_only("b/120545219") 1772 def testWhileShapeInferenceSparseTensor(self): 1773 values = constant_op.constant([2.0, 4.0], name="values") 1774 indices = constant_op.constant([[0], [3]], 1775 dtype=dtypes.int64, 1776 name="indices") 1777 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 1778 i = constant_op.constant(0) 1779 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 1780 1781 def c(i, _): 1782 return i < 10 1783 1784 def b1(i, x): # modifies values. (shape of components is not changed.) 1785 return [ 1786 i + 1, 1787 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 1788 ] 1789 1790 def b2(i, x): # adds new values. (shape of components is changed.) 1791 return [ 1792 i + 1, 1793 sparse_ops.sparse_add( 1794 x, 1795 sparse_tensor.SparseTensor( 1796 indices=math_ops.cast( 1797 array_ops.fill([1, 1], i), dtypes.int64), 1798 values=array_ops.fill([1], 1.0), 1799 dense_shape=x.dense_shape)) 1800 ] 1801 1802 def b3(i, x): # modifies rank. (shape of all components is changed.) 1803 return [ 1804 i + 1, 1805 sparse_tensor.SparseTensor( 1806 array_ops.concat([x.indices, [[i], [i]]], axis=1), x.values * 2.0, 1807 array_ops.concat([x.dense_shape, [10]], axis=0)) 1808 ] 1809 1810 # Default shape invariant; b1 only modifies values. 1811 _, r = control_flow_ops.while_loop(c, b1, [i, x]) 1812 self.assertEqual(r.indices.get_shape().as_list(), [None, 1]) 1813 self.assertEqual(r.values.get_shape().as_list(), [None]) 1814 self.assertEqual(r.dense_shape.get_shape().as_list(), [1]) 1815 1816 # Default shape invariant; b2 adds new values 1817 _, r = control_flow_ops.while_loop(c, b2, [i, x]) 1818 self.assertEqual(r.indices.get_shape().as_list(), [None, 1]) 1819 self.assertEqual(r.values.get_shape().as_list(), [None]) 1820 self.assertEqual(r.dense_shape.get_shape().as_list(), [1]) 1821 1822 # Default shape invariant; b3 modifies rank (which is not allowed). 1823 with self.assertRaises(ValueError): 1824 _, r = control_flow_ops.while_loop(c, b3, [i, x]) 1825 1826 # Explicit shape invariant, allowing any rank; b1 only modifies values. 1827 _, r = control_flow_ops.while_loop( 1828 c, b1, [i, x], 1829 [i.get_shape(), tensor_shape.TensorShape([None])]) 1830 self.assertEqual(r.indices.get_shape().as_list(), [None, None]) 1831 self.assertEqual(r.values.get_shape().as_list(), [None]) 1832 self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) 1833 1834 # Explicit shape invariant, allowing any rank; b3 modifies rank. 1835 _, r = control_flow_ops.while_loop( 1836 c, b3, [i, x], 1837 [i.get_shape(), tensor_shape.TensorShape([None])]) 1838 self.assertEqual(r.indices.get_shape().as_list(), [None, None]) 1839 self.assertEqual(r.values.get_shape().as_list(), [None]) 1840 self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) 1841 1842 # Shape invariant with ndims=None. Technically, this isn't supported 1843 # according to the docs, but we support it for backwards compatibility. 1844 _, r = control_flow_ops.while_loop( 1845 c, b1, [i, x], 1846 [i.get_shape(), tensor_shape.TensorShape(None)]) 1847 self.assertEqual(r.indices.get_shape().as_list(), [None, None]) 1848 self.assertEqual(r.values.get_shape().as_list(), [None]) 1849 self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) 1850 _, r = control_flow_ops.while_loop( 1851 c, b3, [i, x], 1852 [i.get_shape(), tensor_shape.TensorShape(None)]) 1853 self.assertEqual(r.indices.get_shape().as_list(), [None, None]) 1854 self.assertEqual(r.values.get_shape().as_list(), [None]) 1855 self.assertEqual(r.dense_shape.get_shape().as_list(), [None]) 1856 1857 # Explicit shape invariant, with a specific (incompatible) rank. 1858 with self.assertRaisesRegexp(ValueError, "is not compatible with"): 1859 _, r = control_flow_ops.while_loop( 1860 c, b1, [i, x], 1861 [i.get_shape(), tensor_shape.TensorShape([5])]) 1862 1863 @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)") 1864 @test_util.run_v1_only("b/120545219") 1865 def testWhileShapeInferenceIndexedSlices(self): 1866 with self.cached_session(): 1867 values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values") 1868 indices = constant_op.constant([0, 3], name="indices") 1869 shape = constant_op.constant([10, 2], name="dense_shape") 1870 i = constant_op.constant(0) 1871 x = ops.IndexedSlices(values, indices, dense_shape=shape) 1872 1873 def c(i, _): 1874 return i < 10 1875 1876 def b(i, x): 1877 return [ 1878 i + 1, 1879 ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) 1880 ] 1881 1882 _, r = control_flow_ops.while_loop(c, b, [i, x]) 1883 self.assertEqual(r.dense_shape.get_shape()[0], 2) 1884 self.assertEqual(r.values.get_shape(), tensor_shape.TensorShape([2, 2])) 1885 1886 _, r = control_flow_ops.while_loop( 1887 c, b, [i, x], 1888 [i.get_shape(), tensor_shape.TensorShape([None, 2])]) 1889 self.assertEqual(r.dense_shape.get_shape()[0], 2) 1890 self.assertEqual(r.values.get_shape().as_list(), [None, 2]) 1891 1892 with self.assertRaisesRegexp(ValueError, "is not compatible with"): 1893 _, r = control_flow_ops.while_loop( 1894 c, b, [i, x], 1895 [i.get_shape(), tensor_shape.TensorShape([None, 5])]) 1896 1897 @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)") 1898 def testWhileShapeInferenceRaggedTensor(self): 1899 if context.executing_eagerly(): 1900 self.skipTest("b/116328420") 1901 i = constant_op.constant(0) 1902 x = ragged_factory_ops.constant([[1, 2], [3], [4, 5, 6]]) 1903 c = lambda i, _: i < 10 1904 1905 def b1(i, x): # Adds new values to rows (but doesn't create new rows) 1906 return [ 1907 i + 1, 1908 array_ops.concat([x, x], axis=1) 1909 ] 1910 1911 def b2(i, x): # Adds new rows. 1912 return [ 1913 i + 1, 1914 array_ops.concat([x, x], axis=0) 1915 ] 1916 1917 # Default shape invariant; b1 adds new values to rows. 1918 _, r = control_flow_ops.while_loop(c, b1, [i, x]) 1919 self.assertEqual(r.row_splits.shape.as_list(), [4]) 1920 1921 self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None])) 1922 1923 # Default shape invariant; b2 adds new rows (not allowed). 1924 if not context.executing_eagerly(): 1925 with self.assertRaises(ValueError): 1926 _, r = control_flow_ops.while_loop(c, b2, [i, x]) 1927 1928 # Explicit shape invariant; b1 adds new values to rows. 1929 _, r = control_flow_ops.while_loop( 1930 c, b1, [i, x], 1931 [i.get_shape(), tensor_shape.TensorShape([None, None])]) 1932 self.assertTrue(r.row_splits.shape.as_list() in ([4], [None])) 1933 self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None])) 1934 1935 # Explicit shape invariant; b2 adds new rows. 1936 _, r = control_flow_ops.while_loop( 1937 c, b2, [i, x], 1938 [i.get_shape(), tensor_shape.TensorShape([None, None])]) 1939 self.assertTrue(r.row_splits.shape.as_list() in ([3 * 2**10 + 1], [None])) 1940 self.assertTrue(r.values.shape.as_list() in ([6 * 2**10], [None])) 1941 1942 @test_util.disable_control_flow_v2("b/116328420 (RaggedTensor)") 1943 def testWhileShapeInferenceRaggedTensorRaggedRank2(self): 1944 if context.executing_eagerly(): 1945 self.skipTest("b/116328420") 1946 i = constant_op.constant(0) 1947 x = ragged_factory_ops.constant([[[1, 2], [3], [4, 5, 6]], 1948 [[], [8, 9, 10]]]) 1949 c = lambda i, _: i < 10 1950 def b(i, x): 1951 return [ 1952 i + 1, 1953 array_ops.concat([x, x[..., i:i+1]], axis=-1) 1954 ] 1955 _, r = control_flow_ops.while_loop(c, b, [i, x]) 1956 self.assertEqual(r.row_splits.shape.as_list(), [3]) 1957 self.assertTrue(r.values.row_splits.shape.as_list() in ([6], [None])) 1958 self.assertTrue(r.values.values.shape.as_list() in ([49], [None])) 1959 1960 def _testNestedWhile_1(self, use_gpu): 1961 with self.cached_session(use_gpu=use_gpu): 1962 n = constant_op.constant(0) 1963 1964 def cpu_sum(s): 1965 c = lambda i, s: math_ops.less(i, 10) 1966 1967 def b(i, s): 1968 i1 = math_ops.add(i, 1) 1969 with ops.device("/cpu:0"): 1970 s1 = math_ops.add(i, s) 1971 return i1, s1 1972 1973 _, r_s = control_flow_ops.while_loop(c, b, [n, s]) 1974 return r_s 1975 1976 c = lambda x: math_ops.less(x, 200) 1977 b = lambda x: math_ops.add(x, cpu_sum(n)) 1978 r = control_flow_ops.while_loop(c, b, [n]) 1979 self.assertEqual(225, self.evaluate(r)) 1980 1981 def testNestedWhile_1(self): 1982 self._testNestedWhile_1(use_gpu=False) 1983 self._testNestedWhile_1(use_gpu=True) 1984 1985 def _testNestedWhile_2(self, use_gpu): 1986 # Test the cases that A -> Enter and Exit -> A are partitioned. 1987 with self.cached_session(use_gpu=use_gpu): 1988 s0 = constant_op.constant(2.0) 1989 1990 def inner_loop(s): 1991 c = lambda s: math_ops.less(s, 20.0) 1992 1993 def b(s): 1994 s1 = math_ops.add(s, s) 1995 return s1 1996 1997 r_s = control_flow_ops.while_loop(c, b, [s], parallel_iterations=1) 1998 return r_s 1999 2000 outer_c = lambda x: math_ops.less(x, 3000.0) 2001 2002 def outer_b(x): 2003 x = logging_ops.Print(x, [x]) # Edge "Print -> Enter" is partitioned 2004 x = inner_loop(x) 2005 with ops.device("/cpu:0"): 2006 x = math_ops.square(x) # Edge "Exit -> Square" is partitioned 2007 return x 2008 2009 r = control_flow_ops.while_loop( 2010 outer_c, outer_b, [s0], parallel_iterations=1) 2011 self.assertEqual(1048576.0, self.evaluate(r)) 2012 2013 def testNestedWhile_2(self): 2014 self._testNestedWhile_2(use_gpu=False) 2015 self._testNestedWhile_2(use_gpu=True) 2016 2017 @test_util.run_v1_only("b/120545219") 2018 def testWhileWithControl_1(self): 2019 with self.cached_session(): 2020 n = constant_op.constant(0) 2021 r = constant_op.constant(0) 2022 condition = lambda n_, r_: math_ops.less(n_, 10) 2023 2024 def body(n_, r_): 2025 n_ = math_ops.add(n_, 1) 2026 with r_.graph.control_dependencies([r_]): 2027 r_ = constant_op.constant(12) 2028 return [n_, r_] 2029 2030 res = control_flow_ops.while_loop( 2031 condition, body, [n, r], parallel_iterations=1) 2032 self.assertAllEqual(12, res[1]) 2033 2034 @test_util.run_deprecated_v1 2035 def testWhileWithControl_2(self): 2036 with self.cached_session(): 2037 r = constant_op.constant(0) 2038 condition = lambda r_: math_ops.less(r_, 10) 2039 2040 def body(r_): 2041 with r_.graph.control_dependencies([r_]): 2042 r_ = constant_op.constant(12) 2043 return [r_] 2044 2045 res = control_flow_ops.while_loop( 2046 condition, body, [r], parallel_iterations=1) 2047 self.assertAllEqual(12, self.evaluate(res)) 2048 2049 @test_util.run_v1_only("b/120545219") 2050 def testWhileWithControl_3(self): 2051 with self.cached_session() as sess: 2052 b = array_ops.placeholder(dtypes.bool) 2053 c = constant_op.constant(1) 2054 x0 = constant_op.constant(0) 2055 with ops.control_dependencies([b]): 2056 r = control_flow_ops.while_loop(lambda x: x < 10, lambda x: x + c, [x0]) 2057 self.assertEqual(10, sess.run(r, {b: True})) 2058 2059 @test_util.run_v1_only("b/120545219") 2060 def testWhileWithControl_4(self): 2061 with self.cached_session() as sess: 2062 b = array_ops.placeholder(dtypes.bool) 2063 c = constant_op.constant(1) 2064 x0 = constant_op.constant(0) 2065 with ops.control_dependencies([b]): 2066 r = control_flow_ops.while_loop( 2067 lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0]) 2068 self.assertEqual(10, sess.run(r, {b: True})) 2069 2070 @test_util.run_v1_only("b/120545219") 2071 def testWhileWithControl_5(self): 2072 with self.cached_session() as sess: 2073 b = array_ops.placeholder(dtypes.bool) 2074 c = constant_op.constant(1) 2075 x0 = constant_op.constant(0) 2076 2077 def body(x): 2078 with ops.control_dependencies([b]): 2079 return x + c 2080 2081 r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0]) 2082 self.assertEqual(10, sess.run(r, {b: True})) 2083 2084 def testWhileCondWithControl(self): 2085 # Ensure that no control edges by an outer control dependency context are 2086 # added to nodes inside cond/while contexts. 2087 with self.cached_session() as sess: 2088 const_true = lambda: constant_op.constant(True) 2089 const_false = lambda: constant_op.constant(False) 2090 cond = lambda i: control_flow_ops.cond(i > 0, const_true, const_false) 2091 body = lambda i: control_flow_ops.cond(i > 0, lambda: i - 1, lambda: i) 2092 2093 with ops.control_dependencies([control_flow_ops.no_op()]): 2094 loop = control_flow_ops.while_loop(cond, body, 2095 (constant_op.constant(5),)) 2096 self.assertEqual(0, self.evaluate(loop)) 2097 2098 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 2099 @test_util.run_v1_only("b/120545219") 2100 def testWhileCondWithControl_1(self): 2101 with self.cached_session(): 2102 v = variable_scope.get_variable( 2103 "v", [], initializer=init_ops.constant_initializer(2)) 2104 i0 = constant_op.constant(0) 2105 with ops.control_dependencies([i0]): 2106 2107 def loop_condition(i): 2108 return i < 4 2109 2110 def loop_body(i): 2111 some_cond = control_flow_ops.cond( 2112 constant_op.constant(True), 2113 lambda: state_ops.assign(v, math_ops.square(v)), lambda: v) 2114 with ops.control_dependencies([some_cond]): 2115 return i + 1 2116 2117 r = control_flow_ops.while_loop(loop_condition, loop_body, (i0,)) 2118 self.evaluate(variables.global_variables_initializer()) 2119 self.assertEqual(4, self.evaluate(r)) 2120 self.assertAllClose(65536.0, self.evaluate(v)) 2121 2122 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 2123 @test_util.run_v1_only("b/120545219") 2124 def testWhileCondExitControl(self): 2125 2126 with self.cached_session(): 2127 v = variables.Variable(1) 2128 2129 def false_branch(): 2130 cond = lambda i: i < 100 2131 2132 def body(i): 2133 x = state_ops.assign(v, i) 2134 return x + 1 2135 2136 loop = control_flow_ops.while_loop(cond, body, [0]) 2137 # Make sure to handle correctly control edge from Exit to a node. 2138 with ops.control_dependencies([loop]): 2139 return constant_op.constant(6.0) 2140 2141 r = control_flow_ops.cond( 2142 constant_op.constant(False), lambda: constant_op.constant(1.0), 2143 false_branch) 2144 self.evaluate(variables.global_variables_initializer()) 2145 self.assertEqual(6.0, self.evaluate(r)) 2146 self.assertEqual(99, self.evaluate(v)) 2147 2148 def testCondWhile_1(self): 2149 2150 with self.cached_session(): 2151 n = ops.convert_to_tensor(0, name="n") 2152 c = lambda x: math_ops.less(x, 10) 2153 b = lambda x: math_ops.add(x, 1) 2154 r = control_flow_ops.cond( 2155 math_ops.less(0, 1), lambda: control_flow_ops.while_loop(c, b, [n]), 2156 lambda: n) 2157 self.assertAllEqual(10, self.evaluate(r)) 2158 2159 def testCondWhile_2(self): 2160 2161 with self.cached_session(): 2162 n = ops.convert_to_tensor(0) 2163 c = lambda x: math_ops.less(x, 10) 2164 b = lambda x: math_ops.add(x, 1) 2165 r = control_flow_ops.cond( 2166 math_ops.less(1, 0), lambda: math_ops.add(n, 1), 2167 lambda: control_flow_ops.while_loop(c, b, [n])) 2168 self.assertAllEqual(10, self.evaluate(r)) 2169 2170 def _testCondWhile_3(self, use_gpu): 2171 with self.cached_session(use_gpu=use_gpu) as sess: 2172 p = array_ops.placeholder(dtypes.bool) 2173 n = constant_op.constant(0.0) 2174 2175 def c(x): 2176 return math_ops.less(x, 10.0) 2177 2178 def b(x): 2179 with ops.device("/cpu:0"): 2180 x1 = math_ops.add(x, 1.0) 2181 return x1 2182 2183 r = control_flow_ops.cond(p, 2184 lambda: control_flow_ops.while_loop(c, b, [n]), 2185 lambda: math_ops.multiply(n, 2.0)) 2186 r1 = gradients_impl.gradients(r, [n]) 2187 self.assertEqual(10., sess.run(r, {p: True})) 2188 self.assertEqual([1.0], sess.run(r1, {p: True})) 2189 self.assertEqual(0.0, sess.run(r, {p: False})) 2190 self.assertEqual([2.0], sess.run(r1, {p: False})) 2191 2192 @test_util.run_deprecated_v1 2193 def testCondWhile_3(self): 2194 self._testCondWhile_3(use_gpu=False) 2195 self._testCondWhile_3(use_gpu=True) 2196 2197 def testWhileCond_1(self): 2198 2199 with self.cached_session(): 2200 i = ops.convert_to_tensor(0, name="i") 2201 n = ops.convert_to_tensor(10, name="n") 2202 one = ops.convert_to_tensor(1, name="one") 2203 c = lambda x: math_ops.less(x, n) 2204 # pylint: disable=undefined-variable 2205 # for OSS build 2206 b = lambda x: control_flow_ops.cond( 2207 constant_op.constant(True), 2208 lambda: math_ops.add(x, one), lambda: math_ops.subtract(x, one)) 2209 # pylint: enable=undefined-variable 2210 r = control_flow_ops.while_loop(c, b, [i]) 2211 self.assertAllEqual(10, self.evaluate(r)) 2212 2213 def testWhileCond_2(self): 2214 2215 with self.cached_session(): 2216 n = ops.convert_to_tensor(0, name="n") 2217 c = lambda x: math_ops.less(x, 10) 2218 b = lambda x: control_flow_ops.cond(constant_op.constant(True), lambda: math_ops.add(x, 1), lambda: n) 2219 r = control_flow_ops.while_loop(c, b, [n]) 2220 self.assertAllEqual(10, self.evaluate(r)) 2221 2222 def testWhileCond_3(self): 2223 2224 with self.cached_session(): 2225 n = ops.convert_to_tensor(0) 2226 c = lambda x: math_ops.less(x, 10) 2227 # pylint: disable=undefined-variable 2228 # for OSS build 2229 b = lambda x: control_flow_ops.cond(math_ops.less(0, 1), 2230 lambda: math_ops.add(x, 1), 2231 lambda: math_ops.subtract(x, 1)) 2232 # pylint: enable=undefined-variable 2233 r = control_flow_ops.while_loop(c, b, [n]) 2234 self.assertAllEqual(10, self.evaluate(r)) 2235 2236 @test_util.run_deprecated_v1 2237 def testWhileCondGradMultiDevice(self): 2238 config = config_pb2.ConfigProto(device_count={"CPU": 2}, 2239 allow_soft_placement=True) 2240 with self.cached_session(use_gpu=True, config=config) as sess: 2241 pred = array_ops.placeholder(dtypes.bool, []) 2242 x_init = constant_op.constant(1.0) 2243 2244 with ops.device("/cpu:0"): 2245 z = control_flow_ops.while_loop( 2246 lambda i, _: i < 3, 2247 lambda i, x: (i + 1, control_flow_ops.cond( 2248 pred, lambda: x * 2.0, lambda: 10.0)), 2249 [0, x_init]) 2250 2251 with ops.device("/cpu:1"): 2252 grad = gradients_impl.gradients(z, x_init)[0] 2253 2254 with ops.device("/cpu:0"): 2255 grad_grad = gradients_impl.gradients(grad, x_init)[0] 2256 2257 self.assertEqual(sess.run(grad, {pred: True}), 8.0) 2258 self.assertEqual(sess.run(grad, {pred: False}), 0.0) 2259 2260 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 2261 return 2262 2263 self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0) 2264 self.assertEqual(sess.run(grad_grad, {pred: False}), 0.0) 2265 2266 # NOTE: It is ok to have parallel_iterations > 1 2267 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2268 @test_util.run_deprecated_v1 2269 def testWhileUpdateVariable_1(self): 2270 with self.cached_session(): 2271 select = variables.Variable([3.0, 4.0, 5.0]) 2272 n = constant_op.constant(0) 2273 2274 def loop_iterator(j): 2275 return math_ops.less(j, 3) 2276 2277 def loop_body(j): 2278 ns = state_ops.scatter_update(select, j, 10.0) 2279 nj = math_ops.add(j, 1) 2280 op = control_flow_ops.group(ns) 2281 nj = control_flow_ops.with_dependencies([op], nj) 2282 return [nj] 2283 2284 r = control_flow_ops.while_loop( 2285 loop_iterator, loop_body, [n], parallel_iterations=1) 2286 self.evaluate(variables.global_variables_initializer()) 2287 self.assertEqual(3, self.evaluate(r)) 2288 result = self.evaluate(select) 2289 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 2290 2291 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2292 @test_util.run_v1_only("b/120545219") 2293 def testWhileUpdateVariable_2(self): 2294 with self.cached_session(): 2295 select1 = variables.Variable([3.0, 4.0, 5.0]) 2296 select2 = variables.Variable([3.0, 4.0, 5.0]) 2297 n = constant_op.constant(0) 2298 2299 def loop_iterator(j): 2300 return math_ops.less(j, 3) 2301 2302 def loop_body(j): 2303 ns1 = state_ops.scatter_update(select1, j, 10.0) 2304 ns2 = state_ops.scatter_update(select2, j, 10.0) 2305 nj = math_ops.add(j, 1) 2306 op = control_flow_ops.group(ns1, ns2) 2307 nj = control_flow_ops.with_dependencies([op], nj) 2308 return [nj] 2309 2310 r = control_flow_ops.while_loop( 2311 loop_iterator, loop_body, [n], parallel_iterations=1) 2312 self.evaluate(variables.global_variables_initializer()) 2313 self.assertEqual(3, self.evaluate(r)) 2314 result1 = self.evaluate(select1) 2315 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result1) 2316 result2 = self.evaluate(select2) 2317 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2) 2318 2319 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2320 @test_util.run_v1_only("b/120545219") 2321 def testWhileUpdateVariable_3(self): 2322 with self.cached_session(): 2323 select = variables.Variable([3.0, 4.0, 5.0]) 2324 n = constant_op.constant(0) 2325 2326 def loop_iterator(j, _): 2327 return math_ops.less(j, 3) 2328 2329 def loop_body(j, _): 2330 ns = state_ops.scatter_update(select, j, 10.0) 2331 nj = math_ops.add(j, 1) 2332 return [nj, ns] 2333 2334 r = control_flow_ops.while_loop( 2335 loop_iterator, 2336 loop_body, [n, array_ops.identity(select)], 2337 parallel_iterations=1) 2338 self.evaluate(variables.global_variables_initializer()) 2339 result = r[1] 2340 self.assertAllClose(np.array([10.0, 10.0, 10.0]), result) 2341 2342 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2343 @test_util.run_v1_only("b/120545219") 2344 def testWhileUpdateVariable_4(self): 2345 with self.cached_session(): 2346 var_a = variables.Variable(0, name="a") 2347 var_b = variables.Variable(0, name="b") 2348 self.evaluate(variables.global_variables_initializer()) 2349 2350 c = constant_op.constant(0, name="c") 2351 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 2352 2353 # Loop condition 2354 def pred(i): 2355 return math_ops.less(i, 10) 2356 2357 # Loop body 2358 def loop_body(i): 2359 asn2 = state_ops.assign_add(var_b, asn1, name="b_add") 2360 with ops.control_dependencies([asn2]): 2361 ni = math_ops.add(i, 1, name="i_add") 2362 return ni 2363 2364 lpa = control_flow_ops.while_loop( 2365 pred, loop_body, [c], parallel_iterations=1) 2366 2367 self.assertEqual(0, self.evaluate(var_b)) 2368 self.evaluate(lpa) # Run the loop 2369 self.assertEqual(10, self.evaluate(var_b)) 2370 2371 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2372 @test_util.run_v1_only("b/120545219") 2373 def testWhileUpdateVariable_5(self): 2374 with self.cached_session(): 2375 # Create some variables. 2376 var_a = variables.Variable(0, name="a") 2377 var_b = variables.Variable(0, name="b") 2378 self.evaluate(variables.global_variables_initializer()) 2379 2380 # Change condition to check var_b 2381 def pred(_): 2382 return math_ops.less(var_b, 10) 2383 2384 # Change body to increment var_b 2385 def loop_body(i): 2386 asn1 = state_ops.assign_add( 2387 var_a, constant_op.constant(1), name="a_add") 2388 asn2 = state_ops.assign_add( 2389 var_b, constant_op.constant(1), name="b_add") 2390 with ops.control_dependencies([asn1, asn2]): 2391 inc_b = array_ops.identity(var_b) 2392 return inc_b 2393 2394 lpa = control_flow_ops.while_loop( 2395 pred, loop_body, [var_b], parallel_iterations=1, name="loop") 2396 2397 self.assertEqual(0, self.evaluate(var_b)) 2398 self.evaluate(lpa) # Run the loop 2399 self.assertEqual(10, self.evaluate(var_a)) 2400 self.assertEqual(10, self.evaluate(var_b)) 2401 2402 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 2403 @test_util.run_v1_only("b/120545219") 2404 def testWhileUpdateVariable_6(self): 2405 with self.cached_session(): 2406 # Create some variables. 2407 var_a = variables.Variable(0, name="a") 2408 var_b = variables.Variable(0, name="b") 2409 c = constant_op.constant(0) 2410 self.evaluate(variables.global_variables_initializer()) 2411 2412 # Loop condition 2413 def pred(i): 2414 return math_ops.less(i, 10) 2415 2416 # Loop body 2417 def loop_body(i): 2418 asn1 = state_ops.assign_add(var_a, 1, name="a_add") 2419 with ops.control_dependencies([asn1]): 2420 asn2 = state_ops.assign_add(var_b, var_a, name="b_add") 2421 with ops.control_dependencies([asn2]): 2422 ni = math_ops.add(i, 1, name="i_add") 2423 return ni 2424 2425 lpa = control_flow_ops.while_loop( 2426 pred, loop_body, [c], parallel_iterations=1, name="loop") 2427 2428 self.assertEqual(0, self.evaluate(var_b)) 2429 self.evaluate(lpa) # Run the loop 2430 self.assertEqual(55, self.evaluate(var_b)) 2431 self.assertEqual(10, self.evaluate(var_a)) 2432 2433 @test_util.run_v1_only("b/120545219") 2434 def testWhileQueue_1(self): 2435 with self.cached_session(): 2436 q = data_flow_ops.FIFOQueue(-1, dtypes.int32) 2437 i = constant_op.constant(0) 2438 2439 def c(i): 2440 return math_ops.less(i, 10) 2441 2442 def b(i): 2443 ni = math_ops.add(i, 1) 2444 ni = control_flow_ops.with_dependencies([q.enqueue((i,))], ni) 2445 return ni 2446 2447 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 2448 self.assertEqual([10], self.evaluate(r)) 2449 for i in xrange(10): 2450 self.assertEqual([i], self.evaluate(q.dequeue())) 2451 2452 @test_util.run_v1_only("b/120545219") 2453 def testWhileTimeOut(self): 2454 run_options = config_pb2.RunOptions(timeout_in_ms=1) 2455 with self.cached_session() as sess: 2456 n = constant_op.constant(0) 2457 c = lambda x: True 2458 b = lambda x: math_ops.add(x, 1) 2459 r = control_flow_ops.while_loop(c, b, [n]) 2460 with self.assertRaises(errors_impl.DeadlineExceededError): 2461 sess.run(r, options=run_options) 2462 2463 @test_util.disable_control_flow_v2("b/117119329 (stack)") 2464 @test_util.run_v1_only("b/120545219") 2465 def testWhileStack_1(self): 2466 with self.cached_session(): 2467 s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo") 2468 i = constant_op.constant(0) 2469 2470 def c(i): 2471 return math_ops.less(i, 10) 2472 2473 def b(i): 2474 ni = math_ops.add(i, 1) 2475 ni = control_flow_ops.with_dependencies( 2476 [gen_data_flow_ops.stack_push_v2(s, i)], ni) 2477 return ni 2478 2479 r = control_flow_ops.while_loop(c, b, [i], parallel_iterations=1) 2480 2481 x = constant_op.constant(0) 2482 2483 def c1(i, _): 2484 return math_ops.greater(i, 0) 2485 2486 def b1(i, x): 2487 ni = math_ops.subtract(i, 1) 2488 nx = x + gen_data_flow_ops.stack_pop_v2(s, dtypes.int32) 2489 return [ni, nx] 2490 2491 _, rx = control_flow_ops.while_loop( 2492 c1, 2493 b1, [r, x], 2494 [r.get_shape(), tensor_shape.unknown_shape()], 2495 parallel_iterations=1) 2496 self.assertEqual(45, self.evaluate(rx)) 2497 2498 def _testWhileGrad_ColocateGradients(self, colocate): 2499 gpu_dev_name = test.gpu_device_name() if test.is_gpu_available( 2500 ) else "/device:CPU:0" 2501 2502 graph = ops.Graph() 2503 with graph.as_default(): 2504 v = constant_op.constant(2.0, name="v") 2505 c = lambda v: math_ops.less(v, 100.0) 2506 2507 def b(x): 2508 with ops.device(gpu_dev_name): 2509 return math_ops.square(x) 2510 2511 loop = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2512 r = gradients_impl.gradients( 2513 loop, v, colocate_gradients_with_ops=colocate)[0] 2514 2515 r_ops = graph.get_operations() 2516 r_devices = [(op.name, op.device) for op in r_ops] 2517 2518 self.assertTrue(any("Square" in op.name for op in r_ops)) 2519 2520 for (name, dev) in r_devices: 2521 if not colocate and name.endswith("Square"): 2522 # Only forward graph contain gpu in Square device 2523 self.assertTrue(gpu_dev_name in dev) 2524 elif colocate and "Square" in name: 2525 # Forward and backward graphs contain gpu in Square/Square_grad devices 2526 self.assertTrue(gpu_dev_name in dev) 2527 else: 2528 self.assertFalse(gpu_dev_name in dev) 2529 2530 with self.session(graph=graph) as sess: 2531 self.assertAllClose(1024.0, self.evaluate(r)) 2532 2533 @test_util.disable_control_flow_v2("b/116351701 (colocation)") 2534 @test_util.run_v1_only("b/120545219") 2535 def testWhileGrad_ColocateGradients(self): 2536 self._testWhileGrad_ColocateGradients(colocate=False) 2537 self._testWhileGrad_ColocateGradients(colocate=True) 2538 2539 @test_util.run_v1_only("b/120545219") 2540 def testWhileGrad_Square(self): 2541 with self.cached_session(): 2542 v = constant_op.constant(2.0, name="v") 2543 c = lambda v: math_ops.less(v, 100.0) 2544 b = math_ops.square 2545 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2546 r = control_flow_ops.cond(math_ops.less(1, 2), lambda: r, lambda: v) 2547 2548 r = gradients_impl.gradients(r, v)[0] 2549 self.assertAllClose(1024.0, self.evaluate(r)) 2550 2551 @test_util.run_v1_only("b/120545219") 2552 def testWhileGrad_Shape(self): 2553 with self.cached_session(): 2554 x = array_ops.placeholder(dtypes.float32, shape=[None]) 2555 v = constant_op.constant([2.0], name="v") 2556 n = constant_op.constant(0, name="n") 2557 c = lambda i, v: math_ops.less(i, 5) 2558 b = lambda i, v: [i + 1, math_ops.multiply(x, v)] 2559 r = control_flow_ops.while_loop( 2560 c, 2561 b, [n, v], 2562 [n.get_shape(), tensor_shape.unknown_shape()], 2563 parallel_iterations=1) 2564 2565 r = gradients_impl.gradients(r[1], x)[0] 2566 self.assertEqual([None], r.get_shape().as_list()) 2567 self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]})) 2568 2569 @test_util.run_deprecated_v1 2570 def testWhileGrad_BaseShape(self): 2571 with self.cached_session() as sess: 2572 x = array_ops.placeholder(dtypes.float32, [None]) 2573 v0 = constant_op.constant([2.0, 2.0], name="v") 2574 c = lambda v: constant_op.constant(False) 2575 b = lambda v: math_ops.multiply(v, x) 2576 r = control_flow_ops.while_loop(c, b, [v0]) 2577 y = math_ops.square(x) 2578 2579 r = gradients_impl.gradients([r, y], x)[0] 2580 self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]})) 2581 2582 @test_util.run_v1_only("b/120545219") 2583 def testWhileGrad_MultipleUses(self): 2584 with self.cached_session(): 2585 v = constant_op.constant(2.0, name="v") 2586 c = lambda v: math_ops.less(v, 100.0) 2587 b = math_ops.square 2588 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2589 r = math_ops.multiply(r, r) 2590 2591 r = gradients_impl.gradients(r, v)[0] 2592 self.assertEqual(524288.0, self.evaluate(r)) 2593 2594 @test_util.run_v1_only("b/120545219") 2595 def testWhileGrad_LoopAdd(self): 2596 with self.cached_session(): 2597 v = constant_op.constant(2.0, name="v") 2598 c = lambda v: math_ops.less(v, 100.0) 2599 b = math_ops.square 2600 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2601 r = math_ops.add(r, r) 2602 2603 r = gradients_impl.gradients(r, v)[0] 2604 self.assertAllClose(2048.0, self.evaluate(r)) 2605 2606 def _testWhileGrad_Mul(self, use_gpu, p_iters): 2607 with self.cached_session(use_gpu=use_gpu) as sess: 2608 a = constant_op.constant(3.0, name="a") 2609 v = constant_op.constant(2.0, name="v") 2610 c = lambda v: math_ops.less(v, 100.0) 2611 b = lambda v: math_ops.multiply(v, a) 2612 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=p_iters) 2613 2614 grad_a, grad_v = gradients_impl.gradients(r, [a, v]) 2615 grad_a_val, grad_v_val = self.evaluate([grad_a, grad_v]) 2616 self.assertAllClose(216.0, grad_a_val) 2617 self.assertAllClose(81.0, grad_v_val) 2618 2619 @test_util.run_deprecated_v1 2620 def testWhileGrad_Mul(self): 2621 self._testWhileGrad_Mul(use_gpu=False, p_iters=1) 2622 self._testWhileGrad_Mul(use_gpu=False, p_iters=10) 2623 self._testWhileGrad_Mul(use_gpu=True, p_iters=1) 2624 self._testWhileGrad_Mul(use_gpu=True, p_iters=10) 2625 2626 def _testNestedWhileCondWhileGrad(self, use_gpu): 2627 2628 with self.cached_session(use_gpu=use_gpu): 2629 v = constant_op.constant(1.0) 2630 2631 def inner_loop(s): 2632 z = constant_op.constant(0) 2633 c = lambda i, x: math_ops.less(i, 4) 2634 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 2635 return control_flow_ops.while_loop(c, b, [z, s]) 2636 2637 c = lambda x: math_ops.less(x, 128.0) 2638 2639 def b(x): 2640 return control_flow_ops.cond( 2641 constant_op.constant(True), 2642 lambda: math_ops.square(inner_loop(x)[1]), 2643 lambda: math_ops.multiply(x, 2.0)) 2644 2645 r = control_flow_ops.while_loop(c, b, [v]) 2646 r = gradients_impl.gradients(r, v)[0] 2647 self.assertAllClose(512.0, self.evaluate(r)) 2648 2649 @test_util.run_deprecated_v1 2650 def testNestedWhileCondWhileGrad(self): 2651 self._testNestedWhileCondWhileGrad(use_gpu=False) 2652 2653 @test_util.run_deprecated_v1 2654 def testNestedWhileCondWhileGradGpu(self): 2655 self._testNestedWhileCondWhileGrad(use_gpu=True) 2656 2657 @test_util.run_v1_only("b/120545219") 2658 def testWhileGrad_Variable(self): 2659 with self.cached_session(): 2660 a = variables.Variable(3.0) 2661 v = constant_op.constant(2.0, name="v") 2662 c = lambda v: math_ops.less(v, 100.0) 2663 b = lambda v: math_ops.multiply(v, a) 2664 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2665 2666 r = gradients_impl.gradients(r, a) 2667 self.evaluate(variables.global_variables_initializer()) 2668 self.assertAllClose(216.0, r[0]) 2669 2670 @test_util.run_deprecated_v1 2671 def testWhileGrad_ResourceVariable(self): 2672 with self.cached_session(): 2673 a = resource_variable_ops.ResourceVariable(3.0) 2674 v = constant_op.constant(2.0, name="v") 2675 c = lambda v: math_ops.less(v, 100.0) 2676 b = lambda v: math_ops.multiply(v, a) 2677 r = control_flow_ops.while_loop(c, b, [v], parallel_iterations=1) 2678 2679 g = gradients_impl.gradients(r, a) 2680 self.evaluate(variables.global_variables_initializer()) 2681 self.assertAllClose(216.0, g[0]) 2682 2683 def testWhileGrad_EagerResourceVariable(self): 2684 with context.eager_mode(): 2685 a = resource_variable_ops.ResourceVariable( 2686 np.ones([2, 2], dtype=np.float32)) 2687 v = constant_op.constant(1.0) 2688 2689 @eager_function.defun 2690 def fn(): 2691 r = control_flow_ops.while_loop( 2692 lambda i, _: i < 2, 2693 lambda i, x: (i + 1, x * math_ops.reduce_sum(a) * v), 2694 [0, 1.0])[1] 2695 return gradients_impl.gradients(r, [v])[0] 2696 2697 self.assertEqual(self.evaluate(fn()), 32.) 2698 2699 @test_util.disable_xla("b/128643381") 2700 def testWhileGrad_ResourceVarInFunctionCall(self): 2701 2702 @def_function.function 2703 def foo(x, var): 2704 return x + math_ops.reduce_sum(var.sparse_read([1, 3])) 2705 2706 @def_function.function 2707 def bar(var): 2708 r = control_flow_ops.while_loop( 2709 lambda i, _: i < 2, 2710 lambda i, x: (i + 1, foo(x, var)), 2711 [0, 0.0])[1] 2712 return gradients_impl.gradients(r, var)[0] 2713 2714 var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.]) 2715 self.evaluate(variables.global_variables_initializer()) 2716 grad = self.evaluate(bar(var)) 2717 self.assertIsInstance(grad, ops.IndexedSlicesValue) 2718 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.]) 2719 2720 @test_util.disable_xla("b/128643461") 2721 def testWhileGrad_ResourceVarInNestedFunctionCall(self): 2722 2723 @def_function.function 2724 def foo(x, var): 2725 return x + math_ops.reduce_sum(var.sparse_read([1, 3])) 2726 2727 @def_function.function 2728 def foo2(x, var): 2729 return foo(x, var) 2730 2731 @def_function.function 2732 def bar(var): 2733 r = control_flow_ops.while_loop( 2734 lambda i, _: i < 2, 2735 lambda i, x: (i + 1, foo2(x, var)), 2736 [0, 0.0])[1] 2737 return gradients_impl.gradients(r, var)[0] 2738 2739 var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.]) 2740 self.evaluate(variables.global_variables_initializer()) 2741 grad = self.evaluate(bar(var)) 2742 self.assertIsInstance(grad, ops.IndexedSlicesValue) 2743 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.]) 2744 2745 def testWhileGrad_ResourceVarInLoopInFunctionCall(self): 2746 if test.is_gpu_available(): 2747 self.skipTest("b/128635252") 2748 2749 @def_function.function 2750 def foo(x, var): 2751 return control_flow_ops.while_loop( 2752 lambda j, _: j < 3, 2753 lambda j, y: (j + 1, 2754 y + math_ops.reduce_sum(var.sparse_read([1, 2]))), 2755 [0, x])[1] 2756 2757 @def_function.function 2758 def bar(var): 2759 r = control_flow_ops.while_loop( 2760 lambda i, _: i < 2, 2761 lambda i, x: (i + 1, foo(x, var)), 2762 [0, 0.0])[1] 2763 return gradients_impl.gradients(r, var)[0] 2764 2765 var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.]) 2766 self.evaluate(variables.global_variables_initializer()) 2767 grad = self.evaluate(bar(var)) 2768 self.assertIsInstance(grad, ops.IndexedSlicesValue) 2769 self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.]) 2770 2771 @test_util.disable_xla("b/128639858") 2772 def testWhileCondGrad_ResourceVarInFunctionCall(self): 2773 2774 @def_function.function 2775 def foo(x, var): 2776 return x + var.sparse_read([1])[0] 2777 2778 def body(i, x): 2779 return (i + 1, control_flow_ops.cond( 2780 math_ops.equal(i % 2, 0), 2781 lambda: foo(x, var1), 2782 lambda: foo(x, var2))) 2783 2784 @def_function.function 2785 def bar(var1, var2): 2786 r = control_flow_ops.while_loop( 2787 lambda i, _: i < 4, body, [0, 0.0]) 2788 return gradients_impl.gradients(r, [var1, var2]) 2789 2790 var1 = resource_variable_ops.ResourceVariable([1., 2., 3.]) 2791 var2 = resource_variable_ops.ResourceVariable([4., 5.]) 2792 self.evaluate(variables.global_variables_initializer()) 2793 grads = self.evaluate(bar(var1, var2)) 2794 self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.]) 2795 self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.]) 2796 2797 @test_util.run_deprecated_v1 2798 def testWhileGrad_ResourceVarSparseRead(self): 2799 # NOTE(skyewm): this test is interesting because the 2800 # ResourceVariable.sparse_read gradient function returns an IndexedSlices. 2801 var = resource_variable_ops.ResourceVariable(np.ones(5), 2802 dtype=dtypes.float32) 2803 r = control_flow_ops.while_loop( 2804 lambda i, _: i < 3, 2805 lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))), 2806 [0, constant_op.constant(1.0)])[1] 2807 grad = gradients_impl.gradients(r, var)[0] 2808 2809 self.evaluate(variables.global_variables_initializer()) 2810 grad_val = self.evaluate(grad) 2811 self.assertIsInstance(grad_val, ops.IndexedSlicesValue) 2812 arr = gradient_checker_v2._to_numpy(grad_val) 2813 self.assertAllEqual(arr, [0., 12., 0., 12., 0.]) 2814 2815 @test_util.run_deprecated_v1 2816 def testWhileGrad_MultiResourceVarSparseRead(self): 2817 # NOTE(skyewm): this test is interesting because the 2818 # ResourceVariable.sparse_read gradient function returns an IndexedSlices. 2819 var1 = resource_variable_ops.ResourceVariable(np.ones(5), 2820 dtype=dtypes.float32) 2821 var2 = resource_variable_ops.ResourceVariable(np.ones(3), 2822 dtype=dtypes.float32) 2823 x1_init = constant_op.constant([0., 0.]) 2824 x2_init = constant_op.constant(1.) 2825 x3_init = constant_op.constant(1.) 2826 2827 def body(i, unused_x1, x2, x3): 2828 y1 = var1.sparse_read([1, 3]) 2829 y2 = x2 * 2 2830 y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0])) 2831 return i + 1, y1, y2, y3 2832 2833 r = control_flow_ops.while_loop( 2834 lambda i, x1, x2, x3: i < 3, body, 2835 [0, x1_init, x2_init, x3_init])[1:] 2836 var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2]) 2837 2838 self.evaluate(variables.global_variables_initializer()) 2839 var1_grad_val = self.evaluate(var1_grad) 2840 var2_grad_val = self.evaluate(var2_grad) 2841 self.assertIsInstance(var1_grad_val, ops.IndexedSlicesValue) 2842 self.assertIsInstance(var2_grad_val, ops.IndexedSlicesValue) 2843 self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val), 2844 [0., 1., 0., 1., 0.]) 2845 self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val), 2846 [3., 0., 0.]) 2847 2848 @test_util.run_deprecated_v1 2849 def testWhileGrad_Gather(self): 2850 # NOTE(skyewm): this test is interesting because the gather gradient 2851 # function returns an IndexedSlices. 2852 x = constant_op.constant([1., 1., 1., 1., 1.]) 2853 y = control_flow_ops.while_loop( 2854 lambda i, _: i < 3, 2855 lambda i, x: (i + 1, x + array_ops.gather(x, [0])), 2856 [0, x[:1]])[1] 2857 z = y * 3.0 2858 grad = gradients_impl.gradients(z, x)[0] 2859 self.assertEqual(self.evaluate(y), 8.) 2860 self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.]) 2861 2862 @test_util.run_deprecated_v1 2863 def testWhileGrad_GatherNoFanOut(self): 2864 # NOTE(skyewm): this test is interesting because the gather gradient 2865 # function returns an IndexedSlices. 2866 x = constant_op.constant([1., 1., 1., 1., 1.]) 2867 y = control_flow_ops.while_loop( 2868 lambda i, _: i < 3, 2869 lambda i, x: (i + 1, array_ops.gather(x, [0])), 2870 [0, x[:1]])[1] 2871 z = y * 3.0 2872 grad = gradients_impl.gradients(z, x)[0] 2873 self.assertEqual(self.evaluate(y), 1.) 2874 self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.]) 2875 2876 @test_util.run_v1_only("b/120545219") 2877 def testWhileGradInCond(self): 2878 2879 with self.cached_session(): 2880 n = ops.convert_to_tensor(1.0, name="n") 2881 x = array_ops.placeholder(dtypes.float32, shape=None) 2882 c = lambda n: math_ops.less(n, 10.0) 2883 b = lambda n: math_ops.add(n, x) 2884 2885 def fn1(): 2886 r = control_flow_ops.while_loop(c, b, [n], 2887 [tensor_shape.unknown_shape()]) 2888 return gradients_impl.gradients(r, x)[0] 2889 2890 r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x) 2891 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 2892 2893 @test_util.disable_control_flow_v2("b/116340060") 2894 @test_util.run_v1_only("b/120545219") 2895 def testGradInWhileWrtInitialLoopVal(self): 2896 with self.cached_session(): 2897 x = array_ops.placeholder(dtypes.float32, shape=(), name="x") 2898 y = x + 1 2899 2900 def body(i, v): 2901 z = v * 2 2902 return i + 1, gradients_impl.gradients(z, x)[0] 2903 2904 with self.assertRaisesRegexp( 2905 ValueError, 2906 "Cannot compute gradient inside while loop with respect to op 'x'. " 2907 "We do not support taking the gradient wrt or through the initial " 2908 "value of a loop variable. Gradients can be computed through " 2909 "loop invariants or wrt the input parameters to the loop body."): 2910 control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y]) 2911 2912 @test_util.run_v1_only("b/120545219") 2913 def testWhileGradInWhile(self): 2914 with self.cached_session(): 2915 n = ops.convert_to_tensor(1.0, name="n") 2916 x = array_ops.placeholder(dtypes.float32, shape=None) 2917 c = lambda n: math_ops.less(n, 10.0) 2918 b = lambda n: math_ops.add(n, x) 2919 2920 def b1(n): 2921 r = control_flow_ops.while_loop(c, b, [n], 2922 [tensor_shape.unknown_shape()]) 2923 return gradients_impl.gradients(r, x) 2924 2925 r = control_flow_ops.while_loop(lambda n: n < 6.0, b1, [n], 2926 [tensor_shape.unknown_shape()]) 2927 self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) 2928 2929 @test_util.run_v1_only("b/120545219") 2930 def testCondGradInNestedWhiles(self): 2931 2932 def outer_body(i, x): 2933 _, x = control_flow_ops.while_loop( 2934 lambda j, x: j < 3, inner_body, [0, 0.0]) 2935 return i + 1, x 2936 2937 def inner_body(j, x): 2938 y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x) 2939 return j + 1, gradients_impl.gradients(y, x)[0] 2940 2941 i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0]) 2942 2943 with self.cached_session() as sess: 2944 i_val, x_val = self.evaluate([i, x]) 2945 self.assertEqual(i_val, 3) 2946 self.assertAllClose(x_val, 1.0) 2947 2948 @test_util.run_gpu_only 2949 def testGpuResourceAccess(self): 2950 with ops.device(test.gpu_device_name()): 2951 var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0)) 2952 2953 @def_function.function 2954 def foo(): 2955 return control_flow_ops.while_loop( 2956 lambda i, _: i < 3, 2957 lambda i, x: (i + 1, control_flow_ops.cond( 2958 constant_op.constant(True), 2959 lambda: x + var, 2960 lambda: x)), 2961 [0, 0.0])[1] 2962 2963 self.evaluate(variables.global_variables_initializer()) 2964 self.assertEqual(self.evaluate(foo()), 9.0) 2965 2966 @test_util.disable_xla("b/128643398") 2967 def testNestedResourceAccess(self): 2968 var = resource_variable_ops.ResourceVariable(constant_op.constant(3.0)) 2969 2970 @eager_function.defun 2971 def test_fn(): 2972 x = constant_op.constant(0.0) 2973 r = control_flow_ops.while_loop( 2974 # Outer loop condition 2975 lambda i, y: i < 2, 2976 # Outer loop body 2977 lambda i, y: (i + 1, y + control_flow_ops.cond( 2978 constant_op.constant(True), 2979 # True branch 2980 lambda: control_flow_ops.while_loop( 2981 # Inner loop condition 2982 lambda j, z: j < 3, 2983 # Inner loop body 2984 lambda j, z: (j + 1, z + math_ops.square(var)), 2985 # Inner initial loop value 2986 [0, y])[1], 2987 # False branch 2988 lambda: (0.0))), 2989 # Outer initial loop value 2990 [0, x])[1] 2991 2992 grad = gradients_impl.gradients(r, x)[0] 2993 return r, grad 2994 2995 self.evaluate(variables.global_variables_initializer()) 2996 r, grad = self.evaluate(test_fn()) 2997 # 2 * 3 * 3^2 2998 self.assertEqual(r, 81.0) 2999 # v1 control flow gets the wrong answer!!! 3000 # Gradient computation: 3001 # f(x) = x + 3^2 3002 # inner_loop(x) = f(f(f(x))) = x + 3*3^2 = x + 27 3003 # g(x) = x + inner_loop(x) = 2x + 27 3004 # outer_loop(x) = g(g(x)) = 4x + 81 3005 # outer_loop'(x) = 4 3006 # Note that v1 control flow gets 4.0 as well if the cond is removed. 3007 if control_flow_util.ENABLE_CONTROL_FLOW_V2: 3008 self.assertEqual(grad, 4.0) 3009 3010 def testWhile_NestedInput(self): 3011 with self.cached_session() as sess: 3012 named = collections.namedtuple("named", ("a", "b")) 3013 loop_vars = [ 3014 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 3015 (constant_op.constant(2.0), constant_op.constant(3.0)), 3016 constant_op.constant(4.0) 3017 ] 3018 c = lambda lv0, _1, _2: lv0.a < 100.0 3019 3020 def b(lv0, lv1, lv2): 3021 lv0 = named(a=lv0.a + 1, b=lv0.b) 3022 lv1 = (lv1[0] + 1, lv1[1]) 3023 lv2 += 2 3024 return [lv0, lv1, lv2] 3025 3026 r = control_flow_ops.while_loop(c, b, loop_vars) 3027 3028 self.assertTrue(isinstance(r, list)) 3029 self.assertTrue(isinstance(r[0], named)) 3030 self.assertTrue(isinstance(r[1], tuple)) 3031 self.assertTrue(isinstance(r[2], ops.Tensor)) 3032 3033 r_flattened = nest.flatten(r) 3034 self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0], 3035 self.evaluate(r_flattened)) 3036 3037 @test_util.run_v1_only("b/120545219") 3038 def testWhile_NestedBadArityFails(self): 3039 with self.cached_session(): 3040 named = collections.namedtuple("named", ("a", "b")) 3041 loop_vars = [ 3042 named(a=constant_op.constant(0.0), b=constant_op.constant(1.0)), 3043 (constant_op.constant(2.0), constant_op.constant(3.0)), 3044 constant_op.constant(4.0) 3045 ] 3046 c = lambda lv0, _1, _2: lv0.a < 100.0 3047 3048 def b(lv0, lv1, _): 3049 return [lv0, lv1] 3050 3051 with self.assertRaisesRegexp(ValueError, "the same number of elements"): 3052 control_flow_ops.while_loop(c, b, loop_vars) 3053 3054 @test_util.run_v1_only("b/120545219") 3055 def testWhileGrad_ys_xs(self): 3056 with self.cached_session(): 3057 x = constant_op.constant(3.0, name="x") 3058 y = constant_op.constant(2.0, name="y") 3059 3060 c = lambda x, y: math_ops.less(x, 100.0) 3061 3062 def b(x, y): 3063 y1 = math_ops.add(x, y) 3064 x1 = math_ops.multiply(x, y1) 3065 return x1, y1 3066 3067 rx, ry = control_flow_ops.while_loop(c, b, [x, y], parallel_iterations=1) 3068 3069 r = gradients_impl.gradients([rx, ry], x) 3070 self.assertAllClose(304.0, r[0]) 3071 r = gradients_impl.gradients([rx, ry], y) 3072 self.assertAllClose(124.0, r[0]) 3073 r = gradients_impl.gradients([rx], x) 3074 self.assertAllClose(295.0, r[0]) 3075 r = gradients_impl.gradients([rx], y) 3076 self.assertAllClose(120.0, r[0]) 3077 3078 @test_util.run_deprecated_v1 3079 def testWhileGrad_Dependency(self): 3080 with self.cached_session(): 3081 i = constant_op.constant(0, name="i") 3082 x = constant_op.constant(2.0, name="x") 3083 3084 c = lambda i, x: math_ops.less(i, 10) 3085 3086 def b(i, x): 3087 x = math_ops.multiply(x, 2.0) 3088 i = math_ops.add(i, 1) 3089 return i, x 3090 3091 ri, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3092 3093 r = gradients_impl.gradients([ri, rx], x) 3094 self.assertAllClose(1024.0, r[0]) 3095 r = gradients_impl.gradients([rx], x) 3096 self.assertAllClose(1024.0, r[0]) 3097 3098 @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)") 3099 @test_util.run_v1_only("b/120545219") 3100 def testWhileGrad_NoGradient(self): 3101 with self.cached_session(): 3102 v = constant_op.constant(2.0, name="v") 3103 c = lambda v: math_ops.less(v, 100.0) 3104 b = math_ops.square 3105 r = control_flow_ops.while_loop(c, b, [v], back_prop=False) 3106 r = math_ops.add(r, v) 3107 r = gradients_impl.gradients(r, v) 3108 self.assertAllClose(1.0, r[0]) 3109 3110 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3111 @test_util.run_v1_only("b/120545219") 3112 def testWhileGrad_NoDependency(self): 3113 with self.cached_session() as sess: 3114 variable = variables.Variable(array_ops.ones([2, 3])) 3115 duration = array_ops.zeros([], dtype=dtypes.int32) 3116 3117 def cond(duration, tensor, _): 3118 del tensor 3119 return duration < 10 3120 3121 def body(duration, tensor, _): 3122 return (duration + 1, tensor, tensor) 3123 3124 loop_vars = [duration, variable, variable] 3125 tensors = control_flow_ops.while_loop( 3126 cond=cond, body=body, loop_vars=loop_vars) 3127 cost = math_ops.reduce_sum(tensors[2]) 3128 grad = gradients_impl.gradients(cost, [variable]) 3129 self.evaluate(variables.global_variables_initializer()) 3130 self.assertAllClose(np.ones([2, 3]), sess.run(grad[0])) 3131 3132 @test_util.run_deprecated_v1 3133 def testWhileGrad_Const(self): 3134 with self.cached_session() as sess: 3135 c0 = constant_op.constant(0.0, name="c0") 3136 c1 = constant_op.constant(1.0, name="c1") 3137 duration = constant_op.constant(0, name="t") 3138 3139 def cond(duration, _): 3140 return duration < 1 3141 3142 def body(duration, _): 3143 return duration + 1, c1 3144 3145 loop_vars = [duration, c0] 3146 tensors = control_flow_ops.while_loop( 3147 cond=cond, body=body, loop_vars=loop_vars) 3148 cost = math_ops.reduce_sum(tensors[1]) 3149 grad = gradients_impl.gradients(cost, [c0]) 3150 self.assertAllClose(0.0, sess.run(grad[0])) 3151 3152 @test_util.run_v1_only("b/120545219") 3153 def testWhileGrad_SerialTwoLoops(self): 3154 with self.cached_session(): 3155 i = constant_op.constant(0, name="i") 3156 x = constant_op.constant(2.0, name="x") 3157 3158 c = lambda i, x: math_ops.less(i, 5) 3159 3160 def b(i, x): 3161 x = math_ops.multiply(x, 2.0) 3162 i = math_ops.add(i, 1) 3163 return i, x 3164 3165 _, rx = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3166 _, rx = control_flow_ops.while_loop(c, b, [i, rx], parallel_iterations=1) 3167 3168 r = gradients_impl.gradients([rx], x) 3169 self.assertAllClose(1024.0, r[0]) 3170 3171 @test_util.run_v1_only("b/120545219") 3172 def testWhileGrad_ParallelTwoLoops(self): 3173 with self.cached_session(): 3174 i = constant_op.constant(0, name="i") 3175 x = constant_op.constant(2.0, name="x") 3176 3177 c = lambda i, x: math_ops.less(i, 5) 3178 3179 def b(i, x): 3180 x = math_ops.multiply(x, 2.0) 3181 i = math_ops.add(i, 1) 3182 return i, x 3183 3184 _, r1 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3185 _, r2 = control_flow_ops.while_loop(c, b, [i, x], parallel_iterations=1) 3186 rx = math_ops.add(r1, r2) 3187 3188 r = gradients_impl.gradients([rx], x) 3189 self.assertAllClose(64.0, r[0]) 3190 3191 @test_util.run_v1_only("b/120545219") 3192 def testWhileGrad_OneOutputWithControlDependencyOnSecond(self): 3193 with self.cached_session(): 3194 i = constant_op.constant(0, name="i") 3195 x = constant_op.constant(1.0, name="x") 3196 y = constant_op.constant(1.0, name="y") 3197 c = lambda i, *_: math_ops.less(i, 1, name="cond_less") 3198 3199 def b(i, xi, yi): 3200 # return (i + 1, xi, xi + yi) 3201 return (math_ops.add(i, 1, name="inc"), array_ops.identity( 3202 xi, name="xi"), math_ops.add(xi, yi, name="xi_plus_yi")) 3203 3204 _, x_f, y_f = control_flow_ops.while_loop(c, b, [i, x, y]) 3205 with ops.control_dependencies([x_f]): 3206 y_f_d = array_ops.identity(y_f, name="y_f_d") 3207 3208 self.assertAllClose(2.0, self.evaluate(y_f_d)) # y_f_d = 1.0 + 1.0 3209 g = gradients_impl.gradients([y_f_d], [x])[0] 3210 self.assertTrue(g is not None) 3211 self.assertAllClose(1.0, 3212 self.evaluate(g)) # y_f_d = x + 1.0, dy_f_d/dx = 1.0 3213 3214 def _testNestedWhileGrad_Simple(self, use_gpu): 3215 with self.cached_session(use_gpu=use_gpu): 3216 v = constant_op.constant(1.0) 3217 3218 def inner_loop(s): 3219 c = lambda x: math_ops.less(x, 4.0) 3220 b = lambda x: math_ops.multiply(x, 2.0) 3221 return control_flow_ops.while_loop(c, b, [s]) 3222 3223 c = lambda x: math_ops.less(x, 2.0) 3224 b = lambda x: math_ops.multiply(inner_loop(x), 2.0) 3225 r = control_flow_ops.while_loop(c, b, [v]) 3226 3227 r = gradients_impl.gradients(r, v)[0] 3228 self.assertAllClose(8.0, self.evaluate(r)) 3229 3230 @test_util.run_deprecated_v1 3231 def testNestedWhileGrad_Simple(self): 3232 self._testNestedWhileGrad_Simple(use_gpu=False) 3233 self._testNestedWhileGrad_Simple(use_gpu=True) 3234 3235 @test_util.run_v1_only("b/120545219") 3236 def testNestedWhileGrad_SerialInner(self): 3237 with self.cached_session(): 3238 v = constant_op.constant(1.0) 3239 3240 def inner_loop1(s): 3241 z = constant_op.constant(0) 3242 c = lambda i, x: math_ops.less(i, 4) 3243 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3244 return control_flow_ops.while_loop(c, b, [z, s]) 3245 3246 def inner_loop2(s): 3247 z = constant_op.constant(0) 3248 c = lambda i, x: math_ops.less(i, 4) 3249 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3250 return control_flow_ops.while_loop(c, b, [z, s]) 3251 3252 c = lambda x: math_ops.less(x, 128.0) 3253 b = lambda x: inner_loop2(inner_loop1(x)[1])[1] 3254 r = control_flow_ops.while_loop(c, b, [v]) 3255 3256 r = gradients_impl.gradients(r, v)[0] 3257 self.assertAllClose(256.0, self.evaluate(r)) 3258 3259 @test_util.run_deprecated_v1 3260 def testNestedWhileGrad_ParallelInner(self): 3261 with self.cached_session(): 3262 v = constant_op.constant(1.0) 3263 3264 def inner_loop1(s): 3265 z = constant_op.constant(0) 3266 c = lambda i, x: math_ops.less(i, 4) 3267 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3268 return control_flow_ops.while_loop(c, b, [z, s]) 3269 3270 def inner_loop2(s): 3271 z = constant_op.constant(0) 3272 c = lambda i, x: math_ops.less(i, 4) 3273 b = lambda i, x: [math_ops.add(i, 1), math_ops.multiply(x, 2.0)] 3274 return control_flow_ops.while_loop(c, b, [z, s]) 3275 3276 c = lambda x: math_ops.less(x, 128.0) 3277 b = lambda x: math_ops.multiply(inner_loop1(x)[1], inner_loop2(x)[1]) 3278 r = control_flow_ops.while_loop(c, b, [v]) 3279 3280 r = gradients_impl.gradients(r, v)[0] 3281 self.assertAllClose(512.0, self.evaluate(r)) 3282 3283 @test_util.run_v1_only("b/120545219") 3284 def testNestedWhileGrad_ParallelIterations(self): 3285 # Make sure the stack pushes and pops of an inner loop are executed in 3286 # the sequential order of the iterations of its outer loop. 3287 with self.cached_session() as sess: 3288 3289 def inner_loop(t): 3290 fn = lambda n: n + math_ops.square(var) 3291 return map_fn.map_fn(fn=fn, elems=t, parallel_iterations=10) 3292 3293 def outer_loop(inp): 3294 return map_fn.map_fn( 3295 fn=inner_loop, elems=inp, parallel_iterations=10) 3296 3297 var = variables.Variable(constant_op.constant(3.0)) 3298 inp = constant_op.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) 3299 res = outer_loop(inp) 3300 optimizer = adam.AdamOptimizer(learning_rate=0.001) 3301 train_op = optimizer.minimize(math_ops.reduce_mean(math_ops.square(res))) 3302 self.evaluate(variables.global_variables_initializer()) 3303 self.evaluate(train_op) 3304 self.assertAllClose(2.999, var.read_value()) 3305 3306 def _testWhileCondGrad_Simple(self, use_gpu): 3307 with self.cached_session(use_gpu=use_gpu): 3308 v = ops.convert_to_tensor(2.0, name="v") 3309 n = ops.convert_to_tensor(100.0, name="n") 3310 one = ops.convert_to_tensor(1.0, name="one") 3311 c = lambda x: math_ops.less(x, n) 3312 # pylint: disable=undefined-variable 3313 # for OSS build 3314 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 3315 lambda: math_ops.square(x), 3316 lambda: math_ops.subtract(x, one)) 3317 # pylint: enable=undefined-variable 3318 r = control_flow_ops.while_loop(c, b, [v]) 3319 r = gradients_impl.gradients(r, v)[0] 3320 self.assertAllClose(1024.0, self.evaluate(r)) 3321 3322 @test_util.run_deprecated_v1 3323 def testWhileCondGrad_Simple(self): 3324 self._testWhileCondGrad_Simple(use_gpu=False) 3325 self._testWhileCondGrad_Simple(use_gpu=True) 3326 3327 @test_util.run_deprecated_v1 3328 def testWhileCondGrad_UnknownShape(self): 3329 with self.cached_session() as sess: 3330 v = array_ops.placeholder(dtypes.float32) 3331 n = ops.convert_to_tensor(100.0, name="n") 3332 one = ops.convert_to_tensor(1.0, name="one") 3333 c = lambda x: math_ops.less(x, n) 3334 # pylint: disable=undefined-variable 3335 # for OSS build 3336 b = lambda x: control_flow_ops.cond(constant_op.constant(True), 3337 lambda: math_ops.square(x), 3338 lambda: math_ops.subtract(x, one)) 3339 # pylint: enable=undefined-variable 3340 r = control_flow_ops.while_loop(c, b, [v]) 3341 r = gradients_impl.gradients(r, v)[0] 3342 r = sess.run(r, feed_dict={v: 2.0}) 3343 self.assertAllClose(1024.0, r) 3344 3345 @test_util.run_deprecated_v1 3346 def testWhileGrad_Concat(self): 3347 with self.cached_session() as sess: 3348 x = variable_scope.get_variable("x", initializer=[[1., 2.]]) 3349 i0 = constant_op.constant(0) 3350 h0 = array_ops.zeros([0, 2]) 3351 3352 def condition(i, _): 3353 return i < 2 3354 3355 def body(i, h): 3356 return i + 1, array_ops.concat([h, x], 0) 3357 3358 _, h = control_flow_ops.while_loop( 3359 condition, body, [i0, h0], 3360 [i0.get_shape(), tensor_shape.TensorShape([None, 2])]) 3361 s = math_ops.reduce_sum(h) 3362 3363 optimizer = gradient_descent.GradientDescentOptimizer(0.01) 3364 op = optimizer.minimize(s) 3365 3366 self.evaluate(variables.global_variables_initializer()) 3367 self.evaluate(op) 3368 self.assertAllClose([[0.98000002, 1.98000002]], self.evaluate(x)) 3369 3370 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3371 @test_util.run_v1_only("b/120545219") 3372 def testWhileWithRefsWithGradients_1(self): 3373 with self.cached_session() as sess: 3374 x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access 3375 i = constant_op.constant(0) 3376 c = lambda i, x: math_ops.less(i, 10) 3377 3378 self.assertEqual(x.dtype, dtypes.float32_ref) 3379 3380 def body(i, x): 3381 self.assertEqual(x.dtype, dtypes.float32_ref) 3382 return [i + 1, gen_array_ops.ref_identity(x)] 3383 3384 r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5) 3385 3386 grad_ys = [variables.VariableV1(73)._ref()] # pylint: disable=protected-access 3387 grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys) 3388 3389 self.evaluate(variables.global_variables_initializer()) 3390 3391 self.assertEqual(r[0].dtype, dtypes.int32) 3392 self.assertEqual(r[1].dtype, dtypes.float32_ref) 3393 3394 value_i, value_x, value_x_grad = sess.run(r + grad) 3395 3396 self.assertEqual(10, value_i) 3397 self.assertEqual(0, value_x) 3398 self.assertEqual(73, value_x_grad) 3399 3400 @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)") 3401 @test_util.run_v1_only("b/120545219") 3402 def testWhileGrad_IndexedSlices(self): 3403 with self.cached_session(): 3404 values = constant_op.constant([2.0, 4.0], name="values") 3405 indices = constant_op.constant([0, 3], name="indices") 3406 shape = constant_op.constant([10], name="dense_shape") 3407 i = constant_op.constant(0) 3408 x = ops.IndexedSlices(values, indices, dense_shape=shape) 3409 3410 def c(i, _): 3411 return i < 10 3412 3413 def b(i, x): 3414 return [ 3415 i + 1, 3416 ops.IndexedSlices(x.values * 2.0, x.indices, x.dense_shape) 3417 ] 3418 3419 _, r = control_flow_ops.while_loop(c, b, [i, x]) 3420 r = gradients_impl.gradients(r.values, values)[0] 3421 self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) 3422 3423 @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)") 3424 @test_util.run_v1_only("b/120545219") 3425 def testWhileGrad_SparseTensor(self): 3426 with self.cached_session(): 3427 values = constant_op.constant([2.0, 4.0], name="values") 3428 indices = constant_op.constant( 3429 [[0], [3]], dtype=dtypes.int64, name="indices") 3430 shape = constant_op.constant([10], dtype=dtypes.int64, name="dense_shape") 3431 i = constant_op.constant(0) 3432 x = sparse_tensor.SparseTensor(indices, values, dense_shape=shape) 3433 3434 def c(i, _): 3435 return i < 10 3436 3437 def b(i, x): 3438 return [ 3439 i + 1, 3440 sparse_tensor.SparseTensor(x.indices, x.values * 2.0, x.dense_shape) 3441 ] 3442 3443 _, r = control_flow_ops.while_loop(c, b, [i, x]) 3444 r = gradients_impl.gradients(r.values, values)[0] 3445 self.assertAllClose(np.array([1024.0, 1024.0]), self.evaluate(r)) 3446 3447 @test_util.run_v1_only("b/120545219") 3448 def testCallGradInLoop(self): 3449 with self.cached_session() as sess: 3450 i0 = constant_op.constant(0) 3451 params = constant_op.constant(5.0) 3452 params_1 = math_ops.square(params) 3453 3454 def c(i, _): 3455 return i < 10 3456 3457 def b(i, x): 3458 data = constant_op.constant([1.0, 2.0, 3.0]) 3459 data = math_ops.multiply(data, params_1) 3460 x1 = x + gradients_impl.gradients(data, params)[0] 3461 return i + 1, x1 3462 3463 output_grad = control_flow_ops.while_loop( 3464 c, b, [i0, constant_op.constant(0.0)]) 3465 self.assertAllClose(600.0, self.evaluate(output_grad)[1]) 3466 3467 @test_util.run_deprecated_v1 3468 def testWhileAndTensorArray(self): 3469 with self.cached_session() as sess: 3470 param = constant_op.constant(2.0) 3471 n0 = constant_op.constant(0) 3472 y0 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="elems") 3473 3474 def c(i, _): 3475 return i < 10 3476 3477 def b(i, y): 3478 return [ 3479 i + 1, 3480 map_fn.map_fn(lambda x: math_ops.multiply(x, param), y) 3481 ] 3482 3483 r = control_flow_ops.while_loop(c, b, [n0, y0], parallel_iterations=1) 3484 r = gradients_impl.gradients(r, param)[0] 3485 self.assertAllClose(107520.0, self.evaluate(r)) 3486 3487 @test_util.run_deprecated_v1 3488 def testNestedWhileAndTensorArray(self): 3489 n = constant_op.constant(3.0) 3490 3491 def Body(row, ta): 3492 3493 def InnerBody(row, col, ta): 3494 # Note: row and col are 1-based. 3495 ta = ta.write( 3496 math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col) 3497 return row, col + 1., ta 3498 3499 ta = control_flow_ops.while_loop( 3500 lambda _, col, _1: col <= n, 3501 InnerBody, [row, constant_op.constant(1.), ta], 3502 return_same_structure=False)[2] 3503 return row + 1., ta 3504 3505 ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9) 3506 ta = control_flow_ops.while_loop( 3507 lambda row, _: row <= n, 3508 Body, [constant_op.constant(1.), ta], 3509 return_same_structure=False)[1] 3510 3511 output = array_ops.reshape(ta.stack(), [3, 3]) 3512 self.assertAllEqual( 3513 self.evaluate(output), [[1., 2., 3.], [2., 4., 6.], [3., 6., 9.]]) 3514 # TODO(b/117675481): This does not work with current TA. Enable with new TA. 3515 # grad = gradients_impl.gradients(output, [n]) 3516 # self.assertEqual(self.evaluate(grad), 3.5) 3517 3518 @test_util.run_deprecated_v1 3519 def testWhileGrad_StopGrad(self): 3520 with self.cached_session(): 3521 x = constant_op.constant(3.0, name="x") 3522 y = constant_op.constant(2.0, name="y") 3523 3524 c = lambda x, y: math_ops.less(x, 100.0) 3525 3526 def b(x, y): 3527 y1 = math_ops.square(y) 3528 x1 = math_ops.add(math_ops.square(x), y1) 3529 return x1, y1 3530 3531 rx, ry = control_flow_ops.while_loop(c, b, [x, y]) 3532 3533 r = gradients_impl.gradients(rx, y)[0] 3534 self.assertEqual(136.0, self.evaluate(r)) 3535 r = gradients_impl.gradients(ry, y)[0] 3536 self.assertEqual(32.0, self.evaluate(r)) 3537 3538 r = gradients_impl.gradients(array_ops.stop_gradient(rx), y)[0] 3539 self.assertEqual(r, None) 3540 r = gradients_impl.gradients(array_ops.stop_gradient(ry), y)[0] 3541 self.assertEqual(r, None) 3542 3543 r = gradients_impl.gradients( 3544 array_ops.stop_gradient(math_ops.square(rx)), y)[0] 3545 self.assertEqual(r, None) 3546 r = gradients_impl.gradients( 3547 array_ops.stop_gradient(math_ops.add(rx, ry)), x)[0] 3548 self.assertEqual(r, None) 3549 r = gradients_impl.gradients( 3550 array_ops.stop_gradient(math_ops.add(rx, ry)), y)[0] 3551 self.assertEqual(r, None) 3552 3553 r = gradients_impl.gradients(math_ops.add(rx, ry), y)[0] 3554 self.assertEqual(168.0, self.evaluate(r)) 3555 r = gradients_impl.gradients( 3556 math_ops.add(rx, array_ops.stop_gradient(ry)), y)[0] 3557 self.assertEqual(136.0, self.evaluate(r)) 3558 r = gradients_impl.gradients( 3559 math_ops.add(array_ops.stop_gradient(rx), ry), y)[0] 3560 self.assertEqual(32.0, self.evaluate(r)) 3561 3562 @test_util.run_deprecated_v1 3563 @test_util.disable_control_flow_v2("b/118712257") 3564 def testWhileGrad_StopGradInside(self): 3565 with self.cached_session(): 3566 x = constant_op.constant(3.0, name="x") 3567 y = constant_op.constant(2.0, name="y") 3568 3569 c = lambda x, y: math_ops.less(x, 100.0) 3570 3571 def b(x, y): 3572 y1 = array_ops.stop_gradient(math_ops.square(y)) 3573 x1 = math_ops.add(math_ops.square(x), y1) 3574 return x1, y1 3575 3576 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 3577 3578 r = gradients_impl.gradients(rx, y)[0] 3579 self.assertAllClose(0.0, self.evaluate(r)) 3580 r = gradients_impl.gradients(rx, x)[0] 3581 self.assertAllClose(156.0, self.evaluate(r)) 3582 3583 @test_util.run_deprecated_v1 3584 @test_util.disable_control_flow_v2("b/118712257") 3585 def testWhileGrad_StopGradInsideNoShape(self): 3586 with self.cached_session() as sess: 3587 x = array_ops.placeholder(dtypes.float32) 3588 y = array_ops.placeholder(dtypes.float32) 3589 3590 c = lambda x, y: math_ops.less(math_ops.reduce_sum(x), 100.0) 3591 3592 def b(x, y): 3593 y1 = array_ops.stop_gradient(math_ops.square(y, name="stopped")) 3594 x1 = math_ops.add(math_ops.square(x), y1) 3595 return x1, y1 3596 3597 rx, _ = control_flow_ops.while_loop(c, b, [x, y]) 3598 3599 r = gradients_impl.gradients(rx, y)[0] 3600 feed_dict = {x: [3.0, 4.0], y: [2.0, 3.0]} 3601 self.assertAllClose([0.0, 0.0], sess.run(r, feed_dict=feed_dict)) 3602 r = gradients_impl.gradients(rx, x)[0] 3603 self.assertAllClose([156.0, 400.0], sess.run(r, feed_dict=feed_dict)) 3604 name = "gradients/while/stopped_grad" 3605 all_ops = x.graph.get_operations() 3606 self.assertFalse(any(name in op.name for op in all_ops)) 3607 3608 @test_util.run_deprecated_v1 3609 def testWhileGradGradFail(self): 3610 theta = variables.Variable(initial_value=1.) 3611 3612 def fn(prev, x): 3613 return prev + x * theta 3614 3615 result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32)) 3616 grad_theta = gradients_impl.gradients(result, theta) 3617 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 3618 with self.assertRaisesRegexp(TypeError, "Second-order gradient"): 3619 gradients_impl.gradients(grad_theta, theta) 3620 grad_theta_stopped = array_ops.stop_gradient(grad_theta) 3621 gradients_impl.gradients(grad_theta_stopped, theta) 3622 3623 @test_util.run_deprecated_v1 3624 def testStopGradOnWhileGrad(self): 3625 with self.cached_session(): 3626 x = constant_op.constant(2.0, name="x") 3627 y = constant_op.constant(2.0, name="y") 3628 3629 c = lambda x: math_ops.less(x, 100.0) 3630 b = lambda x: math_ops.multiply(x, y) 3631 rx = control_flow_ops.while_loop(c, b, [x]) 3632 3633 rg = gradients_impl.gradients(rx, y)[0] 3634 rg = array_ops.stop_gradient(rg) 3635 r = math_ops.add(math_ops.square(y), rx) 3636 r = math_ops.add(r, rg) 3637 r = gradients_impl.gradients(r, y)[0] 3638 self.assertEqual(388.0, self.evaluate(r)) 3639 3640 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3641 @test_util.run_deprecated_v1 3642 def testWhileGradientWithNontrainablePath1(self): 3643 q = variables.Variable([7., 8.]) 3644 3645 def cond(_, y): 3646 del y 3647 return False 3648 3649 def body(x, _): 3650 return x, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) 3651 3652 _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) 3653 dy_dq, = gradients_impl.gradients(y, q) 3654 self.assertIsNotNone(dy_dq) 3655 with self.cached_session() as sess: 3656 self.evaluate(q.initializer) 3657 self.assertAllClose([0., 0.], self.evaluate(dy_dq)) 3658 3659 @test_util.disable_control_flow_v2("b/113324949 (RefVariable)") 3660 @test_util.run_v1_only("b/120545219") 3661 def testWhileGradientWithNontrainablePath2(self): 3662 q = variables.Variable([7., 8.]) 3663 3664 def cond(_, y): 3665 return math_ops.equal(y, 0.) 3666 3667 def body(x, _): 3668 zero = constant_op.constant(0, dtype=dtypes.int64) 3669 return zero, math_ops.cast(x, dtypes.float32) + math_ops.reduce_sum(q) 3670 3671 _, y = control_flow_ops.while_loop(cond, body, (math_ops.argmin(q), 0.)) 3672 dy_dq, = gradients_impl.gradients(y, q) 3673 self.assertIsNotNone(dy_dq) 3674 with self.cached_session() as sess: 3675 self.evaluate(q.initializer) 3676 self.assertAllClose([1., 1.], self.evaluate(dy_dq)) 3677 3678 @test_util.run_v1_only("b/120545219") 3679 def testIssue16504(self): 3680 c = constant_op.constant(np.arange(100), dtype=dtypes.float32) 3681 w = variables.Variable( 3682 initial_value=np.ones(100), dtype=dtypes.float32) / 100 3683 k = variables.Variable(0, dtype=dtypes.int32) 3684 chg_w = constant_op.constant(np.inf, dtype=dtypes.float32) 3685 3686 def cond(k, _, chg_w): 3687 return math_ops.logical_and(k < 10, chg_w > 1e-3) 3688 3689 def body(k, w, chg_w): 3690 grad, = gradients_impl.gradients(-math_ops.reduce_sum(w * c), w) 3691 w_n = w * math_ops.exp(-0.1 * grad) 3692 w_n /= math_ops.reduce_sum(w_n) 3693 chg_w = ( 3694 math_ops.reduce_sum(math_ops.abs(w_n - w)) / math_ops.reduce_sum( 3695 math_ops.abs(w))) 3696 return k + 1, w_n, chg_w 3697 3698 _, w, _ = control_flow_ops.while_loop(cond, body, [k, w, chg_w]) 3699 grad, = gradients_impl.gradients(w, c) 3700 self.assertIsNotNone(grad) 3701 3702 @test_util.run_v1_only("b/120545219") 3703 def testStopGradMultiFlows(self): 3704 with self.cached_session(): 3705 3706 def body(i, y, r): 3707 x = variable_scope.get_variable( 3708 "x", 3709 shape=(), 3710 dtype=dtypes.float32, 3711 initializer=init_ops.ones_initializer()) 3712 y *= x 3713 return [i + 1, y, r + math_ops.reduce_sum(y)] 3714 3715 i0 = constant_op.constant(0) 3716 y0 = array_ops.ones(5) 3717 r0 = constant_op.constant(0.0) 3718 cond = lambda i, y, r: i < 1 3719 _, _, r = control_flow_ops.while_loop( 3720 cond, body, [i0, y0, r0], back_prop=True) 3721 3722 vars_ = variables.global_variables() 3723 grads = linalg_ops.norm(gradients_impl.gradients(r, vars_)[0]) 3724 z = math_ops.add(r, array_ops.stop_gradient(math_ops.reduce_sum(grads))) 3725 result = gradients_impl.gradients(z, vars_)[0] 3726 self.evaluate(variables.global_variables_initializer()) 3727 self.assertEqual(5.0, self.evaluate(result)) 3728 3729 @test_util.run_v1_only("b/120545219") 3730 def testOneValueCond(self): 3731 3732 with self.cached_session(): 3733 c = array_ops.placeholder(dtypes.int32, shape=[]) 3734 one = ops.convert_to_tensor(1, name="one") 3735 two = ops.convert_to_tensor(2, name="two") 3736 p = math_ops.greater_equal(c, 1) 3737 i = control_flow_ops.cond(p, lambda: one, lambda: two) 3738 self.assertTrue(isinstance(i, ops.Tensor)) 3739 3740 # True case: c = 2 is >= 1 3741 self.assertEqual([1], i.eval(feed_dict={c: 2})) 3742 3743 # False case: c = 0 is not >= 1 3744 self.assertEqual([2], i.eval(feed_dict={c: 0})) 3745 3746 @test_util.run_deprecated_v1 3747 def testExampleCond(self): 3748 3749 with self.cached_session(): 3750 x = ops.convert_to_tensor([-2.0, 2.0], name="x") 3751 d = array_ops.placeholder(dtypes.int32, shape=[]) 3752 3753 def l2(): 3754 return math_ops.sqrt(math_ops.reduce_sum(math_ops.square(x))) 3755 3756 def l1(): 3757 return math_ops.reduce_sum(math_ops.abs(x)) 3758 3759 i = control_flow_ops.cond(math_ops.equal(d, 2), l2, l1) 3760 self.assertAllClose(4.0, i.eval(feed_dict={d: 1})) 3761 self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2})) 3762 3763 @test_util.run_v1_only("b/120545219") 3764 def testCase(self): 3765 with self.cached_session(): 3766 x = constant_op.constant(1) 3767 y = constant_op.constant(2) 3768 z = constant_op.constant(3) 3769 f1 = lambda: constant_op.constant(17) 3770 f2 = lambda: constant_op.constant(23) 3771 f3 = lambda: constant_op.constant(-1) 3772 3773 r1 = control_flow_ops.case( 3774 { 3775 x < y: f1, 3776 x > z: f2 3777 }, default=f3, exclusive=True) 3778 self.assertAllEqual(r1, 17) 3779 3780 r2 = control_flow_ops.case([(y > z, f1), (y > x, f2)], default=f3) 3781 self.assertAllEqual(r2, 23) 3782 3783 # Duplicate events can happen, first one is selected 3784 r3 = control_flow_ops.case([(x < y, f1), (x < y, f2)], default=f3) 3785 self.assertAllEqual(r3, 17) 3786 3787 # Duplicate events cause an error if exclusive = True 3788 r4 = control_flow_ops.case( 3789 [(x < y, f1), (x < y, f2)], default=f3, exclusive=True) 3790 with self.assertRaisesOpError("Input error:"): 3791 self.evaluate(r4) 3792 3793 # Check that the default is called if none of the others are 3794 r5 = control_flow_ops.case({x > y: f1}, default=f3) 3795 self.assertAllEqual(r5, -1) 3796 3797 ran_once = [False, False, False] 3798 3799 def break_run_twice(ix): 3800 3801 def _break(): 3802 ran_once[ix] = True 3803 return constant_op.constant(ix) 3804 3805 return _break 3806 3807 # Should not fail - each conditional gets called exactly once 3808 # except default. Default gets called twice: once to create an 3809 # empty output and once for the actual cond switch. 3810 r6 = control_flow_ops.case( 3811 [(x < y, break_run_twice(0)), (x > y, break_run_twice(1))], 3812 default=lambda: constant_op.constant(2)) 3813 3814 self.assertAllEqual(r6, 0) 3815 3816 @test_util.run_v1_only("b/120545219") 3817 def testCaseSideEffects(self): 3818 with self.cached_session() as sess: 3819 v0 = variables.Variable(-1) 3820 v1 = variables.Variable(-1) 3821 v2 = variables.Variable(-1) 3822 3823 a = lambda: control_flow_ops.with_dependencies([state_ops.assign(v0, 0)], 0) 3824 b = lambda: control_flow_ops.with_dependencies([state_ops.assign(v1, 1)], 1) 3825 c = lambda: control_flow_ops.with_dependencies([state_ops.assign(v2, 2)], 2) 3826 3827 x = constant_op.constant(1) 3828 y = constant_op.constant(2) 3829 3830 r0 = control_flow_ops.case( 3831 ((x < y, a), (x > y, b)), default=c, exclusive=True) 3832 r1 = control_flow_ops.case( 3833 ((x > y, a), (x < y, b)), default=c, exclusive=True) 3834 r2 = control_flow_ops.case( 3835 ((x > y, a), (x > y, b)), default=c, exclusive=True) 3836 3837 self.evaluate(variables.global_variables_initializer()) 3838 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 3839 self.assertEqual(2, self.evaluate(r2)) 3840 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, -1, 2]) 3841 3842 self.evaluate(variables.global_variables_initializer()) 3843 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 3844 self.assertEqual(1, self.evaluate(r1)) 3845 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1, 1, -1]) 3846 3847 self.evaluate(variables.global_variables_initializer()) 3848 self.assertAllEqual(self.evaluate([v0, v1, v2]), [-1] * 3) 3849 self.assertEqual(0, self.evaluate(r0)) 3850 self.assertAllEqual(self.evaluate([v0, v1, v2]), [0, -1, -1]) 3851 3852 @test_util.disable_control_flow_v2("b/113324949 (ref vars)") 3853 @test_util.run_v1_only("b/120545219") 3854 def testOneOpCond(self): 3855 with self.cached_session(): 3856 v = variables.Variable(0) 3857 c = ops.convert_to_tensor(0) 3858 one = ops.convert_to_tensor(1) 3859 two = ops.convert_to_tensor(2) 3860 p = math_ops.greater_equal(c, 1) 3861 3862 def a(): 3863 return state_ops.assign(v, one) 3864 3865 def b(): 3866 return state_ops.assign(v, two) 3867 3868 i = control_flow_ops.cond(p, a, b) 3869 self.assertTrue(isinstance(i, ops.Tensor)) 3870 self.evaluate(variables.global_variables_initializer()) 3871 3872 self.assertEqual(0, self.evaluate(v)) 3873 3874 # True case: c = 2 is >= 1, v is set to 1. 3875 self.assertEqual(1, i.eval(feed_dict={c.name: 2})) 3876 self.assertEqual(1, self.evaluate(v)) 3877 3878 # False case: c = 0 is not >= 1, v is set to 2. 3879 self.assertEqual(2, i.eval(feed_dict={c.name: 0})) 3880 self.assertEqual(2, self.evaluate(v)) 3881 3882 @test_util.run_v1_only("b/120545219") 3883 def testWithOpsDependencies(self): 3884 with self.cached_session() as sess: 3885 v = variables.VariableV1(0.0) 3886 c = constant_op.constant(10) 3887 3888 # Fetching v directly will result in an uninitialized error 3889 with self.assertRaisesOpError("Attempting to use uninitialized value"): 3890 self.evaluate([c, v]) 3891 3892 # Use a control dependency to ensure init_variable is run 3893 # while asking for c 3894 real_v = control_flow_ops.with_dependencies( 3895 name="real_tensor", 3896 output_tensor=v._ref(), # pylint: disable=protected-access 3897 dependencies=[v.initializer]) 3898 c_val, real_v_val = self.evaluate([c, real_v]) 3899 3900 # Ensure the result of 'real_c' is the same as 'c' 3901 self.assertAllEqual(10, c_val) 3902 3903 # Ensure that 'v' is initialized 3904 self.assertAllClose(0.0, real_v_val) 3905 3906 @test_util.run_v1_only("b/120545219") 3907 def testWithTensorDependencies(self): 3908 with self.cached_session(): 3909 v = variables.VariableV1(0.0) 3910 c1 = constant_op.constant(10) 3911 c2 = constant_op.constant(20) 3912 3913 # c1_with_init_v depends on the init op for v 3914 c1_with_init_v = control_flow_ops.with_dependencies( 3915 name="c1_with_init_v", output_tensor=c1, dependencies=[v.initializer]) 3916 # c2_with_c1 depends on the value of c1_with_init_v 3917 c2_with_c1_dep = control_flow_ops.with_dependencies( 3918 name="c2_with_c1_dep", 3919 output_tensor=c2, 3920 dependencies=[c1_with_init_v]) 3921 3922 # Fetching v directly will result in an uninitialized error 3923 with self.assertRaisesOpError("Attempting to use uninitialized value"): 3924 self.evaluate(v) 3925 3926 # Get the value of 'c2_with_c1_dep', which should cause 'v' 3927 # to be initialized. 3928 self.assertAllEqual(20, self.evaluate(c2_with_c1_dep)) 3929 3930 # Ensure that 'v' is initialized 3931 self.assertAllClose(0.0, self.evaluate(v)) 3932 3933 @test_util.run_v1_only("b/120545219") 3934 def testWithIndexedSlicesDependencies(self): 3935 with self.cached_session(): 3936 v = variables.VariableV1( 3937 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32)) 3938 v_at_1 = ops.IndexedSlices(v, constant_op.constant([1])) 3939 gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices) 3940 v_at_1_after_init = control_flow_ops.with_dependencies([v.initializer], 3941 v_at_1) 3942 gather_v_at_1_after_init = array_ops.gather(v_at_1_after_init.values, 3943 v_at_1_after_init.indices) 3944 3945 # Fetching gather_v_at_1 will result in an uninitialized error 3946 with self.assertRaisesOpError("Attempting to use uninitialized value"): 3947 self.evaluate(gather_v_at_1) 3948 3949 # Getting gather_v_at_1_after_init will work, and initialize v. 3950 self.assertAllEqual([[10.0, 11.0]], 3951 self.evaluate(gather_v_at_1_after_init)) 3952 3953 # Double check that 'v' is initialized 3954 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], 3955 self.evaluate(v)) 3956 3957 def testDependenciesDevice(self): 3958 with ops.Graph().as_default(): 3959 # device set on tensor => same device on dep. 3960 with ops.device("/job:ps"): 3961 vd = variables.VariableV1([0.0]) 3962 with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd) 3963 self.assertTrue("/job:ps" in with_vd_dep.device) 3964 3965 # No device set on tensor => no device on dep. 3966 vnod = variables.VariableV1([0.0]) 3967 with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer], 3968 vnod) 3969 self.assertDeviceEqual(None, with_vnod_dep.device) 3970 3971 # device set on tensor, default device on graph => default device on dep. 3972 vdef = variables.VariableV1([0.0], name="vdef") 3973 with ops.device("/job:worker/device:GPU:1"): 3974 with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer], 3975 vdef) 3976 # The device is empty, but the colocation constraint is set. 3977 self.assertDeviceEqual("", with_vdef_dep.device) 3978 self.assertEqual([b"loc:@vdef"], with_vdef_dep.op.colocation_groups()) 3979 3980 @test_util.run_v1_only("b/120545219") 3981 def testGroup(self): 3982 with self.cached_session() as sess: 3983 v1 = variables.VariableV1([0.0]) 3984 v2 = variables.VariableV1([1.0]) 3985 3986 # Group init1 and init2 and run. 3987 init = control_flow_ops.group(v1.initializer, v2.initializer) 3988 # Fetching v1 directly will result in an uninitialized error 3989 with self.assertRaisesOpError("Attempting to use uninitialized value"): 3990 self.evaluate(v1) 3991 3992 # Runs "init" before fetching v1 and v2. 3993 init.run() 3994 v1_val, v2_val = self.evaluate([v1, v2]) 3995 3996 # Ensure that v1 and v2 are initialized 3997 self.assertAllClose([0.0], v1_val) 3998 self.assertAllClose([1.0], v2_val) 3999 4000 @test_util.run_v1_only("b/120545219") 4001 def testGroupEmpty(self): 4002 op = control_flow_ops.group() 4003 self.assertEqual(op.type, "NoOp") 4004 self.assertEqual(op.control_inputs, []) 4005 4006 @test_util.run_deprecated_v1 4007 def testMergeShapes(self): 4008 # All inputs unknown. 4009 p1 = array_ops.placeholder(dtypes.float32) 4010 p2 = array_ops.placeholder(dtypes.float32) 4011 p3 = array_ops.placeholder(dtypes.float32) 4012 m, index = control_flow_ops.merge([p1, p2, p3]) 4013 self.assertIs(None, m.get_shape().ndims) 4014 self.assertEqual([], index.get_shape()) 4015 4016 # All inputs known with different ranks. 4017 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4018 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2, 3]) 4019 m, index = control_flow_ops.merge([p1, p2]) 4020 self.assertIs(None, m.get_shape().ndims) 4021 self.assertEqual([], index.get_shape()) 4022 4023 # All inputs known with some dimensions different. 4024 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4025 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 1]) 4026 m, index = control_flow_ops.merge([p1, p2]) 4027 self.assertEqual([None, None], m.get_shape().as_list()) 4028 self.assertEqual([], index.get_shape()) 4029 4030 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4031 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4032 m, index = control_flow_ops.merge([p1, p2]) 4033 self.assertEqual([None, 2], m.get_shape().as_list()) 4034 self.assertEqual([], index.get_shape()) 4035 4036 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4037 p2 = array_ops.placeholder(dtypes.float32, shape=[2, 2]) 4038 m, index = control_flow_ops.merge([p1, p2]) 4039 self.assertEqual([None, 2], m.get_shape().as_list()) 4040 self.assertEqual([], index.get_shape()) 4041 4042 # All inputs known with same dimensions. 4043 p1 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4044 p2 = array_ops.placeholder(dtypes.float32, shape=[1, 2]) 4045 m, index = control_flow_ops.merge([p1, p2]) 4046 self.assertEqual([1, 2], m.get_shape().as_list()) 4047 self.assertEqual([], index.get_shape()) 4048 4049 p1 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4050 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4051 m, index = control_flow_ops.merge([p1, p2]) 4052 self.assertEqual([None, 2], m.get_shape().as_list()) 4053 self.assertEqual([], index.get_shape()) 4054 4055 p1 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 4056 p2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) 4057 m, index = control_flow_ops.merge([p1, p2]) 4058 self.assertEqual([None, None], m.get_shape().as_list()) 4059 self.assertEqual([], index.get_shape()) 4060 4061 @test_util.run_v1_only("b/120545219") 4062 def testRefSelect(self): 4063 index = array_ops.placeholder(dtypes.int32) 4064 4065 # All inputs unknown. 4066 p1 = array_ops.placeholder(dtypes.float32) 4067 p2 = array_ops.placeholder(dtypes.float32) 4068 p3 = array_ops.placeholder(dtypes.float32) 4069 v1 = variables.VariableV1(p1, validate_shape=False) 4070 v2 = variables.VariableV1(p2, validate_shape=False) 4071 v3 = variables.VariableV1(p3, validate_shape=False) 4072 self.assertIs(None, v1.get_shape().ndims) 4073 s = control_flow_ops.ref_select(index, [v1, v2, v3]) 4074 self.assertIs(None, s.get_shape().ndims) 4075 4076 # All inputs known but different. 4077 v1 = variables.VariableV1([[1, 2]]) 4078 v2 = variables.VariableV1([[2], [1]]) 4079 s = control_flow_ops.ref_select(index, [v1, v2]) 4080 self.assertIs(None, s.get_shape().ndims) 4081 4082 # All inputs known and same. 4083 v1 = variables.VariableV1([[1, 2]]) 4084 v2 = variables.VariableV1([[1, 2]]) 4085 s = control_flow_ops.ref_select(index, [v1, v2]) 4086 self.assertEqual([1, 2], s.get_shape()) 4087 4088 # Possibly the same but not guaranteed. 4089 v1 = variables.VariableV1([[1., 2.]]) 4090 p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2]) 4091 v2 = variables.VariableV1(p2, validate_shape=False) 4092 s = control_flow_ops.ref_select(index, [v1, v2]) 4093 self.assertEqual(None, s.get_shape()) 4094 4095 @test_util.run_deprecated_v1 4096 def testRunLoopTensor(self): 4097 with self.cached_session() as sess: 4098 tensor_list = [] 4099 4100 def condition(t): 4101 return t < constant_op.constant(5) 4102 4103 def body(_): 4104 tensor_list.append(constant_op.constant(5)) 4105 return constant_op.constant(10) 4106 4107 result = control_flow_ops.while_loop(condition, body, 4108 [constant_op.constant(4)]) 4109 self.assertEqual(10, self.evaluate(result)) 4110 4111 # Ensure that we cannot run a tensor that escapes the loop body 4112 # accidentally. 4113 with self.assertRaises(ValueError): 4114 sess.run(tensor_list[0]) 4115 4116 @test_util.run_v1_only("b/120545219") 4117 def testWhilePyFuncBasic(self): 4118 4119 def func(x): 4120 return np.square(x) 4121 4122 with self.cached_session(): 4123 r = control_flow_ops.while_loop( 4124 lambda i, v: i < 4, 4125 lambda i, v: [i + 1, script_ops.py_func(func, [v], [dtypes.float32])[0]], 4126 [constant_op.constant(0), constant_op.constant(2.0, dtypes.float32)], 4127 [tensor_shape.unknown_shape(), tensor_shape.unknown_shape()]) 4128 self.assertEqual(self.evaluate(r[1]), 65536.0) 4129 4130 @test_util.run_v1_only("b/120545219") 4131 def testWhileFuncBasic(self): 4132 4133 @function.Defun(dtypes.float32) 4134 def func(x): 4135 return math_ops.square(math_ops.square(x)) 4136 4137 with self.cached_session(): 4138 x = constant_op.constant(2.0, dtypes.float32) 4139 r = control_flow_ops.while_loop( 4140 lambda i, v: i < 2, lambda i, v: [i + 1, func(v)], 4141 [constant_op.constant(0), x], 4142 [tensor_shape.unknown_shape(), 4143 tensor_shape.unknown_shape()]) 4144 grad = gradients_impl.gradients(r, x)[0] 4145 self.assertEqual(self.evaluate(r[1]), 65536.0) 4146 self.assertEqual(self.evaluate(grad), 524288.0) 4147 # while_v2 does not have stacks. 4148 if not control_flow_util.ENABLE_CONTROL_FLOW_V2: 4149 self.assertEqual( 4150 len([op for op in x.graph.get_operations() if op.type == "StackV2" 4151 ]), 1) 4152 4153 4154 @test_util.run_v1_only("b/120545219") 4155 def testQIntSwitchMerge(self): 4156 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4157 constant_qint = constant_op.constant(np.array([42]), dtypes.qint8) 4158 cond = constant_op.constant(True, dtypes.bool) 4159 v_f, v_t = control_flow_ops.switch(constant_qint, cond) 4160 result = control_flow_ops.merge([v_f, v_t]) 4161 self.evaluate(result) 4162 4163 @test_util.run_v1_only("b/120545219") 4164 def testQIntRefSwitchMerge(self): 4165 with self.cached_session(use_gpu=test.is_gpu_available()) as sess: 4166 var_qint = gen_state_ops.variable( 4167 shape=[1], dtype=dtypes.qint8, name="v", container="", shared_name="") 4168 assign_op = state_ops.assign( 4169 var_qint, constant_op.constant(np.array([42]), dtypes.qint8)) 4170 self.evaluate(assign_op) 4171 4172 cond = constant_op.constant(True, dtypes.bool) 4173 v_f, v_t = control_flow_ops.ref_switch(var_qint, cond) 4174 result = control_flow_ops.ref_merge([v_f, v_t]) 4175 self.evaluate(result) 4176 4177 @test_util.run_v1_only("b/120545219") 4178 def testUInt64SwitchMerge(self): 4179 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4180 constant_uint64 = constant_op.constant(np.array([42]), dtypes.uint64) 4181 cond = constant_op.constant(True, dtypes.bool) 4182 v_f, v_t = control_flow_ops.switch(constant_uint64, cond) 4183 result = control_flow_ops.merge([v_f, v_t]) 4184 self.evaluate(result) 4185 4186 @test_util.run_deprecated_v1 4187 def testQIntArgAndRet(self): 4188 4189 @function.Defun(dtypes.qint8) 4190 def func(x): 4191 return x 4192 4193 with self.cached_session(force_gpu=test.is_gpu_available()) as sess: 4194 qint = constant_op.constant(np.array([42]), dtypes.qint8) 4195 result = func(qint) 4196 self.evaluate(result) 4197 4198 def testSparseIdentity(self): 4199 st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10]) 4200 st2 = control_flow_ops._Identity(st1) 4201 self.assertAllEqual(st1.indices, st2.indices) 4202 self.assertAllEqual(st1.values, st2.values) 4203 self.assertAllEqual(st1.dense_shape, st2.dense_shape) 4204 4205 def testSparseEnterExit(self): 4206 st1 = sparse_tensor.SparseTensor([[0, 5]], ['x'], [10, 10]) 4207 st2 = control_flow_ops._Enter(st1, "foo_1") 4208 st3 = control_flow_ops.exit(st2) 4209 self.assertAllEqual(st1.indices, st3.indices) 4210 self.assertAllEqual(st1.values, st3.values) 4211 self.assertAllEqual(st1.dense_shape, st3.dense_shape) 4212 4213 4214 class ControlFlowContextCheckTest(test.TestCase): 4215 4216 def _getWhileTensor(self): 4217 """Creates and returns a tensor from a while context.""" 4218 tensor = [] 4219 4220 def body(i): 4221 if not tensor: 4222 tensor.append(constant_op.constant(1)) 4223 return i + tensor[0] 4224 4225 control_flow_ops.while_loop(lambda i: i < 10, body, [0]) 4226 return tensor[0] 4227 4228 def _getCondTensor(self): 4229 cond_tensor = [] 4230 4231 def true_fn(): 4232 if not cond_tensor: 4233 cond_tensor.append(constant_op.constant(1)) 4234 return cond_tensor[0] 4235 4236 control_flow_ops.cond( 4237 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 4238 return cond_tensor[0] 4239 4240 @test_util.run_v1_only("b/120545219") 4241 def testInvalidContext(self): 4242 # Accessing a while loop tensor outside of control flow is illegal. 4243 while_tensor = self._getWhileTensor() 4244 with self.assertRaisesRegexp( 4245 ValueError, 4246 "Cannot use 'while/Const_1' as input to 'Add' because 'while/Const_1' " 4247 "is in a while loop. See info log for more details."): 4248 math_ops.add(1, while_tensor) 4249 4250 @test_util.run_v1_only("b/120545219") 4251 def testInvalidContextInCond(self): 4252 # Accessing a while loop tensor in cond is illegal. 4253 while_tensor = self._getWhileTensor() 4254 with self.assertRaisesRegexp( 4255 ValueError, "Cannot use 'while/Const_1' as input to 'cond/Add' because " 4256 "'while/Const_1' is in a while loop. See info log for more details."): 4257 # TODO(skyewm): this passes if we return while_tensor directly instead 4258 # of using it as input to another op. 4259 control_flow_ops.cond( 4260 math_ops.less(1, 2), lambda: math_ops.add(1, while_tensor), 4261 lambda: constant_op.constant(0)) 4262 4263 @test_util.run_v1_only("b/120545219") 4264 def testInvalidContextInWhile(self): 4265 # Accessing a while loop tensor in a different while loop is illegal. 4266 while_tensor = self._getWhileTensor() 4267 with self.assertRaisesRegexp( 4268 ValueError, 4269 "Cannot use 'while/Const_1' as input to 'while_1/Add' because they are " 4270 "in different while loops. See info log for more details."): 4271 control_flow_ops.while_loop(lambda i: i < 10, 4272 lambda x: math_ops.add(1, while_tensor), [0]) 4273 4274 with self.assertRaisesRegexp( 4275 ValueError, 4276 "Cannot use 'while/Const_1' as input to 'while_2/NextIteration' " 4277 "because they are in different while loops. See info log for more " 4278 "details."): 4279 control_flow_ops.while_loop(lambda i: i < 10, lambda i: while_tensor, [0]) 4280 4281 def testValidCondContext(self): 4282 # Accessing a tensor from a cond context is OK (although dangerous). 4283 cond_tensor = self._getCondTensor() 4284 math_ops.add(1, cond_tensor) 4285 4286 def testValidCondContextBranches(self): 4287 # Accessing a tensor from a cond context from the other branch's cond 4288 # context is OK (although dangerous). 4289 cond_tensor = [] 4290 4291 def branch_fn(): 4292 if not cond_tensor: 4293 cond_tensor.append(constant_op.constant(1)) 4294 return cond_tensor[0] 4295 4296 control_flow_ops.cond(math_ops.less(1, 2), branch_fn, branch_fn) 4297 4298 @test_util.run_v1_only("b/120545219") 4299 def testValidWhileContext(self): 4300 # Accessing a tensor in a nested while is OK. 4301 def body(_): 4302 c = constant_op.constant(1) 4303 return control_flow_ops.while_loop(lambda i: i < 3, lambda i: i + c, [0]) 4304 4305 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 4306 4307 @test_util.run_v1_only("b/120545219") 4308 def testValidNestedContexts(self): 4309 # Accessing a tensor from a cond context in a while context, all inside an 4310 # outer while context, is OK. 4311 def body(_): 4312 cond_tensor = self._getCondTensor() 4313 # Create another cond containing the while loop for good measure 4314 return control_flow_ops.cond( 4315 math_ops.less(1, 2), 4316 lambda: control_flow_ops.while_loop(lambda i: i < 3, 4317 lambda i: i + cond_tensor, [0]), 4318 lambda: constant_op.constant(0)) 4319 4320 control_flow_ops.while_loop(lambda i: i < 5, body, [0]) 4321 4322 @test_util.run_v1_only("b/120545219") 4323 def testInvalidNestedContexts(self): 4324 # Accessing a tensor from a while context in a different while context, all 4325 # inside a cond context, is illegal. 4326 def true_fn(): 4327 while_tensor = self._getWhileTensor() 4328 return control_flow_ops.while_loop(lambda i: i < 3, 4329 lambda i: i + while_tensor, [0]) 4330 4331 with self.assertRaisesRegexp( 4332 ValueError, 4333 "Cannot use 'cond/while/Const_1' as input to 'cond/while_1/add' because" 4334 " they are in different while loops. See info log for more details."): 4335 control_flow_ops.cond( 4336 math_ops.less(1, 2), true_fn, lambda: constant_op.constant(0)) 4337 4338 4339 class TupleTest(test.TestCase): 4340 4341 @test_util.run_v1_only("b/120545219") 4342 def testTensors(self): 4343 for v1_first in [True, False]: 4344 with self.cached_session(): 4345 v1 = variables.VariableV1([1.0]) 4346 add1 = math_ops.add( 4347 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 4348 2.0) 4349 v2 = variables.VariableV1([10.0]) 4350 add2 = math_ops.add( 4351 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 4352 20.0) 4353 t1, _, t2 = control_flow_ops.tuple([add1, None, add2]) 4354 4355 # v1 is not initialized. 4356 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4357 self.evaluate(v1) 4358 4359 # v2 is not initialized. 4360 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4361 self.evaluate(v2) 4362 4363 if v1_first: 4364 # Getting t1 initializes v2. 4365 self.assertAllClose([3.0], self.evaluate(t1)) 4366 self.assertAllClose([10.0], self.evaluate(v2)) 4367 else: 4368 # Getting t2 initializes v1. 4369 self.assertAllClose([30.0], self.evaluate(t2)) 4370 self.assertAllClose([1.0], self.evaluate(v1)) 4371 4372 @test_util.run_v1_only("b/120545219") 4373 def testIndexedSlices(self): 4374 for v1_first in [True, False]: 4375 with self.cached_session(): 4376 v1 = variables.VariableV1( 4377 np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype( 4378 np.float32)) 4379 v1_at_1 = ops.IndexedSlices( 4380 control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access 4381 constant_op.constant([1])) 4382 4383 v2 = variables.VariableV1( 4384 np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype( 4385 np.float32)) 4386 v2_at_1 = ops.IndexedSlices( 4387 control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access 4388 constant_op.constant([1])) 4389 4390 st1, st2 = control_flow_ops.tuple([v1_at_1, v2_at_1]) 4391 g1 = array_ops.gather(st1.values, st1.indices) 4392 g2 = array_ops.gather(st2.values, st2.indices) 4393 4394 # v1 is not initialized. 4395 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4396 self.evaluate(v1) 4397 4398 # v2 is not initialized. 4399 with self.assertRaisesOpError("Attempting to use uninitialized value"): 4400 self.evaluate(v2) 4401 4402 if v1_first: 4403 # Getting g1 initializes v2. 4404 self.assertAllClose([[10.0, 11.0]], self.evaluate(g1)) 4405 self.assertAllClose([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]], 4406 self.evaluate(v2)) 4407 else: 4408 # Getting g2 initializes v1. 4409 self.assertAllClose([[10.1, 11.1]], self.evaluate(g2)) 4410 self.assertAllClose([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]], 4411 self.evaluate(v1)) 4412 4413 def testAcceptTensorsAsControlInputs(self): 4414 with self.cached_session(): 4415 var = variables.VariableV1(0) 4416 assign = state_ops.assign(var, 1) 4417 t, = control_flow_ops.tuple( 4418 [constant_op.constant(0)], control_inputs=[assign]) 4419 4420 # Should trigger the assign. 4421 self.evaluate(t) 4422 4423 self.assertEquals(1, self.evaluate(var)) 4424 4425 4426 class AssertTest(test.TestCase): 4427 4428 @test_util.run_deprecated_v1 4429 def testGuardedAssertDoesNotCopyWhenTrue(self): 4430 if test_util.is_gpu_available(): 4431 self.skipTest("b/128646478 fails in opensource") 4432 4433 with self.session(use_gpu=True) as sess: 4434 with ops.device(test.gpu_device_name()): 4435 value = constant_op.constant(1.0) 4436 with ops.device("/cpu:0"): 4437 true = constant_op.constant(True) 4438 guarded_assert = control_flow_ops.Assert(true, [value], name="guarded") 4439 unguarded_assert = gen_logging_ops._assert( 4440 true, [value], name="unguarded") 4441 opts = config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) 4442 guarded_metadata = config_pb2.RunMetadata() 4443 sess.run(guarded_assert, options=opts, run_metadata=guarded_metadata) 4444 unguarded_metadata = config_pb2.RunMetadata() 4445 sess.run(unguarded_assert, options=opts, run_metadata=unguarded_metadata) 4446 guarded_nodestat_names = [ 4447 n.node_name 4448 for d in guarded_metadata.step_stats.dev_stats 4449 for n in d.node_stats 4450 ] 4451 unguarded_nodestat_names = [ 4452 n.node_name 4453 for d in unguarded_metadata.step_stats.dev_stats 4454 for n in d.node_stats 4455 ] 4456 guarded_memcpy_nodestat_names = [ 4457 n for n in guarded_nodestat_names if "MEMCPYDtoH" in n 4458 ] 4459 unguarded_memcpy_nodestat_names = [ 4460 n for n in unguarded_nodestat_names if "MEMCPYDtoH" in n 4461 ] 4462 if "GPU" in [d.device_type for d in device_lib.list_local_devices()]: 4463 # A copy was performed for the unguarded assert 4464 self.assertLess(0, len(unguarded_memcpy_nodestat_names), 4465 str(unguarded_nodestat_names)) 4466 # No copy was performed for the guarded assert 4467 self.assertEqual([], guarded_memcpy_nodestat_names) 4468 4469 4470 class WhileOpBenchmark(test.Benchmark): 4471 """Evaluate the performance of while_loop op.""" 4472 4473 def _getInitVariables(self): 4474 batch_size = 10 4475 image_size = 256 4476 kernel_size = 3 4477 depth = 16 4478 4479 init_step = constant_op.constant(-1) 4480 image = variable_scope.get_variable( 4481 "image", 4482 initializer=random_ops.random_normal( 4483 [batch_size, image_size, image_size, depth], 4484 dtype=dtypes.float32, 4485 stddev=1e-1)) 4486 kernel = variable_scope.get_variable( 4487 "weights", 4488 initializer=random_ops.truncated_normal( 4489 [kernel_size, kernel_size, depth, depth], 4490 dtype=dtypes.float32, 4491 stddev=1e-1)) 4492 return init_step, image, kernel 4493 4494 def _runOneBenchmark(self, 4495 default_device, 4496 num_iters=10, 4497 static_unroll=False, 4498 steps=10): 4499 """Evaluate the while loop performance. 4500 4501 Args: 4502 default_device: The default device to run all ops except the loop_body. 4503 loop_body is always run on GPU. 4504 num_iters: Number of iterations to run. 4505 static_unroll: If true, run unrolled version; otherwise, run while_loop. 4506 steps: Total number of repeated steps to run the loop. 4507 4508 Returns: 4509 The duration of the run in seconds. 4510 """ 4511 4512 def loop_body(i, x): 4513 with ops.device("/gpu:0"): 4514 # Always put loop body on GPU. 4515 nx = nn_ops.conv2d( 4516 input=x, 4517 filter=kernel, 4518 strides=[1, 1, 1, 1], 4519 padding="SAME", 4520 data_format="NHWC", 4521 name="conv2d") 4522 ni = math_ops.add(i, 1) 4523 return ni, nx 4524 4525 ops.reset_default_graph() 4526 with session.Session() as sess, ops.device(default_device): 4527 # Get the initial id i, input x, and kernel. 4528 i, x, kernel = self._getInitVariables() 4529 variables.global_variables_initializer().run() 4530 4531 if static_unroll: 4532 for _ in xrange(steps): 4533 i, x = loop_body(i, x) 4534 else: 4535 i, x = control_flow_ops.while_loop( 4536 lambda i, _: i < steps, 4537 loop_body, [i, x], 4538 parallel_iterations=steps, 4539 swap_memory=True) 4540 4541 r = math_ops.reduce_sum(x) 4542 dx, dk = gradients_impl.gradients(r, [x, kernel]) 4543 # Use group to avoid fetching back results. 4544 r = control_flow_ops.group(dx, dk) 4545 4546 for _ in xrange(3): 4547 # exclude warm up time 4548 self.evaluate(r) 4549 4550 start_time = time.time() 4551 for _ in xrange(num_iters): 4552 self.evaluate(r) 4553 return (time.time() - start_time) / num_iters 4554 4555 def benchmarkWhileOpCrossDevicePlacement(self): 4556 iters = 10 4557 # Run loop body on GPU, but other ops on CPU. 4558 duration = self._runOneBenchmark("cpu", iters, static_unroll=False) 4559 self.report_benchmark( 4560 name="while_op_cross_device", iters=iters, wall_time=duration) 4561 4562 def benchmarkWhileOpSameDevicePlacement(self): 4563 iters = 10 4564 # Run all ops on the same GPU device. 4565 duration = self._runOneBenchmark("gpu", iters, static_unroll=False) 4566 self.report_benchmark( 4567 name="while_op_same_device", iters=iters, wall_time=duration) 4568 4569 def benchmarkWhileOpUnrollCrossDevicePlacement(self): 4570 iters = 10 4571 # Run loop body on GPU, but other ops on CPU. 4572 duration = self._runOneBenchmark("cpu", iters, static_unroll=True) 4573 self.report_benchmark( 4574 name="unroll_cross_device_cpu", iters=iters, wall_time=duration) 4575 4576 def benchmarkWhileOpUnrollSameDevicePlacement(self): 4577 iters = 10 4578 # Run all ops on GPU. 4579 duration = self._runOneBenchmark("gpu", iters, static_unroll=True) 4580 self.report_benchmark( 4581 name="unroll_same_device", iters=iters, wall_time=duration) 4582 4583 4584 @test_util.with_control_flow_v2 4585 class EagerTest(test.TestCase): 4586 4587 def testCond(self): 4588 with context.eager_mode(): 4589 pred = math_ops.less(1, 2) 4590 fn1 = lambda: [constant_op.constant(10)] 4591 fn2 = lambda: [constant_op.constant(20)] 4592 r = control_flow_ops.cond(pred, fn1, fn2) 4593 4594 self.assertAllEqual(r.numpy(), 10) 4595 self.assertFalse(isinstance(r, list)) 4596 4597 # TODO(b/117279927): Re-enable once msan failure is fixed. 4598 def DISABLED_testCondInDefun(self): 4599 with context.eager_mode(): 4600 4601 @eager_function.defun 4602 def foo(pred): 4603 # TODO(b/111124878): this only needs to output one element. 4604 fn1 = lambda: (constant_op.constant(10), constant_op.constant(100)) 4605 fn2 = lambda: (constant_op.constant(20), constant_op.constant(200)) 4606 return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2) 4607 4608 r = foo(True) 4609 self.assertAllEqual(r[0].numpy(), 10) 4610 self.assertNotIsInstance(r, list) 4611 4612 r = foo(False) 4613 self.assertAllEqual(r[0].numpy(), 20) 4614 self.assertFalse(isinstance(r, list)) 4615 4616 def testWhileLoop(self): 4617 with context.eager_mode(): 4618 tensor = constant_op.constant([1, 2, 3, 4, 5]) 4619 self.assertAllEqual(isum(tensor).numpy(), [46, 47, 48, 49, 50]) 4620 4621 def testWhileLoopWithMaxIterations(self): 4622 with context.eager_mode(): 4623 tensor = constant_op.constant([1, 2, 3, 4, 5]) 4624 self.assertAllEqual( 4625 isum(tensor, maximum_iterations=3).numpy(), 4626 [1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3]) 4627 4628 @test_util.run_v1_only("b/120545219") 4629 def testWhileWithMaximumIterationsAndSingleArgument(self): 4630 with context.eager_mode(): 4631 tensor = constant_op.constant(0) 4632 r = control_flow_ops.while_loop( 4633 lambda i: i < 3, lambda i: i + 1, [tensor], maximum_iterations=1) 4634 self.assertEqual(1, r.numpy()) 4635 4636 def testWithDependencies(self): 4637 with context.eager_mode(): 4638 t1 = constant_op.constant(1) 4639 t2 = constant_op.constant(2) 4640 t3 = control_flow_ops.with_dependencies(t1, t2) 4641 self.assertAllEqual(t2.numpy(), t3.numpy()) 4642 4643 def testTuple(self): 4644 with context.eager_mode(): 4645 t1 = constant_op.constant(1) 4646 t2 = constant_op.constant(2) 4647 tup1, tup2 = control_flow_ops.tuple([t1, t2]) 4648 self.assertAllEqual(t1.numpy(), tup1.numpy()) 4649 self.assertAllEqual(t2.numpy(), tup2.numpy()) 4650 4651 @test_util.run_v1_only("b/120545219") 4652 def testCase(self): 4653 with context.eager_mode(): 4654 x = constant_op.constant(1) 4655 y = constant_op.constant(2) 4656 z = constant_op.constant(3) 4657 f1 = lambda: constant_op.constant(17) 4658 f2 = lambda: constant_op.constant(23) 4659 f3 = lambda: constant_op.constant(-1) 4660 4661 r1 = control_flow_ops.case( 4662 [(x < y, f1), (x > z, f2)], default=f3, exclusive=True) 4663 self.assertAllEqual(r1.numpy(), 17) 4664 4665 4666 if __name__ == "__main__": 4667 test.main() 4668