Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the DynamicPartition op."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import unittest
     22 
     23 import numpy as np
     24 from six.moves import xrange  # pylint: disable=redefined-builtin
     25 
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.ops import data_flow_ops
     30 from tensorflow.python.ops import gradients_impl
     31 import tensorflow.python.ops.data_flow_grad  # pylint: disable=unused-import
     32 from tensorflow.python.platform import test
     33 
     34 
     35 class DynamicPartitionTest(test.TestCase):
     36 
     37   def testSimpleOneDimensional(self):
     38     with self.test_session(use_gpu=True) as sess:
     39       data = constant_op.constant([0, 13, 2, 39, 4, 17], dtype=dtypes.float32)
     40       indices = constant_op.constant([0, 0, 2, 3, 2, 1])
     41       partitions = data_flow_ops.dynamic_partition(
     42           data, indices, num_partitions=4)
     43       partition_vals = sess.run(partitions)
     44 
     45     self.assertEqual(4, len(partition_vals))
     46     self.assertAllEqual([0, 13], partition_vals[0])
     47     self.assertAllEqual([17], partition_vals[1])
     48     self.assertAllEqual([2, 4], partition_vals[2])
     49     self.assertAllEqual([39], partition_vals[3])
     50     # Vector data input to DynamicPartition results in
     51     # `num_partitions` vectors of unknown length.
     52     self.assertEqual([None], partitions[0].get_shape().as_list())
     53     self.assertEqual([None], partitions[1].get_shape().as_list())
     54     self.assertEqual([None], partitions[2].get_shape().as_list())
     55     self.assertEqual([None], partitions[3].get_shape().as_list())
     56 
     57   def testSimpleTwoDimensional(self):
     58     with self.test_session(use_gpu=True) as sess:
     59       data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
     60                                    [12, 13, 14], [15, 16, 17]],
     61                                   dtype=dtypes.float32)
     62       indices = constant_op.constant([0, 0, 2, 3, 2, 1])
     63       partitions = data_flow_ops.dynamic_partition(
     64           data, indices, num_partitions=4)
     65       partition_vals = sess.run(partitions)
     66 
     67     self.assertEqual(4, len(partition_vals))
     68     self.assertAllEqual([[0, 1, 2], [3, 4, 5]], partition_vals[0])
     69     self.assertAllEqual([[15, 16, 17]], partition_vals[1])
     70     self.assertAllEqual([[6, 7, 8], [12, 13, 14]], partition_vals[2])
     71     self.assertAllEqual([[9, 10, 11]], partition_vals[3])
     72     # Vector data input to DynamicPartition results in
     73     # `num_partitions` matrices with an unknown number of rows, and 3 columns.
     74     self.assertEqual([None, 3], partitions[0].get_shape().as_list())
     75     self.assertEqual([None, 3], partitions[1].get_shape().as_list())
     76     self.assertEqual([None, 3], partitions[2].get_shape().as_list())
     77     self.assertEqual([None, 3], partitions[3].get_shape().as_list())
     78 
     79   def testLargeOneDimensional(self):
     80     num = 100000
     81     data_list = [x for x in range(num)]
     82     indices_list = [x % 2 for x in range(num)]
     83     part1 = [x for x in range(num) if x % 2 == 0]
     84     part2 = [x for x in range(num) if x % 2 == 1]
     85     with self.test_session(use_gpu=True) as sess:
     86       data = constant_op.constant(data_list, dtype=dtypes.float32)
     87       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
     88       partitions = data_flow_ops.dynamic_partition(
     89           data, indices, num_partitions=2)
     90       partition_vals = sess.run(partitions)
     91 
     92     self.assertEqual(2, len(partition_vals))
     93     self.assertAllEqual(part1, partition_vals[0])
     94     self.assertAllEqual(part2, partition_vals[1])
     95 
     96   def testLargeTwoDimensional(self):
     97     rows = 100000
     98     cols = 100
     99     data_list = [None] * rows
    100     for i in range(rows):
    101       data_list[i] = [i for _ in range(cols)]
    102     num_partitions = 97
    103     indices_list = [(i ** 2) % num_partitions for i in range(rows)]
    104     parts = [[] for _ in range(num_partitions)]
    105     for i in range(rows):
    106       parts[(i ** 2) % num_partitions].append(data_list[i])
    107     with self.test_session(use_gpu=True) as sess:
    108       data = constant_op.constant(data_list, dtype=dtypes.float32)
    109       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    110       partitions = data_flow_ops.dynamic_partition(
    111           data, indices, num_partitions=num_partitions)
    112       partition_vals = sess.run(partitions)
    113 
    114     self.assertEqual(num_partitions, len(partition_vals))
    115     for i in range(num_partitions):
    116       # reshape because of empty parts
    117       parts_np = np.array(parts[i], dtype=np.float).reshape(-1, cols)
    118       self.assertAllEqual(parts_np, partition_vals[i])
    119 
    120   def testSimpleComplex(self):
    121     data_list = [1 + 2j, 3 + 4j, 5 + 6j, 7 + 8j]
    122     indices_list = [1, 0, 1, 0]
    123     with self.test_session(use_gpu=True) as sess:
    124       data = constant_op.constant(data_list, dtype=dtypes.complex64)
    125       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    126       partitions = data_flow_ops.dynamic_partition(
    127           data, indices, num_partitions=2)
    128       partition_vals = sess.run(partitions)
    129 
    130     self.assertEqual(2, len(partition_vals))
    131     self.assertAllEqual([3 + 4j, 7 + 8j], partition_vals[0])
    132     self.assertAllEqual([1 + 2j, 5 + 6j], partition_vals[1])
    133 
    134   def testScalarPartitions(self):
    135     data_list = [10, 13, 12, 11]
    136     with self.test_session(use_gpu=True) as sess:
    137       data = constant_op.constant(data_list, dtype=dtypes.float64)
    138       indices = 3
    139       partitions = data_flow_ops.dynamic_partition(
    140           data, indices, num_partitions=4)
    141       partition_vals = sess.run(partitions)
    142 
    143     self.assertEqual(4, len(partition_vals))
    144     self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4),
    145                         partition_vals[0])
    146     self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4),
    147                         partition_vals[1])
    148     self.assertAllEqual(np.array([], dtype=np.float64).reshape(-1, 4),
    149                         partition_vals[2])
    150     self.assertAllEqual(np.array([10, 13, 12, 11],
    151                                  dtype=np.float64).reshape(-1, 4),
    152                         partition_vals[3])
    153 
    154   def testHigherRank(self):
    155     np.random.seed(7)
    156     with self.test_session(use_gpu=True) as sess:
    157       for n in 2, 3:
    158         for shape in (4,), (4, 5), (4, 5, 2):
    159           partitions = np.random.randint(n, size=np.prod(shape)).reshape(shape)
    160           for extra_shape in (), (6,), (6, 7):
    161             data = np.random.randn(*(shape + extra_shape))
    162             partitions_t = constant_op.constant(partitions, dtype=dtypes.int32)
    163             data_t = constant_op.constant(data)
    164             outputs = data_flow_ops.dynamic_partition(
    165                 data_t, partitions_t, num_partitions=n)
    166             self.assertEqual(n, len(outputs))
    167             outputs_val = sess.run(outputs)
    168             for i, output in enumerate(outputs_val):
    169               self.assertAllEqual(output, data[partitions == i])
    170 
    171             # Test gradients
    172             outputs_grad = [7 * output for output in outputs_val]
    173             grads = gradients_impl.gradients(outputs, [data_t, partitions_t],
    174                                              outputs_grad)
    175             self.assertEqual(grads[1], None)  # Partitions has no gradients
    176             self.assertAllEqual(7 * data, sess.run(grads[0]))
    177 
    178   def testEmptyParts(self):
    179     data_list = [1, 2, 3, 4]
    180     indices_list = [1, 3, 1, 3]
    181     with self.test_session(use_gpu=True) as sess:
    182       data = constant_op.constant(data_list, dtype=dtypes.float32)
    183       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    184       partitions = data_flow_ops.dynamic_partition(
    185           data, indices, num_partitions=4)
    186       partition_vals = sess.run(partitions)
    187 
    188     self.assertEqual(4, len(partition_vals))
    189     self.assertAllEqual([], partition_vals[0])
    190     self.assertAllEqual([1, 3], partition_vals[1])
    191     self.assertAllEqual([], partition_vals[2])
    192     self.assertAllEqual([2, 4], partition_vals[3])
    193 
    194   def testEmptyDataTwoDimensional(self):
    195     data_list = [[], []]
    196     indices_list = [0, 1]
    197     with self.test_session(use_gpu=True) as sess:
    198       data = constant_op.constant(data_list, dtype=dtypes.float32)
    199       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    200       partitions = data_flow_ops.dynamic_partition(
    201           data, indices, num_partitions=3)
    202       partition_vals = sess.run(partitions)
    203 
    204     self.assertEqual(3, len(partition_vals))
    205     self.assertAllEqual([[]], partition_vals[0])
    206     self.assertAllEqual([[]], partition_vals[1])
    207     self.assertAllEqual(np.array([], dtype=np.float).reshape(0, 0),
    208                         partition_vals[2])
    209 
    210   def testEmptyPartitions(self):
    211     data_list = []
    212     indices_list = []
    213     with self.test_session(use_gpu=True) as sess:
    214       data = constant_op.constant(data_list, dtype=dtypes.float32)
    215       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    216       partitions = data_flow_ops.dynamic_partition(
    217           data, indices, num_partitions=2)
    218       partition_vals = sess.run(partitions)
    219 
    220     self.assertEqual(2, len(partition_vals))
    221     self.assertAllEqual([], partition_vals[0])
    222     self.assertAllEqual([], partition_vals[1])
    223 
    224   @unittest.skip("Fails on windows.")
    225   def testGPUTooManyParts(self):
    226     # This test only makes sense on the GPU. There we do not check
    227     # for errors. In this case, we should discard all but the first
    228     # num_partitions indices.
    229     if not test.is_gpu_available():
    230       return
    231 
    232     data_list = [1, 2, 3, 4, 5, 6]
    233     indices_list = [6, 5, 4, 3, 1, 0]
    234     with self.test_session(use_gpu=True) as sess:
    235       data = constant_op.constant(data_list, dtype=dtypes.float32)
    236       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    237       partitions = data_flow_ops.dynamic_partition(
    238           data, indices, num_partitions=2)
    239       partition_vals = sess.run(partitions)
    240 
    241     self.assertEqual(2, len(partition_vals))
    242     self.assertAllEqual([6], partition_vals[0])
    243     self.assertAllEqual([5], partition_vals[1])
    244 
    245   @unittest.skip("Fails on windows.")
    246   def testGPUPartsTooLarge(self):
    247     # This test only makes sense on the GPU. There we do not check
    248     # for errors. In this case, we should discard all the values
    249     # larger than num_partitions.
    250     if not test.is_gpu_available():
    251       return
    252 
    253     data_list = [1, 2, 3, 4, 5, 6]
    254     indices_list = [10, 11, 2, 12, 0, 1000]
    255     with self.test_session(use_gpu=True) as sess:
    256       data = constant_op.constant(data_list, dtype=dtypes.float32)
    257       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    258       partitions = data_flow_ops.dynamic_partition(
    259           data, indices, num_partitions=5)
    260       partition_vals = sess.run(partitions)
    261 
    262     self.assertEqual(5, len(partition_vals))
    263     self.assertAllEqual([5], partition_vals[0])
    264     self.assertAllEqual([], partition_vals[1])
    265     self.assertAllEqual([3], partition_vals[2])
    266     self.assertAllEqual([], partition_vals[3])
    267     self.assertAllEqual([], partition_vals[4])
    268 
    269   @unittest.skip("Fails on windows.")
    270   def testGPUAllIndicesBig(self):
    271     # This test only makes sense on the GPU. There we do not check
    272     # for errors. In this case, we should discard all the values
    273     # and have an empty output.
    274     if not test.is_gpu_available():
    275       return
    276 
    277     data_list = [1.1, 2.1, 3.1, 4.1, 5.1, 6.1]
    278     indices_list = [90, 70, 60, 100, 110, 40]
    279     with self.test_session(use_gpu=True) as sess:
    280       data = constant_op.constant(data_list, dtype=dtypes.float32)
    281       indices = constant_op.constant(indices_list, dtype=dtypes.int32)
    282       partitions = data_flow_ops.dynamic_partition(
    283           data, indices, num_partitions=40)
    284       partition_vals = sess.run(partitions)
    285 
    286     self.assertEqual(40, len(partition_vals))
    287     for i in range(40):
    288       self.assertAllEqual([], partition_vals[i])
    289 
    290   def testErrorIndexOutOfRange(self):
    291     with self.test_session() as sess:
    292       data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
    293                                    [12, 13, 14]])
    294       indices = constant_op.constant([0, 2, 99, 2, 2])
    295       partitions = data_flow_ops.dynamic_partition(
    296           data, indices, num_partitions=4)
    297       with self.assertRaisesOpError(r"partitions\[2\] = 99 is not in \[0, 4\)"):
    298         sess.run(partitions)
    299 
    300   def testScalarIndexOutOfRange(self):
    301     with self.test_session() as sess:
    302       bad = 17
    303       data = np.zeros(5)
    304       partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7)
    305       with self.assertRaisesOpError(r"partitions = 17 is not in \[0, 7\)"):
    306         sess.run(partitions)
    307 
    308   def testHigherRankIndexOutOfRange(self):
    309     with self.test_session() as sess:
    310       shape = (2, 3)
    311       indices = array_ops.placeholder(shape=shape, dtype=np.int32)
    312       data = np.zeros(shape + (5,))
    313       partitions = data_flow_ops.dynamic_partition(
    314           data, indices, num_partitions=7)
    315       for i in xrange(2):
    316         for j in xrange(3):
    317           bad = np.zeros(shape, dtype=np.int32)
    318           bad[i, j] = 17
    319           with self.assertRaisesOpError(
    320               r"partitions\[%d,%d\] = 17 is not in \[0, 7\)" % (i, j)):
    321             sess.run(partitions, feed_dict={indices: bad})
    322 
    323   def testErrorWrongDimsIndices(self):
    324     data = constant_op.constant([[0], [1], [2]])
    325     indices = constant_op.constant([[0], [0]])
    326     with self.assertRaises(ValueError):
    327       data_flow_ops.dynamic_partition(data, indices, num_partitions=4)
    328 
    329 
    330 if __name__ == "__main__":
    331   test.main()
    332