1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for tensorflow.ops.tf.gather.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.python.framework import constant_op 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import array_ops 27 from tensorflow.python.ops import gradients_impl 28 from tensorflow.python.platform import test 29 30 _TEST_TYPES = (dtypes.float32, dtypes.complex64, dtypes.complex128) 31 32 33 class GatherTest(test.TestCase): 34 35 def _buildParams(self, data, dtype): 36 data = data.astype(dtype.as_numpy_dtype) 37 # For complex types, add an index-dependent imaginary component so we can 38 # tell we got the right value. 39 if dtype.is_complex: 40 return data + 10j * data 41 return data 42 43 def testScalar1D(self): 44 with self.test_session(use_gpu=True): 45 data = np.array([0, 1, 2, 3, 7, 5]) 46 for dtype in _TEST_TYPES: 47 for indices in 4, [1, 2, 2, 4, 5]: 48 params_np = self._buildParams(data, dtype) 49 params = constant_op.constant(params_np) 50 indices_tf = constant_op.constant(indices) 51 gather_t = array_ops.gather(params, indices_tf) 52 gather_val = gather_t.eval() 53 np_val = params_np[indices] 54 self.assertAllEqual(np_val, gather_val) 55 self.assertEqual(np_val.shape, gather_t.get_shape()) 56 57 def testScalar2D(self): 58 with self.test_session(use_gpu=True): 59 data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], 60 [9, 10, 11], [12, 13, 14]]) 61 for dtype in _TEST_TYPES: 62 for axis in range(data.ndim): 63 params_np = self._buildParams(data, dtype) 64 params = constant_op.constant(params_np) 65 indices = constant_op.constant(2) 66 gather_t = array_ops.gather(params, indices, axis=axis) 67 gather_val = gather_t.eval() 68 self.assertAllEqual(np.take(params_np, 2, axis=axis), gather_val) 69 expected_shape = data.shape[:axis] + data.shape[axis + 1:] 70 self.assertEqual(expected_shape, gather_t.get_shape()) 71 72 def testSimpleTwoD32(self): 73 with self.test_session(use_gpu=True): 74 data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], 75 [9, 10, 11], [12, 13, 14]]) 76 for dtype in _TEST_TYPES: 77 for axis in range(data.ndim): 78 params_np = self._buildParams(data, dtype) 79 params = constant_op.constant(params_np) 80 # The indices must be in bounds for any axis. 81 indices = constant_op.constant([0, 1, 0, 2]) 82 gather_t = array_ops.gather(params, indices, axis=axis) 83 gather_val = gather_t.eval() 84 self.assertAllEqual(np.take(params_np, [0, 1, 0, 2], axis=axis), 85 gather_val) 86 expected_shape = data.shape[:axis] + (4,) + data.shape[axis + 1:] 87 self.assertEqual(expected_shape, gather_t.get_shape()) 88 89 def testHigherRank(self): 90 # We check that scalar and empty indices shapes work as well 91 shape = (2, 1, 3, 2) 92 for indices_shape in (), (0,), (2, 0), (2, 3): 93 for dtype in _TEST_TYPES: 94 for axis in range(len(shape)): 95 params = self._buildParams(np.random.randn(*shape), dtype) 96 indices = np.random.randint(shape[axis], size=indices_shape) 97 with self.test_session(use_gpu=True) as sess: 98 tf_params = constant_op.constant(params) 99 tf_indices = constant_op.constant(indices) 100 # Check that both positive and negative indices for axis work. 101 tf_axis = constant_op.constant(axis) 102 tf_negative_axis = constant_op.constant(-len(shape) + axis) 103 gather = array_ops.gather(tf_params, tf_indices, axis=tf_axis) 104 gather_negative_axis = array_ops.gather( 105 tf_params, tf_indices, axis=tf_negative_axis) 106 gather_value, gather_negative_axis_value = sess.run( 107 [gather, gather_negative_axis]) 108 gather_np = np.take(params, indices, axis) 109 self.assertAllEqual(gather_np, gather_value) 110 self.assertAllEqual(gather_np, gather_negative_axis_value) 111 expected_shape = (params.shape[:axis] + indices.shape + 112 params.shape[axis + 1:]) 113 self.assertEqual(expected_shape, gather.shape) 114 self.assertEqual(expected_shape, gather_negative_axis.shape) 115 116 # Test gradients 117 gather_grad = np.random.randn( 118 *gather.get_shape().as_list()).astype(dtype.as_numpy_dtype) 119 if dtype.is_complex: 120 gather_grad -= 1j * gather_grad 121 params_grad, indices_grad, axis_grad = gradients_impl.gradients( 122 gather, [tf_params, tf_indices, tf_axis], gather_grad) 123 self.assertEqual(indices_grad, None) 124 self.assertEqual(axis_grad, None) 125 # For axis 0, we are able to create an efficient IndexedSlices for 126 # the gradient. 127 if axis == 0: 128 self.assertEqual(type(params_grad), ops.IndexedSlices) 129 params_grad = ops.convert_to_tensor(params_grad) 130 correct_params_grad = np.zeros(shape).astype(dtype.as_numpy_dtype) 131 outer_dims = axis 132 inner_dims = len(shape) - axis - 1 133 gather_grad = gather_grad.reshape( 134 shape[:axis] + (indices.size,) + shape[axis + 1:]) 135 for source_index, dest_index in enumerate(indices.flat): 136 dest_slice = ((slice(None),) * outer_dims + (dest_index,) + 137 (slice(None),) * inner_dims) 138 source_slice = ((slice(None),) * outer_dims + (source_index,) + 139 (slice(None),) * inner_dims) 140 correct_params_grad[dest_slice] += gather_grad[source_slice] 141 self.assertAllClose(correct_params_grad, params_grad.eval(), 142 atol=2e-6, rtol=2e-6) 143 144 def testString(self): 145 params = np.array([[b"asdf", b"zxcv"], [b"qwer", b"uiop"]]) 146 with self.test_session(): 147 self.assertAllEqual([b"qwer", b"uiop"], 148 array_ops.gather(params, 1, axis=0).eval()) 149 self.assertAllEqual([b"asdf", b"qwer"], 150 array_ops.gather(params, 0, axis=1).eval()) 151 152 def testUnknownIndices(self): 153 params = constant_op.constant([[0, 1, 2]]) 154 indices = array_ops.placeholder(dtypes.int32) 155 gather_t = array_ops.gather(params, indices) 156 self.assertEqual(None, gather_t.get_shape()) 157 158 def testUnknownAxis(self): 159 params = constant_op.constant([[0, 1, 2]]) 160 indices = constant_op.constant([[0, 0], [0, 0]]) 161 axis = array_ops.placeholder(dtypes.int32) 162 gather_t = array_ops.gather(params, indices, axis=axis) 163 # Rank 2 params with rank 2 indices results in a rank 3 shape. 164 self.assertEqual([None, None, None], gather_t.shape.as_list()) 165 166 # If indices is also unknown the result rank is unknown. 167 indices = array_ops.placeholder(dtypes.int32) 168 gather_t = array_ops.gather(params, indices, axis=axis) 169 self.assertEqual(None, gather_t.shape) 170 171 def testBadIndices(self): 172 with self.test_session(use_gpu=True): 173 params = [[0, 1, 2], [3, 4, 5]] 174 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 2\)"): 175 array_ops.gather(params, [[7]], axis=0).eval() 176 with self.assertRaisesOpError(r"indices\[0,0\] = 7 is not in \[0, 3\)"): 177 array_ops.gather(params, [[7]], axis=1).eval() 178 179 def testBadAxis(self): 180 with self.test_session(use_gpu=True): 181 params = [0, 1, 2] 182 params_ph = array_ops.placeholder(dtypes.int32) 183 indices = 0 184 for bad_axis in (1, 2, -2): 185 # Shape inference can validate axis for known params rank. 186 with self.assertRaisesWithPredicateMatch( 187 ValueError, "Shape must be at least rank . but is rank 1"): 188 array_ops.gather(params, indices, axis=bad_axis) 189 # If params rank is unknown, an op error occurs. 190 with self.assertRaisesOpError( 191 r"Expected axis in the range \[-1, 1\), but got %s" % bad_axis): 192 array_ops.gather(params_ph, indices, axis=bad_axis).eval( 193 feed_dict={params_ph: params}) 194 195 def testEmptySlices(self): 196 with self.test_session(use_gpu=True): 197 for dtype in _TEST_TYPES: 198 for itype in np.int32, np.int64: 199 # Leading axis gather. 200 params = np.zeros((7, 0, 0), dtype=dtype.as_numpy_dtype) 201 indices = np.array([3, 4], dtype=itype) 202 gather = array_ops.gather(params, indices, axis=0) 203 self.assertAllEqual(gather.eval(), np.zeros((2, 0, 0))) 204 205 # Middle axis gather. 206 params = np.zeros((0, 7, 0), dtype=dtype.as_numpy_dtype) 207 gather = array_ops.gather(params, indices, axis=1) 208 self.assertAllEqual(gather.eval(), np.zeros((0, 2, 0))) 209 210 # Trailing axis gather. 211 params = np.zeros((0, 0, 7), dtype=dtype.as_numpy_dtype) 212 gather = array_ops.gather(params, indices, axis=2) 213 self.assertAllEqual(gather.eval(), np.zeros((0, 0, 2))) 214 215 216 if __name__ == "__main__": 217 test.main() 218