1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.tf.scatter_nd.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import functools 22 23 import numpy as np 24 25 from tensorflow.python.client import session 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import gradients_impl 30 from tensorflow.python.ops import resource_variable_ops 31 from tensorflow.python.ops import state_ops 32 from tensorflow.python.ops import variables 33 from tensorflow.python.platform import test 34 35 36 def _AsType(v, vtype): 37 return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v) 38 39 40 def _FlatInnerDims(tensor, ndims=2): 41 shape = list(tensor.shape) 42 return tensor.reshape([ 43 functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1) 44 ] + shape[-ndims + 1:]) 45 46 47 def _FlatOuterDims(tensor, ndims=2): 48 shape = list(tensor.shape) 49 return tensor.reshape(shape[:ndims - 1] + [ 50 functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1) 51 ]) 52 53 54 def _NumpyScatterNd(ref, indices, updates, op): 55 ixdim = indices.shape[-1] 56 num_updates = indices.size // ixdim 57 total_nd = len(ref.shape) 58 slice_size = 1 59 for i in range(ixdim, total_nd): 60 slice_size *= ref.shape[i] 61 flat_indices = _FlatInnerDims(indices) 62 flat_updates = updates.reshape((num_updates, slice_size)) 63 output_flat = _FlatOuterDims(ref, ixdim + 1) 64 for ix_updates, ix_output in enumerate(flat_indices): 65 ix_output = tuple(ix_output) 66 output_flat[ix_output] = op(output_flat[ix_output], 67 flat_updates[ix_updates]) 68 return output_flat.reshape(ref.shape) 69 70 71 def _NumpyUpdate(ref, indices, updates): 72 return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) 73 74 75 def _NumpyAdd(ref, indices, updates): 76 return _NumpyScatterNd(ref, indices, updates, lambda p, u: p + u) 77 78 79 def _NumpySub(ref, indices, updates): 80 return _NumpyScatterNd(ref, indices, updates, lambda p, u: p - u) 81 82 83 def _NumpyMul(ref, indices, updates): 84 return _NumpyScatterNd(ref, indices, updates, lambda p, u: p * u) 85 86 87 def _NumpyDiv(ref, indices, updates): 88 return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u) 89 90 91 class StatefulScatterNdTest(test.TestCase): 92 93 def _VariableRankTest(self, 94 np_scatter, 95 tf_scatter, 96 vtype, 97 itype, 98 repeat_indices=False): 99 np.random.seed(8) 100 ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)] 101 indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)] 102 with self.test_session(use_gpu=True): 103 for ref_shape, indices_shape in zip(ref_shapes, indices_shapes): 104 num_updates = indices_shape[0] 105 ixdim = indices_shape[-1] 106 107 indexable_area_shape = () 108 for i in range(ixdim): 109 indexable_area_shape += (ref_shape[i],) 110 all_indices = [ 111 list(coord) 112 for coord, _ in np.ndenumerate( 113 np.empty(indexable_area_shape, vtype)) 114 ] 115 np.random.shuffle(all_indices) 116 indices = np.array(all_indices[:num_updates]) 117 118 if num_updates > 1 and repeat_indices: 119 indices = indices[:num_updates // 2] 120 for _ in range(num_updates - num_updates // 2): 121 indices = np.append( 122 indices, [indices[np.random.randint(num_updates // 2)]], axis=0) 123 np.random.shuffle(indices) 124 indices = _AsType(indices[:num_updates], itype) 125 126 updates_shape = (num_updates,) 127 for i in range(ixdim, len(ref_shape)): 128 updates_shape += (ref_shape[i],) 129 updates = _AsType(np.random.randn(*(updates_shape)), vtype) 130 ref = _AsType(np.random.randn(*(ref_shape)), vtype) 131 132 # Scatter via numpy 133 new = ref.copy() 134 np_scatter(new, indices, updates) 135 # Scatter via tensorflow 136 ref_var = variables.Variable(ref) 137 ref_var.initializer.run() 138 tf_scatter(ref_var, indices, updates).eval() 139 140 # Compare 141 self.assertAllClose(new, ref_var.eval()) 142 143 def _VariableRankTests(self, np_scatter, tf_scatter): 144 for vtype in (np.float32, np.float64, np.complex64, np.complex128): 145 for itype in (np.int32, np.int64): 146 self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) 147 148 def testSimple(self): 149 indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) 150 updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) 151 ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32) 152 expected = np.array([0, 11, 0, 10, 9, 0, 0, 12]) 153 scatter = state_ops.scatter_nd_update(ref, indices, updates) 154 init = variables.global_variables_initializer() 155 156 with self.test_session(use_gpu=True) as sess: 157 sess.run(init) 158 result = sess.run(scatter) 159 self.assertAllClose(result, expected) 160 161 def testSimpleResource(self): 162 indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32) 163 updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32) 164 ref = resource_variable_ops.ResourceVariable( 165 [0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32) 166 expected = np.array([0, 11, 0, 10, 9, 0, 0, 12]) 167 scatter = state_ops.scatter_nd_update(ref, indices, updates) 168 init = variables.global_variables_initializer() 169 170 with self.test_session(use_gpu=True) as sess: 171 sess.run(init) 172 sess.run(scatter) 173 self.assertAllClose(ref.eval(), expected) 174 175 def testSimple2(self): 176 indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32) 177 updates = constant_op.constant([11., 12.], dtype=dtypes.float32) 178 ref = variables.Variable( 179 [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32) 180 expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) 181 scatter = state_ops.scatter_nd_update(ref, indices, updates) 182 init = variables.global_variables_initializer() 183 184 with self.test_session(use_gpu=True) as sess: 185 sess.run(init) 186 result = sess.run(scatter) 187 self.assertAllClose(result, expected) 188 189 def testSimple3(self): 190 indices = constant_op.constant([[1]], dtype=dtypes.int32) 191 updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32) 192 ref = variables.Variable( 193 [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32) 194 expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) 195 scatter = state_ops.scatter_nd_update(ref, indices, updates) 196 init = variables.global_variables_initializer() 197 198 with self.test_session(use_gpu=True) as sess: 199 sess.run(init) 200 result = sess.run(scatter) 201 self.assertAllClose(result, expected) 202 203 def testVariableRankUpdate(self): 204 self._VariableRankTests(_NumpyUpdate, state_ops.scatter_nd_update) 205 206 def testVariableRankAdd(self): 207 self._VariableRankTests(_NumpyAdd, state_ops.scatter_nd_add) 208 209 def testVariableRankSub(self): 210 self._VariableRankTests(_NumpySub, state_ops.scatter_nd_sub) 211 212 # TODO(ebrevdo): Re-enable when we need ScatterNdMul. 213 # def testVariableRankMul(self): 214 # self._VariableRankTests(_NumpyMul, state_ops.scatter_nd_mul) 215 216 # TODO(ebrevdo): Re-enable when we need ScatterNdDiv. 217 # def testVariableRankDiv(self): 218 # self._VariableRankTests(_NumpyDiv, state_ops.scatter_nd_div) 219 220 def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter): 221 for vtype in (np.float32, np.float64): 222 for itype in (np.int32, np.int64): 223 self._VariableRankTest( 224 np_scatter, tf_scatter, vtype, itype, repeat_indices=True) 225 226 def testScatterRepeatIndices(self): 227 """This tests scatter_add using indices that repeat.""" 228 self._ScatterRepeatIndicesTest(_NumpyAdd, state_ops.scatter_nd_add) 229 self._ScatterRepeatIndicesTest(_NumpySub, state_ops.scatter_nd_sub) 230 # TODO(ebrevdo): Re-enable when we need ScatterNdMul and ScatterNdDiv. 231 # self._ScatterRepeatIndicesTest(_NumpyMul, state_ops.scatter_nd_mul) 232 # self._ScatterRepeatIndicesTest(_NumpyDiv, state_ops.scatter_nd_div) 233 234 # TODO(simister): Re-enable once binary size increase due to 235 # extra templating is back under control and this op is re-enabled 236 # def testBooleanScatterUpdate(self): 237 # with self.test_session(use_gpu=False) as session: 238 # var = tf.Variable([True, False]) 239 # update0 = tf.scatter_nd_update(var, [[1]], [True]) 240 # update1 = tf.scatter_nd_update( 241 # var, tf.constant( 242 # [[0]], dtype=tf.int64), [False]) 243 # var.initializer.run() 244 # session.run([update0, update1]) 245 # self.assertAllEqual([False, True], var.eval()) 246 247 def testScatterOutOfRangeCpu(self): 248 # TODO(simister): Re-enable once binary size increase due to 249 # scatter_nd ops is under control. 250 # tf.scatter_nd_mul, tf.scatter_nd_div, 251 for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub, 252 state_ops.scatter_nd_update): 253 params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) 254 updates = np.array([-3, -4, -5]).astype(np.float32) 255 with self.test_session(use_gpu=False): 256 ref = variables.Variable(params) 257 ref.initializer.run() 258 259 # Indices all in range, no problem. 260 indices = np.array([[2], [0], [5]]) 261 op(ref, indices, updates).eval() 262 263 # Test some out of range errors. 264 indices = np.array([[-1], [0], [5]]) 265 with self.assertRaisesOpError( 266 r"Invalid indices: \[0,0\] = \[-1\] does not index into \[6\]"): 267 op(ref, indices, updates).eval() 268 269 indices = np.array([[2], [0], [6]]) 270 with self.assertRaisesOpError( 271 r"Invalid indices: \[2,0\] = \[6\] does not index into \[6\]"): 272 op(ref, indices, updates).eval() 273 274 def testRank3ValidShape(self): 275 indices = array_ops.zeros([2, 2, 2], dtypes.int32) 276 updates = array_ops.zeros([2, 2, 2], dtypes.int32) 277 shape = np.array([2, 2, 2]) 278 ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) 279 self.assertAllEqual( 280 state_ops.scatter_nd_update(ref, indices, 281 updates).get_shape().as_list(), shape) 282 283 def testExtraIndicesDimensions(self): 284 indices = array_ops.zeros([1, 1, 2], dtypes.int32) 285 updates = array_ops.zeros([1, 1], dtypes.int32) 286 shape = np.array([2, 2]) 287 ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) 288 scatter_update = state_ops.scatter_nd_update(ref, indices, updates) 289 self.assertAllEqual(scatter_update.get_shape().as_list(), shape) 290 291 expected_result = np.zeros([2, 2], dtype=np.int32) 292 with self.test_session(): 293 ref.initializer.run() 294 self.assertAllEqual(expected_result, scatter_update.eval()) 295 296 def testRank3InvalidShape1(self): 297 indices = array_ops.zeros([3, 2, 2], dtypes.int32) 298 updates = array_ops.zeros([2, 2, 2], dtypes.int32) 299 shape = np.array([2, 2, 2]) 300 ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) 301 with self.assertRaisesWithPredicateMatch( 302 ValueError, "The outer \\d+ dimensions of indices\\.shape="): 303 state_ops.scatter_nd_update(ref, indices, updates) 304 305 def testRank3InvalidShape2(self): 306 indices = array_ops.zeros([2, 2, 1], dtypes.int32) 307 updates = array_ops.zeros([2, 2], dtypes.int32) 308 shape = np.array([2, 2, 2]) 309 ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) 310 with self.assertRaisesWithPredicateMatch( 311 ValueError, "The inner \\d+ dimensions of input\\.shape="): 312 state_ops.scatter_nd_update(ref, indices, updates) 313 314 def testConcurrentUpdates(self): 315 num_updates = 10000 316 update_values = np.random.rand(num_updates) 317 ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64) 318 indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32) 319 updates = constant_op.constant(update_values, dtype=dtypes.float64) 320 321 expected_result = np.zeros([2, 2], dtype=np.float64) 322 expected_result[0, 1] = np.sum(update_values) 323 324 scatter = state_ops.scatter_nd_add(ref, indices, updates) 325 init = variables.global_variables_initializer() 326 327 with session.Session() as sess: 328 sess.run(init) 329 result = sess.run(scatter) 330 assert np.allclose(result, expected_result) 331 332 # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU. 333 def _disabledTestScatterOutOfRangeGpu(self): 334 if not test.IsBuiltWithCuda(): 335 return 336 # TODO(simister): Re-enable once binary size increase due to 337 # scatter_nd ops is under control. 338 # tf.scatter_nd_mul, tf.scatter_nd_div, 339 for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub, 340 state_ops.scatter_nd_update): 341 params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32) 342 updates = np.array([-3, -4, -5]).astype(np.float32) 343 # With GPU, the code ignores indices that are out of range. 344 # We don't test the implementation; just test there's no failures. 345 with self.test_session(force_gpu=True): 346 ref = variables.Variable(params) 347 ref.initializer.run() 348 349 # Indices all in range, no problem. 350 indices = np.array([2, 0, 5]) 351 op(ref, indices, updates).eval() 352 353 # Indices out of range should not fail. 354 indices = np.array([-1, 0, 5]) 355 op(ref, indices, updates).eval() 356 indices = np.array([2, 0, 6]) 357 op(ref, indices, updates).eval() 358 359 360 class ScatterNdTest(test.TestCase): 361 non_aliasing_add_test = False 362 363 def scatter_nd(self, indices, updates, shape, input_=None): 364 del input_ # input_ is not used in scatter_nd 365 return array_ops.scatter_nd(indices, updates, shape) 366 367 def testRank3ValidShape(self): 368 indices = array_ops.zeros([2, 2, 2], dtypes.int32) 369 updates = array_ops.zeros([2, 2, 2], dtypes.int32) 370 shape = np.array([2, 2, 2]) 371 self.assertAllEqual( 372 self.scatter_nd(indices, updates, shape).get_shape().as_list(), shape) 373 374 def testExtraIndicesDimensions(self): 375 indices = array_ops.zeros([1, 1, 2], dtypes.int32) 376 updates = array_ops.zeros([1, 1], dtypes.int32) 377 shape = np.array([2, 2]) 378 scatter = self.scatter_nd(indices, updates, shape) 379 self.assertAllEqual(scatter.get_shape().as_list(), shape) 380 expected_result = np.zeros([2, 2], dtype=np.int32) 381 with self.test_session(): 382 self.assertAllEqual(expected_result, scatter.eval()) 383 384 def testUndefinedIndicesShape(self): 385 indices = array_ops.placeholder(dtypes.int32, shape=None) 386 updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) 387 shape = constant_op.constant([2, 2, 2], dtypes.int32) 388 self.scatter_nd(indices, updates, shape) 389 390 def testUndefinedUpdatesShape(self): 391 indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) 392 updates = array_ops.placeholder(dtypes.int32, shape=None) 393 shape = constant_op.constant([2, 2, 2], dtypes.int32) 394 self.scatter_nd(indices, updates, shape) 395 396 def testUndefinedOutputShape(self): 397 indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) 398 updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) 399 shape = array_ops.placeholder(dtypes.int32, shape=[None]) 400 self.scatter_nd(indices, updates, shape) 401 402 def testEmptyOutputShape1(self): 403 indices = array_ops.zeros([2, 2, 2], dtypes.int32) 404 updates = array_ops.zeros([2, 2, 2], dtypes.int32) 405 shape = constant_op.constant([0, 3, 2], dtypes.int32) 406 407 with self.assertRaisesWithPredicateMatch( 408 ValueError, "Indices and updates specified for empty output shape"): 409 self.scatter_nd(indices, updates, shape) 410 411 def testEmptyOutputShape2(self): 412 indices = array_ops.placeholder(dtypes.int32, shape=None) 413 updates = array_ops.placeholder(dtypes.int32, shape=None) 414 shape = constant_op.constant([0, 3, 2], dtypes.int32) 415 416 with self.test_session(): 417 with self.assertRaisesOpError( 418 "Indices and updates specified for empty output"): 419 self.scatter_nd(indices, updates, shape).eval(feed_dict={ 420 indices: np.zeros([2, 2, 2], dtype=np.int32), 421 updates: np.zeros([2, 2, 2], dtype=np.int32) 422 }) 423 424 def testEmptyOutputShape3(self): 425 indices = array_ops.zeros([0], dtypes.int32) 426 updates = array_ops.zeros([0], dtypes.int32) 427 shape = constant_op.constant([0], dtypes.int32) 428 scatter = self.scatter_nd(indices, updates, shape) 429 430 with self.test_session(): 431 self.assertEqual(scatter.eval().size, 0) 432 433 def testRank3InvalidShape1(self): 434 indices = array_ops.zeros([3, 2, 2], dtypes.int32) 435 updates = array_ops.zeros([2, 2, 2], dtypes.int32) 436 shape = np.array([2, 2, 2]) 437 with self.assertRaisesWithPredicateMatch( 438 ValueError, "The outer \\d+ dimensions of indices\\.shape="): 439 self.scatter_nd(indices, updates, shape) 440 441 def testRank3InvalidShape2(self): 442 indices = array_ops.zeros([2, 2, 1], dtypes.int32) 443 updates = array_ops.zeros([2, 2], dtypes.int32) 444 shape = np.array([2, 2, 2]) 445 with self.assertRaisesWithPredicateMatch( 446 ValueError, "The inner \\d+ dimensions of (input|output)\\.shape="): 447 self.scatter_nd(indices, updates, shape) 448 449 def testGradientsRank2ElementUpdate(self): 450 indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32) 451 updates = constant_op.constant([1, 4], dtype=dtypes.float64) 452 shape = constant_op.constant([2, 2], dtype=dtypes.int32) 453 input_ = array_ops.zeros(shape, dtype=dtypes.float64) 454 outputs = self.scatter_nd(indices, updates, shape, input_) 455 456 grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64) 457 updates_grad, input_grad = gradients_impl.gradients( 458 [outputs], [updates, input_], [grad_vals]) 459 expected_updates_grad = np.array([1, 4], dtype=np.float64) 460 expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64) 461 with self.test_session(): 462 self.assertAllEqual(expected_updates_grad, updates_grad.eval()) 463 if self.non_aliasing_add_test: 464 self.assertAllEqual(expected_input_grad, input_grad.eval()) 465 466 def testGradientsRank2SliceUpdate(self): 467 indices = constant_op.constant([[1], [0]], dtype=dtypes.int32) 468 updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64) 469 shape = constant_op.constant([2, 2], dtype=dtypes.int32) 470 input_ = array_ops.zeros(shape, dtype=dtypes.float64) 471 outputs = self.scatter_nd(indices, updates, shape, input_) 472 473 grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64) 474 updates_grad, input_grad = gradients_impl.gradients( 475 [outputs], [updates, input_], [grad_vals]) 476 expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64) 477 expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64) 478 with self.test_session(): 479 self.assertAllEqual(expected_updates_grad, updates_grad.eval()) 480 if self.non_aliasing_add_test: 481 self.assertAllEqual(expected_input_grad, input_grad.eval()) 482 483 def testGradientsRank3SliceUpdate(self): 484 indices = constant_op.constant( 485 [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=dtypes.int32) 486 updates = constant_op.constant( 487 [[[5, 7], [2, 4]], [[1, 3], [6, 8]]], dtype=dtypes.float64) 488 shape = constant_op.constant([2, 2, 2], dtype=dtypes.int32) 489 input_ = array_ops.zeros(shape, dtype=dtypes.float64) 490 outputs = self.scatter_nd(indices, updates, shape, input_) 491 492 grad_vals = constant_op.constant( 493 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64) 494 updates_grad, input_grad = gradients_impl.gradients( 495 [outputs], [updates, input_], [grad_vals]) 496 expected_updates_grad = np.array( 497 [[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64) 498 expected_input_grad = np.array( 499 [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64) 500 with self.test_session(): 501 self.assertAllEqual(expected_updates_grad, updates_grad.eval()) 502 if self.non_aliasing_add_test: 503 self.assertAllEqual(expected_input_grad, input_grad.eval()) 504 505 def testGradientsRank7SliceUpdate(self): 506 indices = constant_op.constant( 507 [[[ 508 [[[[0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0]]]], 509 [[[[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1]]]] 510 ]]], dtype=dtypes.int32) 511 updates = constant_op.constant( 512 [[[ 513 [[[[5, 6], [2, 4]]]], 514 [[[[1, 3], [6, 8]]]] 515 ]]], dtype=dtypes.float64) 516 shape = constant_op.constant([1, 1, 2, 1, 1, 2, 2], dtype=dtypes.int32) 517 input_ = array_ops.zeros(shape, dtype=dtypes.float64) 518 outputs = self.scatter_nd(indices, updates, shape, input_) 519 520 grad_vals = constant_op.constant( 521 [[[ 522 [[[[1, 2], [3, 4]]]], 523 [[[[5, 6], [7, 8]]]] 524 ]]], dtype=dtypes.float64) 525 updates_grad, input_grad = gradients_impl.gradients( 526 [outputs], [updates, input_], [grad_vals]) 527 expected_updates_grad = np.array( 528 [[[ 529 [[[[3, 4], [5, 6]]]], 530 [[[[1, 2], [7, 8]]]] 531 ]]], dtype=np.float64) 532 expected_input_grad = np.array( 533 [[[ 534 [[[[1, 2], [3, 4]]]], 535 [[[[5, 6], [7, 8]]]] 536 ]]], dtype=np.float64) 537 with self.test_session(): 538 self.assertAllEqual(expected_updates_grad, updates_grad.eval()) 539 if self.non_aliasing_add_test: 540 self.assertAllEqual(expected_input_grad, input_grad.eval()) 541 542 def testScatterNdRepatedIndicesAdd(self): 543 indices = array_ops.zeros([100000, 1], dtypes.int32) 544 values = np.random.randn(100000) 545 shape = [1] 546 with self.test_session(): 547 val = self.scatter_nd(indices, values, shape).eval() 548 self.assertAllClose([np.sum(values)], val) 549 550 def testSmokeScatterNdBatch2DSliceDim2(self): 551 with self.test_session(): 552 indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32) 553 values = array_ops.zeros([3, 5, 7]) 554 shape = [4, 6, 7] 555 self.scatter_nd(indices, values, shape).eval() 556 557 def testSmokeScatterNdBatch1DSliceDim2(self): 558 with self.test_session(): 559 indices = array_ops.zeros([0, 2], dtype=dtypes.int32) 560 values = array_ops.zeros([0, 7]) 561 shape = [4, 6, 7] 562 self.scatter_nd(indices, values, shape).eval() 563 564 def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self): 565 with self.test_session(): 566 indices = array_ops.zeros([1, 3], dtype=dtypes.int32) 567 values = array_ops.zeros([1, 6, 7, 8, 9]) 568 shape = [3, 4, 5, 6, 7, 8, 9] 569 self.scatter_nd(indices, values, shape).eval() 570 571 def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self): 572 with self.test_session(): 573 indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32) 574 values = array_ops.zeros([1, 2, 6, 7, 8, 9]) 575 shape = [3, 4, 5, 6, 7, 8, 9] 576 self.scatter_nd(indices, values, shape).eval() 577 578 579 class ScatterNdNonAliasingAddTest(ScatterNdTest): 580 non_aliasing_add_test = True 581 582 def scatter_nd(self, indices, updates, shape, input_=None): 583 input_ = (input_ if input_ is not None else array_ops.zeros( 584 shape, dtype=updates.dtype)) 585 return array_ops.scatter_nd_non_aliasing_add(input_, indices, updates) 586 587 588 if __name__ == "__main__": 589 test.main() 590