1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.tf.gather_nd.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import time 22 23 import numpy as np 24 25 from tensorflow.python.client import session 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.framework import ops 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import gradients_impl 31 from tensorflow.python.ops import variables 32 from tensorflow.python.platform import test 33 34 35 class GatherNdTest(test.TestCase): 36 37 def _testSimpleDtype(self, dtype): 38 with self.test_session(use_gpu=True): 39 params = constant_op.constant(np.array([8, 1, 2, 3, 7, 5], dtype=dtype)) 40 indices = constant_op.constant([[4], [4], [0]]) 41 gather_nd_t = array_ops.gather_nd(params, indices) 42 gather_nd_val = gather_nd_t.eval() 43 44 self.assertAllEqual(np.array([7, 7, 8], dtype=dtype), gather_nd_val) 45 self.assertEqual([3], gather_nd_t.get_shape()) 46 47 def testSimpleDtype(self): 48 self._testSimpleDtype(np.float32) 49 self._testSimpleDtype(np.float64) 50 self._testSimpleDtype(np.int32) 51 self._testSimpleDtype(np.int64) 52 self._testSimpleDtype(np.complex64) 53 self._testSimpleDtype(np.complex128) 54 self._testSimpleDtype("|S") # byte strings in python2 + 3 55 56 def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): 57 with self.test_session(use_gpu=True): 58 params = np.ones((3, 3), dtype=np.float32) 59 60 indices_empty = np.empty((0, 2), dtype=np.int32) 61 gather_nd_ok_t = array_ops.gather_nd(params, indices_empty) 62 gather_nd_ok_val = gather_nd_ok_t.eval() 63 self.assertEqual([0], gather_nd_ok_t.get_shape()) 64 self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) 65 66 indices_empty = np.empty((0, 1), dtype=np.int32) 67 gather_nd_ok_t = array_ops.gather_nd(params, indices_empty) 68 gather_nd_ok_val = gather_nd_ok_t.eval() 69 self.assertEqual([0, 3], gather_nd_ok_t.get_shape()) 70 self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val) 71 72 params_empty = np.empty((0, 3), dtype=np.float32) 73 indices_empty = np.empty((0, 2), dtype=np.int32) 74 gather_nd_ok_t = array_ops.gather_nd(params_empty, indices_empty) 75 gather_nd_ok_val = gather_nd_ok_t.eval() 76 self.assertEqual([0], gather_nd_ok_t.get_shape()) 77 self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) 78 79 params_empty = np.empty((0, 3), dtype=np.float32) 80 indices_nonempty = np.zeros((1, 2), dtype=np.int32) 81 gather_nd_break_t = array_ops.gather_nd(params_empty, indices_nonempty) 82 with self.assertRaisesOpError( 83 r"Requested more than 0 entries, but params is empty."): 84 gather_nd_break_t.eval() 85 self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) 86 87 def testIndexScalar(self): 88 with self.test_session(use_gpu=True): 89 params = np.array( 90 [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T 91 indices = constant_op.constant([4, 1]) 92 gather_nd_t = array_ops.gather_nd(params, indices) 93 gather_nd_val = gather_nd_t.eval() 94 self.assertEqual([], gather_nd_t.get_shape()) 95 self.assertAllEqual(np.array(7), gather_nd_val) 96 97 def testParamsRankLargerThanIndexIndexScalarSlices(self): 98 with self.test_session(use_gpu=True): 99 params = np.array( 100 [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T 101 indices = constant_op.constant([4]) 102 gather_nd_t = array_ops.gather_nd(params, indices) 103 gather_nd_val = gather_nd_t.eval() 104 self.assertEqual([2], gather_nd_t.get_shape()) 105 self.assertAllEqual(np.array([-7, 7]), gather_nd_val) 106 107 def testParamsRankLargerThanIndexSlices(self): 108 with self.test_session(use_gpu=True): 109 params = np.array( 110 [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T 111 indices = constant_op.constant([[4], [4], [0]]) 112 gather_nd_t = array_ops.gather_nd(params, indices) 113 gather_nd_val = gather_nd_t.eval() 114 115 self.assertEqual([3, 2], gather_nd_t.get_shape()) 116 self.assertAllEqual(np.array([[-7, 7], [-7, 7], [-8, 8]]), gather_nd_val) 117 118 def testHigherRankParamsLargerThanIndexSlices(self): 119 with self.test_session(use_gpu=True): 120 params = np.array( 121 [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], 122 [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], 123 dtype=np.float32).T 124 params_t = constant_op.constant(params) 125 indices = constant_op.constant([[4], [4], [0]]) 126 gather_nd_t = array_ops.gather_nd(params_t, indices) 127 gather_nd_val = gather_nd_t.eval() 128 129 self.assertEqual([3, 2, 2], gather_nd_t.get_shape()) 130 self.assertAllEqual(params[[4, 4, 0]], gather_nd_val) 131 132 def testEmptyIndicesLastRankMeansCopyEntireTensor(self): 133 with self.test_session(use_gpu=True): 134 params = np.array( 135 [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], 136 [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], 137 dtype=np.float32).T 138 params_t = constant_op.constant(params) 139 indices = constant_op.constant( 140 [[], []], dtype=dtypes.int32) # Size (2, 0) 141 gather_nd_t = array_ops.gather_nd(params_t, indices) 142 gather_nd_val = gather_nd_t.eval() 143 144 self.assertEqual([2, 6, 2, 2], gather_nd_t.get_shape()) 145 self.assertAllEqual( 146 np.vstack((params[np.newaxis, :], params[np.newaxis, :])), 147 gather_nd_val) 148 149 def testHigherRankParamsAndIndicesLargerThanIndexSlices(self): 150 with self.test_session(use_gpu=True): 151 params = np.array( 152 [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], 153 [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], 154 dtype=np.float32).T 155 params_t = constant_op.constant(params) 156 indices = constant_op.constant([[[3], [2], [1]], [[4], [4], [0]]]) 157 gather_nd_t = array_ops.gather_nd(params_t, indices) 158 gather_nd_val = gather_nd_t.eval() 159 160 self.assertEqual([2, 3, 2, 2], gather_nd_t.get_shape()) 161 self.assertAllEqual(params[[3, 2, 1, 4, 4, 0]].reshape(2, 3, 2, 2), 162 gather_nd_val) 163 164 def testHigherRankParams(self): 165 with self.test_session(use_gpu=True): 166 shape = (10, 20, 5, 1, 17) 167 params = np.random.rand(*shape) 168 indices = np.vstack([np.random.randint(0, s, size=2000) for s in shape]).T 169 gather_nd_t = array_ops.gather_nd(params, indices) 170 gather_nd_val = gather_nd_t.eval() 171 172 expected = params[tuple(indices.T)] 173 self.assertAllEqual(expected, gather_nd_val) 174 self.assertEqual([2000], gather_nd_t.get_shape()) 175 176 def testHigherRankParamsAndIndices(self): 177 with self.test_session(use_gpu=True): 178 shape = (10, 20, 5, 1, 17) 179 params = np.random.rand(*shape) 180 indices = np.vstack([np.random.randint(0, s, size=2000) for s in shape]).T 181 indices_reshaped = indices.reshape([10, 10, 20, 5]) 182 gather_nd_t = array_ops.gather_nd(params, indices_reshaped) 183 gather_nd_val = gather_nd_t.eval() 184 185 expected = params[tuple(indices.T)] 186 self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val) 187 self.assertEqual([10, 10, 20], gather_nd_t.get_shape()) 188 189 def assertIndexedSlices(self, t): 190 self.assertIsInstance(t, ops.IndexedSlices) 191 192 def testUnknownIndices(self): 193 params = constant_op.constant([[0, 1, 2]]) 194 indices = array_ops.placeholder(dtypes.int32) 195 gather_nd_t = array_ops.gather_nd(params, indices) 196 shape = gather_nd_t.get_shape() 197 self.assertEqual(None, shape.ndims) 198 self.assertEqual(None, shape[0].value) 199 200 def testBadIndices(self): 201 with self.test_session(use_gpu=True): 202 params = [0, 1, 2] 203 indices = [[[0], [7]]] # Make this one higher rank 204 gather_nd = array_ops.gather_nd(params, indices) 205 with self.assertRaisesOpError( 206 r"flat indices\[1, :\] = \[7\] does not index into param " 207 r"\(shape: \[3\]\)"): 208 gather_nd.eval() 209 210 def testBadIndicesWithSlices(self): 211 with self.test_session(use_gpu=True): 212 params = [[0, 1, 2]] 213 indices = [[[0], [0], [1]]] # Make this one higher rank 214 gather_nd = array_ops.gather_nd(params, indices) 215 with self.assertRaisesOpError( 216 r"flat indices\[2, :\] = \[1\] does not index into param " 217 r"\(shape: \[1,3\]\)"): 218 gather_nd.eval() 219 220 def testGradientsRank2Elements(self): 221 indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32) 222 inputs = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64) 223 outputs = array_ops.gather_nd(inputs, indices) 224 225 grad_vals = constant_op.constant([1, 2], dtype=dtypes.float64) 226 grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0] 227 expected_grads = np.array([[1, 0], [0, 2]], dtype=np.float64) 228 with self.test_session(use_gpu=True): 229 assert np.array_equal(expected_grads, grads.eval()) 230 231 def testGradientsRank2Slices(self): 232 indices = constant_op.constant([[1], [0]], dtype=dtypes.int32) 233 inputs = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64) 234 outputs = array_ops.gather_nd(inputs, indices) 235 236 grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64) 237 grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0] 238 expected_grads = np.array([[3, 4], [1, 2]], dtype=np.float64) 239 with self.test_session(use_gpu=True): 240 self.assertIndexedSlices(grads) 241 self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval()) 242 243 def testGradientsRank3Elements(self): 244 indices = constant_op.constant( 245 [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=dtypes.int32) 246 inputs = constant_op.constant( 247 [[[1, 3], [5, 7]], [[2, 4], [6, 8]]], dtype=dtypes.float64) 248 outputs = array_ops.gather_nd(inputs, indices) 249 250 grad_vals = constant_op.constant( 251 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64) 252 grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0] 253 expected_grads = np.array( 254 [[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64) 255 with self.test_session(use_gpu=True): 256 self.assertAllEqual(expected_grads, grads.eval()) 257 258 def testGradientsRank7Elements(self): 259 # Shape [1,1,2,1,1,2,2] 260 indices = constant_op.constant( 261 [[[ 262 [[[[0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0]]]], 263 [[[[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1]]]] 264 ]]], 265 dtype=dtypes.int32) 266 inputs = constant_op.constant( 267 [[[ 268 [[[[1, 3], [5, 7]]]], 269 [[[[2, 4], [6, 8]]]] 270 ]]], dtype=dtypes.float64) 271 outputs = array_ops.gather_nd(inputs, indices) 272 273 grad_vals = constant_op.constant( 274 [[[ 275 [[[[1, 2], [3, 4]]]], 276 [[[[5, 6], [7, 8]]]] 277 ]]], dtype=dtypes.float64) 278 grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0] 279 expected_grads = np.array( 280 [[[ 281 [[[[5, 6], [1, 2]]]], 282 [[[[3, 4], [7, 8]]]] 283 ]]], dtype=np.float64) 284 with self.test_session(use_gpu=True): 285 self.assertAllEqual(expected_grads, grads.eval()) 286 287 def testGradientsInt64Indices(self): 288 indices = constant_op.constant( 289 [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=dtypes.int64) 290 inputs = constant_op.constant( 291 [[[1, 3], [5, 7]], [[2, 4], [6, 8]]], dtype=dtypes.float64) 292 outputs = array_ops.gather_nd(inputs, indices) 293 294 grad_vals = constant_op.constant( 295 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64) 296 grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0] 297 expected_grads = np.array( 298 [[[5, 6], [1, 2]], [[3, 4], [7, 8]]], dtype=np.float64) 299 with self.test_session(use_gpu=True): 300 self.assertAllEqual(expected_grads, grads.eval()) 301 302 def testGradientsRank2SlicesWithEmptySpace(self): 303 indices = constant_op.constant([[2], [0], [5]], dtype=dtypes.int32) 304 inputs = constant_op.constant( 305 [[1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9], 306 [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9], 307 [1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9]], 308 dtype=dtypes.float64) 309 outputs = array_ops.gather_nd(inputs, indices) 310 grad_vals = constant_op.constant( 311 [[1, 1, 1, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2, 2, 2, 2], 312 [3, 3, 3, 3, 3, 3, 3, 3, 3]], 313 dtype=dtypes.float64) 314 grads = gradients_impl.gradients([outputs], [inputs], [grad_vals])[0] 315 expected_grads = np.array( 316 [[2, 2, 2, 2, 2, 2, 2, 2, 2], [0, 0, 0, 0, 0, 0, 0, 0, 0], 317 [1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0], 318 [0, 0, 0, 0, 0, 0, 0, 0, 0], [3, 3, 3, 3, 3, 3, 3, 3, 3]], 319 dtype=np.float64) 320 with self.test_session(use_gpu=True): 321 self.assertIndexedSlices(grads) 322 self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval()) 323 324 325 class GatherNdOpBenchmark(test.Benchmark): 326 327 def benchmark_gather_nd_op(self): 328 shape = (100, 47, 18, 170, 13) 329 np.random.seed(127) 330 params = np.random.rand(*shape) 331 indices = np.vstack([np.random.randint(0, s, size=10000) for s in shape]).T 332 333 with session.Session(): 334 t_params = variables.Variable(params) 335 t_indices = variables.Variable(indices) 336 gather_op = array_ops.gather_nd(t_params, t_indices) 337 variables.global_variables_initializer().run() 338 for _ in range(10): 339 gather_op.eval() 340 t1 = time.time() 341 for _ in range(1000): 342 gather_op.eval() 343 t2 = time.time() 344 self.report_benchmark(iters=1000, wall_time=(t2 - t1) / 1000.0) 345 346 347 if __name__ == "__main__": 348 test.main() 349