Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.tf.scatter_nd."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 
     23 import numpy as np
     24 
     25 from tensorflow.python.client import session
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import gradients_impl
     30 from tensorflow.python.ops import resource_variable_ops
     31 from tensorflow.python.ops import state_ops
     32 from tensorflow.python.ops import variables
     33 from tensorflow.python.platform import test
     34 
     35 
     36 def _AsType(v, vtype):
     37   return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
     38 
     39 
     40 def _FlatInnerDims(tensor, ndims=2):
     41   shape = list(tensor.shape)
     42   return tensor.reshape([
     43       functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)
     44   ] + shape[-ndims + 1:])
     45 
     46 
     47 def _FlatOuterDims(tensor, ndims=2):
     48   shape = list(tensor.shape)
     49   return tensor.reshape(shape[:ndims - 1] + [
     50       functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)
     51   ])
     52 
     53 
     54 def _NumpyScatterNd(ref, indices, updates, op):
     55   ixdim = indices.shape[-1]
     56   num_updates = indices.size // ixdim
     57   total_nd = len(ref.shape)
     58   slice_size = 1
     59   for i in range(ixdim, total_nd):
     60     slice_size *= ref.shape[i]
     61   flat_indices = _FlatInnerDims(indices)
     62   flat_updates = updates.reshape((num_updates, slice_size))
     63   output_flat = _FlatOuterDims(ref, ixdim + 1)
     64   for ix_updates, ix_output in enumerate(flat_indices):
     65     ix_output = tuple(ix_output)
     66     output_flat[ix_output] = op(output_flat[ix_output],
     67                                 flat_updates[ix_updates])
     68   return output_flat.reshape(ref.shape)
     69 
     70 
     71 def _NumpyUpdate(ref, indices, updates):
     72   return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
     73 
     74 
     75 def _NumpyAdd(ref, indices, updates):
     76   return _NumpyScatterNd(ref, indices, updates, lambda p, u: p + u)
     77 
     78 
     79 def _NumpySub(ref, indices, updates):
     80   return _NumpyScatterNd(ref, indices, updates, lambda p, u: p - u)
     81 
     82 
     83 def _NumpyMul(ref, indices, updates):
     84   return _NumpyScatterNd(ref, indices, updates, lambda p, u: p * u)
     85 
     86 
     87 def _NumpyDiv(ref, indices, updates):
     88   return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
     89 
     90 
     91 class StatefulScatterNdTest(test.TestCase):
     92 
     93   def _VariableRankTest(self,
     94                         np_scatter,
     95                         tf_scatter,
     96                         vtype,
     97                         itype,
     98                         repeat_indices=False):
     99     np.random.seed(8)
    100     ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
    101     indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
    102     with self.test_session(use_gpu=True):
    103       for ref_shape, indices_shape in zip(ref_shapes, indices_shapes):
    104         num_updates = indices_shape[0]
    105         ixdim = indices_shape[-1]
    106 
    107         indexable_area_shape = ()
    108         for i in range(ixdim):
    109           indexable_area_shape += (ref_shape[i],)
    110         all_indices = [
    111             list(coord)
    112             for coord, _ in np.ndenumerate(
    113                 np.empty(indexable_area_shape, vtype))
    114         ]
    115         np.random.shuffle(all_indices)
    116         indices = np.array(all_indices[:num_updates])
    117 
    118         if num_updates > 1 and repeat_indices:
    119           indices = indices[:num_updates // 2]
    120           for _ in range(num_updates - num_updates // 2):
    121             indices = np.append(
    122                 indices, [indices[np.random.randint(num_updates // 2)]], axis=0)
    123           np.random.shuffle(indices)
    124         indices = _AsType(indices[:num_updates], itype)
    125 
    126         updates_shape = (num_updates,)
    127         for i in range(ixdim, len(ref_shape)):
    128           updates_shape += (ref_shape[i],)
    129         updates = _AsType(np.random.randn(*(updates_shape)), vtype)
    130         ref = _AsType(np.random.randn(*(ref_shape)), vtype)
    131 
    132         # Scatter via numpy
    133         new = ref.copy()
    134         np_scatter(new, indices, updates)
    135         # Scatter via tensorflow
    136         ref_var = variables.Variable(ref)
    137         ref_var.initializer.run()
    138         tf_scatter(ref_var, indices, updates).eval()
    139 
    140         # Compare
    141         self.assertAllClose(new, ref_var.eval())
    142 
    143   def _VariableRankTests(self, np_scatter, tf_scatter):
    144     for vtype in (np.float32, np.float64, np.complex64, np.complex128):
    145       for itype in (np.int32, np.int64):
    146         self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
    147 
    148   def testSimple(self):
    149     indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
    150     updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
    151     ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
    152     expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
    153     scatter = state_ops.scatter_nd_update(ref, indices, updates)
    154     init = variables.global_variables_initializer()
    155 
    156     with self.test_session(use_gpu=True) as sess:
    157       sess.run(init)
    158       result = sess.run(scatter)
    159       self.assertAllClose(result, expected)
    160 
    161   def testSimpleResource(self):
    162     indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
    163     updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
    164     ref = resource_variable_ops.ResourceVariable(
    165         [0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
    166     expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
    167     scatter = state_ops.scatter_nd_update(ref, indices, updates)
    168     init = variables.global_variables_initializer()
    169 
    170     with self.test_session(use_gpu=True) as sess:
    171       sess.run(init)
    172       sess.run(scatter)
    173       self.assertAllClose(ref.eval(), expected)
    174 
    175   def testSimple2(self):
    176     indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32)
    177     updates = constant_op.constant([11., 12.], dtype=dtypes.float32)
    178     ref = variables.Variable(
    179         [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
    180     expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
    181     scatter = state_ops.scatter_nd_update(ref, indices, updates)
    182     init = variables.global_variables_initializer()
    183 
    184     with self.test_session(use_gpu=True) as sess:
    185       sess.run(init)
    186       result = sess.run(scatter)
    187       self.assertAllClose(result, expected)
    188 
    189   def testSimple3(self):
    190     indices = constant_op.constant([[1]], dtype=dtypes.int32)
    191     updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32)
    192     ref = variables.Variable(
    193         [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
    194     expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
    195     scatter = state_ops.scatter_nd_update(ref, indices, updates)
    196     init = variables.global_variables_initializer()
    197 
    198     with self.test_session(use_gpu=True) as sess:
    199       sess.run(init)
    200       result = sess.run(scatter)
    201       self.assertAllClose(result, expected)
    202 
    203   def testVariableRankUpdate(self):
    204     self._VariableRankTests(_NumpyUpdate, state_ops.scatter_nd_update)
    205 
    206   def testVariableRankAdd(self):
    207     self._VariableRankTests(_NumpyAdd, state_ops.scatter_nd_add)
    208 
    209   def testVariableRankSub(self):
    210     self._VariableRankTests(_NumpySub, state_ops.scatter_nd_sub)
    211 
    212   # TODO(ebrevdo): Re-enable when we need ScatterNdMul.
    213   # def testVariableRankMul(self):
    214   #   self._VariableRankTests(_NumpyMul, state_ops.scatter_nd_mul)
    215 
    216   # TODO(ebrevdo): Re-enable when we need ScatterNdDiv.
    217   # def testVariableRankDiv(self):
    218   #   self._VariableRankTests(_NumpyDiv, state_ops.scatter_nd_div)
    219 
    220   def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
    221     for vtype in (np.float32, np.float64):
    222       for itype in (np.int32, np.int64):
    223         self._VariableRankTest(
    224             np_scatter, tf_scatter, vtype, itype, repeat_indices=True)
    225 
    226   def testScatterRepeatIndices(self):
    227     """This tests scatter_add using indices that repeat."""
    228     self._ScatterRepeatIndicesTest(_NumpyAdd, state_ops.scatter_nd_add)
    229     self._ScatterRepeatIndicesTest(_NumpySub, state_ops.scatter_nd_sub)
    230     # TODO(ebrevdo): Re-enable when we need ScatterNdMul and ScatterNdDiv.
    231     # self._ScatterRepeatIndicesTest(_NumpyMul, state_ops.scatter_nd_mul)
    232     # self._ScatterRepeatIndicesTest(_NumpyDiv, state_ops.scatter_nd_div)
    233 
    234   # TODO(simister): Re-enable once binary size increase due to
    235   # extra templating is back under control and this op is re-enabled
    236   # def testBooleanScatterUpdate(self):
    237   #   with self.test_session(use_gpu=False) as session:
    238   #     var = tf.Variable([True, False])
    239   #     update0 = tf.scatter_nd_update(var, [[1]], [True])
    240   #     update1 = tf.scatter_nd_update(
    241   #         var, tf.constant(
    242   #             [[0]], dtype=tf.int64), [False])
    243   #     var.initializer.run()
    244   #     session.run([update0, update1])
    245   #     self.assertAllEqual([False, True], var.eval())
    246 
    247   def testScatterOutOfRangeCpu(self):
    248     # TODO(simister): Re-enable once binary size increase due to
    249     # scatter_nd ops is under control.
    250     #  tf.scatter_nd_mul, tf.scatter_nd_div,
    251     for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
    252                state_ops.scatter_nd_update):
    253       params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
    254       updates = np.array([-3, -4, -5]).astype(np.float32)
    255       with self.test_session(use_gpu=False):
    256         ref = variables.Variable(params)
    257         ref.initializer.run()
    258 
    259         # Indices all in range, no problem.
    260         indices = np.array([[2], [0], [5]])
    261         op(ref, indices, updates).eval()
    262 
    263         # Test some out of range errors.
    264         indices = np.array([[-1], [0], [5]])
    265         with self.assertRaisesOpError(
    266             r"Invalid indices: \[0,0\] = \[-1\] does not index into \[6\]"):
    267           op(ref, indices, updates).eval()
    268 
    269         indices = np.array([[2], [0], [6]])
    270         with self.assertRaisesOpError(
    271             r"Invalid indices: \[2,0\] = \[6\] does not index into \[6\]"):
    272           op(ref, indices, updates).eval()
    273 
    274   def testRank3ValidShape(self):
    275     indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    276     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    277     shape = np.array([2, 2, 2])
    278     ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    279     self.assertAllEqual(
    280         state_ops.scatter_nd_update(ref, indices,
    281                                     updates).get_shape().as_list(), shape)
    282 
    283   def testExtraIndicesDimensions(self):
    284     indices = array_ops.zeros([1, 1, 2], dtypes.int32)
    285     updates = array_ops.zeros([1, 1], dtypes.int32)
    286     shape = np.array([2, 2])
    287     ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    288     scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
    289     self.assertAllEqual(scatter_update.get_shape().as_list(), shape)
    290 
    291     expected_result = np.zeros([2, 2], dtype=np.int32)
    292     with self.test_session():
    293       ref.initializer.run()
    294       self.assertAllEqual(expected_result, scatter_update.eval())
    295 
    296   def testRank3InvalidShape1(self):
    297     indices = array_ops.zeros([3, 2, 2], dtypes.int32)
    298     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    299     shape = np.array([2, 2, 2])
    300     ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    301     with self.assertRaisesWithPredicateMatch(
    302         ValueError, "The outer \\d+ dimensions of indices\\.shape="):
    303       state_ops.scatter_nd_update(ref, indices, updates)
    304 
    305   def testRank3InvalidShape2(self):
    306     indices = array_ops.zeros([2, 2, 1], dtypes.int32)
    307     updates = array_ops.zeros([2, 2], dtypes.int32)
    308     shape = np.array([2, 2, 2])
    309     ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    310     with self.assertRaisesWithPredicateMatch(
    311         ValueError, "The inner \\d+ dimensions of input\\.shape="):
    312       state_ops.scatter_nd_update(ref, indices, updates)
    313 
    314   def testConcurrentUpdates(self):
    315     num_updates = 10000
    316     update_values = np.random.rand(num_updates)
    317     ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64)
    318     indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32)
    319     updates = constant_op.constant(update_values, dtype=dtypes.float64)
    320 
    321     expected_result = np.zeros([2, 2], dtype=np.float64)
    322     expected_result[0, 1] = np.sum(update_values)
    323 
    324     scatter = state_ops.scatter_nd_add(ref, indices, updates)
    325     init = variables.global_variables_initializer()
    326 
    327     with session.Session() as sess:
    328       sess.run(init)
    329       result = sess.run(scatter)
    330       assert np.allclose(result, expected_result)
    331 
    332   # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
    333   def _disabledTestScatterOutOfRangeGpu(self):
    334     if not test.IsBuiltWithCuda():
    335       return
    336     # TODO(simister): Re-enable once binary size increase due to
    337     # scatter_nd ops is under control.
    338     # tf.scatter_nd_mul, tf.scatter_nd_div,
    339     for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
    340                state_ops.scatter_nd_update):
    341       params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
    342       updates = np.array([-3, -4, -5]).astype(np.float32)
    343       # With GPU, the code ignores indices that are out of range.
    344       # We don't test the implementation; just test there's no failures.
    345       with self.test_session(force_gpu=True):
    346         ref = variables.Variable(params)
    347         ref.initializer.run()
    348 
    349         # Indices all in range, no problem.
    350         indices = np.array([2, 0, 5])
    351         op(ref, indices, updates).eval()
    352 
    353         # Indices out of range should not fail.
    354         indices = np.array([-1, 0, 5])
    355         op(ref, indices, updates).eval()
    356         indices = np.array([2, 0, 6])
    357         op(ref, indices, updates).eval()
    358 
    359 
    360 class ScatterNdTest(test.TestCase):
    361   non_aliasing_add_test = False
    362 
    363   def scatter_nd(self, indices, updates, shape, input_=None):
    364     del input_  # input_ is not used in scatter_nd
    365     return array_ops.scatter_nd(indices, updates, shape)
    366 
    367   def testRank3ValidShape(self):
    368     indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    369     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    370     shape = np.array([2, 2, 2])
    371     self.assertAllEqual(
    372         self.scatter_nd(indices, updates, shape).get_shape().as_list(), shape)
    373 
    374   def testExtraIndicesDimensions(self):
    375     indices = array_ops.zeros([1, 1, 2], dtypes.int32)
    376     updates = array_ops.zeros([1, 1], dtypes.int32)
    377     shape = np.array([2, 2])
    378     scatter = self.scatter_nd(indices, updates, shape)
    379     self.assertAllEqual(scatter.get_shape().as_list(), shape)
    380     expected_result = np.zeros([2, 2], dtype=np.int32)
    381     with self.test_session():
    382       self.assertAllEqual(expected_result, scatter.eval())
    383 
    384   def testUndefinedIndicesShape(self):
    385     indices = array_ops.placeholder(dtypes.int32, shape=None)
    386     updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    387     shape = constant_op.constant([2, 2, 2], dtypes.int32)
    388     self.scatter_nd(indices, updates, shape)
    389 
    390   def testUndefinedUpdatesShape(self):
    391     indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    392     updates = array_ops.placeholder(dtypes.int32, shape=None)
    393     shape = constant_op.constant([2, 2, 2], dtypes.int32)
    394     self.scatter_nd(indices, updates, shape)
    395 
    396   def testUndefinedOutputShape(self):
    397     indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    398     updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    399     shape = array_ops.placeholder(dtypes.int32, shape=[None])
    400     self.scatter_nd(indices, updates, shape)
    401 
    402   def testEmptyOutputShape1(self):
    403     indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    404     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    405     shape = constant_op.constant([0, 3, 2], dtypes.int32)
    406 
    407     with self.assertRaisesWithPredicateMatch(
    408         ValueError, "Indices and updates specified for empty output shape"):
    409       self.scatter_nd(indices, updates, shape)
    410 
    411   def testEmptyOutputShape2(self):
    412     indices = array_ops.placeholder(dtypes.int32, shape=None)
    413     updates = array_ops.placeholder(dtypes.int32, shape=None)
    414     shape = constant_op.constant([0, 3, 2], dtypes.int32)
    415 
    416     with self.test_session():
    417       with self.assertRaisesOpError(
    418           "Indices and updates specified for empty output"):
    419         self.scatter_nd(indices, updates, shape).eval(feed_dict={
    420             indices: np.zeros([2, 2, 2], dtype=np.int32),
    421             updates: np.zeros([2, 2, 2], dtype=np.int32)
    422         })
    423 
    424   def testEmptyOutputShape3(self):
    425     indices = array_ops.zeros([0], dtypes.int32)
    426     updates = array_ops.zeros([0], dtypes.int32)
    427     shape = constant_op.constant([0], dtypes.int32)
    428     scatter = self.scatter_nd(indices, updates, shape)
    429 
    430     with self.test_session():
    431       self.assertEqual(scatter.eval().size, 0)
    432 
    433   def testRank3InvalidShape1(self):
    434     indices = array_ops.zeros([3, 2, 2], dtypes.int32)
    435     updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    436     shape = np.array([2, 2, 2])
    437     with self.assertRaisesWithPredicateMatch(
    438         ValueError, "The outer \\d+ dimensions of indices\\.shape="):
    439       self.scatter_nd(indices, updates, shape)
    440 
    441   def testRank3InvalidShape2(self):
    442     indices = array_ops.zeros([2, 2, 1], dtypes.int32)
    443     updates = array_ops.zeros([2, 2], dtypes.int32)
    444     shape = np.array([2, 2, 2])
    445     with self.assertRaisesWithPredicateMatch(
    446         ValueError, "The inner \\d+ dimensions of (input|output)\\.shape="):
    447       self.scatter_nd(indices, updates, shape)
    448 
    449   def testGradientsRank2ElementUpdate(self):
    450     indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32)
    451     updates = constant_op.constant([1, 4], dtype=dtypes.float64)
    452     shape = constant_op.constant([2, 2], dtype=dtypes.int32)
    453     input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    454     outputs = self.scatter_nd(indices, updates, shape, input_)
    455 
    456     grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    457     updates_grad, input_grad = gradients_impl.gradients(
    458         [outputs], [updates, input_], [grad_vals])
    459     expected_updates_grad = np.array([1, 4], dtype=np.float64)
    460     expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
    461     with self.test_session():
    462       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
    463       if self.non_aliasing_add_test:
    464         self.assertAllEqual(expected_input_grad, input_grad.eval())
    465 
    466   def testGradientsRank2SliceUpdate(self):
    467     indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
    468     updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    469     shape = constant_op.constant([2, 2], dtype=dtypes.int32)
    470     input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    471     outputs = self.scatter_nd(indices, updates, shape, input_)
    472 
    473     grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    474     updates_grad, input_grad = gradients_impl.gradients(
    475         [outputs], [updates, input_], [grad_vals])
    476     expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
    477     expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64)
    478     with self.test_session():
    479       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
    480       if self.non_aliasing_add_test:
    481         self.assertAllEqual(expected_input_grad, input_grad.eval())
    482 
    483   def testGradientsRank3SliceUpdate(self):
    484     indices = constant_op.constant(
    485         [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=dtypes.int32)
    486     updates = constant_op.constant(
    487         [[[5, 7], [2, 4]], [[1, 3], [6, 8]]], dtype=dtypes.float64)
    488     shape = constant_op.constant([2, 2, 2], dtype=dtypes.int32)
    489     input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    490     outputs = self.scatter_nd(indices, updates, shape, input_)
    491 
    492     grad_vals = constant_op.constant(
    493         [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64)
    494     updates_grad, input_grad = gradients_impl.gradients(
    495         [outputs], [updates, input_], [grad_vals])
    496     expected_updates_grad = np.array(
    497         [[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
    498     expected_input_grad = np.array(
    499         [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64)
    500     with self.test_session():
    501       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
    502       if self.non_aliasing_add_test:
    503         self.assertAllEqual(expected_input_grad, input_grad.eval())
    504 
    505   def testGradientsRank7SliceUpdate(self):
    506     indices = constant_op.constant(
    507         [[[
    508             [[[[0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0]]]],
    509             [[[[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1]]]]
    510         ]]], dtype=dtypes.int32)
    511     updates = constant_op.constant(
    512         [[[
    513             [[[[5, 6], [2, 4]]]],
    514             [[[[1, 3], [6, 8]]]]
    515         ]]], dtype=dtypes.float64)
    516     shape = constant_op.constant([1, 1, 2, 1, 1, 2, 2], dtype=dtypes.int32)
    517     input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    518     outputs = self.scatter_nd(indices, updates, shape, input_)
    519 
    520     grad_vals = constant_op.constant(
    521         [[[
    522             [[[[1, 2], [3, 4]]]],
    523             [[[[5, 6], [7, 8]]]]
    524         ]]], dtype=dtypes.float64)
    525     updates_grad, input_grad = gradients_impl.gradients(
    526         [outputs], [updates, input_], [grad_vals])
    527     expected_updates_grad = np.array(
    528         [[[
    529             [[[[3, 4], [5, 6]]]],
    530             [[[[1, 2], [7, 8]]]]
    531         ]]], dtype=np.float64)
    532     expected_input_grad = np.array(
    533         [[[
    534             [[[[1, 2], [3, 4]]]],
    535             [[[[5, 6], [7, 8]]]]
    536         ]]], dtype=np.float64)
    537     with self.test_session():
    538       self.assertAllEqual(expected_updates_grad, updates_grad.eval())
    539       if self.non_aliasing_add_test:
    540         self.assertAllEqual(expected_input_grad, input_grad.eval())
    541 
    542   def testScatterNdRepatedIndicesAdd(self):
    543     indices = array_ops.zeros([100000, 1], dtypes.int32)
    544     values = np.random.randn(100000)
    545     shape = [1]
    546     with self.test_session():
    547       val = self.scatter_nd(indices, values, shape).eval()
    548     self.assertAllClose([np.sum(values)], val)
    549 
    550   def testSmokeScatterNdBatch2DSliceDim2(self):
    551     with self.test_session():
    552       indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
    553       values = array_ops.zeros([3, 5, 7])
    554       shape = [4, 6, 7]
    555       self.scatter_nd(indices, values, shape).eval()
    556 
    557   def testSmokeScatterNdBatch1DSliceDim2(self):
    558     with self.test_session():
    559       indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
    560       values = array_ops.zeros([0, 7])
    561       shape = [4, 6, 7]
    562       self.scatter_nd(indices, values, shape).eval()
    563 
    564   def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
    565     with self.test_session():
    566       indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
    567       values = array_ops.zeros([1, 6, 7, 8, 9])
    568       shape = [3, 4, 5, 6, 7, 8, 9]
    569       self.scatter_nd(indices, values, shape).eval()
    570 
    571   def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
    572     with self.test_session():
    573       indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
    574       values = array_ops.zeros([1, 2, 6, 7, 8, 9])
    575       shape = [3, 4, 5, 6, 7, 8, 9]
    576       self.scatter_nd(indices, values, shape).eval()
    577 
    578 
    579 class ScatterNdNonAliasingAddTest(ScatterNdTest):
    580   non_aliasing_add_test = True
    581 
    582   def scatter_nd(self, indices, updates, shape, input_=None):
    583     input_ = (input_ if input_ is not None else array_ops.zeros(
    584         shape, dtype=updates.dtype))
    585     return array_ops.scatter_nd_non_aliasing_add(input_, indices, updates)
    586 
    587 
    588 if __name__ == "__main__":
    589   test.main()
    590