Home | History | Annotate | Download | only in test
      1 """Unittests for heapq."""
      2 
      3 import sys
      4 import random
      5 
      6 from test import test_support
      7 from unittest import TestCase, skipUnless
      8 
      9 py_heapq = test_support.import_fresh_module('heapq', blocked=['_heapq'])
     10 c_heapq = test_support.import_fresh_module('heapq', fresh=['_heapq'])
     11 
     12 # _heapq.nlargest/nsmallest are saved in heapq._nlargest/_smallest when
     13 # _heapq is imported, so check them there
     14 func_names = ['heapify', 'heappop', 'heappush', 'heappushpop',
     15               'heapreplace', '_nlargest', '_nsmallest']
     16 
     17 class TestModules(TestCase):
     18     def test_py_functions(self):
     19         for fname in func_names:
     20             self.assertEqual(getattr(py_heapq, fname).__module__, 'heapq')
     21 
     22     @skipUnless(c_heapq, 'requires _heapq')
     23     def test_c_functions(self):
     24         for fname in func_names:
     25             self.assertEqual(getattr(c_heapq, fname).__module__, '_heapq')
     26 
     27 
     28 class TestHeap(TestCase):
     29     module = None
     30 
     31     def test_push_pop(self):
     32         # 1) Push 256 random numbers and pop them off, verifying all's OK.
     33         heap = []
     34         data = []
     35         self.check_invariant(heap)
     36         for i in range(256):
     37             item = random.random()
     38             data.append(item)
     39             self.module.heappush(heap, item)
     40             self.check_invariant(heap)
     41         results = []
     42         while heap:
     43             item = self.module.heappop(heap)
     44             self.check_invariant(heap)
     45             results.append(item)
     46         data_sorted = data[:]
     47         data_sorted.sort()
     48         self.assertEqual(data_sorted, results)
     49         # 2) Check that the invariant holds for a sorted array
     50         self.check_invariant(results)
     51 
     52         self.assertRaises(TypeError, self.module.heappush, [])
     53         try:
     54             self.assertRaises(TypeError, self.module.heappush, None, None)
     55             self.assertRaises(TypeError, self.module.heappop, None)
     56         except AttributeError:
     57             pass
     58 
     59     def check_invariant(self, heap):
     60         # Check the heap invariant.
     61         for pos, item in enumerate(heap):
     62             if pos: # pos 0 has no parent
     63                 parentpos = (pos-1) >> 1
     64                 self.assertTrue(heap[parentpos] <= item)
     65 
     66     def test_heapify(self):
     67         for size in range(30):
     68             heap = [random.random() for dummy in range(size)]
     69             self.module.heapify(heap)
     70             self.check_invariant(heap)
     71 
     72         self.assertRaises(TypeError, self.module.heapify, None)
     73 
     74     def test_naive_nbest(self):
     75         data = [random.randrange(2000) for i in range(1000)]
     76         heap = []
     77         for item in data:
     78             self.module.heappush(heap, item)
     79             if len(heap) > 10:
     80                 self.module.heappop(heap)
     81         heap.sort()
     82         self.assertEqual(heap, sorted(data)[-10:])
     83 
     84     def heapiter(self, heap):
     85         # An iterator returning a heap's elements, smallest-first.
     86         try:
     87             while 1:
     88                 yield self.module.heappop(heap)
     89         except IndexError:
     90             pass
     91 
     92     def test_nbest(self):
     93         # Less-naive "N-best" algorithm, much faster (if len(data) is big
     94         # enough <wink>) than sorting all of data.  However, if we had a max
     95         # heap instead of a min heap, it could go faster still via
     96         # heapify'ing all of data (linear time), then doing 10 heappops
     97         # (10 log-time steps).
     98         data = [random.randrange(2000) for i in range(1000)]
     99         heap = data[:10]
    100         self.module.heapify(heap)
    101         for item in data[10:]:
    102             if item > heap[0]:  # this gets rarer the longer we run
    103                 self.module.heapreplace(heap, item)
    104         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
    105 
    106         self.assertRaises(TypeError, self.module.heapreplace, None)
    107         self.assertRaises(TypeError, self.module.heapreplace, None, None)
    108         self.assertRaises(IndexError, self.module.heapreplace, [], None)
    109 
    110     def test_nbest_with_pushpop(self):
    111         data = [random.randrange(2000) for i in range(1000)]
    112         heap = data[:10]
    113         self.module.heapify(heap)
    114         for item in data[10:]:
    115             self.module.heappushpop(heap, item)
    116         self.assertEqual(list(self.heapiter(heap)), sorted(data)[-10:])
    117         self.assertEqual(self.module.heappushpop([], 'x'), 'x')
    118 
    119     def test_heappushpop(self):
    120         h = []
    121         x = self.module.heappushpop(h, 10)
    122         self.assertEqual((h, x), ([], 10))
    123 
    124         h = [10]
    125         x = self.module.heappushpop(h, 10.0)
    126         self.assertEqual((h, x), ([10], 10.0))
    127         self.assertEqual(type(h[0]), int)
    128         self.assertEqual(type(x), float)
    129 
    130         h = [10];
    131         x = self.module.heappushpop(h, 9)
    132         self.assertEqual((h, x), ([10], 9))
    133 
    134         h = [10];
    135         x = self.module.heappushpop(h, 11)
    136         self.assertEqual((h, x), ([11], 10))
    137 
    138     def test_heapsort(self):
    139         # Exercise everything with repeated heapsort checks
    140         for trial in xrange(100):
    141             size = random.randrange(50)
    142             data = [random.randrange(25) for i in range(size)]
    143             if trial & 1:     # Half of the time, use heapify
    144                 heap = data[:]
    145                 self.module.heapify(heap)
    146             else:             # The rest of the time, use heappush
    147                 heap = []
    148                 for item in data:
    149                     self.module.heappush(heap, item)
    150             heap_sorted = [self.module.heappop(heap) for i in range(size)]
    151             self.assertEqual(heap_sorted, sorted(data))
    152 
    153     def test_merge(self):
    154         inputs = []
    155         for i in xrange(random.randrange(5)):
    156             row = sorted(random.randrange(1000) for j in range(random.randrange(10)))
    157             inputs.append(row)
    158         self.assertEqual(sorted(chain(*inputs)), list(self.module.merge(*inputs)))
    159         self.assertEqual(list(self.module.merge()), [])
    160 
    161     def test_merge_does_not_suppress_index_error(self):
    162         # Issue 19018: Heapq.merge suppresses IndexError from user generator
    163         def iterable():
    164             s = list(range(10))
    165             for i in range(20):
    166                 yield s[i]       # IndexError when i > 10
    167         with self.assertRaises(IndexError):
    168             list(self.module.merge(iterable(), iterable()))
    169 
    170     def test_merge_stability(self):
    171         class Int(int):
    172             pass
    173         inputs = [[], [], [], []]
    174         for i in range(20000):
    175             stream = random.randrange(4)
    176             x = random.randrange(500)
    177             obj = Int(x)
    178             obj.pair = (x, stream)
    179             inputs[stream].append(obj)
    180         for stream in inputs:
    181             stream.sort()
    182         result = [i.pair for i in self.module.merge(*inputs)]
    183         self.assertEqual(result, sorted(result))
    184 
    185     def test_nsmallest(self):
    186         data = [(random.randrange(2000), i) for i in range(1000)]
    187         for f in (None, lambda x:  x[0] * 547 % 2000):
    188             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
    189                 self.assertEqual(self.module.nsmallest(n, data), sorted(data)[:n])
    190                 self.assertEqual(self.module.nsmallest(n, data, key=f),
    191                                  sorted(data, key=f)[:n])
    192 
    193     def test_nlargest(self):
    194         data = [(random.randrange(2000), i) for i in range(1000)]
    195         for f in (None, lambda x:  x[0] * 547 % 2000):
    196             for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
    197                 self.assertEqual(self.module.nlargest(n, data),
    198                                  sorted(data, reverse=True)[:n])
    199                 self.assertEqual(self.module.nlargest(n, data, key=f),
    200                                  sorted(data, key=f, reverse=True)[:n])
    201 
    202     def test_comparison_operator(self):
    203         # Issue 3051: Make sure heapq works with both __lt__ and __le__
    204         def hsort(data, comp):
    205             data = map(comp, data)
    206             self.module.heapify(data)
    207             return [self.module.heappop(data).x for i in range(len(data))]
    208         class LT:
    209             def __init__(self, x):
    210                 self.x = x
    211             def __lt__(self, other):
    212                 return self.x > other.x
    213         class LE:
    214             def __init__(self, x):
    215                 self.x = x
    216             def __le__(self, other):
    217                 return self.x >= other.x
    218         data = [random.random() for i in range(100)]
    219         target = sorted(data, reverse=True)
    220         self.assertEqual(hsort(data, LT), target)
    221         self.assertEqual(hsort(data, LE), target)
    222 
    223 
    224 class TestHeapPython(TestHeap):
    225     module = py_heapq
    226 
    227 
    228 @skipUnless(c_heapq, 'requires _heapq')
    229 class TestHeapC(TestHeap):
    230     module = c_heapq
    231 
    232 
    233 #==============================================================================
    234 
    235 class LenOnly:
    236     "Dummy sequence class defining __len__ but not __getitem__."
    237     def __len__(self):
    238         return 10
    239 
    240 class GetOnly:
    241     "Dummy sequence class defining __getitem__ but not __len__."
    242     def __getitem__(self, ndx):
    243         return 10
    244 
    245 class CmpErr:
    246     "Dummy element that always raises an error during comparison"
    247     def __cmp__(self, other):
    248         raise ZeroDivisionError
    249 
    250 def R(seqn):
    251     'Regular generator'
    252     for i in seqn:
    253         yield i
    254 
    255 class G:
    256     'Sequence using __getitem__'
    257     def __init__(self, seqn):
    258         self.seqn = seqn
    259     def __getitem__(self, i):
    260         return self.seqn[i]
    261 
    262 class I:
    263     'Sequence using iterator protocol'
    264     def __init__(self, seqn):
    265         self.seqn = seqn
    266         self.i = 0
    267     def __iter__(self):
    268         return self
    269     def next(self):
    270         if self.i >= len(self.seqn): raise StopIteration
    271         v = self.seqn[self.i]
    272         self.i += 1
    273         return v
    274 
    275 class Ig:
    276     'Sequence using iterator protocol defined with a generator'
    277     def __init__(self, seqn):
    278         self.seqn = seqn
    279         self.i = 0
    280     def __iter__(self):
    281         for val in self.seqn:
    282             yield val
    283 
    284 class X:
    285     'Missing __getitem__ and __iter__'
    286     def __init__(self, seqn):
    287         self.seqn = seqn
    288         self.i = 0
    289     def next(self):
    290         if self.i >= len(self.seqn): raise StopIteration
    291         v = self.seqn[self.i]
    292         self.i += 1
    293         return v
    294 
    295 class N:
    296     'Iterator missing next()'
    297     def __init__(self, seqn):
    298         self.seqn = seqn
    299         self.i = 0
    300     def __iter__(self):
    301         return self
    302 
    303 class E:
    304     'Test propagation of exceptions'
    305     def __init__(self, seqn):
    306         self.seqn = seqn
    307         self.i = 0
    308     def __iter__(self):
    309         return self
    310     def next(self):
    311         3 // 0
    312 
    313 class S:
    314     'Test immediate stop'
    315     def __init__(self, seqn):
    316         pass
    317     def __iter__(self):
    318         return self
    319     def next(self):
    320         raise StopIteration
    321 
    322 from itertools import chain, imap
    323 def L(seqn):
    324     'Test multiple tiers of iterators'
    325     return chain(imap(lambda x:x, R(Ig(G(seqn)))))
    326 
    327 class SideEffectLT:
    328     def __init__(self, value, heap):
    329         self.value = value
    330         self.heap = heap
    331 
    332     def __lt__(self, other):
    333         self.heap[:] = []
    334         return self.value < other.value
    335 
    336 
    337 class TestErrorHandling(TestCase):
    338     module = None
    339 
    340     def test_non_sequence(self):
    341         for f in (self.module.heapify, self.module.heappop):
    342             self.assertRaises((TypeError, AttributeError), f, 10)
    343         for f in (self.module.heappush, self.module.heapreplace,
    344                   self.module.nlargest, self.module.nsmallest):
    345             self.assertRaises((TypeError, AttributeError), f, 10, 10)
    346 
    347     def test_len_only(self):
    348         for f in (self.module.heapify, self.module.heappop):
    349             self.assertRaises((TypeError, AttributeError), f, LenOnly())
    350         for f in (self.module.heappush, self.module.heapreplace):
    351             self.assertRaises((TypeError, AttributeError), f, LenOnly(), 10)
    352         for f in (self.module.nlargest, self.module.nsmallest):
    353             self.assertRaises(TypeError, f, 2, LenOnly())
    354 
    355     def test_get_only(self):
    356         seq = [CmpErr(), CmpErr(), CmpErr()]
    357         for f in (self.module.heapify, self.module.heappop):
    358             self.assertRaises(ZeroDivisionError, f, seq)
    359         for f in (self.module.heappush, self.module.heapreplace):
    360             self.assertRaises(ZeroDivisionError, f, seq, 10)
    361         for f in (self.module.nlargest, self.module.nsmallest):
    362             self.assertRaises(ZeroDivisionError, f, 2, seq)
    363 
    364     def test_arg_parsing(self):
    365         for f in (self.module.heapify, self.module.heappop,
    366                   self.module.heappush, self.module.heapreplace,
    367                   self.module.nlargest, self.module.nsmallest):
    368             self.assertRaises((TypeError, AttributeError), f, 10)
    369 
    370     def test_iterable_args(self):
    371         for f in (self.module.nlargest, self.module.nsmallest):
    372             for s in ("123", "", range(1000), ('do', 1.2), xrange(2000,2200,5)):
    373                 for g in (G, I, Ig, L, R):
    374                     with test_support.check_py3k_warnings(
    375                             ("comparing unequal types not supported",
    376                              DeprecationWarning), quiet=True):
    377                         self.assertEqual(f(2, g(s)), f(2,s))
    378                 self.assertEqual(f(2, S(s)), [])
    379                 self.assertRaises(TypeError, f, 2, X(s))
    380                 self.assertRaises(TypeError, f, 2, N(s))
    381                 self.assertRaises(ZeroDivisionError, f, 2, E(s))
    382 
    383     # Issue #17278: the heap may change size while it's being walked.
    384 
    385     def test_heappush_mutating_heap(self):
    386         heap = []
    387         heap.extend(SideEffectLT(i, heap) for i in range(200))
    388         # Python version raises IndexError, C version RuntimeError
    389         with self.assertRaises((IndexError, RuntimeError)):
    390             self.module.heappush(heap, SideEffectLT(5, heap))
    391 
    392     def test_heappop_mutating_heap(self):
    393         heap = []
    394         heap.extend(SideEffectLT(i, heap) for i in range(200))
    395         # Python version raises IndexError, C version RuntimeError
    396         with self.assertRaises((IndexError, RuntimeError)):
    397             self.module.heappop(heap)
    398 
    399 
    400 class TestErrorHandlingPython(TestErrorHandling):
    401     module = py_heapq
    402 
    403 
    404 @skipUnless(c_heapq, 'requires _heapq')
    405 class TestErrorHandlingC(TestErrorHandling):
    406     module = c_heapq
    407 
    408 
    409 #==============================================================================
    410 
    411 
    412 def test_main(verbose=None):
    413     test_classes = [TestModules, TestHeapPython, TestHeapC,
    414                     TestErrorHandlingPython, TestErrorHandlingC]
    415     test_support.run_unittest(*test_classes)
    416 
    417     # verify reference counting
    418     if verbose and hasattr(sys, "gettotalrefcount"):
    419         import gc
    420         counts = [None] * 5
    421         for i in xrange(len(counts)):
    422             test_support.run_unittest(*test_classes)
    423             gc.collect()
    424             counts[i] = sys.gettotalrefcount()
    425         print counts
    426 
    427 if __name__ == "__main__":
    428     test_main(verbose=True)
    429