Home | History | Annotate | Download | only in test
      1 """Unittests for heapq."""
      2 
      3 import sys
      4 import random
      5 
      6 from test import test_support
      7 from unittest import TestCase, skipUnless
      8 
      9 py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq'])
     10 c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq'])
     11 
     12 # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
     13 # _heapq is imported, so check them there
     14 func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
     15               'heapreplace', '_nlargest', '_nsmallest']
     16 
     17 class TestModules(TestCase):
     18     def test_py_functions(self):
     19         for fname in func_names:
     20             self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
     21 
     22     @skipUnless(c_heapq, 'requires _heapq')
     23     def test_c_functions(self):
     24         for fname in func_names:
     25             self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
     26 
     27 
     28 class TestHeap(TestCase):
     29     module = None
     30 
     31     def test_push_pop(self):
     32         # 1) Push 256 random numbers and pop them off, verifying all's OK.
     33         heap = []
     34         data = []
     35         self.check_invariant(heap)
     36         for i in range(256):
     37             item = random.random()
     38             data.append(item)
     39             self.module.heappush(heap, item)
     40             self.check_invariant(heap)
     41         results = []
     42         while heap:
     43             item = self.module.heappop(heap)
     44             self.check_invariant(heap)
     45             results.append(item)
     46         data_sorted = data[:]
     47         data_sorted.sort()
     48         self.assertEqual(data_sorted, results)
     49         # 2) Check that the invariant holds for a sorted array
     50         self.check_invariant(results)
     51 
     52         self.assertRaises(TypeError, self.module.heappush, [])
     53         try:
     54             self.assertRaises(TypeError, self.module.heappush, None, None)
     55             self.assertRaises(TypeError, self.module.heappop, None)
     56         except AttributeError:
     57             pass
     58 
     59     def check_invariant(self, heap):
     60         # Check the heap invariant.
     61         for pos, item in enumerate(heap):
     62             if pos: # pos 0 has no parent
     63                 parentpos = (pos-1) >> 1
     64                 self.assertTrue(heap[parentpos] <= item)
     65 
     66     def test_heapify(self):
     67         for size in range(30):
     68             heap = [random.random() for dummy in range(size)]
     69             self.module.heapify(heap)
     70             self.check_invariant(heap)
     71 
     72         self.assertRaises(TypeError, self.module.heapify, None)
     73 
     74     def test_naive_nbest(self):
     75         data = [random.randrange(2000) for i in range(1000)]
     76         heap = []
     77         for item in data:
     78             self.module.heappush(heap, item)
     79             if len(heap) > 10:
     80                 self.module.heappop(heap)
     81         heap.sort()
     82         self.assertEqual(heap, sorted(data)[-10:])
     83 
     84     def heapiter(self, heap):
     85         # An iterator returning a heap's elements, smallest-first.
     86         try:
     87             while 1:
     88                 yield self.module.heappop(heap)
     89         except IndexError:
     90             pass
     91 
     92     def test_nbest(self):
     93         # Less-naive "N-best" algorithm, much faster (if len(data) is big
     94         # enough <wink>) than sorting all of data.  However, if we had a max
     95         # heap instead of a min heap, it could go faster still via
     96         # heapify'ing all of data (linear time), then doing 10 heappops
     97         # (10 log-time steps).
     98         data = [random.randrange(2000) for i in range(1000)]
     99         heap = data[:10]
    100         self.module.heapify(heap)
    101         for item in data[10:]:
    102             if item > heap[0]:  # this gets rarer the longer we run
    103                 self.module.heapreplace(heap, item)
    104         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
    105 
    106         self.assertRaises(TypeError, self.module.heapreplace, None)
    107         self.assertRaises(TypeError, self.module.heapreplace, None, None)
    108         self.assertRaises(IndexError, self.module.heapreplace, [], None)
    109 
    110     def test_nbest_with_pushpop(self):
    111         data = [random.randrange(2000) for i in range(1000)]
    112         heap = data[:10]
    113         self.module.heapify(heap)
    114         for item in data[10:]:
    115             self.module.heappushpop(heap, item)
    116         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
    117         self.assertEqual(self.module.heappushpop([], 'x'), 'x')
    118 
    119     def test_heappushpop(self):
    120         h = []
    121         x = self.module.heappushpop(h, 10)
    122         self.assertEqual((h, x), ([], 10))
    123 
    124         h = [10]
    125         x = self.module.heappushpop(h, 10.0)
    126         self.assertEqual((h, x), ([10], 10.0))
    127         self.assertEqual(type(h[0]), int)
    128         self.assertEqual(type(x), float)
    129 
    130         h = [10];
    131         x = self.module.heappushpop(h, 9)
    132         self.assertEqual((h, x), ([10], 9))
    133 
    134         h = [10];
    135         x = self.module.heappushpop(h, 11)
    136         self.assertEqual((h, x), ([11], 10))
    137 
    138     def test_heapsort(self):
    139         # Exercise everything with repeated heapsort checks
    140         for trial in xrange(100):
    141             size = random.randrange(50)
    142             data = [random.randrange(25) for i in range(size)]
    143             if trial & 1:     # Half of the time, use heapify
    144                 heap = data[:]
    145                 self.module.heapify(heap)
    146             else:             # The rest of the time, use heappush
    147                 heap = []
    148                 for item in data:
    149                     self.module.heappush(heap, item)
    150             heap_sorted = [self.module.heappop(heap) for i in range(size)]
    151             self.assertEqual(heap_sorted, sorted(data))
    152 
    153     def test_merge(self):
    154         inputs = []
    155         for i in xrange(random.randrange(5)):
    156             row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
    157             inputs.append(row)
    158         self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
    159         self.assertEqual(list(self.module.merge()), [])
    160 
    161     def test_merge_stability(self):
    162         class Int(int):
    163             pass
    164         inputs = [[], [], [], []]
    165         for i in range(20000):
    166             stream = random.randrange(4)
    167             x = random.randrange(500)
    168             obj = Int(x)
    169             obj.pair = (x, stream)
    170             inputs[stream].append(obj)
    171         for stream in inputs:
    172             stream.sort()
    173         result = [i.pair for i in self.module.merge(*inputs)]
    174         self.assertEqual(result, sorted(result))
    175 
    176     def test_nsmallest(self):
    177         data = [(random.randrange(2000), i) for i in range(1000)]
    178         for f in (None, lambda x:  x[0] * 547 % 2000):
    179             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
    180                 self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
    181                 self.assertEqual(self.module.nsmallest(n, data, key=f),
    182                                  sorted(data, key=f)[:n])
    183 
    184     def test_nlargest(self):
    185         data = [(random.randrange(2000), i) for i in range(1000)]
    186         for f in (None, lambda x:  x[0] * 547 % 2000):
    187             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
    188                 self.assertEqual(self.module.nlargest(n, data),
    189                                  sorted(data, reverse=True)[:n])
    190                 self.assertEqual(self.module.nlargest(n, data, key=f),
    191                                  sorted(data, key=f, reverse=True)[:n])
    192 
    193     def test_comparison_operator(self):
    194         # Issue 3051: Make sure heapq works with both __lt__ and __le__
    195         def hsort(data, comp):
    196             data = map(comp, data)
    197             self.module.heapify(data)
    198             return [self.module.heappop(data).x for i in range(len(data))]
    199         class LT:
    200             def __init__(self, x):
    201                 self.x = x
    202             def __lt__(self, other):
    203                 return self.x > other.x
    204         class LE:
    205             def __init__(self, x):
    206                 self.x = x
    207             def __le__(self, other):
    208                 return self.x >= other.x
    209         data = [random.random() for i in range(100)]
    210         target = sorted(data, reverse=True)
    211         self.assertEqual(hsort(data, LT), target)
    212         self.assertEqual(hsort(data, LE), target)
    213 
    214 
    215 class TestHeapPython(TestHeap):
    216     module = py_heapq
    217 
    218 
    219 @skipUnless(c_heapq, 'requires _heapq')
    220 class TestHeapC(TestHeap):
    221     module = c_heapq
    222 
    223 
    224 #==============================================================================
    225 
    226 class LenOnly:
    227     "Dummy sequence class defining __len__ but not __getitem__."
    228     def __len__(self):
    229         return 10
    230 
    231 class GetOnly:
    232     "Dummy sequence class defining __getitem__ but not __len__."
    233     def __getitem__(self, ndx):
    234         return 10
    235 
    236 class CmpErr:
    237     "Dummy element that always raises an error during comparison"
    238     def __cmp__(self, other):
    239         raise ZeroDivisionError
    240 
    241 def R(seqn):
    242     'Regular generator'
    243     for i in seqn:
    244         yield i
    245 
    246 class G:
    247     'Sequence using __getitem__'
    248     def __init__(self, seqn):
    249         self.seqn = seqn
    250     def __getitem__(self, i):
    251         return self.seqn[i]
    252 
    253 class I:
    254     'Sequence using iterator protocol'
    255     def __init__(self, seqn):
    256         self.seqn = seqn
    257         self.i = 0
    258     def __iter__(self):
    259         return self
    260     def next(self):
    261         if self.i >= len(self.seqn): raise StopIteration
    262         v = self.seqn[self.i]
    263         self.i += 1
    264         return v
    265 
    266 class Ig:
    267     'Sequence using iterator protocol defined with a generator'
    268     def __init__(self, seqn):
    269         self.seqn = seqn
    270         self.i = 0
    271     def __iter__(self):
    272         for val in self.seqn:
    273             yield val
    274 
    275 class X:
    276     'Missing __getitem__ and __iter__'
    277     def __init__(self, seqn):
    278         self.seqn = seqn
    279         self.i = 0
    280     def next(self):
    281         if self.i >= len(self.seqn): raise StopIteration
    282         v = self.seqn[self.i]
    283         self.i += 1
    284         return v
    285 
    286 class N:
    287     'Iterator missing next()'
    288     def __init__(self, seqn):
    289         self.seqn = seqn
    290         self.i = 0
    291     def __iter__(self):
    292         return self
    293 
    294 class E:
    295     'Test propagation of exceptions'
    296     def __init__(self, seqn):
    297         self.seqn = seqn
    298         self.i = 0
    299     def __iter__(self):
    300         return self
    301     def next(self):
    302         3 // 0
    303 
    304 class S:
    305     'Test immediate stop'
    306     def __init__(self, seqn):
    307         pass
    308     def __iter__(self):
    309         return self
    310     def next(self):
    311         raise StopIteration
    312 
    313 from itertools import chain, imap
    314 def L(seqn):
    315     'Test multiple tiers of iterators'
    316     return chain(imap(lambda x:x, R(Ig(G(seqn)))))
    317 
    318 class SideEffectLT:
    319     def __init__(self, value, heap):
    320         self.value = value
    321         self.heap = heap
    322 
    323     def __lt__(self, other):
    324         self.heap[:] = []
    325         return self.value < other.value
    326 
    327 
    328 class TestErrorHandling(TestCase):
    329     module = None
    330 
    331     def test_non_sequence(self):
    332         for f in (self.module.heapify, self.module.heappop):
    333             self.assertRaises((TypeError, AttributeError), f, 10)
    334         for f in (self.module.heappush, self.module.heapreplace,
    335                   self.module.nlargest, self.module.nsmallest):
    336             self.assertRaises((TypeError, AttributeError), f, 10, 10)
    337 
    338     def test_len_only(self):
    339         for f in (self.module.heapify, self.module.heappop):
    340             self.assertRaises((TypeError, AttributeError), f, LenOnly())
    341         for f in (self.module.heappush, self.module.heapreplace):
    342             self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
    343         for f in (self.module.nlargest, self.module.nsmallest):
    344             self.assertRaises(TypeError, f, 2, LenOnly())
    345 
    346     def test_get_only(self):
    347         seq = [CmpErr(), CmpErr(), CmpErr()]
    348         for f in (self.module.heapify, self.module.heappop):
    349             self.assertRaises(ZeroDivisionError, f, seq)
    350         for f in (self.module.heappush, self.module.heapreplace):
    351             self.assertRaises(ZeroDivisionError, f, seq, 10)
    352         for f in (self.module.nlargest, self.module.nsmallest):
    353             self.assertRaises(ZeroDivisionError, f, 2, seq)
    354 
    355     def test_arg_parsing(self):
    356         for f in (self.module.heapify, self.module.heappop,
    357                   self.module.heappush, self.module.heapreplace,
    358                   self.module.nlargest, self.module.nsmallest):
    359             self.assertRaises((TypeError, AttributeError), f, 10)
    360 
    361     def test_iterable_args(self):
    362         for f in (self.module.nlargest, self.module.nsmallest):
    363             for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
    364                 for g in (G, I, Ig, L, R):
    365                     with test_support.check_py3k_warnings(
    366                             ("comparing unequal types not supported",
    367                              DeprecationWarning), quiet=True):
    368                         self.assertEqual(f(2, g(s)), f(2,s))
    369                 self.assertEqual(f(2, S(s)), [])
    370                 self.assertRaises(TypeError, f, 2, X(s))
    371                 self.assertRaises(TypeError, f, 2, N(s))
    372                 self.assertRaises(ZeroDivisionError, f, 2, E(s))
    373 
    374     # Issue #17278: the heap may change size while it's being walked.
    375 
    376     def test_heappush_mutating_heap(self):
    377         heap = []
    378         heap.extend(SideEffectLT(i, heap) for i in range(200))
    379         # Python version raises IndexError, C version RuntimeError
    380         with self.assertRaises((IndexError, RuntimeError)):
    381             self.module.heappush(heap, SideEffectLT(5, heap))
    382 
    383     def test_heappop_mutating_heap(self):
    384         heap = []
    385         heap.extend(SideEffectLT(i, heap) for i in range(200))
    386         # Python version raises IndexError, C version RuntimeError
    387         with self.assertRaises((IndexError, RuntimeError)):
    388             self.module.heappop(heap)
    389 
    390 
    391 class TestErrorHandlingPython(TestErrorHandling):
    392     module = py_heapq
    393 
    394 
    395 @skipUnless(c_heapq, 'requires _heapq')
    396 class TestErrorHandlingC(TestErrorHandling):
    397     module = c_heapq
    398 
    399 
    400 #==============================================================================
    401 
    402 
    403 def test_main(verbose=None):
    404     test_classes = [TestModules, TestHeapPython, TestHeapC,
    405                     TestErrorHandlingPython, TestErrorHandlingC]
    406     test_support.run_unittest(*test_classes)
    407 
    408     # verify reference counting
    409     if verbose and hasattr(sys, "gettotalrefcount"):
    410         import gc
    411         counts = [None] * 5
    412         for i in xrange(len(counts)):
    413             test_support.run_unittest(*test_classes)
    414             gc.collect()
    415             counts[i] = sys.gettotalrefcount()
    416         print counts
    417 
    418 if __name__ == "__main__":
    419     test_main(verbose=True)
    420