Home | History | Annotate | Download | only in test
      1 # Tests for rich comparisons
      2 
      3 import unittest
      4 from test import test_support
      5 
      6 import operator
      7 
      8 class Number:
      9 
     10     def __init__(self, x):
     11         self.x = x
     12 
     13     def __lt__(self, other):
     14         return self.x < other
     15 
     16     def __le__(self, other):
     17         return self.x <= other
     18 
     19     def __eq__(self, other):
     20         return self.x == other
     21 
     22     def __ne__(self, other):
     23         return self.x != other
     24 
     25     def __gt__(self, other):
     26         return self.x > other
     27 
     28     def __ge__(self, other):
     29         return self.x >= other
     30 
     31     def __cmp__(self, other):
     32         raise test_support.TestFailed, "Number.__cmp__() should not be called"
     33 
     34     def __repr__(self):
     35         return "Number(%r)" % (self.x, )
     36 
     37 class Vector:
     38 
     39     def __init__(self, data):
     40         self.data = data
     41 
     42     def __len__(self):
     43         return len(self.data)
     44 
     45     def __getitem__(self, i):
     46         return self.data[i]
     47 
     48     def __setitem__(self, i, v):
     49         self.data[i] = v
     50 
     51     __hash__ = None # Vectors cannot be hashed
     52 
     53     def __nonzero__(self):
     54         raise TypeError, "Vectors cannot be used in Boolean contexts"
     55 
     56     def __cmp__(self, other):
     57         raise test_support.TestFailed, "Vector.__cmp__() should not be called"
     58 
     59     def __repr__(self):
     60         return "Vector(%r)" % (self.data, )
     61 
     62     def __lt__(self, other):
     63         return Vector([a < b for a, b in zip(self.data, self.__cast(other))])
     64 
     65     def __le__(self, other):
     66         return Vector([a <= b for a, b in zip(self.data, self.__cast(other))])
     67 
     68     def __eq__(self, other):
     69         return Vector([a == b for a, b in zip(self.data, self.__cast(other))])
     70 
     71     def __ne__(self, other):
     72         return Vector([a != b for a, b in zip(self.data, self.__cast(other))])
     73 
     74     def __gt__(self, other):
     75         return Vector([a > b for a, b in zip(self.data, self.__cast(other))])
     76 
     77     def __ge__(self, other):
     78         return Vector([a >= b for a, b in zip(self.data, self.__cast(other))])
     79 
     80     def __cast(self, other):
     81         if isinstance(other, Vector):
     82             other = other.data
     83         if len(self.data) != len(other):
     84             raise ValueError, "Cannot compare vectors of different length"
     85         return other
     86 
     87 opmap = {
     88     "lt": (lambda a,b: a< b, operator.lt, operator.__lt__),
     89     "le": (lambda a,b: a<=b, operator.le, operator.__le__),
     90     "eq": (lambda a,b: a==b, operator.eq, operator.__eq__),
     91     "ne": (lambda a,b: a!=b, operator.ne, operator.__ne__),
     92     "gt": (lambda a,b: a> b, operator.gt, operator.__gt__),
     93     "ge": (lambda a,b: a>=b, operator.ge, operator.__ge__)
     94 }
     95 
     96 class VectorTest(unittest.TestCase):
     97 
     98     def checkfail(self, error, opname, *args):
     99         for op in opmap[opname]:
    100             self.assertRaises(error, op, *args)
    101 
    102     def checkequal(self, opname, a, b, expres):
    103         for op in opmap[opname]:
    104             realres = op(a, b)
    105             # can't use assertEqual(realres, expres) here
    106             self.assertEqual(len(realres), len(expres))
    107             for i in xrange(len(realres)):
    108                 # results are bool, so we can use "is" here
    109                 self.assertTrue(realres[i] is expres[i])
    110 
    111     def test_mixed(self):
    112         # check that comparisons involving Vector objects
    113         # which return rich results (i.e. Vectors with itemwise
    114         # comparison results) work
    115         a = Vector(range(2))
    116         b = Vector(range(3))
    117         # all comparisons should fail for different length
    118         for opname in opmap:
    119             self.checkfail(ValueError, opname, a, b)
    120 
    121         a = range(5)
    122         b = 5 * [2]
    123         # try mixed arguments (but not (a, b) as that won't return a bool vector)
    124         args = [(a, Vector(b)), (Vector(a), b), (Vector(a), Vector(b))]
    125         for (a, b) in args:
    126             self.checkequal("lt", a, b, [True,  True,  False, False, False])
    127             self.checkequal("le", a, b, [True,  True,  True,  False, False])
    128             self.checkequal("eq", a, b, [False, False, True,  False, False])
    129             self.checkequal("ne", a, b, [True,  True,  False, True,  True ])
    130             self.checkequal("gt", a, b, [False, False, False, True,  True ])
    131             self.checkequal("ge", a, b, [False, False, True,  True,  True ])
    132 
    133             for ops in opmap.itervalues():
    134                 for op in ops:
    135                     # calls __nonzero__, which should fail
    136                     self.assertRaises(TypeError, bool, op(a, b))
    137 
    138 class NumberTest(unittest.TestCase):
    139 
    140     def test_basic(self):
    141         # Check that comparisons involving Number objects
    142         # give the same results give as comparing the
    143         # corresponding ints
    144         for a in xrange(3):
    145             for b in xrange(3):
    146                 for typea in (int, Number):
    147                     for typeb in (int, Number):
    148                         if typea==typeb==int:
    149                             continue # the combination int, int is useless
    150                         ta = typea(a)
    151                         tb = typeb(b)
    152                         for ops in opmap.itervalues():
    153                             for op in ops:
    154                                 realoutcome = op(a, b)
    155                                 testoutcome = op(ta, tb)
    156                                 self.assertEqual(realoutcome, testoutcome)
    157 
    158     def checkvalue(self, opname, a, b, expres):
    159         for typea in (int, Number):
    160             for typeb in (int, Number):
    161                 ta = typea(a)
    162                 tb = typeb(b)
    163                 for op in opmap[opname]:
    164                     realres = op(ta, tb)
    165                     realres = getattr(realres, "x", realres)
    166                     self.assertTrue(realres is expres)
    167 
    168     def test_values(self):
    169         # check all operators and all comparison results
    170         self.checkvalue("lt", 0, 0, False)
    171         self.checkvalue("le", 0, 0, True )
    172         self.checkvalue("eq", 0, 0, True )
    173         self.checkvalue("ne", 0, 0, False)
    174         self.checkvalue("gt", 0, 0, False)
    175         self.checkvalue("ge", 0, 0, True )
    176 
    177         self.checkvalue("lt", 0, 1, True )
    178         self.checkvalue("le", 0, 1, True )
    179         self.checkvalue("eq", 0, 1, False)
    180         self.checkvalue("ne", 0, 1, True )
    181         self.checkvalue("gt", 0, 1, False)
    182         self.checkvalue("ge", 0, 1, False)
    183 
    184         self.checkvalue("lt", 1, 0, False)
    185         self.checkvalue("le", 1, 0, False)
    186         self.checkvalue("eq", 1, 0, False)
    187         self.checkvalue("ne", 1, 0, True )
    188         self.checkvalue("gt", 1, 0, True )
    189         self.checkvalue("ge", 1, 0, True )
    190 
    191 class MiscTest(unittest.TestCase):
    192 
    193     def test_misbehavin(self):
    194         class Misb:
    195             def __lt__(self_, other): return 0
    196             def __gt__(self_, other): return 0
    197             def __eq__(self_, other): return 0
    198             def __le__(self_, other): self.fail("This shouldn't happen")
    199             def __ge__(self_, other): self.fail("This shouldn't happen")
    200             def __ne__(self_, other): self.fail("This shouldn't happen")
    201             def __cmp__(self_, other): raise RuntimeError, "expected"
    202         a = Misb()
    203         b = Misb()
    204         self.assertEqual(a<b, 0)
    205         self.assertEqual(a==b, 0)
    206         self.assertEqual(a>b, 0)
    207         self.assertRaises(RuntimeError, cmp, a, b)
    208 
    209     def test_not(self):
    210         # Check that exceptions in __nonzero__ are properly
    211         # propagated by the not operator
    212         import operator
    213         class Exc(Exception):
    214             pass
    215         class Bad:
    216             def __nonzero__(self):
    217                 raise Exc
    218 
    219         def do(bad):
    220             not bad
    221 
    222         for func in (do, operator.not_):
    223             self.assertRaises(Exc, func, Bad())
    224 
    225     def test_recursion(self):
    226         # Check that comparison for recursive objects fails gracefully
    227         from UserList import UserList
    228         a = UserList()
    229         b = UserList()
    230         a.append(b)
    231         b.append(a)
    232         self.assertRaises(RuntimeError, operator.eq, a, b)
    233         self.assertRaises(RuntimeError, operator.ne, a, b)
    234         self.assertRaises(RuntimeError, operator.lt, a, b)
    235         self.assertRaises(RuntimeError, operator.le, a, b)
    236         self.assertRaises(RuntimeError, operator.gt, a, b)
    237         self.assertRaises(RuntimeError, operator.ge, a, b)
    238 
    239         b.append(17)
    240         # Even recursive lists of different lengths are different,
    241         # but they cannot be ordered
    242         self.assertTrue(not (a == b))
    243         self.assertTrue(a != b)
    244         self.assertRaises(RuntimeError, operator.lt, a, b)
    245         self.assertRaises(RuntimeError, operator.le, a, b)
    246         self.assertRaises(RuntimeError, operator.gt, a, b)
    247         self.assertRaises(RuntimeError, operator.ge, a, b)
    248         a.append(17)
    249         self.assertRaises(RuntimeError, operator.eq, a, b)
    250         self.assertRaises(RuntimeError, operator.ne, a, b)
    251         a.insert(0, 11)
    252         b.insert(0, 12)
    253         self.assertTrue(not (a == b))
    254         self.assertTrue(a != b)
    255         self.assertTrue(a < b)
    256 
    257 class DictTest(unittest.TestCase):
    258 
    259     def test_dicts(self):
    260         # Verify that __eq__ and __ne__ work for dicts even if the keys and
    261         # values don't support anything other than __eq__ and __ne__ (and
    262         # __hash__).  Complex numbers are a fine example of that.
    263         import random
    264         imag1a = {}
    265         for i in range(50):
    266             imag1a[random.randrange(100)*1j] = random.randrange(100)*1j
    267         items = imag1a.items()
    268         random.shuffle(items)
    269         imag1b = {}
    270         for k, v in items:
    271             imag1b[k] = v
    272         imag2 = imag1b.copy()
    273         imag2[k] = v + 1.0
    274         self.assertTrue(imag1a == imag1a)
    275         self.assertTrue(imag1a == imag1b)
    276         self.assertTrue(imag2 == imag2)
    277         self.assertTrue(imag1a != imag2)
    278         for opname in ("lt", "le", "gt", "ge"):
    279             for op in opmap[opname]:
    280                 self.assertRaises(TypeError, op, imag1a, imag2)
    281 
    282 class ListTest(unittest.TestCase):
    283 
    284     def test_coverage(self):
    285         # exercise all comparisons for lists
    286         x = [42]
    287         self.assertIs(x<x, False)
    288         self.assertIs(x<=x, True)
    289         self.assertIs(x==x, True)
    290         self.assertIs(x!=x, False)
    291         self.assertIs(x>x, False)
    292         self.assertIs(x>=x, True)
    293         y = [42, 42]
    294         self.assertIs(x<y, True)
    295         self.assertIs(x<=y, True)
    296         self.assertIs(x==y, False)
    297         self.assertIs(x!=y, True)
    298         self.assertIs(x>y, False)
    299         self.assertIs(x>=y, False)
    300 
    301     def test_badentry(self):
    302         # make sure that exceptions for item comparison are properly
    303         # propagated in list comparisons
    304         class Exc(Exception):
    305             pass
    306         class Bad:
    307             def __eq__(self, other):
    308                 raise Exc
    309 
    310         x = [Bad()]
    311         y = [Bad()]
    312 
    313         for op in opmap["eq"]:
    314             self.assertRaises(Exc, op, x, y)
    315 
    316     def test_goodentry(self):
    317         # This test exercises the final call to PyObject_RichCompare()
    318         # in Objects/listobject.c::list_richcompare()
    319         class Good:
    320             def __lt__(self, other):
    321                 return True
    322 
    323         x = [Good()]
    324         y = [Good()]
    325 
    326         for op in opmap["lt"]:
    327             self.assertIs(op(x, y), True)
    328 
    329 def test_main():
    330     test_support.run_unittest(VectorTest, NumberTest, MiscTest, ListTest)
    331     with test_support.check_py3k_warnings(("dict inequality comparisons "
    332                                              "not supported in 3.x",
    333                                              DeprecationWarning)):
    334         test_support.run_unittest(DictTest)
    335 
    336 
    337 if __name__ == "__main__":
    338     test_main()
    339