Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for tensorflow.kernels.unique_op."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.ops import array_ops
     25 from tensorflow.python.ops import gen_array_ops
     26 from tensorflow.python.platform import test
     27 
     28 
     29 class UniqueTest(test.TestCase):
     30 
     31   def testInt32(self):
     32     x = np.random.randint(2, high=10, size=7000)
     33     with self.test_session() as sess:
     34       y, idx = array_ops.unique(x)
     35       tf_y, tf_idx = sess.run([y, idx])
     36 
     37     self.assertEqual(len(x), len(tf_idx))
     38     self.assertEqual(len(tf_y), len(np.unique(x)))
     39     for i in range(len(x)):
     40       self.assertEqual(x[i], tf_y[tf_idx[i]])
     41 
     42   def testInt32OutIdxInt64(self):
     43     x = np.random.randint(2, high=10, size=7000)
     44     with self.test_session() as sess:
     45       y, idx = array_ops.unique(x, out_idx=dtypes.int64)
     46       tf_y, tf_idx = sess.run([y, idx])
     47 
     48     self.assertEqual(len(x), len(tf_idx))
     49     self.assertEqual(len(tf_y), len(np.unique(x)))
     50     for i in range(len(x)):
     51       self.assertEqual(x[i], tf_y[tf_idx[i]])
     52 
     53   def testString(self):
     54     indx = np.random.randint(65, high=122, size=7000)
     55     x = [chr(i) for i in indx]
     56     with self.test_session() as sess:
     57       y, idx = array_ops.unique(x)
     58       tf_y, tf_idx = sess.run([y, idx])
     59 
     60     self.assertEqual(len(x), len(tf_idx))
     61     self.assertEqual(len(tf_y), len(np.unique(x)))
     62     for i in range(len(x)):
     63       self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii'))
     64 
     65   def testInt32Axis(self):
     66     for dtype in [np.int32, np.int64]:
     67       x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]])
     68       with self.test_session() as sess:
     69         y0, idx0 = gen_array_ops._unique_v2(x, axis=np.array([0], dtype))
     70         tf_y0, tf_idx0 = sess.run([y0, idx0])
     71         y1, idx1 = gen_array_ops._unique_v2(x, axis=np.array([1], dtype))
     72         tf_y1, tf_idx1 = sess.run([y1, idx1])
     73       self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]]))
     74       self.assertAllEqual(tf_idx0, np.array([0, 0, 1]))
     75       self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]]))
     76       self.assertAllEqual(tf_idx1, np.array([0, 1, 1]))
     77 
     78   def testInt32V2(self):
     79     # This test is only temporary, once V2 is used
     80     # by default, the axis will be wrapped to allow `axis=None`.
     81     x = np.random.randint(2, high=10, size=7000)
     82     with self.test_session() as sess:
     83       y, idx = gen_array_ops._unique_v2(x, axis=np.array([], np.int32))
     84       tf_y, tf_idx = sess.run([y, idx])
     85 
     86     self.assertEqual(len(x), len(tf_idx))
     87     self.assertEqual(len(tf_y), len(np.unique(x)))
     88     for i in range(len(x)):
     89       self.assertEqual(x[i], tf_y[tf_idx[i]])
     90 
     91 
     92 class UniqueWithCountsTest(test.TestCase):
     93 
     94   def testInt32(self):
     95     x = np.random.randint(2, high=10, size=7000)
     96     with self.test_session() as sess:
     97       y, idx, count = array_ops.unique_with_counts(x)
     98       tf_y, tf_idx, tf_count = sess.run([y, idx, count])
     99 
    100     self.assertEqual(len(x), len(tf_idx))
    101     self.assertEqual(len(tf_y), len(np.unique(x)))
    102     for i in range(len(x)):
    103       self.assertEqual(x[i], tf_y[tf_idx[i]])
    104     for value, count in zip(tf_y, tf_count):
    105       self.assertEqual(count, np.sum(x == value))
    106 
    107   def testInt32OutIdxInt64(self):
    108     x = np.random.randint(2, high=10, size=7000)
    109     with self.test_session() as sess:
    110       y, idx, count = array_ops.unique_with_counts(x, out_idx=dtypes.int64)
    111       tf_y, tf_idx, tf_count = sess.run([y, idx, count])
    112 
    113     self.assertEqual(len(x), len(tf_idx))
    114     self.assertEqual(len(tf_y), len(np.unique(x)))
    115     for i in range(len(x)):
    116       self.assertEqual(x[i], tf_y[tf_idx[i]])
    117     for value, count in zip(tf_y, tf_count):
    118       self.assertEqual(count, np.sum(x == value))
    119 
    120   def testString(self):
    121     indx = np.random.randint(65, high=122, size=7000)
    122     x = [chr(i) for i in indx]
    123 
    124     with self.test_session() as sess:
    125       y, idx, count = array_ops.unique_with_counts(x)
    126       tf_y, tf_idx, tf_count = sess.run([y, idx, count])
    127 
    128     self.assertEqual(len(x), len(tf_idx))
    129     self.assertEqual(len(tf_y), len(np.unique(x)))
    130     for i in range(len(x)):
    131       self.assertEqual(x[i], tf_y[tf_idx[i]].decode('ascii'))
    132     for value, count in zip(tf_y, tf_count):
    133       v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)]
    134       self.assertEqual(count, sum(v))
    135 
    136 
    137 if __name__ == '__main__':
    138   test.main()
    139