Home | History | Annotate | Download | only in test
      1 """Unit tests for collections.defaultdict."""
      2 
      3 import os
      4 import copy
      5 import tempfile
      6 import unittest
      7 from test import test_support
      8 
      9 from collections import defaultdict
     10 
     11 def foobar():
     12     return list
     13 
     14 class TestDefaultDict(unittest.TestCase):
     15 
     16     def test_basic(self):
     17         d1 = defaultdict()
     18         self.assertEqual(d1.default_factory, None)
     19         d1.default_factory = list
     20         d1[12].append(42)
     21         self.assertEqual(d1, {12: [42]})
     22         d1[12].append(24)
     23         self.assertEqual(d1, {12: [42, 24]})
     24         d1[13]
     25         d1[14]
     26         self.assertEqual(d1, {12: [42, 24], 13: [], 14: []})
     27         self.assertTrue(d1[12] is not d1[13] is not d1[14])
     28         d2 = defaultdict(list, foo=1, bar=2)
     29         self.assertEqual(d2.default_factory, list)
     30         self.assertEqual(d2, {"foo": 1, "bar": 2})
     31         self.assertEqual(d2["foo"], 1)
     32         self.assertEqual(d2["bar"], 2)
     33         self.assertEqual(d2[42], [])
     34         self.assertIn("foo", d2)
     35         self.assertIn("foo", d2.keys())
     36         self.assertIn("bar", d2)
     37         self.assertIn("bar", d2.keys())
     38         self.assertIn(42, d2)
     39         self.assertIn(42, d2.keys())
     40         self.assertNotIn(12, d2)
     41         self.assertNotIn(12, d2.keys())
     42         d2.default_factory = None
     43         self.assertEqual(d2.default_factory, None)
     44         try:
     45             d2[15]
     46         except KeyError, err:
     47             self.assertEqual(err.args, (15,))
     48         else:
     49             self.fail("d2[15] didn't raise KeyError")
     50         self.assertRaises(TypeError, defaultdict, 1)
     51 
     52     def test_missing(self):
     53         d1 = defaultdict()
     54         self.assertRaises(KeyError, d1.__missing__, 42)
     55         d1.default_factory = list
     56         self.assertEqual(d1.__missing__(42), [])
     57 
     58     def test_repr(self):
     59         d1 = defaultdict()
     60         self.assertEqual(d1.default_factory, None)
     61         self.assertEqual(repr(d1), "defaultdict(None, {})")
     62         self.assertEqual(eval(repr(d1)), d1)
     63         d1[11] = 41
     64         self.assertEqual(repr(d1), "defaultdict(None, {11: 41})")
     65         d2 = defaultdict(int)
     66         self.assertEqual(d2.default_factory, int)
     67         d2[12] = 42
     68         self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})")
     69         def foo(): return 43
     70         d3 = defaultdict(foo)
     71         self.assertTrue(d3.default_factory is foo)
     72         d3[13]
     73         self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo))
     74 
     75     def test_print(self):
     76         d1 = defaultdict()
     77         def foo(): return 42
     78         d2 = defaultdict(foo, {1: 2})
     79         # NOTE: We can't use tempfile.[Named]TemporaryFile since this
     80         # code must exercise the tp_print C code, which only gets
     81         # invoked for *real* files.
     82         tfn = tempfile.mktemp()
     83         try:
     84             f = open(tfn, "w+")
     85             try:
     86                 print >>f, d1
     87                 print >>f, d2
     88                 f.seek(0)
     89                 self.assertEqual(f.readline(), repr(d1) + "\n")
     90                 self.assertEqual(f.readline(), repr(d2) + "\n")
     91             finally:
     92                 f.close()
     93         finally:
     94             os.remove(tfn)
     95 
     96     def test_copy(self):
     97         d1 = defaultdict()
     98         d2 = d1.copy()
     99         self.assertEqual(type(d2), defaultdict)
    100         self.assertEqual(d2.default_factory, None)
    101         self.assertEqual(d2, {})
    102         d1.default_factory = list
    103         d3 = d1.copy()
    104         self.assertEqual(type(d3), defaultdict)
    105         self.assertEqual(d3.default_factory, list)
    106         self.assertEqual(d3, {})
    107         d1[42]
    108         d4 = d1.copy()
    109         self.assertEqual(type(d4), defaultdict)
    110         self.assertEqual(d4.default_factory, list)
    111         self.assertEqual(d4, {42: []})
    112         d4[12]
    113         self.assertEqual(d4, {42: [], 12: []})
    114 
    115         # Issue 6637: Copy fails for empty default dict
    116         d = defaultdict()
    117         d['a'] = 42
    118         e = d.copy()
    119         self.assertEqual(e['a'], 42)
    120 
    121     def test_shallow_copy(self):
    122         d1 = defaultdict(foobar, {1: 1})
    123         d2 = copy.copy(d1)
    124         self.assertEqual(d2.default_factory, foobar)
    125         self.assertEqual(d2, d1)
    126         d1.default_factory = list
    127         d2 = copy.copy(d1)
    128         self.assertEqual(d2.default_factory, list)
    129         self.assertEqual(d2, d1)
    130 
    131     def test_deep_copy(self):
    132         d1 = defaultdict(foobar, {1: [1]})
    133         d2 = copy.deepcopy(d1)
    134         self.assertEqual(d2.default_factory, foobar)
    135         self.assertEqual(d2, d1)
    136         self.assertTrue(d1[1] is not d2[1])
    137         d1.default_factory = list
    138         d2 = copy.deepcopy(d1)
    139         self.assertEqual(d2.default_factory, list)
    140         self.assertEqual(d2, d1)
    141 
    142     def test_keyerror_without_factory(self):
    143         d1 = defaultdict()
    144         try:
    145             d1[(1,)]
    146         except KeyError, err:
    147             self.assertEqual(err.args[0], (1,))
    148         else:
    149             self.fail("expected KeyError")
    150 
    151     def test_recursive_repr(self):
    152         # Issue2045: stack overflow when default_factory is a bound method
    153         class sub(defaultdict):
    154             def __init__(self):
    155                 self.default_factory = self._factory
    156             def _factory(self):
    157                 return []
    158         d = sub()
    159         self.assertTrue(repr(d).startswith(
    160             "defaultdict(<bound method sub._factory of defaultdict(..."))
    161 
    162         # NOTE: printing a subclass of a builtin type does not call its
    163         # tp_print slot. So this part is essentially the same test as above.
    164         tfn = tempfile.mktemp()
    165         try:
    166             f = open(tfn, "w+")
    167             try:
    168                 print >>f, d
    169             finally:
    170                 f.close()
    171         finally:
    172             os.remove(tfn)
    173 
    174     def test_callable_arg(self):
    175         self.assertRaises(TypeError, defaultdict, {})
    176 
    177 def test_main():
    178     test_support.run_unittest(TestDefaultDict)
    179 
    180 if __name__ == "__main__":
    181     test_main()
    182