1 """Unit tests for collections.defaultdict.""" 2 3 import os 4 import copy 5 import tempfile 6 import unittest 7 from test import test_support 8 9 from collections import defaultdict 10 11 def foobar(): 12 return list 13 14 class TestDefaultDict(unittest.TestCase): 15 16 def test_basic(self): 17 d1 = defaultdict() 18 self.assertEqual(d1.default_factory, None) 19 d1.default_factory = list 20 d1[12].append(42) 21 self.assertEqual(d1, {12: [42]}) 22 d1[12].append(24) 23 self.assertEqual(d1, {12: [42, 24]}) 24 d1[13] 25 d1[14] 26 self.assertEqual(d1, {12: [42, 24], 13: [], 14: []}) 27 self.assertTrue(d1[12] is not d1[13] is not d1[14]) 28 d2 = defaultdict(list, foo=1, bar=2) 29 self.assertEqual(d2.default_factory, list) 30 self.assertEqual(d2, {"foo": 1, "bar": 2}) 31 self.assertEqual(d2["foo"], 1) 32 self.assertEqual(d2["bar"], 2) 33 self.assertEqual(d2[42], []) 34 self.assertIn("foo", d2) 35 self.assertIn("foo", d2.keys()) 36 self.assertIn("bar", d2) 37 self.assertIn("bar", d2.keys()) 38 self.assertIn(42, d2) 39 self.assertIn(42, d2.keys()) 40 self.assertNotIn(12, d2) 41 self.assertNotIn(12, d2.keys()) 42 d2.default_factory = None 43 self.assertEqual(d2.default_factory, None) 44 try: 45 d2[15] 46 except KeyError, err: 47 self.assertEqual(err.args, (15,)) 48 else: 49 self.fail("d2[15] didn't raise KeyError") 50 self.assertRaises(TypeError, defaultdict, 1) 51 52 def test_missing(self): 53 d1 = defaultdict() 54 self.assertRaises(KeyError, d1.__missing__, 42) 55 d1.default_factory = list 56 self.assertEqual(d1.__missing__(42), []) 57 58 def test_repr(self): 59 d1 = defaultdict() 60 self.assertEqual(d1.default_factory, None) 61 self.assertEqual(repr(d1), "defaultdict(None, {})") 62 self.assertEqual(eval(repr(d1)), d1) 63 d1[11] = 41 64 self.assertEqual(repr(d1), "defaultdict(None, {11: 41})") 65 d2 = defaultdict(int) 66 self.assertEqual(d2.default_factory, int) 67 d2[12] = 42 68 self.assertEqual(repr(d2), "defaultdict(<type 'int'>, {12: 42})") 69 def foo(): return 43 70 d3 = defaultdict(foo) 71 self.assertTrue(d3.default_factory is foo) 72 d3[13] 73 self.assertEqual(repr(d3), "defaultdict(%s, {13: 43})" % repr(foo)) 74 75 def test_print(self): 76 d1 = defaultdict() 77 def foo(): return 42 78 d2 = defaultdict(foo, {1: 2}) 79 # NOTE: We can't use tempfile.[Named]TemporaryFile since this 80 # code must exercise the tp_print C code, which only gets 81 # invoked for *real* files. 82 tfn = tempfile.mktemp() 83 try: 84 f = open(tfn, "w+") 85 try: 86 print >>f, d1 87 print >>f, d2 88 f.seek(0) 89 self.assertEqual(f.readline(), repr(d1) + "\n") 90 self.assertEqual(f.readline(), repr(d2) + "\n") 91 finally: 92 f.close() 93 finally: 94 os.remove(tfn) 95 96 def test_copy(self): 97 d1 = defaultdict() 98 d2 = d1.copy() 99 self.assertEqual(type(d2), defaultdict) 100 self.assertEqual(d2.default_factory, None) 101 self.assertEqual(d2, {}) 102 d1.default_factory = list 103 d3 = d1.copy() 104 self.assertEqual(type(d3), defaultdict) 105 self.assertEqual(d3.default_factory, list) 106 self.assertEqual(d3, {}) 107 d1[42] 108 d4 = d1.copy() 109 self.assertEqual(type(d4), defaultdict) 110 self.assertEqual(d4.default_factory, list) 111 self.assertEqual(d4, {42: []}) 112 d4[12] 113 self.assertEqual(d4, {42: [], 12: []}) 114 115 # Issue 6637: Copy fails for empty default dict 116 d = defaultdict() 117 d['a'] = 42 118 e = d.copy() 119 self.assertEqual(e['a'], 42) 120 121 def test_shallow_copy(self): 122 d1 = defaultdict(foobar, {1: 1}) 123 d2 = copy.copy(d1) 124 self.assertEqual(d2.default_factory, foobar) 125 self.assertEqual(d2, d1) 126 d1.default_factory = list 127 d2 = copy.copy(d1) 128 self.assertEqual(d2.default_factory, list) 129 self.assertEqual(d2, d1) 130 131 def test_deep_copy(self): 132 d1 = defaultdict(foobar, {1: [1]}) 133 d2 = copy.deepcopy(d1) 134 self.assertEqual(d2.default_factory, foobar) 135 self.assertEqual(d2, d1) 136 self.assertTrue(d1[1] is not d2[1]) 137 d1.default_factory = list 138 d2 = copy.deepcopy(d1) 139 self.assertEqual(d2.default_factory, list) 140 self.assertEqual(d2, d1) 141 142 def test_keyerror_without_factory(self): 143 d1 = defaultdict() 144 try: 145 d1[(1,)] 146 except KeyError, err: 147 self.assertEqual(err.args[0], (1,)) 148 else: 149 self.fail("expected KeyError") 150 151 def test_recursive_repr(self): 152 # Issue2045: stack overflow when default_factory is a bound method 153 class sub(defaultdict): 154 def __init__(self): 155 self.default_factory = self._factory 156 def _factory(self): 157 return [] 158 d = sub() 159 self.assertTrue(repr(d).startswith( 160 "defaultdict(<bound method sub._factory of defaultdict(...")) 161 162 # NOTE: printing a subclass of a builtin type does not call its 163 # tp_print slot. So this part is essentially the same test as above. 164 tfn = tempfile.mktemp() 165 try: 166 f = open(tfn, "w+") 167 try: 168 print >>f, d 169 finally: 170 f.close() 171 finally: 172 os.remove(tfn) 173 174 def test_callable_arg(self): 175 self.assertRaises(TypeError, defaultdict, {}) 176 177 def test_main(): 178 test_support.run_unittest(TestDefaultDict) 179 180 if __name__ == "__main__": 181 test_main() 182