Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.ops.tf.gather."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import constant_op
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import gradients_impl
     28 from tensorflow.python.platform import test
     29 
     30 _TEST_TYPES = (dtypes.float32, dtypes.complex64, dtypes.complex128)
     31 
     32 
     33 class GatherTest(test.TestCase):
     34 
     35   def _buildParams(self, data, dtype):
     36     data = data.astype(dtype.as_numpy_dtype)
     37     # For complex types, add an index-dependent imaginary component so we can
     38     # tell we got the right value.
     39     if dtype.is_complex:
     40       return data + 10j * data
     41     return data
     42 
     43   def testScalar1D(self):
     44     with self.test_session(use_gpu=True):
     45       data = np.array([0, 1, 2, 3, 7, 5])
     46       for dtype in _TEST_TYPES:
     47         for indices in 4, [1, 2, 2, 4, 5]:
     48           params_np = self._buildParams(data, dtype)
     49           params = constant_op.constant(params_np)
     50           indices_tf = constant_op.constant(indices)
     51           gather_t = array_ops.gather(params, indices_tf)
     52           gather_val = gather_t.eval()
     53           np_val = params_np[indices]
     54           self.assertAllEqual(np_val, gather_val)
     55           self.assertEqual(np_val.shape, gather_t.get_shape())
     56 
     57   def testScalar2D(self):
     58     with self.test_session(use_gpu=True):
     59       data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
     60                        [9, 10, 11], [12, 13, 14]])
     61       for dtype in _TEST_TYPES:
     62         for axis in range(data.ndim):
     63           params_np = self._buildParams(data, dtype)
     64           params = constant_op.constant(params_np)
     65           indices = constant_op.constant(2)
     66           gather_t = array_ops.gather(params, indices, axis=axis)
     67           gather_val = gather_t.eval()
     68           self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val)
     69           expected_shape = data.shape[:axis] + data.shape[axis + 1:]
     70           self.assertEqual(expected_shape, gather_t.get_shape())
     71 
     72   def testSimpleTwoD32(self):
     73     with self.test_session(use_gpu=True):
     74       data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8],
     75                        [9, 10, 11], [12, 13, 14]])
     76       for dtype in _TEST_TYPES:
     77         for axis in range(data.ndim):
     78           params_np = self._buildParams(data, dtype)
     79           params = constant_op.constant(params_np)
     80           # The indices must be in bounds for any axis.
     81           indices = constant_op.constant([0, 1, 0, 2])
     82           gather_t = array_ops.gather(params, indices, axis=axis)
     83           gather_val = gather_t.eval()
     84           self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis),
     85                               gather_val)
     86           expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:]
     87           self.assertEqual(expected_shape, gather_t.get_shape())
     88 
     89   def testHigherRank(self):
     90     # We check that scalar and empty indices shapes work as well
     91     shape = (2, 1, 3, 2)
     92     for indices_shape in (), (0,), (2, 0), (2, 3):
     93       for dtype in _TEST_TYPES:
     94         for axis in range(len(shape)):
     95           params = self._buildParams(np.random.randn(*shape), dtype)
     96           indices = np.random.randint(shape[axis], size=indices_shape)
     97           with self.test_session(use_gpu=True) as sess:
     98             tf_params = constant_op.constant(params)
     99             tf_indices = constant_op.constant(indices)
    100             # Check that both positive and negative indices for axis work.
    101             tf_axis = constant_op.constant(axis)
    102             tf_negative_axis = constant_op.constant(-len(shape) + axis)
    103             gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis)
    104             gather_negative_axis = array_ops.gather(
    105                 tf_params, tf_indices, axis=tf_negative_axis)
    106             gather_value, gather_negative_axis_value = sess.run(
    107                 [gather, gather_negative_axis])
    108             gather_np = np.take(params, indices, axis)
    109             self.assertAllEqual(gather_np, gather_value)
    110             self.assertAllEqual(gather_np, gather_negative_axis_value)
    111             expected_shape = (params.shape[:axis] + indices.shape +
    112                               params.shape[axis + 1:])
    113             self.assertEqual(expected_shape, gather.shape)
    114             self.assertEqual(expected_shape, gather_negative_axis.shape)
    115 
    116             # Test gradients
    117             gather_grad = np.random.randn(
    118                 *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype)
    119             if dtype.is_complex:
    120               gather_grad -= 1j * gather_grad
    121             params_grad, indices_grad, axis_grad = gradients_impl.gradients(
    122                 gather, [tf_params, tf_indices, tf_axis], gather_grad)
    123             self.assertEqual(indices_grad, None)
    124             self.assertEqual(axis_grad, None)
    125             # For axis 0, we are able to create an efficient IndexedSlices for
    126             # the gradient.
    127             if axis == 0:
    128               self.assertEqual(type(params_grad), ops.IndexedSlices)
    129               params_grad = ops.convert_to_tensor(params_grad)
    130             correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype)
    131             outer_dims = axis
    132             inner_dims = len(shape) - axis - 1
    133             gather_grad = gather_grad.reshape(
    134                 shape[:axis] + (indices.size,) + shape[axis + 1:])
    135             for source_index, dest_index in enumerate(indices.flat):
    136               dest_slice = ((slice(None),) * outer_dims + (dest_index,) +
    137                             (slice(None),) * inner_dims)
    138               source_slice = ((slice(None),) * outer_dims + (source_index,) +
    139                               (slice(None),) * inner_dims)
    140               correct_params_grad[dest_slice] += gather_grad[source_slice]
    141             self.assertAllClose(correct_params_grad, params_grad.eval(),
    142                                 atol=2e-6, rtol=2e-6)
    143 
    144   def testString(self):
    145     params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]])
    146     with self.test_session():
    147       self.assertAllEqual([b"qwer", b"uiop"],
    148                           array_ops.gather(params, 1, axis=0).eval())
    149       self.assertAllEqual([b"asdf", b"qwer"],
    150                           array_ops.gather(params, 0, axis=1).eval())
    151 
    152   def testUnknownIndices(self):
    153     params = constant_op.constant([[0, 1, 2]])
    154     indices = array_ops.placeholder(dtypes.int32)
    155     gather_t = array_ops.gather(params, indices)
    156     self.assertEqual(None, gather_t.get_shape())
    157 
    158   def testUnknownAxis(self):
    159     params = constant_op.constant([[0, 1, 2]])
    160     indices = constant_op.constant([[0, 0], [0, 0]])
    161     axis = array_ops.placeholder(dtypes.int32)
    162     gather_t = array_ops.gather(params, indices, axis=axis)
    163     # Rank 2 params with rank 2 indices results in a rank 3 shape.
    164     self.assertEqual([None, None, None], gather_t.shape.as_list())
    165 
    166     # If indices is also unknown the result rank is unknown.
    167     indices = array_ops.placeholder(dtypes.int32)
    168     gather_t = array_ops.gather(params, indices, axis=axis)
    169     self.assertEqual(None, gather_t.shape)
    170 
    171   def testBadIndices(self):
    172     with self.test_session(use_gpu=True):
    173       params = [[0, 1, 2], [3, 4, 5]]
    174       with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"):
    175         array_ops.gather(params, [[7]], axis=0).eval()
    176       with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"):
    177         array_ops.gather(params, [[7]], axis=1).eval()
    178 
    179   def testBadAxis(self):
    180     with self.test_session(use_gpu=True):
    181       params = [0, 1, 2]
    182       params_ph = array_ops.placeholder(dtypes.int32)
    183       indices = 0
    184       for bad_axis in (1, 2, -2):
    185         # Shape inference can validate axis for known params rank.
    186         with self.assertRaisesWithPredicateMatch(
    187             ValueError, "Shape must be at least rank . but is rank 1"):
    188           array_ops.gather(params, indices, axis=bad_axis)
    189         # If params rank is unknown, an op error occurs.
    190         with self.assertRaisesOpError(
    191             r"Expected axis in the range \[-1, 1\), but got %s" % bad_axis):
    192           array_ops.gather(params_ph, indices, axis=bad_axis).eval(
    193               feed_dict={params_ph: params})
    194 
    195   def testEmptySlices(self):
    196     with self.test_session(use_gpu=True):
    197       for dtype in _TEST_TYPES:
    198         for itype in np.int32, np.int64:
    199           # Leading axis gather.
    200           params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype)
    201           indices = np.array([3, 4], dtype=itype)
    202           gather = array_ops.gather(params, indices, axis=0)
    203           self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0)))
    204 
    205           # Middle axis gather.
    206           params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype)
    207           gather = array_ops.gather(params, indices, axis=1)
    208           self.assertAllEqual(gather.eval(), np.zeros((0, 2, 0)))
    209 
    210           # Trailing axis gather.
    211           params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype)
    212           gather = array_ops.gather(params, indices, axis=2)
    213           self.assertAllEqual(gather.eval(), np.zeros((0, 0, 2)))
    214 
    215 
    216 if __name__ == "__main__":
    217   test.main()
    218