Home | History | Annotate | Download | only in ops
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for control_flow_ops.py."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import numpy as np
     23 
     24 from tensorflow.core.framework import graph_pb2
     25 from tensorflow.core.framework import node_def_pb2
     26 from tensorflow.python.client import session
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import errors
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import sparse_tensor
     32 from tensorflow.python.framework import tensor_shape
     33 from tensorflow.python.framework import test_util
     34 from tensorflow.python.ops import array_ops
     35 from tensorflow.python.ops import control_flow_ops
     36 from tensorflow.python.ops import embedding_ops
     37 from tensorflow.python.ops import gradients_impl
     38 from tensorflow.python.ops import init_ops
     39 from tensorflow.python.ops import math_ops
     40 from tensorflow.python.ops import state_ops
     41 from tensorflow.python.ops import tensor_array_ops
     42 from tensorflow.python.ops import variable_scope
     43 from tensorflow.python.ops import variables
     44 import tensorflow.python.ops.tensor_array_grad  # pylint: disable=unused-import
     45 from tensorflow.python.platform import googletest
     46 from tensorflow.python.training import momentum
     47 from tensorflow.python.util import nest
     48 
     49 
     50 TestTuple = collections.namedtuple("TestTuple", "a b")
     51 SingletonTestTuple = collections.namedtuple("SingletonTestTuple", "a")
     52 
     53 
     54 @test_util.with_c_api
     55 class GroupTestCase(test_util.TensorFlowTestCase):
     56 
     57   def _StripNode(self, nd):
     58     snode = node_def_pb2.NodeDef(name=nd.name, op=nd.op, input=nd.input)
     59     if nd.device:
     60       snode.device = nd.device
     61     return snode
     62 
     63   def _StripGraph(self, gd):
     64     """Copy gd keeping only, node.name, node.op, node.input, and node.device."""
     65     return graph_pb2.GraphDef(node=[self._StripNode(nd) for nd in gd.node])
     66 
     67   def testGroup_NoDevices(self):
     68     with ops.Graph().as_default() as g:
     69       a = constant_op.constant(0, name="a")
     70       b = constant_op.constant(0, name="b")
     71       c = constant_op.constant(0, name="c")
     72       control_flow_ops.group(a.op, b.op, c.op, name="root")
     73     gd = g.as_graph_def()
     74     self.assertProtoEquals("""
     75       node { name: "a" op: "Const"}
     76       node { name: "b" op: "Const"}
     77       node { name: "c" op: "Const"}
     78       node { name: "root" op: "NoOp" input: "^a" input: "^b" input: "^c" }
     79     """, self._StripGraph(gd))
     80 
     81   def testGroup_OneDevice(self):
     82     with ops.Graph().as_default() as g:
     83       with g.device("/task:0"):
     84         a = constant_op.constant(0, name="a")
     85         b = constant_op.constant(0, name="b")
     86       control_flow_ops.group(a.op, b.op, name="root")
     87     gd = g.as_graph_def()
     88     self.assertProtoEquals("""
     89       node { name: "a" op: "Const" device: "/task:0" }
     90       node { name: "b" op: "Const" device: "/task:0" }
     91       node { name: "root" op: "NoOp" input: "^a" input: "^b" device: "/task:0" }
     92     """, self._StripGraph(gd))
     93 
     94   def testGroup_MultiDevice(self):
     95     with ops.Graph().as_default() as g:
     96       with g.device("/task:0"):
     97         a = constant_op.constant(0, name="a")
     98         b = constant_op.constant(0, name="b")
     99       with g.device("/task:1"):
    100         c = constant_op.constant(0, name="c")
    101         d = constant_op.constant(0, name="d")
    102       with g.device("/task:2"):
    103         control_flow_ops.group(a.op, b.op, c.op, d.op, name="root")
    104     gd = g.as_graph_def()
    105     self.assertProtoEquals("""
    106       node { name: "a" op: "Const" device: "/task:0"}
    107       node { name: "b" op: "Const" device: "/task:0"}
    108       node { name: "c" op: "Const" device: "/task:1"}
    109       node { name: "d" op: "Const" device: "/task:1"}
    110       node { name: "root/NoOp" op: "NoOp" input: "^a" input: "^b"
    111              device: "/task:0" }
    112       node { name: "root/NoOp_1" op: "NoOp" input: "^c" input: "^d"
    113              device: "/task:1" }
    114       node { name: "root" op: "NoOp" input: "^root/NoOp" input: "^root/NoOp_1"
    115              device: "/task:2" }
    116     """, self._StripGraph(gd))
    117 
    118   def testPassingList(self):
    119     with ops.Graph().as_default() as g:
    120       a = constant_op.constant(0, name="a")
    121       b = constant_op.constant(0, name="b")
    122       control_flow_ops.group([a.op, b.op], name="root")
    123     gd = g.as_graph_def()
    124     self.assertProtoEquals("""
    125       node { name: "a" op: "Const"}
    126       node { name: "b" op: "Const"}
    127       node { name: "root" op: "NoOp" input: "^a" input: "^b" }
    128     """, self._StripGraph(gd))
    129 
    130   def testPassingNonTensors(self):
    131     with ops.Graph().as_default():
    132       with self.assertRaises(TypeError):
    133         control_flow_ops.group(1, 2)
    134 
    135 
    136 @test_util.with_c_api
    137 class ShapeTestCase(test_util.TensorFlowTestCase):
    138 
    139   def testShape(self):
    140     with ops.Graph().as_default():
    141       tensor = constant_op.constant([1.0, 2.0])
    142       self.assertEquals([2], tensor.get_shape())
    143       self.assertEquals([2],
    144                         control_flow_ops.with_dependencies(
    145                             [constant_op.constant(1.0)], tensor).get_shape())
    146 
    147 
    148 @test_util.with_c_api
    149 class WithDependenciesTestCase(test_util.TensorFlowTestCase):
    150 
    151   def testTupleDependencies(self):
    152     with ops.Graph().as_default():
    153       counter = variable_scope.get_variable(
    154           "my_counter", shape=[], initializer=init_ops.zeros_initializer())
    155       increment_counter = state_ops.assign_add(counter, 1)
    156       const_with_dep = control_flow_ops.with_dependencies(
    157           (increment_counter, constant_op.constant(42)),
    158           constant_op.constant(7))
    159       with self.test_session():
    160         variables.global_variables_initializer().run()
    161         self.assertEquals(0, counter.eval())
    162         self.assertEquals(7, const_with_dep.eval())
    163         self.assertEquals(1, counter.eval())
    164 
    165   def testListDependencies(self):
    166     with ops.Graph().as_default():
    167       counter = variable_scope.get_variable(
    168           "my_counter", shape=[], initializer=init_ops.zeros_initializer())
    169       increment_counter = state_ops.assign_add(counter, 1)
    170       const_with_dep = control_flow_ops.with_dependencies(
    171           [increment_counter, constant_op.constant(42)],
    172           constant_op.constant(7))
    173       with self.test_session():
    174         variables.global_variables_initializer().run()
    175         self.assertEquals(0, counter.eval())
    176         self.assertEquals(7, const_with_dep.eval())
    177         self.assertEquals(1, counter.eval())
    178 
    179 
    180 @test_util.with_c_api
    181 class SwitchTestCase(test_util.TensorFlowTestCase):
    182 
    183   def testIndexedSlicesWithDenseShape(self):
    184     with self.test_session():
    185       data = ops.IndexedSlices(
    186           constant_op.constant([1, 2, 3]),
    187           constant_op.constant([0, 1]),
    188           dense_shape=constant_op.constant([3]))
    189       zero = constant_op.constant(0)
    190       one = constant_op.constant(1)
    191       less_op = math_ops.less(zero, one)
    192       _, switch_true = control_flow_ops.switch(data, less_op)
    193       self.assertAllEqual([1, 2, 3], switch_true.values.eval())
    194       self.assertAllEqual([0, 1], switch_true.indices.eval())
    195 
    196   def testIndexedSlicesGradient(self):
    197     with ops.Graph().as_default():
    198       embedding_matrix = variable_scope.get_variable(
    199           "embedding_matrix", [5, 5],
    200           initializer=init_ops.random_normal_initializer())
    201 
    202       def cond(it, _):
    203         return it < 5
    204 
    205       def body(it, cost):
    206         embedding = embedding_ops.embedding_lookup(embedding_matrix + 0.0, [0])
    207         cost += math_ops.reduce_sum(embedding)
    208         return it + 1, cost
    209 
    210       _, cost = control_flow_ops.while_loop(
    211           cond, body, [constant_op.constant(0),
    212                        constant_op.constant(0.0)])
    213       optimizer = momentum.MomentumOptimizer(0.1, 0.9)
    214       train_op = optimizer.minimize(cost)
    215       with self.test_session() as sess:
    216         sess.run(variables.global_variables_initializer())
    217         for _ in range(10):
    218           sess.run([train_op])
    219 
    220   def testResourceReadInLoop(self):
    221     with ops.Graph().as_default():
    222       embedding_matrix = variable_scope.get_variable(
    223           "embedding_matrix",
    224           initializer=[[2.0], [3.0]],
    225           use_resource=True)
    226 
    227       def cond(it, _):
    228         return it < 5
    229 
    230       def body(it, cost):
    231         embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
    232         cost += math_ops.reduce_sum(embedding)
    233         return it + 1, cost
    234 
    235       _, cost = control_flow_ops.while_loop(
    236           cond, body, [constant_op.constant(0),
    237                        constant_op.constant(0.0)])
    238       with self.test_session() as sess:
    239         sess.run(variables.global_variables_initializer())
    240         self.assertAllEqual(10.0, cost.eval())
    241 
    242   def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
    243     with ops.Graph().as_default():
    244       embedding_matrix = variable_scope.get_variable(
    245           "embedding_matrix", [5, 5],
    246           initializer=init_ops.random_normal_initializer(),
    247           use_resource=use_resource)
    248 
    249       def cond(it, _):
    250         return it < 5
    251 
    252       def body(it, cost):
    253         embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
    254         cost = control_flow_ops.cond(
    255             math_ops.equal(it, 3), lambda: math_ops.square(cost),
    256             lambda: cost + math_ops.reduce_sum(embedding))
    257         return it + 1, cost
    258 
    259       _, cost = control_flow_ops.while_loop(
    260           cond, body, [constant_op.constant(0),
    261                        constant_op.constant(0.0)])
    262 
    263       dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
    264       dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
    265                                            dynamic_grads.indices)
    266 
    267       embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
    268       static = math_ops.square(
    269           math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
    270           math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
    271       static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
    272       static_grads = math_ops.segment_sum(static_grads.values,
    273                                           static_grads.indices)
    274 
    275       with self.test_session() as sess:
    276         sess.run(variables.global_variables_initializer())
    277         self.assertAllEqual(*sess.run([static_grads, dynamic_grads]))
    278 
    279   def testIndexedSlicesGradientInCondInWhileLoop(self):
    280     self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=False)
    281 
    282   def testIndexedSlicesGradientInCondInWhileLoopResource(self):
    283     self.doTestIndexedSlicesGradientInCondInWhileLoop(use_resource=True)
    284 
    285   def testIndexedSlicesWithShapeGradientInWhileLoop(self):
    286     for dtype in [dtypes.float32, dtypes.float64]:
    287       with self.test_session() as sess:
    288         num_steps = 9
    289 
    290         inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
    291         initial_outputs = tensor_array_ops.TensorArray(
    292             dtype=dtype, size=num_steps)
    293         initial_i = constant_op.constant(0, dtype=dtypes.int32)
    294 
    295         def cond(i, _):
    296           return i < num_steps  # pylint: disable=cell-var-from-loop
    297 
    298         def body(i, outputs):
    299           x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
    300           outputs = outputs.write(i, x)
    301           return i + 1, outputs
    302 
    303         _, outputs = control_flow_ops.while_loop(cond, body,
    304                                                  [initial_i, initial_outputs])
    305 
    306         outputs = math_ops.reduce_sum(outputs.stack())
    307         r = gradients_impl.gradients([outputs], [inputs])[0]
    308         grad_wr_inputs = ops.convert_to_tensor(r)
    309         o, grad = sess.run([outputs, grad_wr_inputs],
    310                            feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
    311         self.assertEquals(o, 20)
    312         self.assertAllEqual(grad, [1] * num_steps)
    313 
    314   def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
    315     for dtype in [dtypes.float32, dtypes.float64]:
    316       with self.test_session() as sess:
    317         inputs = array_ops.placeholder(dtype=dtype)
    318         initial_outputs = tensor_array_ops.TensorArray(
    319             dtype=dtype, dynamic_size=True, size=1)
    320         initial_i = constant_op.constant(0, dtype=dtypes.int32)
    321 
    322         def cond(i, _):
    323           return i < array_ops.size(inputs)  # pylint: disable=cell-var-from-loop
    324 
    325         def body(i, outputs):
    326           x = array_ops.gather(inputs, i)  # pylint: disable=cell-var-from-loop
    327           outputs = outputs.write(i, x)
    328           return i + 1, outputs
    329 
    330         _, outputs = control_flow_ops.while_loop(cond, body,
    331                                                  [initial_i, initial_outputs])
    332 
    333         outputs = math_ops.reduce_sum(outputs.stack())
    334         r = gradients_impl.gradients([outputs], [inputs])[0]
    335         grad_wr_inputs = ops.convert_to_tensor(r)
    336         o, grad = sess.run([outputs, grad_wr_inputs],
    337                            feed_dict={inputs: [1, 3, 2]})
    338         self.assertEquals(o, 6)
    339         self.assertAllEqual(grad, [1] * 3)
    340 
    341   def testGradientThroughSingleBranchOutsideOfContext(self):
    342     with self.test_session():
    343       x = constant_op.constant(2.)
    344       s = constant_op.constant(True)
    345       x_false, x_true = control_flow_ops.switch(x, s)
    346       grad_x_true = gradients_impl.gradients(x_true, x)[0]
    347       grad_x_false = gradients_impl.gradients(x_false, x)[0]
    348       self.assertEquals(grad_x_true.eval(), 1.)
    349       self.assertEquals(grad_x_false.eval(), 0.)
    350 
    351 
    352 @test_util.with_c_api
    353 class SmartCondTest(test_util.TensorFlowTestCase):
    354 
    355   def testSmartCondTrue(self):
    356     with ops.Graph().as_default():
    357       with session.Session():
    358         x = constant_op.constant(2)
    359         y = constant_op.constant(5)
    360         z = control_flow_ops.smart_cond(
    361             True, lambda: math_ops.multiply(x, 16),
    362             lambda: math_ops.multiply(y, 5))
    363         self.assertEqual(z.eval(), 32)
    364 
    365   def testSmartCondFalse(self):
    366     with ops.Graph().as_default():
    367       with session.Session():
    368         x = constant_op.constant(4)
    369         y = constant_op.constant(3)
    370         z = control_flow_ops.smart_cond(
    371             False, lambda: math_ops.multiply(x, 16),
    372             lambda: math_ops.multiply(y, 3))
    373         self.assertEqual(z.eval(), 9)
    374 
    375   def testSmartCondMissingArg1(self):
    376     with ops.Graph().as_default():
    377       with session.Session():
    378         x = constant_op.constant(1)
    379         with self.assertRaises(TypeError):
    380           control_flow_ops.smart_cond(True, false_fn=lambda: x)
    381 
    382   def testSmartCondMissingArg2(self):
    383     with ops.Graph().as_default():
    384       with session.Session():
    385         x = constant_op.constant(1)
    386         with self.assertRaises(TypeError):
    387           control_flow_ops.smart_cond(True, lambda: x)
    388 
    389 
    390 @test_util.with_c_api
    391 class CondTest(test_util.TensorFlowTestCase):
    392 
    393   def testCondTrue(self):
    394     # Create new Graph and Session for each test so we pick up _USE_C_API
    395     # correctly.
    396     with ops.Graph().as_default():
    397       with session.Session():
    398         x = constant_op.constant(2)
    399         y = constant_op.constant(5)
    400         z = control_flow_ops.cond(
    401             math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
    402             lambda: math_ops.add(y, 23))
    403         self.assertEquals(z.eval(), 34)
    404 
    405   def testCondFalse(self):
    406     with ops.Graph().as_default():
    407       with session.Session():
    408         x = constant_op.constant(2)
    409         y = constant_op.constant(1)
    410         z = control_flow_ops.cond(
    411             math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
    412             lambda: math_ops.add(y, 23))
    413         self.assertEquals(z.eval(), 24)
    414 
    415   def testCondTrueLegacy(self):
    416     with ops.Graph().as_default():
    417       with session.Session():
    418         x = constant_op.constant(2)
    419         y = constant_op.constant(5)
    420         z = control_flow_ops.cond(
    421             math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
    422             fn2=lambda: math_ops.add(y, 23))
    423         self.assertEquals(z.eval(), 34)
    424 
    425   def testCondFalseLegacy(self):
    426     with ops.Graph().as_default():
    427       with session.Session():
    428         x = constant_op.constant(2)
    429         y = constant_op.constant(1)
    430         z = control_flow_ops.cond(
    431             math_ops.less(x, y), fn1=lambda: math_ops.multiply(x, 17),
    432             fn2=lambda: math_ops.add(y, 23))
    433         self.assertEquals(z.eval(), 24)
    434 
    435   def testCondModifyBoolPred(self):
    436     # This test in particular used to fail only when running in GPU, hence
    437     # use_gpu=True.
    438     with ops.Graph().as_default():
    439       with session.Session() as sess:
    440         bool_var = variable_scope.get_variable("bool_var", dtype=dtypes.bool,
    441                                                initializer=True)
    442         cond_on_bool_var = control_flow_ops.cond(
    443             pred=bool_var,
    444             true_fn=lambda: state_ops.assign(bool_var, False),
    445             false_fn=lambda: True)
    446         sess.run(bool_var.initializer)
    447         self.assertEquals(sess.run(cond_on_bool_var), False)
    448         self.assertEquals(sess.run(cond_on_bool_var), True)
    449 
    450   def testCondMissingArg1(self):
    451     with ops.Graph().as_default():
    452       with session.Session():
    453         x = constant_op.constant(1)
    454         with self.assertRaises(TypeError):
    455           control_flow_ops.cond(True, false_fn=lambda: x)
    456 
    457   def testCondMissingArg2(self):
    458     with ops.Graph().as_default():
    459       with session.Session():
    460         x = constant_op.constant(1)
    461         with self.assertRaises(TypeError):
    462           control_flow_ops.cond(True, lambda: x)
    463 
    464   def testCondDuplicateArg1(self):
    465     with ops.Graph().as_default():
    466       with session.Session():
    467         x = constant_op.constant(1)
    468         with self.assertRaises(TypeError):
    469           control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
    470 
    471   def testCondDuplicateArg2(self):
    472     with ops.Graph().as_default():
    473       with session.Session():
    474         x = constant_op.constant(1)
    475         with self.assertRaises(TypeError):
    476           control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
    477 
    478 
    479 @test_util.with_c_api
    480 class ContextTest(test_util.TensorFlowTestCase):
    481 
    482   def testCondContext(self):
    483     with self.test_session() as sess:
    484       x = constant_op.constant(2)
    485       y = constant_op.constant(5)
    486       control_flow_ops.cond(
    487           math_ops.less(x, y), lambda: math_ops.multiply(x, 17),
    488           lambda: math_ops.add(y, 23))
    489       for op in sess.graph.get_operations():
    490         c = op._get_control_flow_context()
    491         if c:
    492           self.assertProtoEquals(
    493               c.to_proto(),
    494               control_flow_ops.CondContext.from_proto(c.to_proto()).to_proto())
    495 
    496   def _testWhileContextHelper(self, maximum_iterations=None):
    497     with self.test_session() as sess:
    498       i = constant_op.constant(0)
    499       c = lambda i: math_ops.less(i, 10)
    500       b = lambda i: math_ops.add(i, 1)
    501       control_flow_ops.while_loop(
    502           c, b, [i], maximum_iterations=maximum_iterations)
    503       for op in sess.graph.get_operations():
    504         control_flow_context = op._get_control_flow_context()
    505         if control_flow_context:
    506           self.assertProtoEquals(
    507               control_flow_context.to_proto(),
    508               control_flow_ops.WhileContext.from_proto(
    509                   control_flow_context.to_proto()).to_proto())
    510 
    511   def testWhileContext(self):
    512     self._testWhileContextHelper()
    513 
    514   def testWhileContextWithMaximumIterations(self):
    515     self._testWhileContextHelper(maximum_iterations=10)
    516 
    517   def testControlContextImportScope(self):
    518     with self.test_session():
    519       constant_op.constant(0, name="a")
    520       constant_op.constant(2, name="test_scope/a")
    521       b1 = constant_op.constant(1, name="b")
    522       b2 = constant_op.constant(3, name="test_scope/b")
    523 
    524       c = control_flow_ops.ControlFlowContext()
    525       c._values = ["a", "b"]
    526       c._external_values = {"a": b1}
    527 
    528       c_with_scope = control_flow_ops.ControlFlowContext(
    529           values_def=c._to_values_def(), import_scope="test_scope")
    530 
    531       # _values and _external_values should be have scope prepended.
    532       self.assertEquals(
    533           c_with_scope._values, set(["test_scope/a", "test_scope/b"]))
    534       self.assertEquals(
    535           c_with_scope._external_values, {"test_scope/a": b2})
    536 
    537       # Calling _to_proto() with export_scope should remove "test_scope".
    538       self.assertProtoEquals(
    539           c._to_values_def(),
    540           c_with_scope._to_values_def(export_scope="test_scope"))
    541 
    542 
    543 def _get_nested_shape(nested):
    544 
    545   def _get_shape(tensor):
    546     if isinstance(tensor, tensor_array_ops.TensorArray):
    547       return tensor_array_ops.TensorArray
    548     elif isinstance(tensor, ops.IndexedSlices):
    549       return tensor.dense_shape
    550     else:
    551       return tensor.get_shape()
    552 
    553   return nest.map_structure(_get_shape, nested)
    554 
    555 
    556 def _create_tensor_array(size, shape):
    557   ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=size,
    558                                     clear_after_read=False)
    559   for i in range(size):
    560     ta = ta.write(i, array_ops.zeros(shape))
    561   return ta
    562 
    563 
    564 def _raw_nested_shape(nested_shape):
    565 
    566   def _raw_shape(shape):
    567     if isinstance(shape, tensor_shape.TensorShape) and shape.ndims is not None:
    568       return [x.value for x in shape]
    569     else:
    570       return None
    571 
    572   return nest.map_structure(_raw_shape, nested_shape)
    573 
    574 
    575 # TODO(yori): Add tests for indexed slices.
    576 @test_util.with_c_api
    577 class DataTypesTest(test_util.TensorFlowTestCase):
    578 
    579   def assertAllEqualNested(self, a, b):
    580     if isinstance(a, (list, tuple)):
    581       for entry_a, entry_b in zip(a, b):
    582         self.assertAllEqualNested(entry_a, entry_b)
    583     else:
    584       self.assertAllEqual(a, b)
    585 
    586   def _testShape(self, fn_true, fn_false, expected_shape,
    587                  strict=False):
    588     condition = array_ops.placeholder(dtypes.bool)
    589     output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
    590                                         strict=strict)
    591     self.assertEqual(
    592         _raw_nested_shape(_get_nested_shape(output_cond)),
    593         _raw_nested_shape(expected_shape))
    594 
    595     output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
    596                                         strict=strict)
    597     self.assertEqual(
    598         _raw_nested_shape(_get_nested_shape(output_case)),
    599         _raw_nested_shape(expected_shape))
    600 
    601   def _testReturnValues(self, fn_true, fn_false, expected_value_true,
    602                         expected_value_false, strict=False,
    603                         check_cond=True, feed_dict=None):
    604     if feed_dict is None: feed_dict = {}
    605 
    606     condition = array_ops.placeholder(dtypes.bool)
    607     output_cond = control_flow_ops.cond(condition, fn_true, fn_false,
    608                                         strict=strict)
    609     output_case = control_flow_ops.case([(condition, fn_true)], fn_false,
    610                                         strict=strict)
    611 
    612     with self.test_session() as sess:
    613       variables.global_variables_initializer().run()
    614       true_feed_dict = {condition: True}
    615       true_feed_dict.update(feed_dict)
    616       result_cond, result_case = sess.run([output_cond, output_case],
    617                                           feed_dict=true_feed_dict)
    618       self.assertAllEqualNested(result_cond, expected_value_true)
    619       if check_cond:
    620         self.assertAllEqualNested(result_case, expected_value_true)
    621       false_feed_dict = {condition: False}
    622       false_feed_dict.update(feed_dict)
    623       result_cond, result_case = sess.run([output_cond, output_case],
    624                                           feed_dict=false_feed_dict)
    625       self.assertAllEqualNested(result_cond, expected_value_false)
    626       if check_cond:
    627         self.assertAllEqualNested(result_case, expected_value_false)
    628 
    629   def test_int(self):
    630     shape = tensor_shape.TensorShape([])
    631     fn_true = lambda: 1
    632     fn_false = lambda: 2
    633     self._testShape(fn_true, fn_false, shape)
    634     self._testReturnValues(fn_true, fn_false, 1, 2)
    635     self._testShape(fn_true, fn_false, shape, strict=True)
    636     self._testReturnValues(fn_true, fn_false, 1, 2, strict=True)
    637 
    638   def test_float(self):
    639     shape = tensor_shape.TensorShape([])
    640     fn_true = lambda: 1.0
    641     fn_false = lambda: 2.0
    642     self._testShape(fn_true, fn_false, shape)
    643     self._testReturnValues(fn_true, fn_false, 1.0, 2.0)
    644 
    645   def test_noop(self):
    646     shape = tensor_shape.TensorShape(None)
    647     self._testShape(control_flow_ops.no_op, control_flow_ops.no_op, shape)
    648     self._testReturnValues(control_flow_ops.no_op, control_flow_ops.no_op,
    649                            True, False, check_cond=False)
    650 
    651   def test_string(self):
    652     shape = tensor_shape.TensorShape([])
    653     fn_true = lambda: "abc"
    654     fn_false = lambda: "xyz"
    655     self._testShape(fn_true, fn_false, shape)
    656     self._testReturnValues(fn_true, fn_false, b"abc", b"xyz")
    657 
    658   def test_variable(self):
    659     shape = tensor_shape.TensorShape([])
    660     fn_true = lambda: variables.Variable(3.0)
    661     fn_false = lambda: variables.Variable(4.0)
    662     self._testShape(fn_true, fn_false, shape)
    663     self._testReturnValues(fn_true, fn_false, 3.0, 4.0)
    664 
    665   def test_none(self):
    666     fn_none = lambda: None
    667     fn_tensor = lambda: constant_op.constant(1)
    668 
    669     with self.assertRaises(ValueError):
    670       control_flow_ops.cond(constant_op.constant(True), fn_none, fn_tensor)
    671 
    672     with self.assertRaises(ValueError):
    673       control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_none)
    674 
    675   def test_tensors(self):
    676 
    677     def _build_true_branch(dtype):
    678 
    679       def _build():
    680         return (array_ops.zeros([2, 2], dtype=dtype),
    681                 array_ops.ones([3, 3], dtype=dtype))
    682 
    683       return _build
    684 
    685     def _build_false_branch(dtype):
    686 
    687       def _build():
    688         return (array_ops.ones([2, 2], dtype=dtype),
    689                 array_ops.zeros([3, 3], dtype=dtype))
    690 
    691       return _build
    692 
    693     for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
    694       shape = (tensor_shape.TensorShape([2, 2]),
    695                tensor_shape.TensorShape([3, 3]))
    696       fn_true = _build_true_branch(dtype)
    697       fn_false = _build_false_branch(dtype)
    698       self._testShape(fn_true, fn_false, shape)
    699       self._testReturnValues(fn_true, fn_false,
    700                              (np.zeros([2, 2]), np.ones([3, 3])),
    701                              (np.ones([2, 2]), np.zeros([3, 3])))
    702 
    703   def test_tensors_unknown_shape(self):
    704 
    705     def _build_true_branch(dtype):
    706       tensor = array_ops.placeholder(dtype=dtype, shape=None)
    707 
    708       def _build():
    709         return tensor
    710 
    711       return _build, tensor
    712 
    713     def _build_false_branch(dtype):
    714       tensor = array_ops.placeholder(dtype=dtype, shape=None)
    715 
    716       def _build():
    717         return tensor
    718 
    719       return _build, tensor
    720 
    721     for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
    722       shape = tensor_shape.TensorShape(None)
    723       fn_true, true_tensor = _build_true_branch(dtype)
    724       fn_false, false_tensor = _build_false_branch(dtype)
    725       self._testShape(fn_true, fn_false, shape)
    726       self._testReturnValues(fn_true, fn_false,
    727                              np.zeros([2, 2]), np.ones([2, 2]),
    728                              feed_dict={true_tensor: np.zeros([2, 2]),
    729                                         false_tensor: np.ones([2, 2])})
    730 
    731   def test_sparse_tensors(self):
    732     shape = tensor_shape.TensorShape([None, None])
    733 
    734     def true_fn():
    735       return [sparse_tensor.SparseTensor(indices=[[0, 0], [1, 2]],
    736                                          values=[1, 2], dense_shape=[3, 4])]
    737 
    738     def false_fn():
    739       return [sparse_tensor.SparseTensor(indices=[[0, 0], [2, 1]],
    740                                          values=[3, 4], dense_shape=[3, 4])]
    741 
    742     value1 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [1, 2]],
    743                                              values=[1, 2], dense_shape=[3, 4])
    744     value2 = sparse_tensor.SparseTensorValue(indices=[[0, 0], [2, 1]],
    745                                              values=[3, 4], dense_shape=[3, 4])
    746     self._testShape(true_fn, false_fn, shape)
    747     self._testReturnValues(true_fn, false_fn, value1, value2)
    748     self._testShape(true_fn, false_fn, [shape], strict=True)
    749     self._testReturnValues(true_fn, false_fn, [value1], [value2], strict=True)
    750 
    751   def test_tensors_with_partially_specified_shapes(self):
    752 
    753     def _build_branch(dtype, shape):
    754       a = array_ops.placeholder(dtype=dtype, shape=shape[0])
    755       b = array_ops.placeholder(dtype=dtype, shape=shape[1])
    756       c = array_ops.placeholder(dtype=dtype, shape=shape[2])
    757 
    758       def _build():
    759         return a, b, c
    760 
    761       return _build, (a, b, c)
    762 
    763     for dtype in (dtypes.float16, dtypes.int8, dtypes.int32, dtypes.uint8):
    764       shape = (tensor_shape.TensorShape([None, 2]),
    765                tensor_shape.TensorShape([None]),
    766                tensor_shape.TensorShape([3, None]))
    767       fn_true, true_tensors = _build_branch(dtype, shape)
    768       fn_false, false_tensors = _build_branch(dtype, shape)
    769       self._testShape(fn_true, fn_false, shape)
    770       self._testReturnValues(fn_true, fn_false,
    771                              (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
    772                              (np.zeros([2, 2]), np.zeros(5), np.ones([3, 3])),
    773                              feed_dict={true_tensors[0]: np.zeros([2, 2]),
    774                                         false_tensors[0]: np.zeros([2, 2]),
    775                                         true_tensors[1]: np.zeros([5]),
    776                                         false_tensors[1]: np.zeros([5]),
    777                                         true_tensors[2]: np.ones([3, 3]),
    778                                         false_tensors[2]: np.ones([3, 3])})
    779 
    780   def test_tensor_arrays(self):
    781     element_shape = tensor_shape.TensorShape([2])
    782     ta1 = _create_tensor_array(4, element_shape)
    783     ta2 = _create_tensor_array(4, element_shape)
    784     shape = tensor_array_ops.TensorArray
    785     fn_true = lambda: ta1
    786     fn_false = lambda: ta2
    787     self._testShape(fn_true, fn_false, shape)
    788 
    789   def test_tensor_array_reads(self):
    790     shape = tensor_shape.TensorShape([2])
    791     ta = _create_tensor_array(4, shape)
    792     fn_true = lambda: ta.read(0)
    793     fn_false = lambda: ta.read(1)
    794     self._testShape(fn_true, fn_false, shape)
    795 
    796   def test_list(self):
    797     shape = [tensor_shape.TensorShape([]), tensor_shape.TensorShape([]),
    798              tensor_shape.TensorShape([])]
    799     fn_true = lambda: [constant_op.constant(1), 2, variables.Variable(3.0)]
    800     fn_false = lambda: [constant_op.constant(3), 4, variables.Variable(5.0)]
    801     self._testShape(fn_true, fn_false, shape)
    802     self._testReturnValues(fn_true, fn_false, [1, 2, 3.0], [3, 4, 5.0])
    803 
    804   def test_non_strict(self):
    805     shape = tensor_shape.TensorShape([])
    806     fn_tensor = lambda: constant_op.constant(1)
    807     fn_list = lambda: [constant_op.constant(2)]
    808     fn_tuple = lambda: (constant_op.constant(3),)
    809     self._testShape(fn_tensor, fn_list, shape)
    810     self._testShape(fn_tensor, fn_tuple, shape)
    811     self._testShape(fn_list, fn_tuple, shape)
    812     self._testReturnValues(fn_tensor, fn_list, 1, 2)
    813     self._testReturnValues(fn_tensor, fn_tuple, 1, 3)
    814     self._testReturnValues(fn_list, fn_tuple, 2, 3)
    815 
    816   def test_singleton_strict(self):
    817     fn_tensor = lambda: constant_op.constant(1)
    818     fn_list = lambda: [constant_op.constant(2)]
    819     fn_tuple = lambda: (constant_op.constant(3),)
    820 
    821     with self.assertRaises(ValueError):
    822       control_flow_ops.cond(constant_op.constant(True), fn_tensor, fn_list,
    823                             strict=True)
    824 
    825     with self.assertRaises(TypeError):
    826       control_flow_ops.cond(constant_op.constant(True), fn_list, fn_tuple,
    827                             strict=True)
    828 
    829     with self.assertRaises(ValueError):
    830       control_flow_ops.case([(constant_op.constant(True), fn_tensor)], fn_list,
    831                             strict=True)
    832 
    833     with self.assertRaises(TypeError):
    834       control_flow_ops.case([(constant_op.constant(True), fn_list)], fn_tuple,
    835                             strict=True)
    836 
    837   def test_singleton_list(self):
    838     shape = tensor_shape.TensorShape([])
    839     fn_true = lambda: [constant_op.constant(1)]
    840     fn_false = lambda: [constant_op.constant(3)]
    841     self._testShape(fn_true, fn_false, shape)
    842     self._testReturnValues(fn_true, fn_false, 1, 3)
    843     self._testShape(fn_true, fn_false, [shape], strict=True)
    844     self._testReturnValues(fn_true, fn_false, [1], [3], strict=True)
    845 
    846   def test_singleton_tuple(self):
    847     shape = tensor_shape.TensorShape([])
    848     fn_true = lambda: (constant_op.constant(1),)
    849     fn_false = lambda: (constant_op.constant(3),)
    850     self._testShape(fn_true, fn_false, shape)
    851     self._testReturnValues(fn_true, fn_false, 1, 3)
    852     self._testShape(fn_true, fn_false, (shape,), strict=True)
    853     self._testReturnValues(fn_true, fn_false, (1,), (3,),
    854                            strict=True)
    855 
    856   def test_singleton_namedtuple(self):
    857     shape = tensor_shape.TensorShape([])
    858     fn_true = lambda: SingletonTestTuple(constant_op.constant(1))
    859     fn_false = lambda: SingletonTestTuple(constant_op.constant(3))
    860     self._testShape(fn_true, fn_false, shape)
    861     self._testReturnValues(fn_true, fn_false, 1, 3)
    862     self._testShape(fn_true, fn_false, SingletonTestTuple(shape),
    863                     strict=True)
    864     self._testReturnValues(fn_true, fn_false, SingletonTestTuple(1),
    865                            SingletonTestTuple(3), strict=True)
    866 
    867   def test_tuple(self):
    868     shape = (tensor_shape.TensorShape([]), tensor_shape.TensorShape([]))
    869     fn_true = lambda: (constant_op.constant(1), 2)
    870     fn_false = lambda: (constant_op.constant(3), 4)
    871     self._testShape(fn_true, fn_false, shape)
    872     self._testReturnValues(fn_true, fn_false, (1, 2), (3, 4))
    873 
    874   def test_namedtuple(self):
    875     shape = TestTuple(tensor_shape.TensorShape([]),
    876                       tensor_shape.TensorShape([]))
    877     fn_true = lambda: TestTuple(constant_op.constant(1), 2)
    878     fn_false = lambda: TestTuple(constant_op.constant(3), 4)
    879     self._testShape(fn_true, fn_false, shape)
    880     self._testReturnValues(fn_true, fn_false, TestTuple(1, 2), TestTuple(3, 4))
    881 
    882   def test_nested(self):
    883     shape = [tensor_shape.TensorShape([]),
    884              TestTuple(tensor_shape.TensorShape([]),
    885                        [tensor_shape.TensorShape([]),
    886                         tensor_shape.TensorShape([])]),
    887              tensor_shape.TensorShape([5, 5]),
    888              tensor_shape.TensorShape([])]
    889 
    890     def true_fn():
    891       return [constant_op.constant(1),
    892               TestTuple(constant_op.constant(2), [3, 4]),
    893               array_ops.zeros([5, 5]), 6]
    894 
    895     def false_fn():
    896       return [constant_op.constant(11),
    897               TestTuple(constant_op.constant(12), [13, 14]),
    898               array_ops.ones([5, 5]), 16]
    899 
    900     self._testShape(true_fn, false_fn, shape)
    901     self._testReturnValues(
    902         true_fn, false_fn,
    903         [1, TestTuple(2, [3, 4]), np.zeros([5, 5]), 6],
    904         [11, TestTuple(12, [13, 14]),
    905          np.ones([5, 5]), 16])
    906 
    907   def test_cond_inside_while_loop(self):
    908 
    909     def body(i, matrix):
    910       result_tuple, unused_matrix = control_flow_ops.cond(
    911           constant_op.constant(True),
    912           lambda: (TestTuple(matrix * 2, matrix * 4), matrix),
    913           lambda: (TestTuple(matrix * 4, matrix * 2), matrix))
    914       return [i+1, result_tuple.a]
    915 
    916     iteration, matrix = control_flow_ops.while_loop(
    917         lambda i, matrix: i < 10,
    918         body,
    919         loop_vars=[constant_op.constant(0),
    920                    array_ops.ones([2, 2])])
    921 
    922     self.assertEqual(iteration.get_shape(), tensor_shape.TensorShape([]))
    923     self.assertEqual(matrix.get_shape(), tensor_shape.TensorShape([2, 2]))
    924 
    925 
    926 @test_util.with_c_api
    927 class CaseTest(test_util.TensorFlowTestCase):
    928 
    929   def testCase_withDefault(self):
    930     x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
    931     conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
    932                   (math_ops.equal(x, 2), lambda: constant_op.constant(4))]
    933     default = lambda: constant_op.constant(6)
    934     output = control_flow_ops.case(conditions, default, exclusive=True)
    935     with self.test_session() as sess:
    936       self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
    937       self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
    938       self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
    939 
    940   def testCase_multiple_matches_exclusive(self):
    941     x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
    942     conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
    943                   (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
    944                   (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
    945     default = lambda: constant_op.constant(8)
    946     output = control_flow_ops.case(conditions, default, exclusive=True)
    947     with self.test_session() as sess:
    948       self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
    949       self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
    950       with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
    951         sess.run(output, feed_dict={x: 2})
    952 
    953   def testCase_multiple_matches_non_exclusive(self):
    954     x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
    955     conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
    956                   (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
    957                   (math_ops.equal(x, 2), lambda: constant_op.constant(6))]
    958     default = lambda: constant_op.constant(8)
    959     output = control_flow_ops.case(conditions, default, exclusive=False)
    960     with self.test_session() as sess:
    961       self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
    962       self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
    963       self.assertEqual(sess.run(output, feed_dict={x: 3}), 8)
    964 
    965   def testCase_withoutDefault(self):
    966     x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
    967     conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2)),
    968                   (math_ops.equal(x, 2), lambda: constant_op.constant(4)),
    969                   (math_ops.equal(x, 3), lambda: constant_op.constant(6))]
    970     output = control_flow_ops.case(conditions, exclusive=True)
    971     with self.test_session() as sess:
    972       self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
    973       self.assertEqual(sess.run(output, feed_dict={x: 2}), 4)
    974       self.assertEqual(sess.run(output, feed_dict={x: 3}), 6)
    975       with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
    976         sess.run(output, feed_dict={x: 4})
    977 
    978   def testCase_withoutDefault_oneCondition(self):
    979     x = array_ops.placeholder(dtype=dtypes.int32, shape=[])
    980     conditions = [(math_ops.equal(x, 1), lambda: constant_op.constant(2))]
    981     output = control_flow_ops.case(conditions, exclusive=True)
    982     with self.test_session() as sess:
    983       self.assertEqual(sess.run(output, feed_dict={x: 1}), 2)
    984       with self.assertRaisesRegexp(errors.InvalidArgumentError, "Input error:"):
    985         sess.run(output, feed_dict={x: 4})
    986 
    987 
    988 if __name__ == "__main__":
    989   googletest.main()
    990