Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.tf.gather_nd."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import time
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.client import session
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import gradients_impl
     31 from tensorflow.python.ops import variables
     32 from tensorflow.python.platform import test
     33 
     34 
     35 class GatherNdTest(test.TestCase):
     36 
     37   def _testSimpleDtype(self, dtype):
     38     with self.test_session(use_gpu=True):
     39       params = constant_op.constant(np.array([8, 1, 2, 3, 7, 5], dtype=dtype))
     40       indices = constant_op.constant([[4], [4], [0]])
     41       gather_nd_t = array_ops.gather_nd(params, indices)
     42       gather_nd_val = gather_nd_t.eval()
     43 
     44     self.assertAllEqual(np.array([7, 7, 8], dtype=dtype), gather_nd_val)
     45     self.assertEqual([3], gather_nd_t.get_shape())
     46 
     47   def testSimpleDtype(self):
     48     self._testSimpleDtype(np.float32)
     49     self._testSimpleDtype(np.float64)
     50     self._testSimpleDtype(np.int32)
     51     self._testSimpleDtype(np.int64)
     52     self._testSimpleDtype(np.complex64)
     53     self._testSimpleDtype(np.complex128)
     54     self._testSimpleDtype("|S")  # byte strings in python2 + 3
     55 
     56   def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
     57     with self.test_session(use_gpu=True):
     58       params = np.ones((3, 3), dtype=np.float32)
     59 
     60       indices_empty = np.empty((0, 2), dtype=np.int32)
     61       gather_nd_ok_t = array_ops.gather_nd(params, indices_empty)
     62       gather_nd_ok_val = gather_nd_ok_t.eval()
     63       self.assertEqual([0], gather_nd_ok_t.get_shape())
     64       self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
     65 
     66       indices_empty = np.empty((0, 1), dtype=np.int32)
     67       gather_nd_ok_t = array_ops.gather_nd(params, indices_empty)
     68       gather_nd_ok_val = gather_nd_ok_t.eval()
     69       self.assertEqual([0, 3], gather_nd_ok_t.get_shape())
     70       self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val)
     71 
     72       params_empty = np.empty((0, 3), dtype=np.float32)
     73       indices_empty = np.empty((0, 2), dtype=np.int32)
     74       gather_nd_ok_t = array_ops.gather_nd(params_empty, indices_empty)
     75       gather_nd_ok_val = gather_nd_ok_t.eval()
     76       self.assertEqual([0], gather_nd_ok_t.get_shape())
     77       self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
     78 
     79       params_empty = np.empty((0, 3), dtype=np.float32)
     80       indices_nonempty = np.zeros((1, 2), dtype=np.int32)
     81       gather_nd_break_t = array_ops.gather_nd(params_empty, indices_nonempty)
     82       with self.assertRaisesOpError(
     83           r"Requested more than 0 entries, but params is empty."):
     84         gather_nd_break_t.eval()
     85       self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val)
     86 
     87   def testIndexScalar(self):
     88     with self.test_session(use_gpu=True):
     89       params = np.array(
     90           [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
     91       indices = constant_op.constant([4, 1])
     92       gather_nd_t = array_ops.gather_nd(params, indices)
     93       gather_nd_val = gather_nd_t.eval()
     94       self.assertEqual([], gather_nd_t.get_shape())
     95       self.assertAllEqual(np.array(7), gather_nd_val)
     96 
     97   def testParamsRankLargerThanIndexIndexScalarSlices(self):
     98     with self.test_session(use_gpu=True):
     99       params = np.array(
    100           [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
    101       indices = constant_op.constant([4])
    102       gather_nd_t = array_ops.gather_nd(params, indices)
    103       gather_nd_val = gather_nd_t.eval()
    104       self.assertEqual([2], gather_nd_t.get_shape())
    105       self.assertAllEqual(np.array([-7, 7]), gather_nd_val)
    106 
    107   def testParamsRankLargerThanIndexSlices(self):
    108     with self.test_session(use_gpu=True):
    109       params = np.array(
    110           [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T
    111       indices = constant_op.constant([[4], [4], [0]])
    112       gather_nd_t = array_ops.gather_nd(params, indices)
    113       gather_nd_val = gather_nd_t.eval()
    114 
    115     self.assertEqual([3, 2], gather_nd_t.get_shape())
    116     self.assertAllEqual(np.array([[-7, 7], [-7, 7], [-8, 8]]), gather_nd_val)
    117 
    118   def testHigherRankParamsLargerThanIndexSlices(self):
    119     with self.test_session(use_gpu=True):
    120       params = np.array(
    121           [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]],
    122            [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]],
    123           dtype=np.float32).T
    124       params_t = constant_op.constant(params)
    125       indices = constant_op.constant([[4], [4], [0]])
    126       gather_nd_t = array_ops.gather_nd(params_t, indices)
    127       gather_nd_val = gather_nd_t.eval()
    128 
    129     self.assertEqual([3, 2, 2], gather_nd_t.get_shape())
    130     self.assertAllEqual(params[[4, 4, 0]], gather_nd_val)
    131 
    132   def testEmptyIndicesLastRankMeansCopyEntireTensor(self):
    133     with self.test_session(use_gpu=True):
    134       params = np.array(
    135           [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]],
    136            [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]],
    137           dtype=np.float32).T
    138       params_t = constant_op.constant(params)
    139       indices = constant_op.constant(
    140           [[], []], dtype=dtypes.int32)  # Size (2, 0)
    141       gather_nd_t = array_ops.gather_nd(params_t, indices)
    142       gather_nd_val = gather_nd_t.eval()
    143 
    144     self.assertEqual([2, 6, 2, 2], gather_nd_t.get_shape())
    145     self.assertAllEqual(
    146         np.vstack((params[np.newaxis, :], params[np.newaxis, :])),
    147         gather_nd_val)
    148 
    149   def testHigherRankParamsAndIndicesLargerThanIndexSlices(self):
    150     with self.test_session(use_gpu=True):
    151       params = np.array(
    152           [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]],
    153            [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]],
    154           dtype=np.float32).T
    155       params_t = constant_op.constant(params)
    156       indices = constant_op.constant([[[3], [2], [1]], [[4], [4], [0]]])
    157       gather_nd_t = array_ops.gather_nd(params_t, indices)
    158       gather_nd_val = gather_nd_t.eval()
    159 
    160     self.assertEqual([2, 3, 2, 2], gather_nd_t.get_shape())
    161     self.assertAllEqual(params[[3, 2, 1, 4, 4, 0]].reshape(2, 3, 2, 2),
    162                         gather_nd_val)
    163 
    164   def testHigherRankParams(self):
    165     with self.test_session(use_gpu=True):
    166       shape = (10, 20, 5, 1, 17)
    167       params = np.random.rand(*shape)
    168       indices = np.vstack([np.random.randint(0, s, size=2000) for s in shape]).T
    169       gather_nd_t = array_ops.gather_nd(params, indices)
    170       gather_nd_val = gather_nd_t.eval()
    171 
    172     expected = params[tuple(indices.T)]
    173     self.assertAllEqual(expected, gather_nd_val)
    174     self.assertEqual([2000], gather_nd_t.get_shape())
    175 
    176   def testHigherRankParamsAndIndices(self):
    177     with self.test_session(use_gpu=True):
    178       shape = (10, 20, 5, 1, 17)
    179       params = np.random.rand(*shape)
    180       indices = np.vstack([np.random.randint(0, s, size=2000) for s in shape]).T
    181       indices_reshaped = indices.reshape([10, 10, 20, 5])
    182       gather_nd_t = array_ops.gather_nd(params, indices_reshaped)
    183       gather_nd_val = gather_nd_t.eval()
    184 
    185     expected = params[tuple(indices.T)]
    186     self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val)
    187     self.assertEqual([10, 10, 20], gather_nd_t.get_shape())
    188 
    189   def assertIndexedSlices(self, t):
    190     self.assertIsInstance(t, ops.IndexedSlices)
    191 
    192   def testUnknownIndices(self):
    193     params = constant_op.constant([[0, 1, 2]])
    194     indices = array_ops.placeholder(dtypes.int32)
    195     gather_nd_t = array_ops.gather_nd(params, indices)
    196     shape = gather_nd_t.get_shape()
    197     self.assertEqual(None, shape.ndims)
    198     self.assertEqual(None, shape[0].value)
    199 
    200   def testBadIndices(self):
    201     with self.test_session(use_gpu=True):
    202       params = [0, 1, 2]
    203       indices = [[[0], [7]]]  # Make this one higher rank
    204       gather_nd = array_ops.gather_nd(params, indices)
    205       with self.assertRaisesOpError(
    206           r"flat indices\[1, :\] = \[7\] does not index into param "
    207           r"\(shape: \[3\]\)"):
    208         gather_nd.eval()
    209 
    210   def testBadIndicesWithSlices(self):
    211     with self.test_session(use_gpu=True):
    212       params = [[0, 1, 2]]
    213       indices = [[[0], [0], [1]]]  # Make this one higher rank
    214       gather_nd = array_ops.gather_nd(params, indices)
    215       with self.assertRaisesOpError(
    216           r"flat indices\[2, :\] = \[1\] does not index into param "
    217           r"\(shape: \[1,3\]\)"):
    218         gather_nd.eval()
    219 
    220   def testGradientsRank2Elements(self):
    221     indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32)
    222     inputs = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    223     outputs = array_ops.gather_nd(inputs, indices)
    224 
    225     grad_vals = constant_op.constant([1, 2], dtype=dtypes.float64)
    226     grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    227     expected_grads = np.array([[1, 0], [0, 2]], dtype=np.float64)
    228     with self.test_session(use_gpu=True):
    229       assert np.array_equal(expected_grads, grads.eval())
    230 
    231   def testGradientsRank2Slices(self):
    232     indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
    233     inputs = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    234     outputs = array_ops.gather_nd(inputs, indices)
    235 
    236     grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    237     grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    238     expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64)
    239     with self.test_session(use_gpu=True):
    240       self.assertIndexedSlices(grads)
    241       self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
    242 
    243   def testGradientsRank3Elements(self):
    244     indices = constant_op.constant(
    245         [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=dtypes.int32)
    246     inputs = constant_op.constant(
    247         [[[1, 3], [5, 7]], [[2, 4], [6, 8]]], dtype=dtypes.float64)
    248     outputs = array_ops.gather_nd(inputs, indices)
    249 
    250     grad_vals = constant_op.constant(
    251         [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64)
    252     grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    253     expected_grads = np.array(
    254         [[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64)
    255     with self.test_session(use_gpu=True):
    256       self.assertAllEqual(expected_grads, grads.eval())
    257 
    258   def testGradientsRank7Elements(self):
    259     # Shape [1,1,2,1,1,2,2]
    260     indices = constant_op.constant(
    261         [[[
    262             [[[[0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0]]]],
    263             [[[[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1]]]]
    264         ]]],
    265         dtype=dtypes.int32)
    266     inputs = constant_op.constant(
    267         [[[
    268             [[[[1, 3], [5, 7]]]],
    269             [[[[2, 4], [6, 8]]]]
    270         ]]], dtype=dtypes.float64)
    271     outputs = array_ops.gather_nd(inputs, indices)
    272 
    273     grad_vals = constant_op.constant(
    274         [[[
    275             [[[[1, 2], [3, 4]]]],
    276             [[[[5, 6], [7, 8]]]]
    277         ]]], dtype=dtypes.float64)
    278     grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    279     expected_grads = np.array(
    280         [[[
    281             [[[[5, 6], [1, 2]]]],
    282             [[[[3, 4], [7, 8]]]]
    283         ]]], dtype=np.float64)
    284     with self.test_session(use_gpu=True):
    285       self.assertAllEqual(expected_grads, grads.eval())
    286 
    287   def testGradientsInt64Indices(self):
    288     indices = constant_op.constant(
    289         [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=dtypes.int64)
    290     inputs = constant_op.constant(
    291         [[[1, 3], [5, 7]], [[2, 4], [6, 8]]], dtype=dtypes.float64)
    292     outputs = array_ops.gather_nd(inputs, indices)
    293 
    294     grad_vals = constant_op.constant(
    295         [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64)
    296     grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    297     expected_grads = np.array(
    298         [[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64)
    299     with self.test_session(use_gpu=True):
    300       self.assertAllEqual(expected_grads, grads.eval())
    301 
    302   def testGradientsRank2SlicesWithEmptySpace(self):
    303     indices = constant_op.constant([[2], [0], [5]], dtype=dtypes.int32)
    304     inputs = constant_op.constant(
    305         [[1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9],
    306          [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9],
    307          [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9]],
    308         dtype=dtypes.float64)
    309     outputs = array_ops.gather_nd(inputs, indices)
    310     grad_vals = constant_op.constant(
    311         [[1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2],
    312          [3, 3, 3, 3, 3, 3, 3, 3, 3]],
    313         dtype=dtypes.float64)
    314     grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0]
    315     expected_grads = np.array(
    316         [[2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0],
    317          [1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0],
    318          [0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]],
    319         dtype=np.float64)
    320     with self.test_session(use_gpu=True):
    321       self.assertIndexedSlices(grads)
    322       self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
    323 
    324 
    325 class GatherNdOpBenchmark(test.Benchmark):
    326 
    327   def benchmark_gather_nd_op(self):
    328     shape = (100, 47, 18, 170, 13)
    329     np.random.seed(127)
    330     params = np.random.rand(*shape)
    331     indices = np.vstack([np.random.randint(0, s, size=10000) for s in shape]).T
    332 
    333     with session.Session():
    334       t_params = variables.Variable(params)
    335       t_indices = variables.Variable(indices)
    336       gather_op = array_ops.gather_nd(t_params, t_indices)
    337       variables.global_variables_initializer().run()
    338       for _ in range(10):
    339         gather_op.eval()
    340       t1 = time.time()
    341       for _ in range(1000):
    342         gather_op.eval()
    343       t2 = time.time()
    344       self.report_benchmark(iters=1000, wall_time=(t2 - t1) / 1000.0)
    345 
    346 
    347 if __name__ == "__main__":
    348   test.main()
    349