1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the DynamicPartition op.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import unittest 22 23 import numpy as np 24 from six.moves import xrange # pylint: disable=redefined-builtin 25 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import dtypes 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.ops import data_flow_ops 30 from tensorflow.python.ops import gradients_impl 31 import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import 32 from tensorflow.python.platform import test 33 34 35 class DynamicPartitionTest(test.TestCase): 36 37 def testSimpleOneDimensional(self): 38 with self.test_session(use_gpu=True) as sess: 39 data = constant_op.constant([0, 13, 2, 39, 4, 17], dtype=dtypes.float32) 40 indices = constant_op.constant([0, 0, 2, 3, 2, 1]) 41 partitions = data_flow_ops.dynamic_partition( 42 data, indices, num_partitions=4) 43 partition_vals = sess.run(partitions) 44 45 self.assertEqual(4, len(partition_vals)) 46 self.assertAllEqual([0, 13], partition_vals[0]) 47 self.assertAllEqual([17], partition_vals[1]) 48 self.assertAllEqual([2, 4], partition_vals[2]) 49 self.assertAllEqual([39], partition_vals[3]) 50 # Vector data input to DynamicPartition results in 51 # `num_partitions` vectors of unknown length. 52 self.assertEqual([None], partitions[0].get_shape().as_list()) 53 self.assertEqual([None], partitions[1].get_shape().as_list()) 54 self.assertEqual([None], partitions[2].get_shape().as_list()) 55 self.assertEqual([None], partitions[3].get_shape().as_list()) 56 57 def testSimpleTwoDimensional(self): 58 with self.test_session(use_gpu=True) as sess: 59 data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], 60 [12, 13, 14], [15, 16, 17]], 61 dtype=dtypes.float32) 62 indices = constant_op.constant([0, 0, 2, 3, 2, 1]) 63 partitions = data_flow_ops.dynamic_partition( 64 data, indices, num_partitions=4) 65 partition_vals = sess.run(partitions) 66 67 self.assertEqual(4, len(partition_vals)) 68 self.assertAllEqual([[0, 1, 2], [3, 4, 5]], partition_vals[0]) 69 self.assertAllEqual([[15, 16, 17]], partition_vals[1]) 70 self.assertAllEqual([[6, 7, 8], [12, 13, 14]], partition_vals[2]) 71 self.assertAllEqual([[9, 10, 11]], partition_vals[3]) 72 # Vector data input to DynamicPartition results in 73 # `num_partitions` matrices with an unknown number of rows, and 3 columns. 74 self.assertEqual([None, 3], partitions[0].get_shape().as_list()) 75 self.assertEqual([None, 3], partitions[1].get_shape().as_list()) 76 self.assertEqual([None, 3], partitions[2].get_shape().as_list()) 77 self.assertEqual([None, 3], partitions[3].get_shape().as_list()) 78 79 def testLargeOneDimensional(self): 80 num = 100000 81 data_list = [x for x in range(num)] 82 indices_list = [x % 2 for x in range(num)] 83 part1 = [x for x in range(num) if x % 2 == 0] 84 part2 = [x for x in range(num) if x % 2 == 1] 85 with self.test_session(use_gpu=True) as sess: 86 data = constant_op.constant(data_list, dtype=dtypes.float32) 87 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 88 partitions = data_flow_ops.dynamic_partition( 89 data, indices, num_partitions=2) 90 partition_vals = sess.run(partitions) 91 92 self.assertEqual(2, len(partition_vals)) 93 self.assertAllEqual(part1, partition_vals[0]) 94 self.assertAllEqual(part2, partition_vals[1]) 95 96 def testLargeTwoDimensional(self): 97 rows = 100000 98 cols = 100 99 data_list = [None] * rows 100 for i in range(rows): 101 data_list[i] = [i for _ in range(cols)] 102 num_partitions = 97 103 indices_list = [(i ** 2) % num_partitions for i in range(rows)] 104 parts = [[] for _ in range(num_partitions)] 105 for i in range(rows): 106 parts[(i ** 2) % num_partitions].append(data_list[i]) 107 with self.test_session(use_gpu=True) as sess: 108 data = constant_op.constant(data_list, dtype=dtypes.float32) 109 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 110 partitions = data_flow_ops.dynamic_partition( 111 data, indices, num_partitions=num_partitions) 112 partition_vals = sess.run(partitions) 113 114 self.assertEqual(num_partitions, len(partition_vals)) 115 for i in range(num_partitions): 116 # reshape because of empty parts 117 parts_np = np.array(parts[i], dtype=np.float).reshape(-1, cols) 118 self.assertAllEqual(parts_np, partition_vals[i]) 119 120 def testSimpleComplex(self): 121 data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j] 122 indices_list = [1, 0, 1, 0] 123 with self.test_session(use_gpu=True) as sess: 124 data = constant_op.constant(data_list, dtype=dtypes.complex64) 125 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 126 partitions = data_flow_ops.dynamic_partition( 127 data, indices, num_partitions=2) 128 partition_vals = sess.run(partitions) 129 130 self.assertEqual(2, len(partition_vals)) 131 self.assertAllEqual([3 + 4j, 7 + 8j], partition_vals[0]) 132 self.assertAllEqual([1 + 2j, 5 + 6j], partition_vals[1]) 133 134 def testScalarPartitions(self): 135 data_list = [10, 13, 12, 11] 136 with self.test_session(use_gpu=True) as sess: 137 data = constant_op.constant(data_list, dtype=dtypes.float64) 138 indices = 3 139 partitions = data_flow_ops.dynamic_partition( 140 data, indices, num_partitions=4) 141 partition_vals = sess.run(partitions) 142 143 self.assertEqual(4, len(partition_vals)) 144 self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), 145 partition_vals[0]) 146 self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), 147 partition_vals[1]) 148 self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4), 149 partition_vals[2]) 150 self.assertAllEqual(np.array([10, 13, 12, 11], 151 dtype=np.float64).reshape(-1, 4), 152 partition_vals[3]) 153 154 def testHigherRank(self): 155 np.random.seed(7) 156 with self.test_session(use_gpu=True) as sess: 157 for n in 2, 3: 158 for shape in (4,), (4, 5), (4, 5, 2): 159 partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape) 160 for extra_shape in (), (6,), (6, 7): 161 data = np.random.randn(*(shape + extra_shape)) 162 partitions_t = constant_op.constant(partitions, dtype=dtypes.int32) 163 data_t = constant_op.constant(data) 164 outputs = data_flow_ops.dynamic_partition( 165 data_t, partitions_t, num_partitions=n) 166 self.assertEqual(n, len(outputs)) 167 outputs_val = sess.run(outputs) 168 for i, output in enumerate(outputs_val): 169 self.assertAllEqual(output, data[partitions == i]) 170 171 # Test gradients 172 outputs_grad = [7 * output for output in outputs_val] 173 grads = gradients_impl.gradients(outputs, [data_t, partitions_t], 174 outputs_grad) 175 self.assertEqual(grads[1], None) # Partitions has no gradients 176 self.assertAllEqual(7 * data, sess.run(grads[0])) 177 178 def testEmptyParts(self): 179 data_list = [1, 2, 3, 4] 180 indices_list = [1, 3, 1, 3] 181 with self.test_session(use_gpu=True) as sess: 182 data = constant_op.constant(data_list, dtype=dtypes.float32) 183 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 184 partitions = data_flow_ops.dynamic_partition( 185 data, indices, num_partitions=4) 186 partition_vals = sess.run(partitions) 187 188 self.assertEqual(4, len(partition_vals)) 189 self.assertAllEqual([], partition_vals[0]) 190 self.assertAllEqual([1, 3], partition_vals[1]) 191 self.assertAllEqual([], partition_vals[2]) 192 self.assertAllEqual([2, 4], partition_vals[3]) 193 194 def testEmptyDataTwoDimensional(self): 195 data_list = [[], []] 196 indices_list = [0, 1] 197 with self.test_session(use_gpu=True) as sess: 198 data = constant_op.constant(data_list, dtype=dtypes.float32) 199 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 200 partitions = data_flow_ops.dynamic_partition( 201 data, indices, num_partitions=3) 202 partition_vals = sess.run(partitions) 203 204 self.assertEqual(3, len(partition_vals)) 205 self.assertAllEqual([[]], partition_vals[0]) 206 self.assertAllEqual([[]], partition_vals[1]) 207 self.assertAllEqual(np.array([], dtype=np.float).reshape(0, 0), 208 partition_vals[2]) 209 210 def testEmptyPartitions(self): 211 data_list = [] 212 indices_list = [] 213 with self.test_session(use_gpu=True) as sess: 214 data = constant_op.constant(data_list, dtype=dtypes.float32) 215 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 216 partitions = data_flow_ops.dynamic_partition( 217 data, indices, num_partitions=2) 218 partition_vals = sess.run(partitions) 219 220 self.assertEqual(2, len(partition_vals)) 221 self.assertAllEqual([], partition_vals[0]) 222 self.assertAllEqual([], partition_vals[1]) 223 224 @unittest.skip("Fails on windows.") 225 def testGPUTooManyParts(self): 226 # This test only makes sense on the GPU. There we do not check 227 # for errors. In this case, we should discard all but the first 228 # num_partitions indices. 229 if not test.is_gpu_available(): 230 return 231 232 data_list = [1, 2, 3, 4, 5, 6] 233 indices_list = [6, 5, 4, 3, 1, 0] 234 with self.test_session(use_gpu=True) as sess: 235 data = constant_op.constant(data_list, dtype=dtypes.float32) 236 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 237 partitions = data_flow_ops.dynamic_partition( 238 data, indices, num_partitions=2) 239 partition_vals = sess.run(partitions) 240 241 self.assertEqual(2, len(partition_vals)) 242 self.assertAllEqual([6], partition_vals[0]) 243 self.assertAllEqual([5], partition_vals[1]) 244 245 @unittest.skip("Fails on windows.") 246 def testGPUPartsTooLarge(self): 247 # This test only makes sense on the GPU. There we do not check 248 # for errors. In this case, we should discard all the values 249 # larger than num_partitions. 250 if not test.is_gpu_available(): 251 return 252 253 data_list = [1, 2, 3, 4, 5, 6] 254 indices_list = [10, 11, 2, 12, 0, 1000] 255 with self.test_session(use_gpu=True) as sess: 256 data = constant_op.constant(data_list, dtype=dtypes.float32) 257 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 258 partitions = data_flow_ops.dynamic_partition( 259 data, indices, num_partitions=5) 260 partition_vals = sess.run(partitions) 261 262 self.assertEqual(5, len(partition_vals)) 263 self.assertAllEqual([5], partition_vals[0]) 264 self.assertAllEqual([], partition_vals[1]) 265 self.assertAllEqual([3], partition_vals[2]) 266 self.assertAllEqual([], partition_vals[3]) 267 self.assertAllEqual([], partition_vals[4]) 268 269 @unittest.skip("Fails on windows.") 270 def testGPUAllIndicesBig(self): 271 # This test only makes sense on the GPU. There we do not check 272 # for errors. In this case, we should discard all the values 273 # and have an empty output. 274 if not test.is_gpu_available(): 275 return 276 277 data_list = [1.1, 2.1, 3.1, 4.1, 5.1, 6.1] 278 indices_list = [90, 70, 60, 100, 110, 40] 279 with self.test_session(use_gpu=True) as sess: 280 data = constant_op.constant(data_list, dtype=dtypes.float32) 281 indices = constant_op.constant(indices_list, dtype=dtypes.int32) 282 partitions = data_flow_ops.dynamic_partition( 283 data, indices, num_partitions=40) 284 partition_vals = sess.run(partitions) 285 286 self.assertEqual(40, len(partition_vals)) 287 for i in range(40): 288 self.assertAllEqual([], partition_vals[i]) 289 290 def testErrorIndexOutOfRange(self): 291 with self.test_session() as sess: 292 data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], 293 [12, 13, 14]]) 294 indices = constant_op.constant([0, 2, 99, 2, 2]) 295 partitions = data_flow_ops.dynamic_partition( 296 data, indices, num_partitions=4) 297 with self.assertRaisesOpError(r"partitions\[2\] = 99 is not in \[0, 4\)"): 298 sess.run(partitions) 299 300 def testScalarIndexOutOfRange(self): 301 with self.test_session() as sess: 302 bad = 17 303 data = np.zeros(5) 304 partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7) 305 with self.assertRaisesOpError(r"partitions = 17 is not in \[0, 7\)"): 306 sess.run(partitions) 307 308 def testHigherRankIndexOutOfRange(self): 309 with self.test_session() as sess: 310 shape = (2, 3) 311 indices = array_ops.placeholder(shape=shape, dtype=np.int32) 312 data = np.zeros(shape + (5,)) 313 partitions = data_flow_ops.dynamic_partition( 314 data, indices, num_partitions=7) 315 for i in xrange(2): 316 for j in xrange(3): 317 bad = np.zeros(shape, dtype=np.int32) 318 bad[i, j] = 17 319 with self.assertRaisesOpError( 320 r"partitions\[%d,%d\] = 17 is not in \[0, 7\)" % (i, j)): 321 sess.run(partitions, feed_dict={indices: bad}) 322 323 def testErrorWrongDimsIndices(self): 324 data = constant_op.constant([[0], [1], [2]]) 325 indices = constant_op.constant([[0], [0]]) 326 with self.assertRaises(ValueError): 327 data_flow_ops.dynamic_partition(data, indices, num_partitions=4) 328 329 330 if __name__ == "__main__": 331 test.main() 332