1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for control_flow_ops.py.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import numpy as np 23 24 from tensorflow.core.framework import graph_pb2 25 from tensorflow.core.framework import node_def_pb2 26 from tensorflow.python.client import session 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes 29 from tensorflow.python.framework import errors 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import sparse_tensor 32 from tensorflow.python.framework import tensor_shape 33 from tensorflow.python.framework import test_util 34 from tensorflow.python.ops import array_ops 35 from tensorflow.python.ops import control_flow_ops 36 from tensorflow.python.ops import embedding_ops 37 from tensorflow.python.ops import gradients_impl 38 from tensorflow.python.ops import init_ops 39 from tensorflow.python.ops import math_ops 40 from tensorflow.python.ops import state_ops 41 from tensorflow.python.ops import tensor_array_ops 42 from tensorflow.python.ops import variable_scope 43 from tensorflow.python.ops import variables 44 import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import 45 from tensorflow.python.platform import googletest 46 from tensorflow.python.training import momentum 47 from tensorflow.python.util import nest 48 49 50 TestTuple = collections.namedtuple("TestTuple", "a b") 51 SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a") 52 53 54 @test_util.with_c_api 55 class GroupTestCase(test_util.TensorFlowTestCase): 56 57 def _StripNode(self, nd): 58 snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input) 59 if nd.device: 60 snode.device = nd.device 61 return snode 62 63 def _StripGraph(self, gd): 64 """Copy gd keeping only, node.name, node.op, node.input, and node.device.""" 65 return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node]) 66 67 def testGroup_NoDevices(self): 68 with ops.Graph().as_default() as g: 69 a = constant_op.constant(0, name="a") 70 b = constant_op.constant(0, name="b") 71 c = constant_op.constant(0, name="c") 72 control_flow_ops.group(a.op, b.op, c.op, name="root") 73 gd = g.as_graph_def() 74 self.assertProtoEquals(""" 75 node { name: "a" op: "Const"} 76 node { name: "b" op: "Const"} 77 node { name: "c" op: "Const"} 78 node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" } 79 """, self._StripGraph(gd)) 80 81 def testGroup_OneDevice(self): 82 with ops.Graph().as_default() as g: 83 with g.device("/task:0"): 84 a = constant_op.constant(0, name="a") 85 b = constant_op.constant(0, name="b") 86 control_flow_ops.group(a.op, b.op, name="root") 87 gd = g.as_graph_def() 88 self.assertProtoEquals(""" 89 node { name: "a" op: "Const" device: "/task:0" } 90 node { name: "b" op: "Const" device: "/task:0" } 91 node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" } 92 """, self._StripGraph(gd)) 93 94 def testGroup_MultiDevice(self): 95 with ops.Graph().as_default() as g: 96 with g.device("/task:0"): 97 a = constant_op.constant(0, name="a") 98 b = constant_op.constant(0, name="b") 99 with g.device("/task:1"): 100 c = constant_op.constant(0, name="c") 101 d = constant_op.constant(0, name="d") 102 with g.device("/task:2"): 103 control_flow_ops.group(a.op, b.op, c.op, d.op, name="root") 104 gd = g.as_graph_def() 105 self.assertProtoEquals(""" 106 node { name: "a" op: "Const" device: "/task:0"} 107 node { name: "b" op: "Const" device: "/task:0"} 108 node { name: "c" op: "Const" device: "/task:1"} 109 node { name: "d" op: "Const" device: "/task:1"} 110 node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b" 111 device: "/task:0" } 112 node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d" 113 device: "/task:1" } 114 node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1" 115 device: "/task:2" } 116 """, self._StripGraph(gd)) 117 118 def testPassingList(self): 119 with ops.Graph().as_default() as g: 120 a = constant_op.constant(0, name="a") 121 b = constant_op.constant(0, name="b") 122 control_flow_ops.group([a.op, b.op], name="root") 123 gd = g.as_graph_def() 124 self.assertProtoEquals(""" 125 node { name: "a" op: "Const"} 126 node { name: "b" op: "Const"} 127 node { name: "root" op: "NoOp" input: "^a" input: "^b" } 128 """, self._StripGraph(gd)) 129 130 def testPassingNonTensors(self): 131 with ops.Graph().as_default(): 132 with self.assertRaises(TypeError): 133 control_flow_ops.group(1, 2) 134 135 136 @test_util.with_c_api 137 class ShapeTestCase(test_util.TensorFlowTestCase): 138 139 def testShape(self): 140 with ops.Graph().as_default(): 141 tensor = constant_op.constant([1.0, 2.0]) 142 self.assertEquals([2], tensor.get_shape()) 143 self.assertEquals([2], 144 control_flow_ops.with_dependencies( 145 [constant_op.constant(1.0)], tensor).get_shape()) 146 147 148 @test_util.with_c_api 149 class WithDependenciesTestCase(test_util.TensorFlowTestCase): 150 151 def testTupleDependencies(self): 152 with ops.Graph().as_default(): 153 counter = variable_scope.get_variable( 154 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 155 increment_counter = state_ops.assign_add(counter, 1) 156 const_with_dep = control_flow_ops.with_dependencies( 157 (increment_counter, constant_op.constant(42)), 158 constant_op.constant(7)) 159 with self.test_session(): 160 variables.global_variables_initializer().run() 161 self.assertEquals(0, counter.eval()) 162 self.assertEquals(7, const_with_dep.eval()) 163 self.assertEquals(1, counter.eval()) 164 165 def testListDependencies(self): 166 with ops.Graph().as_default(): 167 counter = variable_scope.get_variable( 168 "my_counter", shape=[], initializer=init_ops.zeros_initializer()) 169 increment_counter = state_ops.assign_add(counter, 1) 170 const_with_dep = control_flow_ops.with_dependencies( 171 [increment_counter, constant_op.constant(42)], 172 constant_op.constant(7)) 173 with self.test_session(): 174 variables.global_variables_initializer().run() 175 self.assertEquals(0, counter.eval()) 176 self.assertEquals(7, const_with_dep.eval()) 177 self.assertEquals(1, counter.eval()) 178 179 180 @test_util.with_c_api 181 class SwitchTestCase(test_util.TensorFlowTestCase): 182 183 def testIndexedSlicesWithDenseShape(self): 184 with self.test_session(): 185 data = ops.IndexedSlices( 186 constant_op.constant([1, 2, 3]), 187 constant_op.constant([0, 1]), 188 dense_shape=constant_op.constant([3])) 189 zero = constant_op.constant(0) 190 one = constant_op.constant(1) 191 less_op = math_ops.less(zero, one) 192 _, switch_true = control_flow_ops.switch(data, less_op) 193 self.assertAllEqual([1, 2, 3], switch_true.values.eval()) 194 self.assertAllEqual([0, 1], switch_true.indices.eval()) 195 196 def testIndexedSlicesGradient(self): 197 with ops.Graph().as_default(): 198 embedding_matrix = variable_scope.get_variable( 199 "embedding_matrix", [5, 5], 200 initializer=init_ops.random_normal_initializer()) 201 202 def cond(it, _): 203 return it < 5 204 205 def body(it, cost): 206 embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0]) 207 cost += math_ops.reduce_sum(embedding) 208 return it + 1, cost 209 210 _, cost = control_flow_ops.while_loop( 211 cond, body, [constant_op.constant(0), 212 constant_op.constant(0.0)]) 213 optimizer = momentum.MomentumOptimizer(0.1, 0.9) 214 train_op = optimizer.minimize(cost) 215 with self.test_session() as sess: 216 sess.run(variables.global_variables_initializer()) 217 for _ in range(10): 218 sess.run([train_op]) 219 220 def testResourceReadInLoop(self): 221 with ops.Graph().as_default(): 222 embedding_matrix = variable_scope.get_variable( 223 "embedding_matrix", 224 initializer=[[2.0], [3.0]], 225 use_resource=True) 226 227 def cond(it, _): 228 return it < 5 229 230 def body(it, cost): 231 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 232 cost += math_ops.reduce_sum(embedding) 233 return it + 1, cost 234 235 _, cost = control_flow_ops.while_loop( 236 cond, body, [constant_op.constant(0), 237 constant_op.constant(0.0)]) 238 with self.test_session() as sess: 239 sess.run(variables.global_variables_initializer()) 240 self.assertAllEqual(10.0, cost.eval()) 241 242 def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): 243 with ops.Graph().as_default(): 244 embedding_matrix = variable_scope.get_variable( 245 "embedding_matrix", [5, 5], 246 initializer=init_ops.random_normal_initializer(), 247 use_resource=use_resource) 248 249 def cond(it, _): 250 return it < 5 251 252 def body(it, cost): 253 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 254 cost = control_flow_ops.cond( 255 math_ops.equal(it, 3), lambda: math_ops.square(cost), 256 lambda: cost + math_ops.reduce_sum(embedding)) 257 return it + 1, cost 258 259 _, cost = control_flow_ops.while_loop( 260 cond, body, [constant_op.constant(0), 261 constant_op.constant(0.0)]) 262 263 dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] 264 dynamic_grads = math_ops.segment_sum(dynamic_grads.values, 265 dynamic_grads.indices) 266 267 embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) 268 static = math_ops.square( 269 math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + 270 math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) 271 static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] 272 static_grads = math_ops.segment_sum(static_grads.values, 273 static_grads.indices) 274 275 with self.test_session() as sess: 276 sess.run(variables.global_variables_initializer()) 277 self.assertAllEqual(*sess.run([static_grads, dynamic_grads])) 278 279 def testIndexedSlicesGradientInCondInWhileLoop(self): 280 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False) 281 282 def testIndexedSlicesGradientInCondInWhileLoopResource(self): 283 self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True) 284 285 def testIndexedSlicesWithShapeGradientInWhileLoop(self): 286 for dtype in [dtypes.float32, dtypes.float64]: 287 with self.test_session() as sess: 288 num_steps = 9 289 290 inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) 291 initial_outputs = tensor_array_ops.TensorArray( 292 dtype=dtype, size=num_steps) 293 initial_i = constant_op.constant(0, dtype=dtypes.int32) 294 295 def cond(i, _): 296 return i < num_steps # pylint: disable=cell-var-from-loop 297 298 def body(i, outputs): 299 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 300 outputs = outputs.write(i, x) 301 return i + 1, outputs 302 303 _, outputs = control_flow_ops.while_loop(cond, body, 304 [initial_i, initial_outputs]) 305 306 outputs = math_ops.reduce_sum(outputs.stack()) 307 r = gradients_impl.gradients([outputs], [inputs])[0] 308 grad_wr_inputs = ops.convert_to_tensor(r) 309 o, grad = sess.run([outputs, grad_wr_inputs], 310 feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) 311 self.assertEquals(o, 20) 312 self.assertAllEqual(grad, [1] * num_steps) 313 314 def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): 315 for dtype in [dtypes.float32, dtypes.float64]: 316 with self.test_session() as sess: 317 inputs = array_ops.placeholder(dtype=dtype) 318 initial_outputs = tensor_array_ops.TensorArray( 319 dtype=dtype, dynamic_size=True, size=1) 320 initial_i = constant_op.constant(0, dtype=dtypes.int32) 321 322 def cond(i, _): 323 return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop 324 325 def body(i, outputs): 326 x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop 327 outputs = outputs.write(i, x) 328 return i + 1, outputs 329 330 _, outputs = control_flow_ops.while_loop(cond, body, 331 [initial_i, initial_outputs]) 332 333 outputs = math_ops.reduce_sum(outputs.stack()) 334 r = gradients_impl.gradients([outputs], [inputs])[0] 335 grad_wr_inputs = ops.convert_to_tensor(r) 336 o, grad = sess.run([outputs, grad_wr_inputs], 337 feed_dict={inputs: [1, 3, 2]}) 338 self.assertEquals(o, 6) 339 self.assertAllEqual(grad, [1] * 3) 340 341 def testGradientThroughSingleBranchOutsideOfContext(self): 342 with self.test_session(): 343 x = constant_op.constant(2.) 344 s = constant_op.constant(True) 345 x_false, x_true = control_flow_ops.switch(x, s) 346 grad_x_true = gradients_impl.gradients(x_true, x)[0] 347 grad_x_false = gradients_impl.gradients(x_false, x)[0] 348 self.assertEquals(grad_x_true.eval(), 1.) 349 self.assertEquals(grad_x_false.eval(), 0.) 350 351 352 @test_util.with_c_api 353 class SmartCondTest(test_util.TensorFlowTestCase): 354 355 def testSmartCondTrue(self): 356 with ops.Graph().as_default(): 357 with session.Session(): 358 x = constant_op.constant(2) 359 y = constant_op.constant(5) 360 z = control_flow_ops.smart_cond( 361 True, lambda: math_ops.multiply(x, 16), 362 lambda: math_ops.multiply(y, 5)) 363 self.assertEqual(z.eval(), 32) 364 365 def testSmartCondFalse(self): 366 with ops.Graph().as_default(): 367 with session.Session(): 368 x = constant_op.constant(4) 369 y = constant_op.constant(3) 370 z = control_flow_ops.smart_cond( 371 False, lambda: math_ops.multiply(x, 16), 372 lambda: math_ops.multiply(y, 3)) 373 self.assertEqual(z.eval(), 9) 374 375 def testSmartCondMissingArg1(self): 376 with ops.Graph().as_default(): 377 with session.Session(): 378 x = constant_op.constant(1) 379 with self.assertRaises(TypeError): 380 control_flow_ops.smart_cond(True, false_fn=lambda: x) 381 382 def testSmartCondMissingArg2(self): 383 with ops.Graph().as_default(): 384 with session.Session(): 385 x = constant_op.constant(1) 386 with self.assertRaises(TypeError): 387 control_flow_ops.smart_cond(True, lambda: x) 388 389 390 @test_util.with_c_api 391 class CondTest(test_util.TensorFlowTestCase): 392 393 def testCondTrue(self): 394 # Create new Graph and Session for each test so we pick up _USE_C_API 395 # correctly. 396 with ops.Graph().as_default(): 397 with session.Session(): 398 x = constant_op.constant(2) 399 y = constant_op.constant(5) 400 z = control_flow_ops.cond( 401 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 402 lambda: math_ops.add(y, 23)) 403 self.assertEquals(z.eval(), 34) 404 405 def testCondFalse(self): 406 with ops.Graph().as_default(): 407 with session.Session(): 408 x = constant_op.constant(2) 409 y = constant_op.constant(1) 410 z = control_flow_ops.cond( 411 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 412 lambda: math_ops.add(y, 23)) 413 self.assertEquals(z.eval(), 24) 414 415 def testCondTrueLegacy(self): 416 with ops.Graph().as_default(): 417 with session.Session(): 418 x = constant_op.constant(2) 419 y = constant_op.constant(5) 420 z = control_flow_ops.cond( 421 math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), 422 fn2=lambda: math_ops.add(y, 23)) 423 self.assertEquals(z.eval(), 34) 424 425 def testCondFalseLegacy(self): 426 with ops.Graph().as_default(): 427 with session.Session(): 428 x = constant_op.constant(2) 429 y = constant_op.constant(1) 430 z = control_flow_ops.cond( 431 math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17), 432 fn2=lambda: math_ops.add(y, 23)) 433 self.assertEquals(z.eval(), 24) 434 435 def testCondModifyBoolPred(self): 436 # This test in particular used to fail only when running in GPU, hence 437 # use_gpu=True. 438 with ops.Graph().as_default(): 439 with session.Session() as sess: 440 bool_var = variable_scope.get_variable("bool_var", dtype=dtypes.bool, 441 initializer=True) 442 cond_on_bool_var = control_flow_ops.cond( 443 pred=bool_var, 444 true_fn=lambda: state_ops.assign(bool_var, False), 445 false_fn=lambda: True) 446 sess.run(bool_var.initializer) 447 self.assertEquals(sess.run(cond_on_bool_var), False) 448 self.assertEquals(sess.run(cond_on_bool_var), True) 449 450 def testCondMissingArg1(self): 451 with ops.Graph().as_default(): 452 with session.Session(): 453 x = constant_op.constant(1) 454 with self.assertRaises(TypeError): 455 control_flow_ops.cond(True, false_fn=lambda: x) 456 457 def testCondMissingArg2(self): 458 with ops.Graph().as_default(): 459 with session.Session(): 460 x = constant_op.constant(1) 461 with self.assertRaises(TypeError): 462 control_flow_ops.cond(True, lambda: x) 463 464 def testCondDuplicateArg1(self): 465 with ops.Graph().as_default(): 466 with session.Session(): 467 x = constant_op.constant(1) 468 with self.assertRaises(TypeError): 469 control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) 470 471 def testCondDuplicateArg2(self): 472 with ops.Graph().as_default(): 473 with session.Session(): 474 x = constant_op.constant(1) 475 with self.assertRaises(TypeError): 476 control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) 477 478 479 @test_util.with_c_api 480 class ContextTest(test_util.TensorFlowTestCase): 481 482 def testCondContext(self): 483 with self.test_session() as sess: 484 x = constant_op.constant(2) 485 y = constant_op.constant(5) 486 control_flow_ops.cond( 487 math_ops.less(x, y), lambda: math_ops.multiply(x, 17), 488 lambda: math_ops.add(y, 23)) 489 for op in sess.graph.get_operations(): 490 c = op._get_control_flow_context() 491 if c: 492 self.assertProtoEquals( 493 c.to_proto(), 494 control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto()) 495 496 def _testWhileContextHelper(self, maximum_iterations=None): 497 with self.test_session() as sess: 498 i = constant_op.constant(0) 499 c = lambda i: math_ops.less(i, 10) 500 b = lambda i: math_ops.add(i, 1) 501 control_flow_ops.while_loop( 502 c, b, [i], maximum_iterations=maximum_iterations) 503 for op in sess.graph.get_operations(): 504 control_flow_context = op._get_control_flow_context() 505 if control_flow_context: 506 self.assertProtoEquals( 507 control_flow_context.to_proto(), 508 control_flow_ops.WhileContext.from_proto( 509 control_flow_context.to_proto()).to_proto()) 510 511 def testWhileContext(self): 512 self._testWhileContextHelper() 513 514 def testWhileContextWithMaximumIterations(self): 515 self._testWhileContextHelper(maximum_iterations=10) 516 517 def testControlContextImportScope(self): 518 with self.test_session(): 519 constant_op.constant(0, name="a") 520 constant_op.constant(2, name="test_scope/a") 521 b1 = constant_op.constant(1, name="b") 522 b2 = constant_op.constant(3, name="test_scope/b") 523 524 c = control_flow_ops.ControlFlowContext() 525 c._values = ["a", "b"] 526 c._external_values = {"a": b1} 527 528 c_with_scope = control_flow_ops.ControlFlowContext( 529 values_def=c._to_values_def(), import_scope="test_scope") 530 531 # _values and _external_values should be have scope prepended. 532 self.assertEquals( 533 c_with_scope._values, set(["test_scope/a", "test_scope/b"])) 534 self.assertEquals( 535 c_with_scope._external_values, {"test_scope/a": b2}) 536 537 # Calling _to_proto() with export_scope should remove "test_scope". 538 self.assertProtoEquals( 539 c._to_values_def(), 540 c_with_scope._to_values_def(export_scope="test_scope")) 541 542 543 def _get_nested_shape(nested): 544 545 def _get_shape(tensor): 546 if isinstance(tensor, tensor_array_ops.TensorArray): 547 return tensor_array_ops.TensorArray 548 elif isinstance(tensor, ops.IndexedSlices): 549 return tensor.dense_shape 550 else: 551 return tensor.get_shape() 552 553 return nest.map_structure(_get_shape, nested) 554 555 556 def _create_tensor_array(size, shape): 557 ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size, 558 clear_after_read=False) 559 for i in range(size): 560 ta = ta.write(i, array_ops.zeros(shape)) 561 return ta 562 563 564 def _raw_nested_shape(nested_shape): 565 566 def _raw_shape(shape): 567 if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None: 568 return [x.value for x in shape] 569 else: 570 return None 571 572 return nest.map_structure(_raw_shape, nested_shape) 573 574 575 # TODO(yori): Add tests for indexed slices. 576 @test_util.with_c_api 577 class DataTypesTest(test_util.TensorFlowTestCase): 578 579 def assertAllEqualNested(self, a, b): 580 if isinstance(a, (list, tuple)): 581 for entry_a, entry_b in zip(a, b): 582 self.assertAllEqualNested(entry_a, entry_b) 583 else: 584 self.assertAllEqual(a, b) 585 586 def _testShape(self, fn_true, fn_false, expected_shape, 587 strict=False): 588 condition = array_ops.placeholder(dtypes.bool) 589 output_cond = control_flow_ops.cond(condition, fn_true, fn_false, 590 strict=strict) 591 self.assertEqual( 592 _raw_nested_shape(_get_nested_shape(output_cond)), 593 _raw_nested_shape(expected_shape)) 594 595 output_case = control_flow_ops.case([(condition, fn_true)], fn_false, 596 strict=strict) 597 self.assertEqual( 598 _raw_nested_shape(_get_nested_shape(output_case)), 599 _raw_nested_shape(expected_shape)) 600 601 def _testReturnValues(self, fn_true, fn_false, expected_value_true, 602 expected_value_false, strict=False, 603 check_cond=True, feed_dict=None): 604 if feed_dict is None: feed_dict = {} 605 606 condition = array_ops.placeholder(dtypes.bool) 607 output_cond = control_flow_ops.cond(condition, fn_true, fn_false, 608 strict=strict) 609 output_case = control_flow_ops.case([(condition, fn_true)], fn_false, 610 strict=strict) 611 612 with self.test_session() as sess: 613 variables.global_variables_initializer().run() 614 true_feed_dict = {condition: True} 615 true_feed_dict.update(feed_dict) 616 result_cond, result_case = sess.run([output_cond, output_case], 617 feed_dict=true_feed_dict) 618 self.assertAllEqualNested(result_cond, expected_value_true) 619 if check_cond: 620 self.assertAllEqualNested(result_case, expected_value_true) 621 false_feed_dict = {condition: False} 622 false_feed_dict.update(feed_dict) 623 result_cond, result_case = sess.run([output_cond, output_case], 624 feed_dict=false_feed_dict) 625 self.assertAllEqualNested(result_cond, expected_value_false) 626 if check_cond: 627 self.assertAllEqualNested(result_case, expected_value_false) 628 629 def test_int(self): 630 shape = tensor_shape.TensorShape([]) 631 fn_true = lambda: 1 632 fn_false = lambda: 2 633 self._testShape(fn_true, fn_false, shape) 634 self._testReturnValues(fn_true, fn_false, 1, 2) 635 self._testShape(fn_true, fn_false, shape, strict=True) 636 self._testReturnValues(fn_true, fn_false, 1, 2, strict=True) 637 638 def test_float(self): 639 shape = tensor_shape.TensorShape([]) 640 fn_true = lambda: 1.0 641 fn_false = lambda: 2.0 642 self._testShape(fn_true, fn_false, shape) 643 self._testReturnValues(fn_true, fn_false, 1.0, 2.0) 644 645 def test_noop(self): 646 shape = tensor_shape.TensorShape(None) 647 self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape) 648 self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op, 649 True, False, check_cond=False) 650 651 def test_string(self): 652 shape = tensor_shape.TensorShape([]) 653 fn_true = lambda: "abc" 654 fn_false = lambda: "xyz" 655 self._testShape(fn_true, fn_false, shape) 656 self._testReturnValues(fn_true, fn_false, b"abc", b"xyz") 657 658 def test_variable(self): 659 shape = tensor_shape.TensorShape([]) 660 fn_true = lambda: variables.Variable(3.0) 661 fn_false = lambda: variables.Variable(4.0) 662 self._testShape(fn_true, fn_false, shape) 663 self._testReturnValues(fn_true, fn_false, 3.0, 4.0) 664 665 def test_none(self): 666 fn_none = lambda: None 667 fn_tensor = lambda: constant_op.constant(1) 668 669 with self.assertRaises(ValueError): 670 control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor) 671 672 with self.assertRaises(ValueError): 673 control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none) 674 675 def test_tensors(self): 676 677 def _build_true_branch(dtype): 678 679 def _build(): 680 return (array_ops.zeros([2, 2], dtype=dtype), 681 array_ops.ones([3, 3], dtype=dtype)) 682 683 return _build 684 685 def _build_false_branch(dtype): 686 687 def _build(): 688 return (array_ops.ones([2, 2], dtype=dtype), 689 array_ops.zeros([3, 3], dtype=dtype)) 690 691 return _build 692 693 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 694 shape = (tensor_shape.TensorShape([2, 2]), 695 tensor_shape.TensorShape([3, 3])) 696 fn_true = _build_true_branch(dtype) 697 fn_false = _build_false_branch(dtype) 698 self._testShape(fn_true, fn_false, shape) 699 self._testReturnValues(fn_true, fn_false, 700 (np.zeros([2, 2]), np.ones([3, 3])), 701 (np.ones([2, 2]), np.zeros([3, 3]))) 702 703 def test_tensors_unknown_shape(self): 704 705 def _build_true_branch(dtype): 706 tensor = array_ops.placeholder(dtype=dtype, shape=None) 707 708 def _build(): 709 return tensor 710 711 return _build, tensor 712 713 def _build_false_branch(dtype): 714 tensor = array_ops.placeholder(dtype=dtype, shape=None) 715 716 def _build(): 717 return tensor 718 719 return _build, tensor 720 721 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 722 shape = tensor_shape.TensorShape(None) 723 fn_true, true_tensor = _build_true_branch(dtype) 724 fn_false, false_tensor = _build_false_branch(dtype) 725 self._testShape(fn_true, fn_false, shape) 726 self._testReturnValues(fn_true, fn_false, 727 np.zeros([2, 2]), np.ones([2, 2]), 728 feed_dict={true_tensor: np.zeros([2, 2]), 729 false_tensor: np.ones([2, 2])}) 730 731 def test_sparse_tensors(self): 732 shape = tensor_shape.TensorShape([None, None]) 733 734 def true_fn(): 735 return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]], 736 values=[1, 2], dense_shape=[3, 4])] 737 738 def false_fn(): 739 return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]], 740 values=[3, 4], dense_shape=[3, 4])] 741 742 value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]], 743 values=[1, 2], dense_shape=[3, 4]) 744 value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]], 745 values=[3, 4], dense_shape=[3, 4]) 746 self._testShape(true_fn, false_fn, shape) 747 self._testReturnValues(true_fn, false_fn, value1, value2) 748 self._testShape(true_fn, false_fn, [shape], strict=True) 749 self._testReturnValues(true_fn, false_fn, [value1], [value2], strict=True) 750 751 def test_tensors_with_partially_specified_shapes(self): 752 753 def _build_branch(dtype, shape): 754 a = array_ops.placeholder(dtype=dtype, shape=shape[0]) 755 b = array_ops.placeholder(dtype=dtype, shape=shape[1]) 756 c = array_ops.placeholder(dtype=dtype, shape=shape[2]) 757 758 def _build(): 759 return a, b, c 760 761 return _build, (a, b, c) 762 763 for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8): 764 shape = (tensor_shape.TensorShape([None, 2]), 765 tensor_shape.TensorShape([None]), 766 tensor_shape.TensorShape([3, None])) 767 fn_true, true_tensors = _build_branch(dtype, shape) 768 fn_false, false_tensors = _build_branch(dtype, shape) 769 self._testShape(fn_true, fn_false, shape) 770 self._testReturnValues(fn_true, fn_false, 771 (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), 772 (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])), 773 feed_dict={true_tensors[0]: np.zeros([2, 2]), 774 false_tensors[0]: np.zeros([2, 2]), 775 true_tensors[1]: np.zeros([5]), 776 false_tensors[1]: np.zeros([5]), 777 true_tensors[2]: np.ones([3, 3]), 778 false_tensors[2]: np.ones([3, 3])}) 779 780 def test_tensor_arrays(self): 781 element_shape = tensor_shape.TensorShape([2]) 782 ta1 = _create_tensor_array(4, element_shape) 783 ta2 = _create_tensor_array(4, element_shape) 784 shape = tensor_array_ops.TensorArray 785 fn_true = lambda: ta1 786 fn_false = lambda: ta2 787 self._testShape(fn_true, fn_false, shape) 788 789 def test_tensor_array_reads(self): 790 shape = tensor_shape.TensorShape([2]) 791 ta = _create_tensor_array(4, shape) 792 fn_true = lambda: ta.read(0) 793 fn_false = lambda: ta.read(1) 794 self._testShape(fn_true, fn_false, shape) 795 796 def test_list(self): 797 shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]), 798 tensor_shape.TensorShape([])] 799 fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)] 800 fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)] 801 self._testShape(fn_true, fn_false, shape) 802 self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0]) 803 804 def test_non_strict(self): 805 shape = tensor_shape.TensorShape([]) 806 fn_tensor = lambda: constant_op.constant(1) 807 fn_list = lambda: [constant_op.constant(2)] 808 fn_tuple = lambda: (constant_op.constant(3),) 809 self._testShape(fn_tensor, fn_list, shape) 810 self._testShape(fn_tensor, fn_tuple, shape) 811 self._testShape(fn_list, fn_tuple, shape) 812 self._testReturnValues(fn_tensor, fn_list, 1, 2) 813 self._testReturnValues(fn_tensor, fn_tuple, 1, 3) 814 self._testReturnValues(fn_list, fn_tuple, 2, 3) 815 816 def test_singleton_strict(self): 817 fn_tensor = lambda: constant_op.constant(1) 818 fn_list = lambda: [constant_op.constant(2)] 819 fn_tuple = lambda: (constant_op.constant(3),) 820 821 with self.assertRaises(ValueError): 822 control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list, 823 strict=True) 824 825 with self.assertRaises(TypeError): 826 control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple, 827 strict=True) 828 829 with self.assertRaises(ValueError): 830 control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list, 831 strict=True) 832 833 with self.assertRaises(TypeError): 834 control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple, 835 strict=True) 836 837 def test_singleton_list(self): 838 shape = tensor_shape.TensorShape([]) 839 fn_true = lambda: [constant_op.constant(1)] 840 fn_false = lambda: [constant_op.constant(3)] 841 self._testShape(fn_true, fn_false, shape) 842 self._testReturnValues(fn_true, fn_false, 1, 3) 843 self._testShape(fn_true, fn_false, [shape], strict=True) 844 self._testReturnValues(fn_true, fn_false, [1], [3], strict=True) 845 846 def test_singleton_tuple(self): 847 shape = tensor_shape.TensorShape([]) 848 fn_true = lambda: (constant_op.constant(1),) 849 fn_false = lambda: (constant_op.constant(3),) 850 self._testShape(fn_true, fn_false, shape) 851 self._testReturnValues(fn_true, fn_false, 1, 3) 852 self._testShape(fn_true, fn_false, (shape,), strict=True) 853 self._testReturnValues(fn_true, fn_false, (1,), (3,), 854 strict=True) 855 856 def test_singleton_namedtuple(self): 857 shape = tensor_shape.TensorShape([]) 858 fn_true = lambda: SingletonTestTuple(constant_op.constant(1)) 859 fn_false = lambda: SingletonTestTuple(constant_op.constant(3)) 860 self._testShape(fn_true, fn_false, shape) 861 self._testReturnValues(fn_true, fn_false, 1, 3) 862 self._testShape(fn_true, fn_false, SingletonTestTuple(shape), 863 strict=True) 864 self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1), 865 SingletonTestTuple(3), strict=True) 866 867 def test_tuple(self): 868 shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([])) 869 fn_true = lambda: (constant_op.constant(1), 2) 870 fn_false = lambda: (constant_op.constant(3), 4) 871 self._testShape(fn_true, fn_false, shape) 872 self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4)) 873 874 def test_namedtuple(self): 875 shape = TestTuple(tensor_shape.TensorShape([]), 876 tensor_shape.TensorShape([])) 877 fn_true = lambda: TestTuple(constant_op.constant(1), 2) 878 fn_false = lambda: TestTuple(constant_op.constant(3), 4) 879 self._testShape(fn_true, fn_false, shape) 880 self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4)) 881 882 def test_nested(self): 883 shape = [tensor_shape.TensorShape([]), 884 TestTuple(tensor_shape.TensorShape([]), 885 [tensor_shape.TensorShape([]), 886 tensor_shape.TensorShape([])]), 887 tensor_shape.TensorShape([5, 5]), 888 tensor_shape.TensorShape([])] 889 890 def true_fn(): 891 return [constant_op.constant(1), 892 TestTuple(constant_op.constant(2), [3, 4]), 893 array_ops.zeros([5, 5]), 6] 894 895 def false_fn(): 896 return [constant_op.constant(11), 897 TestTuple(constant_op.constant(12), [13, 14]), 898 array_ops.ones([5, 5]), 16] 899 900 self._testShape(true_fn, false_fn, shape) 901 self._testReturnValues( 902 true_fn, false_fn, 903 [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6], 904 [11, TestTuple(12, [13, 14]), 905 np.ones([5, 5]), 16]) 906 907 def test_cond_inside_while_loop(self): 908 909 def body(i, matrix): 910 result_tuple, unused_matrix = control_flow_ops.cond( 911 constant_op.constant(True), 912 lambda: (TestTuple(matrix * 2, matrix * 4), matrix), 913 lambda: (TestTuple(matrix * 4, matrix * 2), matrix)) 914 return [i+1, result_tuple.a] 915 916 iteration, matrix = control_flow_ops.while_loop( 917 lambda i, matrix: i < 10, 918 body, 919 loop_vars=[constant_op.constant(0), 920 array_ops.ones([2, 2])]) 921 922 self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([])) 923 self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2])) 924 925 926 @test_util.with_c_api 927 class CaseTest(test_util.TensorFlowTestCase): 928 929 def testCase_withDefault(self): 930 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 931 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 932 (math_ops.equal(x, 2), lambda: constant_op.constant(4))] 933 default = lambda: constant_op.constant(6) 934 output = control_flow_ops.case(conditions, default, exclusive=True) 935 with self.test_session() as sess: 936 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 937 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 938 self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) 939 940 def testCase_multiple_matches_exclusive(self): 941 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 942 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 943 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 944 (math_ops.equal(x, 2), lambda: constant_op.constant(6))] 945 default = lambda: constant_op.constant(8) 946 output = control_flow_ops.case(conditions, default, exclusive=True) 947 with self.test_session() as sess: 948 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 949 self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) 950 with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): 951 sess.run(output, feed_dict={x: 2}) 952 953 def testCase_multiple_matches_non_exclusive(self): 954 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 955 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 956 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 957 (math_ops.equal(x, 2), lambda: constant_op.constant(6))] 958 default = lambda: constant_op.constant(8) 959 output = control_flow_ops.case(conditions, default, exclusive=False) 960 with self.test_session() as sess: 961 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 962 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 963 self.assertEqual(sess.run(output, feed_dict={x: 3}), 8) 964 965 def testCase_withoutDefault(self): 966 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 967 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)), 968 (math_ops.equal(x, 2), lambda: constant_op.constant(4)), 969 (math_ops.equal(x, 3), lambda: constant_op.constant(6))] 970 output = control_flow_ops.case(conditions, exclusive=True) 971 with self.test_session() as sess: 972 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 973 self.assertEqual(sess.run(output, feed_dict={x: 2}), 4) 974 self.assertEqual(sess.run(output, feed_dict={x: 3}), 6) 975 with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): 976 sess.run(output, feed_dict={x: 4}) 977 978 def testCase_withoutDefault_oneCondition(self): 979 x = array_ops.placeholder(dtype=dtypes.int32, shape=[]) 980 conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))] 981 output = control_flow_ops.case(conditions, exclusive=True) 982 with self.test_session() as sess: 983 self.assertEqual(sess.run(output, feed_dict={x: 1}), 2) 984 with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"): 985 sess.run(output, feed_dict={x: 4}) 986 987 988 if __name__ == "__main__": 989 googletest.main() 990