Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.data_flow_ops.{,parallel_}dynamic_stitch."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import data_flow_ops
     27 from tensorflow.python.ops import gradients_impl
     28 import tensorflow.python.ops.data_flow_grad  # pylint: disable=unused-import
     29 from tensorflow.python.platform import test
     30 from tensorflow.python.framework import dtypes
     31 
     32 
     33 class DynamicStitchTestBase(object):
     34 
     35   def __init__(self, stitch_op):
     36     self.stitch_op = stitch_op
     37 
     38   def testScalar(self):
     39     with self.test_session(use_gpu=True):
     40       indices = [constant_op.constant(0), constant_op.constant(1)]
     41       data = [constant_op.constant(40), constant_op.constant(60)]
     42       for step in -1, 1:
     43         stitched_t = self.stitch_op(indices[::step], data)
     44         stitched_val = stitched_t.eval()
     45         self.assertAllEqual([40, 60][::step], stitched_val)
     46         # Dimension 0 is max(flatten(indices))+1.
     47         self.assertEqual([2], stitched_t.get_shape().as_list())
     48 
     49   def testShapeInferenceForScalarWithNonConstantIndices(self):
     50     with self.test_session(use_gpu=True):
     51       indices = [
     52           array_ops.placeholder(dtype=dtypes.int32),
     53           constant_op.constant(1)
     54       ]
     55       data = [constant_op.constant(40), constant_op.constant(60)]
     56       for step in -1, 1:
     57         stitched_t = self.stitch_op(indices[::step], data)
     58         # Dimension 0 is max(flatten(indices))+1, but the first indices input is
     59         # not a constant tensor, so we can only infer it as a vector of unknown
     60         # length.
     61         self.assertEqual([None], stitched_t.get_shape().as_list())
     62 
     63   def testSimpleOneDimensional(self):
     64     with self.test_session(use_gpu=True):
     65       indices = [
     66           constant_op.constant([0, 4, 7]),
     67           constant_op.constant([1, 6, 2, 3, 5])
     68       ]
     69       data = [
     70           constant_op.constant([0, 40, 70]),
     71           constant_op.constant([10, 60, 20, 30, 50])
     72       ]
     73       stitched_t = self.stitch_op(indices, data)
     74       stitched_val = stitched_t.eval()
     75       self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
     76       # Dimension 0 is max(flatten(indices))+1.
     77       self.assertEqual([8], stitched_t.get_shape().as_list())
     78 
     79   def testOneListOneDimensional(self):
     80     with self.test_session(use_gpu=True):
     81       indices = [constant_op.constant([1, 6, 2, 3, 5, 0, 4, 7])]
     82       data = [constant_op.constant([10, 60, 20, 30, 50, 0, 40, 70])]
     83       stitched_t = self.stitch_op(indices, data)
     84       stitched_val = stitched_t.eval()
     85       self.assertAllEqual([0, 10, 20, 30, 40, 50, 60, 70], stitched_val)
     86       # Dimension 0 is max(flatten(indices))+1.
     87       self.assertEqual([8], stitched_t.get_shape().as_list())
     88 
     89   def testSimpleTwoDimensional(self):
     90     with self.test_session(use_gpu=True):
     91       indices = [
     92           constant_op.constant([0, 4, 7]),
     93           constant_op.constant([1, 6]),
     94           constant_op.constant([2, 3, 5])
     95       ]
     96       data = [
     97           constant_op.constant([[0, 1], [40, 41], [70, 71]]),
     98           constant_op.constant([[10, 11], [60, 61]]),
     99           constant_op.constant([[20, 21], [30, 31], [50, 51]])
    100       ]
    101       stitched_t = self.stitch_op(indices, data)
    102       stitched_val = stitched_t.eval()
    103       self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41],
    104                            [50, 51], [60, 61], [70, 71]], stitched_val)
    105       # Dimension 0 is max(flatten(indices))+1.
    106       self.assertEqual([8, 2], stitched_t.get_shape().as_list())
    107 
    108   def testHigherRank(self):
    109     with self.test_session(use_gpu=True) as sess:
    110       indices = [
    111           constant_op.constant(6),
    112           constant_op.constant([4, 1]),
    113           constant_op.constant([[5, 2], [0, 3]])
    114       ]
    115       data = [
    116           constant_op.constant([61, 62]),
    117           constant_op.constant([[41, 42], [11, 12]]),
    118           constant_op.constant([[[51, 52], [21, 22]], [[1, 2], [31, 32]]])
    119       ]
    120       stitched_t = self.stitch_op(indices, data)
    121       stitched_val = stitched_t.eval()
    122       correct = 10 * np.arange(7)[:, None] + [1, 2]
    123       self.assertAllEqual(correct, stitched_val)
    124       self.assertEqual([7, 2], stitched_t.get_shape().as_list())
    125       # Test gradients
    126       stitched_grad = 7 * stitched_val
    127       grads = gradients_impl.gradients(stitched_t, indices + data,
    128                                        stitched_grad)
    129       self.assertEqual(grads[:3], [None] * 3)  # Indices have no gradients
    130       for datum, grad in zip(data, sess.run(grads[3:])):
    131         self.assertAllEqual(7 * datum.eval(), grad)
    132 
    133   def testErrorIndicesMultiDimensional(self):
    134     indices = [
    135         constant_op.constant([0, 4, 7]),
    136         constant_op.constant([[1, 6, 2, 3, 5]])
    137     ]
    138     data = [
    139         constant_op.constant([[0, 40, 70]]),
    140         constant_op.constant([10, 60, 20, 30, 50])
    141     ]
    142     with self.assertRaises(ValueError):
    143       self.stitch_op(indices, data)
    144 
    145   def testErrorDataNumDimsMismatch(self):
    146     indices = [
    147         constant_op.constant([0, 4, 7]),
    148         constant_op.constant([1, 6, 2, 3, 5])
    149     ]
    150     data = [
    151         constant_op.constant([0, 40, 70]),
    152         constant_op.constant([[10, 60, 20, 30, 50]])
    153     ]
    154     with self.assertRaises(ValueError):
    155       self.stitch_op(indices, data)
    156 
    157   def testErrorDataDimSizeMismatch(self):
    158     indices = [
    159         constant_op.constant([0, 4, 5]),
    160         constant_op.constant([1, 6, 2, 3])
    161     ]
    162     data = [
    163         constant_op.constant([[0], [40], [70]]),
    164         constant_op.constant([[10, 11], [60, 61], [20, 21], [30, 31]])
    165     ]
    166     with self.assertRaises(ValueError):
    167       self.stitch_op(indices, data)
    168 
    169   def testErrorDataAndIndicesSizeMismatch(self):
    170     indices = [
    171         constant_op.constant([0, 4, 7]),
    172         constant_op.constant([1, 6, 2, 3, 5])
    173     ]
    174     data = [
    175         constant_op.constant([0, 40, 70]),
    176         constant_op.constant([10, 60, 20, 30])
    177     ]
    178     with self.assertRaises(ValueError):
    179       self.stitch_op(indices, data)
    180 
    181 
    182 class DynamicStitchTest(DynamicStitchTestBase, test.TestCase):
    183 
    184   def __init__(self, *test_case_args):
    185     test.TestCase.__init__(self, *test_case_args)
    186     DynamicStitchTestBase.__init__(self, data_flow_ops.dynamic_stitch)
    187 
    188 
    189 class ParallelDynamicStitchTest(DynamicStitchTestBase, test.TestCase):
    190 
    191   def __init__(self, *test_case_args):
    192     test.TestCase.__init__(self, *test_case_args)
    193     DynamicStitchTestBase.__init__(self, data_flow_ops.parallel_dynamic_stitch)
    194 
    195   def testScalar(self):
    196     with self.test_session(use_gpu=True):
    197       indices = [constant_op.constant(0), constant_op.constant(1)]
    198       data = [constant_op.constant(40.0), constant_op.constant(60.0)]
    199       for step in -1, 1:
    200         stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
    201         stitched_val = stitched_t.eval()
    202         self.assertAllEqual([40.0, 60.0][::step], stitched_val)
    203         # Dimension 0 is max(flatten(indices))+1.
    204         self.assertEqual([2], stitched_t.get_shape().as_list())
    205 
    206   def testHigherRank(self):
    207     with self.test_session(use_gpu=True) as sess:
    208       indices = [
    209           constant_op.constant(6),
    210           constant_op.constant([4, 1]),
    211           constant_op.constant([[5, 2], [0, 3]])
    212       ]
    213       data = [
    214           constant_op.constant([61, 62], dtype=dtypes.float32),
    215           constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
    216           constant_op.constant(
    217               [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
    218       ]
    219       stitched_t = data_flow_ops.dynamic_stitch(indices, data)
    220       stitched_val = stitched_t.eval()
    221       correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
    222       self.assertAllEqual(correct, stitched_val)
    223       self.assertEqual([7, 2], stitched_t.get_shape().as_list())
    224       # Test gradients
    225       stitched_grad = 7 * stitched_val
    226       grads = gradients_impl.gradients(stitched_t, indices + data,
    227                                        stitched_grad)
    228       self.assertEqual(grads[:3], [None] * 3)  # Indices have no gradients
    229       for datum, grad in zip(data, sess.run(grads[3:])):
    230         self.assertAllEqual(7.0 * datum.eval(), grad)
    231 
    232   # GPU version unit tests
    233   def testScalarGPU(self):
    234     with self.test_session():
    235       indices = [constant_op.constant(0), constant_op.constant(1)]
    236       data = [constant_op.constant(40.0), constant_op.constant(60.0)]
    237       for step in -1, 1:
    238         stitched_t = data_flow_ops.dynamic_stitch(indices[::step], data)
    239         stitched_val = stitched_t.eval()
    240         self.assertAllEqual([40.0, 60.0][::step], stitched_val)
    241         # Dimension 0 is max(flatten(indices))+1.
    242         self.assertEqual([2], stitched_t.get_shape().as_list())
    243 
    244   def testHigherRankGPU(self):
    245     with self.test_session() as sess:
    246       indices = [
    247           constant_op.constant(6),
    248           constant_op.constant([4, 1]),
    249           constant_op.constant([[5, 2], [0, 3]])
    250       ]
    251       data = [
    252           constant_op.constant([61, 62], dtype=dtypes.float32),
    253           constant_op.constant([[41, 42], [11, 12]], dtype=dtypes.float32),
    254           constant_op.constant(
    255               [[[51, 52], [21, 22]], [[1, 2], [31, 32]]], dtype=dtypes.float32)
    256       ]
    257       stitched_t = data_flow_ops.dynamic_stitch(indices, data)
    258       stitched_val = stitched_t.eval()
    259       correct = 10 * np.arange(7)[:, None] + [1.0, 2.0]
    260       self.assertAllEqual(correct, stitched_val)
    261       self.assertEqual([7, 2], stitched_t.get_shape().as_list())
    262       # Test gradients
    263       stitched_grad = 7 * stitched_val
    264       grads = gradients_impl.gradients(stitched_t, indices + data,
    265                                        stitched_grad)
    266       self.assertEqual(grads[:3], [None] * 3)  # Indices have no gradients
    267       for datum, grad in zip(data, sess.run(grads[3:])):
    268         self.assertAllEqual(7.0 * datum.eval(), grad)
    269 
    270 
    271 if __name__ == "__main__":
    272   test.main()
    273