Home | History | Annotate | Download | only in test
      1 # Deliberately use "from dataclasses import *".  Every name in __all__
      2 # is tested, so they all must be present.  This is a way to catch
      3 # missing ones.
      4 
      5 from dataclasses import *
      6 
      7 import pickle
      8 import inspect
      9 import builtins
     10 import unittest
     11 from unittest.mock import Mock
     12 from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
     13 from collections import deque, OrderedDict, namedtuple
     14 from functools import total_ordering
     15 
     16 import typing       # Needed for the string "typing.ClassVar[int]" to work as an annotation.
     17 import dataclasses  # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
     18 
     19 # Just any custom exception we can catch.
     20 class CustomError(Exception): pass
     21 
     22 class TestCase(unittest.TestCase):
     23     def test_no_fields(self):
     24         @dataclass
     25         class C:
     26             pass
     27 
     28         o = C()
     29         self.assertEqual(len(fields(C)), 0)
     30 
     31     def test_no_fields_but_member_variable(self):
     32         @dataclass
     33         class C:
     34             i = 0
     35 
     36         o = C()
     37         self.assertEqual(len(fields(C)), 0)
     38 
     39     def test_one_field_no_default(self):
     40         @dataclass
     41         class C:
     42             x: int
     43 
     44         o = C(42)
     45         self.assertEqual(o.x, 42)
     46 
     47     def test_named_init_params(self):
     48         @dataclass
     49         class C:
     50             x: int
     51 
     52         o = C(x=32)
     53         self.assertEqual(o.x, 32)
     54 
     55     def test_two_fields_one_default(self):
     56         @dataclass
     57         class C:
     58             x: int
     59             y: int = 0
     60 
     61         o = C(3)
     62         self.assertEqual((o.x, o.y), (3, 0))
     63 
     64         # Non-defaults following defaults.
     65         with self.assertRaisesRegex(TypeError,
     66                                     "non-default argument 'y' follows "
     67                                     "default argument"):
     68             @dataclass
     69             class C:
     70                 x: int = 0
     71                 y: int
     72 
     73         # A derived class adds a non-default field after a default one.
     74         with self.assertRaisesRegex(TypeError,
     75                                     "non-default argument 'y' follows "
     76                                     "default argument"):
     77             @dataclass
     78             class B:
     79                 x: int = 0
     80 
     81             @dataclass
     82             class C(B):
     83                 y: int
     84 
     85         # Override a base class field and add a default to
     86         #  a field which didn't use to have a default.
     87         with self.assertRaisesRegex(TypeError,
     88                                     "non-default argument 'y' follows "
     89                                     "default argument"):
     90             @dataclass
     91             class B:
     92                 x: int
     93                 y: int
     94 
     95             @dataclass
     96             class C(B):
     97                 x: int = 0
     98 
     99     def test_overwrite_hash(self):
    100         # Test that declaring this class isn't an error.  It should
    101         #  use the user-provided __hash__.
    102         @dataclass(frozen=True)
    103         class C:
    104             x: int
    105             def __hash__(self):
    106                 return 301
    107         self.assertEqual(hash(C(100)), 301)
    108 
    109         # Test that declaring this class isn't an error.  It should
    110         #  use the generated __hash__.
    111         @dataclass(frozen=True)
    112         class C:
    113             x: int
    114             def __eq__(self, other):
    115                 return False
    116         self.assertEqual(hash(C(100)), hash((100,)))
    117 
    118         # But this one should generate an exception, because with
    119         #  unsafe_hash=True, it's an error to have a __hash__ defined.
    120         with self.assertRaisesRegex(TypeError,
    121                                     'Cannot overwrite attribute __hash__'):
    122             @dataclass(unsafe_hash=True)
    123             class C:
    124                 def __hash__(self):
    125                     pass
    126 
    127         # Creating this class should not generate an exception,
    128         #  because even though __hash__ exists before @dataclass is
    129         #  called, (due to __eq__ being defined), since it's None
    130         #  that's okay.
    131         @dataclass(unsafe_hash=True)
    132         class C:
    133             x: int
    134             def __eq__(self):
    135                 pass
    136         # The generated hash function works as we'd expect.
    137         self.assertEqual(hash(C(10)), hash((10,)))
    138 
    139         # Creating this class should generate an exception, because
    140         #  __hash__ exists and is not None, which it would be if it
    141         #  had been auto-generated due to __eq__ being defined.
    142         with self.assertRaisesRegex(TypeError,
    143                                     'Cannot overwrite attribute __hash__'):
    144             @dataclass(unsafe_hash=True)
    145             class C:
    146                 x: int
    147                 def __eq__(self):
    148                     pass
    149                 def __hash__(self):
    150                     pass
    151 
    152     def test_overwrite_fields_in_derived_class(self):
    153         # Note that x from C1 replaces x in Base, but the order remains
    154         #  the same as defined in Base.
    155         @dataclass
    156         class Base:
    157             x: Any = 15.0
    158             y: int = 0
    159 
    160         @dataclass
    161         class C1(Base):
    162             z: int = 10
    163             x: int = 15
    164 
    165         o = Base()
    166         self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
    167 
    168         o = C1()
    169         self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
    170 
    171         o = C1(x=5)
    172         self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
    173 
    174     def test_field_named_self(self):
    175         @dataclass
    176         class C:
    177             self: str
    178         c=C('foo')
    179         self.assertEqual(c.self, 'foo')
    180 
    181         # Make sure the first parameter is not named 'self'.
    182         sig = inspect.signature(C.__init__)
    183         first = next(iter(sig.parameters))
    184         self.assertNotEqual('self', first)
    185 
    186         # But we do use 'self' if no field named self.
    187         @dataclass
    188         class C:
    189             selfx: str
    190 
    191         # Make sure the first parameter is named 'self'.
    192         sig = inspect.signature(C.__init__)
    193         first = next(iter(sig.parameters))
    194         self.assertEqual('self', first)
    195 
    196     def test_field_named_object(self):
    197         @dataclass
    198         class C:
    199             object: str
    200         c = C('foo')
    201         self.assertEqual(c.object, 'foo')
    202 
    203     def test_field_named_object_frozen(self):
    204         @dataclass(frozen=True)
    205         class C:
    206             object: str
    207         c = C('foo')
    208         self.assertEqual(c.object, 'foo')
    209 
    210     def test_field_named_like_builtin(self):
    211         # Attribute names can shadow built-in names
    212         # since code generation is used.
    213         # Ensure that this is not happening.
    214         exclusions = {'None', 'True', 'False'}
    215         builtins_names = sorted(
    216             b for b in builtins.__dict__.keys()
    217             if not b.startswith('__') and b not in exclusions
    218         )
    219         attributes = [(name, str) for name in builtins_names]
    220         C = make_dataclass('C', attributes)
    221 
    222         c = C(*[name for name in builtins_names])
    223 
    224         for name in builtins_names:
    225             self.assertEqual(getattr(c, name), name)
    226 
    227     def test_field_named_like_builtin_frozen(self):
    228         # Attribute names can shadow built-in names
    229         # since code generation is used.
    230         # Ensure that this is not happening
    231         # for frozen data classes.
    232         exclusions = {'None', 'True', 'False'}
    233         builtins_names = sorted(
    234             b for b in builtins.__dict__.keys()
    235             if not b.startswith('__') and b not in exclusions
    236         )
    237         attributes = [(name, str) for name in builtins_names]
    238         C = make_dataclass('C', attributes, frozen=True)
    239 
    240         c = C(*[name for name in builtins_names])
    241 
    242         for name in builtins_names:
    243             self.assertEqual(getattr(c, name), name)
    244 
    245     def test_0_field_compare(self):
    246         # Ensure that order=False is the default.
    247         @dataclass
    248         class C0:
    249             pass
    250 
    251         @dataclass(order=False)
    252         class C1:
    253             pass
    254 
    255         for cls in [C0, C1]:
    256             with self.subTest(cls=cls):
    257                 self.assertEqual(cls(), cls())
    258                 for idx, fn in enumerate([lambda a, b: a < b,
    259                                           lambda a, b: a <= b,
    260                                           lambda a, b: a > b,
    261                                           lambda a, b: a >= b]):
    262                     with self.subTest(idx=idx):
    263                         with self.assertRaisesRegex(TypeError,
    264                                                     f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
    265                             fn(cls(), cls())
    266 
    267         @dataclass(order=True)
    268         class C:
    269             pass
    270         self.assertLessEqual(C(), C())
    271         self.assertGreaterEqual(C(), C())
    272 
    273     def test_1_field_compare(self):
    274         # Ensure that order=False is the default.
    275         @dataclass
    276         class C0:
    277             x: int
    278 
    279         @dataclass(order=False)
    280         class C1:
    281             x: int
    282 
    283         for cls in [C0, C1]:
    284             with self.subTest(cls=cls):
    285                 self.assertEqual(cls(1), cls(1))
    286                 self.assertNotEqual(cls(0), cls(1))
    287                 for idx, fn in enumerate([lambda a, b: a < b,
    288                                           lambda a, b: a <= b,
    289                                           lambda a, b: a > b,
    290                                           lambda a, b: a >= b]):
    291                     with self.subTest(idx=idx):
    292                         with self.assertRaisesRegex(TypeError,
    293                                                     f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
    294                             fn(cls(0), cls(0))
    295 
    296         @dataclass(order=True)
    297         class C:
    298             x: int
    299         self.assertLess(C(0), C(1))
    300         self.assertLessEqual(C(0), C(1))
    301         self.assertLessEqual(C(1), C(1))
    302         self.assertGreater(C(1), C(0))
    303         self.assertGreaterEqual(C(1), C(0))
    304         self.assertGreaterEqual(C(1), C(1))
    305 
    306     def test_simple_compare(self):
    307         # Ensure that order=False is the default.
    308         @dataclass
    309         class C0:
    310             x: int
    311             y: int
    312 
    313         @dataclass(order=False)
    314         class C1:
    315             x: int
    316             y: int
    317 
    318         for cls in [C0, C1]:
    319             with self.subTest(cls=cls):
    320                 self.assertEqual(cls(0, 0), cls(0, 0))
    321                 self.assertEqual(cls(1, 2), cls(1, 2))
    322                 self.assertNotEqual(cls(1, 0), cls(0, 0))
    323                 self.assertNotEqual(cls(1, 0), cls(1, 1))
    324                 for idx, fn in enumerate([lambda a, b: a < b,
    325                                           lambda a, b: a <= b,
    326                                           lambda a, b: a > b,
    327                                           lambda a, b: a >= b]):
    328                     with self.subTest(idx=idx):
    329                         with self.assertRaisesRegex(TypeError,
    330                                                     f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
    331                             fn(cls(0, 0), cls(0, 0))
    332 
    333         @dataclass(order=True)
    334         class C:
    335             x: int
    336             y: int
    337 
    338         for idx, fn in enumerate([lambda a, b: a == b,
    339                                   lambda a, b: a <= b,
    340                                   lambda a, b: a >= b]):
    341             with self.subTest(idx=idx):
    342                 self.assertTrue(fn(C(0, 0), C(0, 0)))
    343 
    344         for idx, fn in enumerate([lambda a, b: a < b,
    345                                   lambda a, b: a <= b,
    346                                   lambda a, b: a != b]):
    347             with self.subTest(idx=idx):
    348                 self.assertTrue(fn(C(0, 0), C(0, 1)))
    349                 self.assertTrue(fn(C(0, 1), C(1, 0)))
    350                 self.assertTrue(fn(C(1, 0), C(1, 1)))
    351 
    352         for idx, fn in enumerate([lambda a, b: a > b,
    353                                   lambda a, b: a >= b,
    354                                   lambda a, b: a != b]):
    355             with self.subTest(idx=idx):
    356                 self.assertTrue(fn(C(0, 1), C(0, 0)))
    357                 self.assertTrue(fn(C(1, 0), C(0, 1)))
    358                 self.assertTrue(fn(C(1, 1), C(1, 0)))
    359 
    360     def test_compare_subclasses(self):
    361         # Comparisons fail for subclasses, even if no fields
    362         #  are added.
    363         @dataclass
    364         class B:
    365             i: int
    366 
    367         @dataclass
    368         class C(B):
    369             pass
    370 
    371         for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
    372                                               (lambda a, b: a != b, True)]):
    373             with self.subTest(idx=idx):
    374                 self.assertEqual(fn(B(0), C(0)), expected)
    375 
    376         for idx, fn in enumerate([lambda a, b: a < b,
    377                                   lambda a, b: a <= b,
    378                                   lambda a, b: a > b,
    379                                   lambda a, b: a >= b]):
    380             with self.subTest(idx=idx):
    381                 with self.assertRaisesRegex(TypeError,
    382                                             "not supported between instances of 'B' and 'C'"):
    383                     fn(B(0), C(0))
    384 
    385     def test_eq_order(self):
    386         # Test combining eq and order.
    387         for (eq,    order, result   ) in [
    388             (False, False, 'neither'),
    389             (False, True,  'exception'),
    390             (True,  False, 'eq_only'),
    391             (True,  True,  'both'),
    392         ]:
    393             with self.subTest(eq=eq, order=order):
    394                 if result == 'exception':
    395                     with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
    396                         @dataclass(eq=eq, order=order)
    397                         class C:
    398                             pass
    399                 else:
    400                     @dataclass(eq=eq, order=order)
    401                     class C:
    402                         pass
    403 
    404                     if result == 'neither':
    405                         self.assertNotIn('__eq__', C.__dict__)
    406                         self.assertNotIn('__lt__', C.__dict__)
    407                         self.assertNotIn('__le__', C.__dict__)
    408                         self.assertNotIn('__gt__', C.__dict__)
    409                         self.assertNotIn('__ge__', C.__dict__)
    410                     elif result == 'both':
    411                         self.assertIn('__eq__', C.__dict__)
    412                         self.assertIn('__lt__', C.__dict__)
    413                         self.assertIn('__le__', C.__dict__)
    414                         self.assertIn('__gt__', C.__dict__)
    415                         self.assertIn('__ge__', C.__dict__)
    416                     elif result == 'eq_only':
    417                         self.assertIn('__eq__', C.__dict__)
    418                         self.assertNotIn('__lt__', C.__dict__)
    419                         self.assertNotIn('__le__', C.__dict__)
    420                         self.assertNotIn('__gt__', C.__dict__)
    421                         self.assertNotIn('__ge__', C.__dict__)
    422                     else:
    423                         assert False, f'unknown result {result!r}'
    424 
    425     def test_field_no_default(self):
    426         @dataclass
    427         class C:
    428             x: int = field()
    429 
    430         self.assertEqual(C(5).x, 5)
    431 
    432         with self.assertRaisesRegex(TypeError,
    433                                     r"__init__\(\) missing 1 required "
    434                                     "positional argument: 'x'"):
    435             C()
    436 
    437     def test_field_default(self):
    438         default = object()
    439         @dataclass
    440         class C:
    441             x: object = field(default=default)
    442 
    443         self.assertIs(C.x, default)
    444         c = C(10)
    445         self.assertEqual(c.x, 10)
    446 
    447         # If we delete the instance attribute, we should then see the
    448         #  class attribute.
    449         del c.x
    450         self.assertIs(c.x, default)
    451 
    452         self.assertIs(C().x, default)
    453 
    454     def test_not_in_repr(self):
    455         @dataclass
    456         class C:
    457             x: int = field(repr=False)
    458         with self.assertRaises(TypeError):
    459             C()
    460         c = C(10)
    461         self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
    462 
    463         @dataclass
    464         class C:
    465             x: int = field(repr=False)
    466             y: int
    467         c = C(10, 20)
    468         self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
    469 
    470     def test_not_in_compare(self):
    471         @dataclass
    472         class C:
    473             x: int = 0
    474             y: int = field(compare=False, default=4)
    475 
    476         self.assertEqual(C(), C(0, 20))
    477         self.assertEqual(C(1, 10), C(1, 20))
    478         self.assertNotEqual(C(3), C(4, 10))
    479         self.assertNotEqual(C(3, 10), C(4, 10))
    480 
    481     def test_hash_field_rules(self):
    482         # Test all 6 cases of:
    483         #  hash=True/False/None
    484         #  compare=True/False
    485         for (hash_,    compare, result  ) in [
    486             (True,     False,   'field' ),
    487             (True,     True,    'field' ),
    488             (False,    False,   'absent'),
    489             (False,    True,    'absent'),
    490             (None,     False,   'absent'),
    491             (None,     True,    'field' ),
    492             ]:
    493             with self.subTest(hash=hash_, compare=compare):
    494                 @dataclass(unsafe_hash=True)
    495                 class C:
    496                     x: int = field(compare=compare, hash=hash_, default=5)
    497 
    498                 if result == 'field':
    499                     # __hash__ contains the field.
    500                     self.assertEqual(hash(C(5)), hash((5,)))
    501                 elif result == 'absent':
    502                     # The field is not present in the hash.
    503                     self.assertEqual(hash(C(5)), hash(()))
    504                 else:
    505                     assert False, f'unknown result {result!r}'
    506 
    507     def test_init_false_no_default(self):
    508         # If init=False and no default value, then the field won't be
    509         #  present in the instance.
    510         @dataclass
    511         class C:
    512             x: int = field(init=False)
    513 
    514         self.assertNotIn('x', C().__dict__)
    515 
    516         @dataclass
    517         class C:
    518             x: int
    519             y: int = 0
    520             z: int = field(init=False)
    521             t: int = 10
    522 
    523         self.assertNotIn('z', C(0).__dict__)
    524         self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
    525 
    526     def test_class_marker(self):
    527         @dataclass
    528         class C:
    529             x: int
    530             y: str = field(init=False, default=None)
    531             z: str = field(repr=False)
    532 
    533         the_fields = fields(C)
    534         # the_fields is a tuple of 3 items, each value
    535         #  is in __annotations__.
    536         self.assertIsInstance(the_fields, tuple)
    537         for f in the_fields:
    538             self.assertIs(type(f), Field)
    539             self.assertIn(f.name, C.__annotations__)
    540 
    541         self.assertEqual(len(the_fields), 3)
    542 
    543         self.assertEqual(the_fields[0].name, 'x')
    544         self.assertEqual(the_fields[0].type, int)
    545         self.assertFalse(hasattr(C, 'x'))
    546         self.assertTrue (the_fields[0].init)
    547         self.assertTrue (the_fields[0].repr)
    548         self.assertEqual(the_fields[1].name, 'y')
    549         self.assertEqual(the_fields[1].type, str)
    550         self.assertIsNone(getattr(C, 'y'))
    551         self.assertFalse(the_fields[1].init)
    552         self.assertTrue (the_fields[1].repr)
    553         self.assertEqual(the_fields[2].name, 'z')
    554         self.assertEqual(the_fields[2].type, str)
    555         self.assertFalse(hasattr(C, 'z'))
    556         self.assertTrue (the_fields[2].init)
    557         self.assertFalse(the_fields[2].repr)
    558 
    559     def test_field_order(self):
    560         @dataclass
    561         class B:
    562             a: str = 'B:a'
    563             b: str = 'B:b'
    564             c: str = 'B:c'
    565 
    566         @dataclass
    567         class C(B):
    568             b: str = 'C:b'
    569 
    570         self.assertEqual([(f.name, f.default) for f in fields(C)],
    571                          [('a', 'B:a'),
    572                           ('b', 'C:b'),
    573                           ('c', 'B:c')])
    574 
    575         @dataclass
    576         class D(B):
    577             c: str = 'D:c'
    578 
    579         self.assertEqual([(f.name, f.default) for f in fields(D)],
    580                          [('a', 'B:a'),
    581                           ('b', 'B:b'),
    582                           ('c', 'D:c')])
    583 
    584         @dataclass
    585         class E(D):
    586             a: str = 'E:a'
    587             d: str = 'E:d'
    588 
    589         self.assertEqual([(f.name, f.default) for f in fields(E)],
    590                          [('a', 'E:a'),
    591                           ('b', 'B:b'),
    592                           ('c', 'D:c'),
    593                           ('d', 'E:d')])
    594 
    595     def test_class_attrs(self):
    596         # We only have a class attribute if a default value is
    597         #  specified, either directly or via a field with a default.
    598         default = object()
    599         @dataclass
    600         class C:
    601             x: int
    602             y: int = field(repr=False)
    603             z: object = default
    604             t: int = field(default=100)
    605 
    606         self.assertFalse(hasattr(C, 'x'))
    607         self.assertFalse(hasattr(C, 'y'))
    608         self.assertIs   (C.z, default)
    609         self.assertEqual(C.t, 100)
    610 
    611     def test_disallowed_mutable_defaults(self):
    612         # For the known types, don't allow mutable default values.
    613         for typ, empty, non_empty in [(list, [], [1]),
    614                                       (dict, {}, {0:1}),
    615                                       (set, set(), set([1])),
    616                                       ]:
    617             with self.subTest(typ=typ):
    618                 # Can't use a zero-length value.
    619                 with self.assertRaisesRegex(ValueError,
    620                                             f'mutable default {typ} for field '
    621                                             'x is not allowed'):
    622                     @dataclass
    623                     class Point:
    624                         x: typ = empty
    625 
    626 
    627                 # Nor a non-zero-length value
    628                 with self.assertRaisesRegex(ValueError,
    629                                             f'mutable default {typ} for field '
    630                                             'y is not allowed'):
    631                     @dataclass
    632                     class Point:
    633                         y: typ = non_empty
    634 
    635                 # Check subtypes also fail.
    636                 class Subclass(typ): pass
    637 
    638                 with self.assertRaisesRegex(ValueError,
    639                                             f"mutable default .*Subclass'>"
    640                                             ' for field z is not allowed'
    641                                             ):
    642                     @dataclass
    643                     class Point:
    644                         z: typ = Subclass()
    645 
    646                 # Because this is a ClassVar, it can be mutable.
    647                 @dataclass
    648                 class C:
    649                     z: ClassVar[typ] = typ()
    650 
    651                 # Because this is a ClassVar, it can be mutable.
    652                 @dataclass
    653                 class C:
    654                     x: ClassVar[typ] = Subclass()
    655 
    656     def test_deliberately_mutable_defaults(self):
    657         # If a mutable default isn't in the known list of
    658         #  (list, dict, set), then it's okay.
    659         class Mutable:
    660             def __init__(self):
    661                 self.l = []
    662 
    663         @dataclass
    664         class C:
    665             x: Mutable
    666 
    667         # These 2 instances will share this value of x.
    668         lst = Mutable()
    669         o1 = C(lst)
    670         o2 = C(lst)
    671         self.assertEqual(o1, o2)
    672         o1.x.l.extend([1, 2])
    673         self.assertEqual(o1, o2)
    674         self.assertEqual(o1.x.l, [1, 2])
    675         self.assertIs(o1.x, o2.x)
    676 
    677     def test_no_options(self):
    678         # Call with dataclass().
    679         @dataclass()
    680         class C:
    681             x: int
    682 
    683         self.assertEqual(C(42).x, 42)
    684 
    685     def test_not_tuple(self):
    686         # Make sure we can't be compared to a tuple.
    687         @dataclass
    688         class Point:
    689             x: int
    690             y: int
    691         self.assertNotEqual(Point(1, 2), (1, 2))
    692 
    693         # And that we can't compare to another unrelated dataclass.
    694         @dataclass
    695         class C:
    696             x: int
    697             y: int
    698         self.assertNotEqual(Point(1, 3), C(1, 3))
    699 
    700     def test_not_tuple(self):
    701         # Test that some of the problems with namedtuple don't happen
    702         #  here.
    703         @dataclass
    704         class Point3D:
    705             x: int
    706             y: int
    707             z: int
    708 
    709         @dataclass
    710         class Date:
    711             year: int
    712             month: int
    713             day: int
    714 
    715         self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
    716         self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
    717 
    718         # Make sure we can't unpack.
    719         with self.assertRaisesRegex(TypeError, 'unpack'):
    720             x, y, z = Point3D(4, 5, 6)
    721 
    722         # Make sure another class with the same field names isn't
    723         #  equal.
    724         @dataclass
    725         class Point3Dv1:
    726             x: int = 0
    727             y: int = 0
    728             z: int = 0
    729         self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
    730 
    731     def test_function_annotations(self):
    732         # Some dummy class and instance to use as a default.
    733         class F:
    734             pass
    735         f = F()
    736 
    737         def validate_class(cls):
    738             # First, check __annotations__, even though they're not
    739             #  function annotations.
    740             self.assertEqual(cls.__annotations__['i'], int)
    741             self.assertEqual(cls.__annotations__['j'], str)
    742             self.assertEqual(cls.__annotations__['k'], F)
    743             self.assertEqual(cls.__annotations__['l'], float)
    744             self.assertEqual(cls.__annotations__['z'], complex)
    745 
    746             # Verify __init__.
    747 
    748             signature = inspect.signature(cls.__init__)
    749             # Check the return type, should be None.
    750             self.assertIs(signature.return_annotation, None)
    751 
    752             # Check each parameter.
    753             params = iter(signature.parameters.values())
    754             param = next(params)
    755             # This is testing an internal name, and probably shouldn't be tested.
    756             self.assertEqual(param.name, 'self')
    757             param = next(params)
    758             self.assertEqual(param.name, 'i')
    759             self.assertIs   (param.annotation, int)
    760             self.assertEqual(param.default, inspect.Parameter.empty)
    761             self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    762             param = next(params)
    763             self.assertEqual(param.name, 'j')
    764             self.assertIs   (param.annotation, str)
    765             self.assertEqual(param.default, inspect.Parameter.empty)
    766             self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    767             param = next(params)
    768             self.assertEqual(param.name, 'k')
    769             self.assertIs   (param.annotation, F)
    770             # Don't test for the default, since it's set to MISSING.
    771             self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    772             param = next(params)
    773             self.assertEqual(param.name, 'l')
    774             self.assertIs   (param.annotation, float)
    775             # Don't test for the default, since it's set to MISSING.
    776             self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
    777             self.assertRaises(StopIteration, next, params)
    778 
    779 
    780         @dataclass
    781         class C:
    782             i: int
    783             j: str
    784             k: F = f
    785             l: float=field(default=None)
    786             z: complex=field(default=3+4j, init=False)
    787 
    788         validate_class(C)
    789 
    790         # Now repeat with __hash__.
    791         @dataclass(frozen=True, unsafe_hash=True)
    792         class C:
    793             i: int
    794             j: str
    795             k: F = f
    796             l: float=field(default=None)
    797             z: complex=field(default=3+4j, init=False)
    798 
    799         validate_class(C)
    800 
    801     def test_missing_default(self):
    802         # Test that MISSING works the same as a default not being
    803         #  specified.
    804         @dataclass
    805         class C:
    806             x: int=field(default=MISSING)
    807         with self.assertRaisesRegex(TypeError,
    808                                     r'__init__\(\) missing 1 required '
    809                                     'positional argument'):
    810             C()
    811         self.assertNotIn('x', C.__dict__)
    812 
    813         @dataclass
    814         class D:
    815             x: int
    816         with self.assertRaisesRegex(TypeError,
    817                                     r'__init__\(\) missing 1 required '
    818                                     'positional argument'):
    819             D()
    820         self.assertNotIn('x', D.__dict__)
    821 
    822     def test_missing_default_factory(self):
    823         # Test that MISSING works the same as a default factory not
    824         #  being specified (which is really the same as a default not
    825         #  being specified, too).
    826         @dataclass
    827         class C:
    828             x: int=field(default_factory=MISSING)
    829         with self.assertRaisesRegex(TypeError,
    830                                     r'__init__\(\) missing 1 required '
    831                                     'positional argument'):
    832             C()
    833         self.assertNotIn('x', C.__dict__)
    834 
    835         @dataclass
    836         class D:
    837             x: int=field(default=MISSING, default_factory=MISSING)
    838         with self.assertRaisesRegex(TypeError,
    839                                     r'__init__\(\) missing 1 required '
    840                                     'positional argument'):
    841             D()
    842         self.assertNotIn('x', D.__dict__)
    843 
    844     def test_missing_repr(self):
    845         self.assertIn('MISSING_TYPE object', repr(MISSING))
    846 
    847     def test_dont_include_other_annotations(self):
    848         @dataclass
    849         class C:
    850             i: int
    851             def foo(self) -> int:
    852                 return 4
    853             @property
    854             def bar(self) -> int:
    855                 return 5
    856         self.assertEqual(list(C.__annotations__), ['i'])
    857         self.assertEqual(C(10).foo(), 4)
    858         self.assertEqual(C(10).bar, 5)
    859         self.assertEqual(C(10).i, 10)
    860 
    861     def test_post_init(self):
    862         # Just make sure it gets called
    863         @dataclass
    864         class C:
    865             def __post_init__(self):
    866                 raise CustomError()
    867         with self.assertRaises(CustomError):
    868             C()
    869 
    870         @dataclass
    871         class C:
    872             i: int = 10
    873             def __post_init__(self):
    874                 if self.i == 10:
    875                     raise CustomError()
    876         with self.assertRaises(CustomError):
    877             C()
    878         # post-init gets called, but doesn't raise. This is just
    879         #  checking that self is used correctly.
    880         C(5)
    881 
    882         # If there's not an __init__, then post-init won't get called.
    883         @dataclass(init=False)
    884         class C:
    885             def __post_init__(self):
    886                 raise CustomError()
    887         # Creating the class won't raise
    888         C()
    889 
    890         @dataclass
    891         class C:
    892             x: int = 0
    893             def __post_init__(self):
    894                 self.x *= 2
    895         self.assertEqual(C().x, 0)
    896         self.assertEqual(C(2).x, 4)
    897 
    898         # Make sure that if we're frozen, post-init can't set
    899         #  attributes.
    900         @dataclass(frozen=True)
    901         class C:
    902             x: int = 0
    903             def __post_init__(self):
    904                 self.x *= 2
    905         with self.assertRaises(FrozenInstanceError):
    906             C()
    907 
    908     def test_post_init_super(self):
    909         # Make sure super() post-init isn't called by default.
    910         class B:
    911             def __post_init__(self):
    912                 raise CustomError()
    913 
    914         @dataclass
    915         class C(B):
    916             def __post_init__(self):
    917                 self.x = 5
    918 
    919         self.assertEqual(C().x, 5)
    920 
    921         # Now call super(), and it will raise.
    922         @dataclass
    923         class C(B):
    924             def __post_init__(self):
    925                 super().__post_init__()
    926 
    927         with self.assertRaises(CustomError):
    928             C()
    929 
    930         # Make sure post-init is called, even if not defined in our
    931         #  class.
    932         @dataclass
    933         class C(B):
    934             pass
    935 
    936         with self.assertRaises(CustomError):
    937             C()
    938 
    939     def test_post_init_staticmethod(self):
    940         flag = False
    941         @dataclass
    942         class C:
    943             x: int
    944             y: int
    945             @staticmethod
    946             def __post_init__():
    947                 nonlocal flag
    948                 flag = True
    949 
    950         self.assertFalse(flag)
    951         c = C(3, 4)
    952         self.assertEqual((c.x, c.y), (3, 4))
    953         self.assertTrue(flag)
    954 
    955     def test_post_init_classmethod(self):
    956         @dataclass
    957         class C:
    958             flag = False
    959             x: int
    960             y: int
    961             @classmethod
    962             def __post_init__(cls):
    963                 cls.flag = True
    964 
    965         self.assertFalse(C.flag)
    966         c = C(3, 4)
    967         self.assertEqual((c.x, c.y), (3, 4))
    968         self.assertTrue(C.flag)
    969 
    970     def test_class_var(self):
    971         # Make sure ClassVars are ignored in __init__, __repr__, etc.
    972         @dataclass
    973         class C:
    974             x: int
    975             y: int = 10
    976             z: ClassVar[int] = 1000
    977             w: ClassVar[int] = 2000
    978             t: ClassVar[int] = 3000
    979             s: ClassVar      = 4000
    980 
    981         c = C(5)
    982         self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
    983         self.assertEqual(len(fields(C)), 2)                 # We have 2 fields.
    984         self.assertEqual(len(C.__annotations__), 6)         # And 4 ClassVars.
    985         self.assertEqual(c.z, 1000)
    986         self.assertEqual(c.w, 2000)
    987         self.assertEqual(c.t, 3000)
    988         self.assertEqual(c.s, 4000)
    989         C.z += 1
    990         self.assertEqual(c.z, 1001)
    991         c = C(20)
    992         self.assertEqual((c.x, c.y), (20, 10))
    993         self.assertEqual(c.z, 1001)
    994         self.assertEqual(c.w, 2000)
    995         self.assertEqual(c.t, 3000)
    996         self.assertEqual(c.s, 4000)
    997 
    998     def test_class_var_no_default(self):
    999         # If a ClassVar has no default value, it should not be set on the class.
   1000         @dataclass
   1001         class C:
   1002             x: ClassVar[int]
   1003 
   1004         self.assertNotIn('x', C.__dict__)
   1005 
   1006     def test_class_var_default_factory(self):
   1007         # It makes no sense for a ClassVar to have a default factory. When
   1008         #  would it be called? Call it yourself, since it's class-wide.
   1009         with self.assertRaisesRegex(TypeError,
   1010                                     'cannot have a default factory'):
   1011             @dataclass
   1012             class C:
   1013                 x: ClassVar[int] = field(default_factory=int)
   1014 
   1015             self.assertNotIn('x', C.__dict__)
   1016 
   1017     def test_class_var_with_default(self):
   1018         # If a ClassVar has a default value, it should be set on the class.
   1019         @dataclass
   1020         class C:
   1021             x: ClassVar[int] = 10
   1022         self.assertEqual(C.x, 10)
   1023 
   1024         @dataclass
   1025         class C:
   1026             x: ClassVar[int] = field(default=10)
   1027         self.assertEqual(C.x, 10)
   1028 
   1029     def test_class_var_frozen(self):
   1030         # Make sure ClassVars work even if we're frozen.
   1031         @dataclass(frozen=True)
   1032         class C:
   1033             x: int
   1034             y: int = 10
   1035             z: ClassVar[int] = 1000
   1036             w: ClassVar[int] = 2000
   1037             t: ClassVar[int] = 3000
   1038 
   1039         c = C(5)
   1040         self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
   1041         self.assertEqual(len(fields(C)), 2)                 # We have 2 fields
   1042         self.assertEqual(len(C.__annotations__), 5)         # And 3 ClassVars
   1043         self.assertEqual(c.z, 1000)
   1044         self.assertEqual(c.w, 2000)
   1045         self.assertEqual(c.t, 3000)
   1046         # We can still modify the ClassVar, it's only instances that are
   1047         #  frozen.
   1048         C.z += 1
   1049         self.assertEqual(c.z, 1001)
   1050         c = C(20)
   1051         self.assertEqual((c.x, c.y), (20, 10))
   1052         self.assertEqual(c.z, 1001)
   1053         self.assertEqual(c.w, 2000)
   1054         self.assertEqual(c.t, 3000)
   1055 
   1056     def test_init_var_no_default(self):
   1057         # If an InitVar has no default value, it should not be set on the class.
   1058         @dataclass
   1059         class C:
   1060             x: InitVar[int]
   1061 
   1062         self.assertNotIn('x', C.__dict__)
   1063 
   1064     def test_init_var_default_factory(self):
   1065         # It makes no sense for an InitVar to have a default factory. When
   1066         #  would it be called? Call it yourself, since it's class-wide.
   1067         with self.assertRaisesRegex(TypeError,
   1068                                     'cannot have a default factory'):
   1069             @dataclass
   1070             class C:
   1071                 x: InitVar[int] = field(default_factory=int)
   1072 
   1073             self.assertNotIn('x', C.__dict__)
   1074 
   1075     def test_init_var_with_default(self):
   1076         # If an InitVar has a default value, it should be set on the class.
   1077         @dataclass
   1078         class C:
   1079             x: InitVar[int] = 10
   1080         self.assertEqual(C.x, 10)
   1081 
   1082         @dataclass
   1083         class C:
   1084             x: InitVar[int] = field(default=10)
   1085         self.assertEqual(C.x, 10)
   1086 
   1087     def test_init_var(self):
   1088         @dataclass
   1089         class C:
   1090             x: int = None
   1091             init_param: InitVar[int] = None
   1092 
   1093             def __post_init__(self, init_param):
   1094                 if self.x is None:
   1095                     self.x = init_param*2
   1096 
   1097         c = C(init_param=10)
   1098         self.assertEqual(c.x, 20)
   1099 
   1100     def test_init_var_inheritance(self):
   1101         # Note that this deliberately tests that a dataclass need not
   1102         #  have a __post_init__ function if it has an InitVar field.
   1103         #  It could just be used in a derived class, as shown here.
   1104         @dataclass
   1105         class Base:
   1106             x: int
   1107             init_base: InitVar[int]
   1108 
   1109         # We can instantiate by passing the InitVar, even though
   1110         #  it's not used.
   1111         b = Base(0, 10)
   1112         self.assertEqual(vars(b), {'x': 0})
   1113 
   1114         @dataclass
   1115         class C(Base):
   1116             y: int
   1117             init_derived: InitVar[int]
   1118 
   1119             def __post_init__(self, init_base, init_derived):
   1120                 self.x = self.x + init_base
   1121                 self.y = self.y + init_derived
   1122 
   1123         c = C(10, 11, 50, 51)
   1124         self.assertEqual(vars(c), {'x': 21, 'y': 101})
   1125 
   1126     def test_default_factory(self):
   1127         # Test a factory that returns a new list.
   1128         @dataclass
   1129         class C:
   1130             x: int
   1131             y: list = field(default_factory=list)
   1132 
   1133         c0 = C(3)
   1134         c1 = C(3)
   1135         self.assertEqual(c0.x, 3)
   1136         self.assertEqual(c0.y, [])
   1137         self.assertEqual(c0, c1)
   1138         self.assertIsNot(c0.y, c1.y)
   1139         self.assertEqual(astuple(C(5, [1])), (5, [1]))
   1140 
   1141         # Test a factory that returns a shared list.
   1142         l = []
   1143         @dataclass
   1144         class C:
   1145             x: int
   1146             y: list = field(default_factory=lambda: l)
   1147 
   1148         c0 = C(3)
   1149         c1 = C(3)
   1150         self.assertEqual(c0.x, 3)
   1151         self.assertEqual(c0.y, [])
   1152         self.assertEqual(c0, c1)
   1153         self.assertIs(c0.y, c1.y)
   1154         self.assertEqual(astuple(C(5, [1])), (5, [1]))
   1155 
   1156         # Test various other field flags.
   1157         # repr
   1158         @dataclass
   1159         class C:
   1160             x: list = field(default_factory=list, repr=False)
   1161         self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
   1162         self.assertEqual(C().x, [])
   1163 
   1164         # hash
   1165         @dataclass(unsafe_hash=True)
   1166         class C:
   1167             x: list = field(default_factory=list, hash=False)
   1168         self.assertEqual(astuple(C()), ([],))
   1169         self.assertEqual(hash(C()), hash(()))
   1170 
   1171         # init (see also test_default_factory_with_no_init)
   1172         @dataclass
   1173         class C:
   1174             x: list = field(default_factory=list, init=False)
   1175         self.assertEqual(astuple(C()), ([],))
   1176 
   1177         # compare
   1178         @dataclass
   1179         class C:
   1180             x: list = field(default_factory=list, compare=False)
   1181         self.assertEqual(C(), C([1]))
   1182 
   1183     def test_default_factory_with_no_init(self):
   1184         # We need a factory with a side effect.
   1185         factory = Mock()
   1186 
   1187         @dataclass
   1188         class C:
   1189             x: list = field(default_factory=factory, init=False)
   1190 
   1191         # Make sure the default factory is called for each new instance.
   1192         C().x
   1193         self.assertEqual(factory.call_count, 1)
   1194         C().x
   1195         self.assertEqual(factory.call_count, 2)
   1196 
   1197     def test_default_factory_not_called_if_value_given(self):
   1198         # We need a factory that we can test if it's been called.
   1199         factory = Mock()
   1200 
   1201         @dataclass
   1202         class C:
   1203             x: int = field(default_factory=factory)
   1204 
   1205         # Make sure that if a field has a default factory function,
   1206         #  it's not called if a value is specified.
   1207         C().x
   1208         self.assertEqual(factory.call_count, 1)
   1209         self.assertEqual(C(10).x, 10)
   1210         self.assertEqual(factory.call_count, 1)
   1211         C().x
   1212         self.assertEqual(factory.call_count, 2)
   1213 
   1214     def test_default_factory_derived(self):
   1215         # See bpo-32896.
   1216         @dataclass
   1217         class Foo:
   1218             x: dict = field(default_factory=dict)
   1219 
   1220         @dataclass
   1221         class Bar(Foo):
   1222             y: int = 1
   1223 
   1224         self.assertEqual(Foo().x, {})
   1225         self.assertEqual(Bar().x, {})
   1226         self.assertEqual(Bar().y, 1)
   1227 
   1228         @dataclass
   1229         class Baz(Foo):
   1230             pass
   1231         self.assertEqual(Baz().x, {})
   1232 
   1233     def test_intermediate_non_dataclass(self):
   1234         # Test that an intermediate class that defines
   1235         #  annotations does not define fields.
   1236 
   1237         @dataclass
   1238         class A:
   1239             x: int
   1240 
   1241         class B(A):
   1242             y: int
   1243 
   1244         @dataclass
   1245         class C(B):
   1246             z: int
   1247 
   1248         c = C(1, 3)
   1249         self.assertEqual((c.x, c.z), (1, 3))
   1250 
   1251         # .y was not initialized.
   1252         with self.assertRaisesRegex(AttributeError,
   1253                                     'object has no attribute'):
   1254             c.y
   1255 
   1256         # And if we again derive a non-dataclass, no fields are added.
   1257         class D(C):
   1258             t: int
   1259         d = D(4, 5)
   1260         self.assertEqual((d.x, d.z), (4, 5))
   1261 
   1262     def test_classvar_default_factory(self):
   1263         # It's an error for a ClassVar to have a factory function.
   1264         with self.assertRaisesRegex(TypeError,
   1265                                     'cannot have a default factory'):
   1266             @dataclass
   1267             class C:
   1268                 x: ClassVar[int] = field(default_factory=int)
   1269 
   1270     def test_is_dataclass(self):
   1271         class NotDataClass:
   1272             pass
   1273 
   1274         self.assertFalse(is_dataclass(0))
   1275         self.assertFalse(is_dataclass(int))
   1276         self.assertFalse(is_dataclass(NotDataClass))
   1277         self.assertFalse(is_dataclass(NotDataClass()))
   1278 
   1279         @dataclass
   1280         class C:
   1281             x: int
   1282 
   1283         @dataclass
   1284         class D:
   1285             d: C
   1286             e: int
   1287 
   1288         c = C(10)
   1289         d = D(c, 4)
   1290 
   1291         self.assertTrue(is_dataclass(C))
   1292         self.assertTrue(is_dataclass(c))
   1293         self.assertFalse(is_dataclass(c.x))
   1294         self.assertTrue(is_dataclass(d.d))
   1295         self.assertFalse(is_dataclass(d.e))
   1296 
   1297     def test_helper_fields_with_class_instance(self):
   1298         # Check that we can call fields() on either a class or instance,
   1299         #  and get back the same thing.
   1300         @dataclass
   1301         class C:
   1302             x: int
   1303             y: float
   1304 
   1305         self.assertEqual(fields(C), fields(C(0, 0.0)))
   1306 
   1307     def test_helper_fields_exception(self):
   1308         # Check that TypeError is raised if not passed a dataclass or
   1309         #  instance.
   1310         with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
   1311             fields(0)
   1312 
   1313         class C: pass
   1314         with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
   1315             fields(C)
   1316         with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
   1317             fields(C())
   1318 
   1319     def test_helper_asdict(self):
   1320         # Basic tests for asdict(), it should return a new dictionary.
   1321         @dataclass
   1322         class C:
   1323             x: int
   1324             y: int
   1325         c = C(1, 2)
   1326 
   1327         self.assertEqual(asdict(c), {'x': 1, 'y': 2})
   1328         self.assertEqual(asdict(c), asdict(c))
   1329         self.assertIsNot(asdict(c), asdict(c))
   1330         c.x = 42
   1331         self.assertEqual(asdict(c), {'x': 42, 'y': 2})
   1332         self.assertIs(type(asdict(c)), dict)
   1333 
   1334     def test_helper_asdict_raises_on_classes(self):
   1335         # asdict() should raise on a class object.
   1336         @dataclass
   1337         class C:
   1338             x: int
   1339             y: int
   1340         with self.assertRaisesRegex(TypeError, 'dataclass instance'):
   1341             asdict(C)
   1342         with self.assertRaisesRegex(TypeError, 'dataclass instance'):
   1343             asdict(int)
   1344 
   1345     def test_helper_asdict_copy_values(self):
   1346         @dataclass
   1347         class C:
   1348             x: int
   1349             y: List[int] = field(default_factory=list)
   1350         initial = []
   1351         c = C(1, initial)
   1352         d = asdict(c)
   1353         self.assertEqual(d['y'], initial)
   1354         self.assertIsNot(d['y'], initial)
   1355         c = C(1)
   1356         d = asdict(c)
   1357         d['y'].append(1)
   1358         self.assertEqual(c.y, [])
   1359 
   1360     def test_helper_asdict_nested(self):
   1361         @dataclass
   1362         class UserId:
   1363             token: int
   1364             group: int
   1365         @dataclass
   1366         class User:
   1367             name: str
   1368             id: UserId
   1369         u = User('Joe', UserId(123, 1))
   1370         d = asdict(u)
   1371         self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
   1372         self.assertIsNot(asdict(u), asdict(u))
   1373         u.id.group = 2
   1374         self.assertEqual(asdict(u), {'name': 'Joe',
   1375                                      'id': {'token': 123, 'group': 2}})
   1376 
   1377     def test_helper_asdict_builtin_containers(self):
   1378         @dataclass
   1379         class User:
   1380             name: str
   1381             id: int
   1382         @dataclass
   1383         class GroupList:
   1384             id: int
   1385             users: List[User]
   1386         @dataclass
   1387         class GroupTuple:
   1388             id: int
   1389             users: Tuple[User, ...]
   1390         @dataclass
   1391         class GroupDict:
   1392             id: int
   1393             users: Dict[str, User]
   1394         a = User('Alice', 1)
   1395         b = User('Bob', 2)
   1396         gl = GroupList(0, [a, b])
   1397         gt = GroupTuple(0, (a, b))
   1398         gd = GroupDict(0, {'first': a, 'second': b})
   1399         self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
   1400                                                          {'name': 'Bob', 'id': 2}]})
   1401         self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
   1402                                                          {'name': 'Bob', 'id': 2})})
   1403         self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
   1404                                                          'second': {'name': 'Bob', 'id': 2}}})
   1405 
   1406     def test_helper_asdict_builtin_containers(self):
   1407         @dataclass
   1408         class Child:
   1409             d: object
   1410 
   1411         @dataclass
   1412         class Parent:
   1413             child: Child
   1414 
   1415         self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
   1416         self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
   1417 
   1418     def test_helper_asdict_factory(self):
   1419         @dataclass
   1420         class C:
   1421             x: int
   1422             y: int
   1423         c = C(1, 2)
   1424         d = asdict(c, dict_factory=OrderedDict)
   1425         self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
   1426         self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
   1427         c.x = 42
   1428         d = asdict(c, dict_factory=OrderedDict)
   1429         self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
   1430         self.assertIs(type(d), OrderedDict)
   1431 
   1432     def test_helper_asdict_namedtuple(self):
   1433         T = namedtuple('T', 'a b c')
   1434         @dataclass
   1435         class C:
   1436             x: str
   1437             y: T
   1438         c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
   1439 
   1440         d = asdict(c)
   1441         self.assertEqual(d, {'x': 'outer',
   1442                              'y': T(1,
   1443                                     {'x': 'inner',
   1444                                      'y': T(11, 12, 13)},
   1445                                     2),
   1446                              }
   1447                          )
   1448 
   1449         # Now with a dict_factory.  OrderedDict is convenient, but
   1450         # since it compares to dicts, we also need to have separate
   1451         # assertIs tests.
   1452         d = asdict(c, dict_factory=OrderedDict)
   1453         self.assertEqual(d, {'x': 'outer',
   1454                              'y': T(1,
   1455                                     {'x': 'inner',
   1456                                      'y': T(11, 12, 13)},
   1457                                     2),
   1458                              }
   1459                          )
   1460 
   1461         # Make sure that the returned dicts are actuall OrderedDicts.
   1462         self.assertIs(type(d), OrderedDict)
   1463         self.assertIs(type(d['y'][1]), OrderedDict)
   1464 
   1465     def test_helper_asdict_namedtuple_key(self):
   1466         # Ensure that a field that contains a dict which has a
   1467         # namedtuple as a key works with asdict().
   1468 
   1469         @dataclass
   1470         class C:
   1471             f: dict
   1472         T = namedtuple('T', 'a')
   1473 
   1474         c = C({T('an a'): 0})
   1475 
   1476         self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
   1477 
   1478     def test_helper_asdict_namedtuple_derived(self):
   1479         class T(namedtuple('Tbase', 'a')):
   1480             def my_a(self):
   1481                 return self.a
   1482 
   1483         @dataclass
   1484         class C:
   1485             f: T
   1486 
   1487         t = T(6)
   1488         c = C(t)
   1489 
   1490         d = asdict(c)
   1491         self.assertEqual(d, {'f': T(a=6)})
   1492         # Make sure that t has been copied, not used directly.
   1493         self.assertIsNot(d['f'], t)
   1494         self.assertEqual(d['f'].my_a(), 6)
   1495 
   1496     def test_helper_astuple(self):
   1497         # Basic tests for astuple(), it should return a new tuple.
   1498         @dataclass
   1499         class C:
   1500             x: int
   1501             y: int = 0
   1502         c = C(1)
   1503 
   1504         self.assertEqual(astuple(c), (1, 0))
   1505         self.assertEqual(astuple(c), astuple(c))
   1506         self.assertIsNot(astuple(c), astuple(c))
   1507         c.y = 42
   1508         self.assertEqual(astuple(c), (1, 42))
   1509         self.assertIs(type(astuple(c)), tuple)
   1510 
   1511     def test_helper_astuple_raises_on_classes(self):
   1512         # astuple() should raise on a class object.
   1513         @dataclass
   1514         class C:
   1515             x: int
   1516             y: int
   1517         with self.assertRaisesRegex(TypeError, 'dataclass instance'):
   1518             astuple(C)
   1519         with self.assertRaisesRegex(TypeError, 'dataclass instance'):
   1520             astuple(int)
   1521 
   1522     def test_helper_astuple_copy_values(self):
   1523         @dataclass
   1524         class C:
   1525             x: int
   1526             y: List[int] = field(default_factory=list)
   1527         initial = []
   1528         c = C(1, initial)
   1529         t = astuple(c)
   1530         self.assertEqual(t[1], initial)
   1531         self.assertIsNot(t[1], initial)
   1532         c = C(1)
   1533         t = astuple(c)
   1534         t[1].append(1)
   1535         self.assertEqual(c.y, [])
   1536 
   1537     def test_helper_astuple_nested(self):
   1538         @dataclass
   1539         class UserId:
   1540             token: int
   1541             group: int
   1542         @dataclass
   1543         class User:
   1544             name: str
   1545             id: UserId
   1546         u = User('Joe', UserId(123, 1))
   1547         t = astuple(u)
   1548         self.assertEqual(t, ('Joe', (123, 1)))
   1549         self.assertIsNot(astuple(u), astuple(u))
   1550         u.id.group = 2
   1551         self.assertEqual(astuple(u), ('Joe', (123, 2)))
   1552 
   1553     def test_helper_astuple_builtin_containers(self):
   1554         @dataclass
   1555         class User:
   1556             name: str
   1557             id: int
   1558         @dataclass
   1559         class GroupList:
   1560             id: int
   1561             users: List[User]
   1562         @dataclass
   1563         class GroupTuple:
   1564             id: int
   1565             users: Tuple[User, ...]
   1566         @dataclass
   1567         class GroupDict:
   1568             id: int
   1569             users: Dict[str, User]
   1570         a = User('Alice', 1)
   1571         b = User('Bob', 2)
   1572         gl = GroupList(0, [a, b])
   1573         gt = GroupTuple(0, (a, b))
   1574         gd = GroupDict(0, {'first': a, 'second': b})
   1575         self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
   1576         self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
   1577         self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
   1578 
   1579     def test_helper_astuple_builtin_containers(self):
   1580         @dataclass
   1581         class Child:
   1582             d: object
   1583 
   1584         @dataclass
   1585         class Parent:
   1586             child: Child
   1587 
   1588         self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
   1589         self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
   1590 
   1591     def test_helper_astuple_factory(self):
   1592         @dataclass
   1593         class C:
   1594             x: int
   1595             y: int
   1596         NT = namedtuple('NT', 'x y')
   1597         def nt(lst):
   1598             return NT(*lst)
   1599         c = C(1, 2)
   1600         t = astuple(c, tuple_factory=nt)
   1601         self.assertEqual(t, NT(1, 2))
   1602         self.assertIsNot(t, astuple(c, tuple_factory=nt))
   1603         c.x = 42
   1604         t = astuple(c, tuple_factory=nt)
   1605         self.assertEqual(t, NT(42, 2))
   1606         self.assertIs(type(t), NT)
   1607 
   1608     def test_helper_astuple_namedtuple(self):
   1609         T = namedtuple('T', 'a b c')
   1610         @dataclass
   1611         class C:
   1612             x: str
   1613             y: T
   1614         c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
   1615 
   1616         t = astuple(c)
   1617         self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
   1618 
   1619         # Now, using a tuple_factory.  list is convenient here.
   1620         t = astuple(c, tuple_factory=list)
   1621         self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
   1622 
   1623     def test_dynamic_class_creation(self):
   1624         cls_dict = {'__annotations__': {'x': int, 'y': int},
   1625                     }
   1626 
   1627         # Create the class.
   1628         cls = type('C', (), cls_dict)
   1629 
   1630         # Make it a dataclass.
   1631         cls1 = dataclass(cls)
   1632 
   1633         self.assertEqual(cls1, cls)
   1634         self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
   1635 
   1636     def test_dynamic_class_creation_using_field(self):
   1637         cls_dict = {'__annotations__': {'x': int, 'y': int},
   1638                     'y': field(default=5),
   1639                     }
   1640 
   1641         # Create the class.
   1642         cls = type('C', (), cls_dict)
   1643 
   1644         # Make it a dataclass.
   1645         cls1 = dataclass(cls)
   1646 
   1647         self.assertEqual(cls1, cls)
   1648         self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
   1649 
   1650     def test_init_in_order(self):
   1651         @dataclass
   1652         class C:
   1653             a: int
   1654             b: int = field()
   1655             c: list = field(default_factory=list, init=False)
   1656             d: list = field(default_factory=list)
   1657             e: int = field(default=4, init=False)
   1658             f: int = 4
   1659 
   1660         calls = []
   1661         def setattr(self, name, value):
   1662             calls.append((name, value))
   1663 
   1664         C.__setattr__ = setattr
   1665         c = C(0, 1)
   1666         self.assertEqual(('a', 0), calls[0])
   1667         self.assertEqual(('b', 1), calls[1])
   1668         self.assertEqual(('c', []), calls[2])
   1669         self.assertEqual(('d', []), calls[3])
   1670         self.assertNotIn(('e', 4), calls)
   1671         self.assertEqual(('f', 4), calls[4])
   1672 
   1673     def test_items_in_dicts(self):
   1674         @dataclass
   1675         class C:
   1676             a: int
   1677             b: list = field(default_factory=list, init=False)
   1678             c: list = field(default_factory=list)
   1679             d: int = field(default=4, init=False)
   1680             e: int = 0
   1681 
   1682         c = C(0)
   1683         # Class dict
   1684         self.assertNotIn('a', C.__dict__)
   1685         self.assertNotIn('b', C.__dict__)
   1686         self.assertNotIn('c', C.__dict__)
   1687         self.assertIn('d', C.__dict__)
   1688         self.assertEqual(C.d, 4)
   1689         self.assertIn('e', C.__dict__)
   1690         self.assertEqual(C.e, 0)
   1691         # Instance dict
   1692         self.assertIn('a', c.__dict__)
   1693         self.assertEqual(c.a, 0)
   1694         self.assertIn('b', c.__dict__)
   1695         self.assertEqual(c.b, [])
   1696         self.assertIn('c', c.__dict__)
   1697         self.assertEqual(c.c, [])
   1698         self.assertNotIn('d', c.__dict__)
   1699         self.assertIn('e', c.__dict__)
   1700         self.assertEqual(c.e, 0)
   1701 
   1702     def test_alternate_classmethod_constructor(self):
   1703         # Since __post_init__ can't take params, use a classmethod
   1704         #  alternate constructor.  This is mostly an example to show
   1705         #  how to use this technique.
   1706         @dataclass
   1707         class C:
   1708             x: int
   1709             @classmethod
   1710             def from_file(cls, filename):
   1711                 # In a real example, create a new instance
   1712                 #  and populate 'x' from contents of a file.
   1713                 value_in_file = 20
   1714                 return cls(value_in_file)
   1715 
   1716         self.assertEqual(C.from_file('filename').x, 20)
   1717 
   1718     def test_field_metadata_default(self):
   1719         # Make sure the default metadata is read-only and of
   1720         #  zero length.
   1721         @dataclass
   1722         class C:
   1723             i: int
   1724 
   1725         self.assertFalse(fields(C)[0].metadata)
   1726         self.assertEqual(len(fields(C)[0].metadata), 0)
   1727         with self.assertRaisesRegex(TypeError,
   1728                                     'does not support item assignment'):
   1729             fields(C)[0].metadata['test'] = 3
   1730 
   1731     def test_field_metadata_mapping(self):
   1732         # Make sure only a mapping can be passed as metadata
   1733         #  zero length.
   1734         with self.assertRaises(TypeError):
   1735             @dataclass
   1736             class C:
   1737                 i: int = field(metadata=0)
   1738 
   1739         # Make sure an empty dict works.
   1740         d = {}
   1741         @dataclass
   1742         class C:
   1743             i: int = field(metadata=d)
   1744         self.assertFalse(fields(C)[0].metadata)
   1745         self.assertEqual(len(fields(C)[0].metadata), 0)
   1746         # Update should work (see bpo-35960).
   1747         d['foo'] = 1
   1748         self.assertEqual(len(fields(C)[0].metadata), 1)
   1749         self.assertEqual(fields(C)[0].metadata['foo'], 1)
   1750         with self.assertRaisesRegex(TypeError,
   1751                                     'does not support item assignment'):
   1752             fields(C)[0].metadata['test'] = 3
   1753 
   1754         # Make sure a non-empty dict works.
   1755         d = {'test': 10, 'bar': '42', 3: 'three'}
   1756         @dataclass
   1757         class C:
   1758             i: int = field(metadata=d)
   1759         self.assertEqual(len(fields(C)[0].metadata), 3)
   1760         self.assertEqual(fields(C)[0].metadata['test'], 10)
   1761         self.assertEqual(fields(C)[0].metadata['bar'], '42')
   1762         self.assertEqual(fields(C)[0].metadata[3], 'three')
   1763         # Update should work.
   1764         d['foo'] = 1
   1765         self.assertEqual(len(fields(C)[0].metadata), 4)
   1766         self.assertEqual(fields(C)[0].metadata['foo'], 1)
   1767         with self.assertRaises(KeyError):
   1768             # Non-existent key.
   1769             fields(C)[0].metadata['baz']
   1770         with self.assertRaisesRegex(TypeError,
   1771                                     'does not support item assignment'):
   1772             fields(C)[0].metadata['test'] = 3
   1773 
   1774     def test_field_metadata_custom_mapping(self):
   1775         # Try a custom mapping.
   1776         class SimpleNameSpace:
   1777             def __init__(self, **kw):
   1778                 self.__dict__.update(kw)
   1779 
   1780             def __getitem__(self, item):
   1781                 if item == 'xyzzy':
   1782                     return 'plugh'
   1783                 return getattr(self, item)
   1784 
   1785             def __len__(self):
   1786                 return self.__dict__.__len__()
   1787 
   1788         @dataclass
   1789         class C:
   1790             i: int = field(metadata=SimpleNameSpace(a=10))
   1791 
   1792         self.assertEqual(len(fields(C)[0].metadata), 1)
   1793         self.assertEqual(fields(C)[0].metadata['a'], 10)
   1794         with self.assertRaises(AttributeError):
   1795             fields(C)[0].metadata['b']
   1796         # Make sure we're still talking to our custom mapping.
   1797         self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
   1798 
   1799     def test_generic_dataclasses(self):
   1800         T = TypeVar('T')
   1801 
   1802         @dataclass
   1803         class LabeledBox(Generic[T]):
   1804             content: T
   1805             label: str = '<unknown>'
   1806 
   1807         box = LabeledBox(42)
   1808         self.assertEqual(box.content, 42)
   1809         self.assertEqual(box.label, '<unknown>')
   1810 
   1811         # Subscripting the resulting class should work, etc.
   1812         Alias = List[LabeledBox[int]]
   1813 
   1814     def test_generic_extending(self):
   1815         S = TypeVar('S')
   1816         T = TypeVar('T')
   1817 
   1818         @dataclass
   1819         class Base(Generic[T, S]):
   1820             x: T
   1821             y: S
   1822 
   1823         @dataclass
   1824         class DataDerived(Base[int, T]):
   1825             new_field: str
   1826         Alias = DataDerived[str]
   1827         c = Alias(0, 'test1', 'test2')
   1828         self.assertEqual(astuple(c), (0, 'test1', 'test2'))
   1829 
   1830         class NonDataDerived(Base[int, T]):
   1831             def new_method(self):
   1832                 return self.y
   1833         Alias = NonDataDerived[float]
   1834         c = Alias(10, 1.0)
   1835         self.assertEqual(c.new_method(), 1.0)
   1836 
   1837     def test_generic_dynamic(self):
   1838         T = TypeVar('T')
   1839 
   1840         @dataclass
   1841         class Parent(Generic[T]):
   1842             x: T
   1843         Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
   1844                                bases=(Parent[int], Generic[T]), namespace={'other': 42})
   1845         self.assertIs(Child[int](1, 2).z, None)
   1846         self.assertEqual(Child[int](1, 2, 3).z, 3)
   1847         self.assertEqual(Child[int](1, 2, 3).other, 42)
   1848         # Check that type aliases work correctly.
   1849         Alias = Child[T]
   1850         self.assertEqual(Alias[int](1, 2).x, 1)
   1851         # Check MRO resolution.
   1852         self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
   1853 
   1854     def test_dataclassses_pickleable(self):
   1855         global P, Q, R
   1856         @dataclass
   1857         class P:
   1858             x: int
   1859             y: int = 0
   1860         @dataclass
   1861         class Q:
   1862             x: int
   1863             y: int = field(default=0, init=False)
   1864         @dataclass
   1865         class R:
   1866             x: int
   1867             y: List[int] = field(default_factory=list)
   1868         q = Q(1)
   1869         q.y = 2
   1870         samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
   1871         for sample in samples:
   1872             for proto in range(pickle.HIGHEST_PROTOCOL + 1):
   1873                 with self.subTest(sample=sample, proto=proto):
   1874                     new_sample = pickle.loads(pickle.dumps(sample, proto))
   1875                     self.assertEqual(sample.x, new_sample.x)
   1876                     self.assertEqual(sample.y, new_sample.y)
   1877                     self.assertIsNot(sample, new_sample)
   1878                     new_sample.x = 42
   1879                     another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
   1880                     self.assertEqual(new_sample.x, another_new_sample.x)
   1881                     self.assertEqual(sample.y, another_new_sample.y)
   1882 
   1883 
   1884 class TestFieldNoAnnotation(unittest.TestCase):
   1885     def test_field_without_annotation(self):
   1886         with self.assertRaisesRegex(TypeError,
   1887                                     "'f' is a field but has no type annotation"):
   1888             @dataclass
   1889             class C:
   1890                 f = field()
   1891 
   1892     def test_field_without_annotation_but_annotation_in_base(self):
   1893         @dataclass
   1894         class B:
   1895             f: int
   1896 
   1897         with self.assertRaisesRegex(TypeError,
   1898                                     "'f' is a field but has no type annotation"):
   1899             # This is still an error: make sure we don't pick up the
   1900             #  type annotation in the base class.
   1901             @dataclass
   1902             class C(B):
   1903                 f = field()
   1904 
   1905     def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
   1906         # Same test, but with the base class not a dataclass.
   1907         class B:
   1908             f: int
   1909 
   1910         with self.assertRaisesRegex(TypeError,
   1911                                     "'f' is a field but has no type annotation"):
   1912             # This is still an error: make sure we don't pick up the
   1913             #  type annotation in the base class.
   1914             @dataclass
   1915             class C(B):
   1916                 f = field()
   1917 
   1918 
   1919 class TestDocString(unittest.TestCase):
   1920     def assertDocStrEqual(self, a, b):
   1921         # Because 3.6 and 3.7 differ in how inspect.signature work
   1922         #  (see bpo #32108), for the time being just compare them with
   1923         #  whitespace stripped.
   1924         self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
   1925 
   1926     def test_existing_docstring_not_overridden(self):
   1927         @dataclass
   1928         class C:
   1929             """Lorem ipsum"""
   1930             x: int
   1931 
   1932         self.assertEqual(C.__doc__, "Lorem ipsum")
   1933 
   1934     def test_docstring_no_fields(self):
   1935         @dataclass
   1936         class C:
   1937             pass
   1938 
   1939         self.assertDocStrEqual(C.__doc__, "C()")
   1940 
   1941     def test_docstring_one_field(self):
   1942         @dataclass
   1943         class C:
   1944             x: int
   1945 
   1946         self.assertDocStrEqual(C.__doc__, "C(x:int)")
   1947 
   1948     def test_docstring_two_fields(self):
   1949         @dataclass
   1950         class C:
   1951             x: int
   1952             y: int
   1953 
   1954         self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
   1955 
   1956     def test_docstring_three_fields(self):
   1957         @dataclass
   1958         class C:
   1959             x: int
   1960             y: int
   1961             z: str
   1962 
   1963         self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
   1964 
   1965     def test_docstring_one_field_with_default(self):
   1966         @dataclass
   1967         class C:
   1968             x: int = 3
   1969 
   1970         self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
   1971 
   1972     def test_docstring_one_field_with_default_none(self):
   1973         @dataclass
   1974         class C:
   1975             x: Union[int, type(None)] = None
   1976 
   1977         self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
   1978 
   1979     def test_docstring_list_field(self):
   1980         @dataclass
   1981         class C:
   1982             x: List[int]
   1983 
   1984         self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
   1985 
   1986     def test_docstring_list_field_with_default_factory(self):
   1987         @dataclass
   1988         class C:
   1989             x: List[int] = field(default_factory=list)
   1990 
   1991         self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
   1992 
   1993     def test_docstring_deque_field(self):
   1994         @dataclass
   1995         class C:
   1996             x: deque
   1997 
   1998         self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
   1999 
   2000     def test_docstring_deque_field_with_default_factory(self):
   2001         @dataclass
   2002         class C:
   2003             x: deque = field(default_factory=deque)
   2004 
   2005         self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
   2006 
   2007 
   2008 class TestInit(unittest.TestCase):
   2009     def test_base_has_init(self):
   2010         class B:
   2011             def __init__(self):
   2012                 self.z = 100
   2013                 pass
   2014 
   2015         # Make sure that declaring this class doesn't raise an error.
   2016         #  The issue is that we can't override __init__ in our class,
   2017         #  but it should be okay to add __init__ to us if our base has
   2018         #  an __init__.
   2019         @dataclass
   2020         class C(B):
   2021             x: int = 0
   2022         c = C(10)
   2023         self.assertEqual(c.x, 10)
   2024         self.assertNotIn('z', vars(c))
   2025 
   2026         # Make sure that if we don't add an init, the base __init__
   2027         #  gets called.
   2028         @dataclass(init=False)
   2029         class C(B):
   2030             x: int = 10
   2031         c = C()
   2032         self.assertEqual(c.x, 10)
   2033         self.assertEqual(c.z, 100)
   2034 
   2035     def test_no_init(self):
   2036         dataclass(init=False)
   2037         class C:
   2038             i: int = 0
   2039         self.assertEqual(C().i, 0)
   2040 
   2041         dataclass(init=False)
   2042         class C:
   2043             i: int = 2
   2044             def __init__(self):
   2045                 self.i = 3
   2046         self.assertEqual(C().i, 3)
   2047 
   2048     def test_overwriting_init(self):
   2049         # If the class has __init__, use it no matter the value of
   2050         #  init=.
   2051 
   2052         @dataclass
   2053         class C:
   2054             x: int
   2055             def __init__(self, x):
   2056                 self.x = 2 * x
   2057         self.assertEqual(C(3).x, 6)
   2058 
   2059         @dataclass(init=True)
   2060         class C:
   2061             x: int
   2062             def __init__(self, x):
   2063                 self.x = 2 * x
   2064         self.assertEqual(C(4).x, 8)
   2065 
   2066         @dataclass(init=False)
   2067         class C:
   2068             x: int
   2069             def __init__(self, x):
   2070                 self.x = 2 * x
   2071         self.assertEqual(C(5).x, 10)
   2072 
   2073 
   2074 class TestRepr(unittest.TestCase):
   2075     def test_repr(self):
   2076         @dataclass
   2077         class B:
   2078             x: int
   2079 
   2080         @dataclass
   2081         class C(B):
   2082             y: int = 10
   2083 
   2084         o = C(4)
   2085         self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
   2086 
   2087         @dataclass
   2088         class D(C):
   2089             x: int = 20
   2090         self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
   2091 
   2092         @dataclass
   2093         class C:
   2094             @dataclass
   2095             class D:
   2096                 i: int
   2097             @dataclass
   2098             class E:
   2099                 pass
   2100         self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
   2101         self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
   2102 
   2103     def test_no_repr(self):
   2104         # Test a class with no __repr__ and repr=False.
   2105         @dataclass(repr=False)
   2106         class C:
   2107             x: int
   2108         self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
   2109                       repr(C(3)))
   2110 
   2111         # Test a class with a __repr__ and repr=False.
   2112         @dataclass(repr=False)
   2113         class C:
   2114             x: int
   2115             def __repr__(self):
   2116                 return 'C-class'
   2117         self.assertEqual(repr(C(3)), 'C-class')
   2118 
   2119     def test_overwriting_repr(self):
   2120         # If the class has __repr__, use it no matter the value of
   2121         #  repr=.
   2122 
   2123         @dataclass
   2124         class C:
   2125             x: int
   2126             def __repr__(self):
   2127                 return 'x'
   2128         self.assertEqual(repr(C(0)), 'x')
   2129 
   2130         @dataclass(repr=True)
   2131         class C:
   2132             x: int
   2133             def __repr__(self):
   2134                 return 'x'
   2135         self.assertEqual(repr(C(0)), 'x')
   2136 
   2137         @dataclass(repr=False)
   2138         class C:
   2139             x: int
   2140             def __repr__(self):
   2141                 return 'x'
   2142         self.assertEqual(repr(C(0)), 'x')
   2143 
   2144 
   2145 class TestEq(unittest.TestCase):
   2146     def test_no_eq(self):
   2147         # Test a class with no __eq__ and eq=False.
   2148         @dataclass(eq=False)
   2149         class C:
   2150             x: int
   2151         self.assertNotEqual(C(0), C(0))
   2152         c = C(3)
   2153         self.assertEqual(c, c)
   2154 
   2155         # Test a class with an __eq__ and eq=False.
   2156         @dataclass(eq=False)
   2157         class C:
   2158             x: int
   2159             def __eq__(self, other):
   2160                 return other == 10
   2161         self.assertEqual(C(3), 10)
   2162 
   2163     def test_overwriting_eq(self):
   2164         # If the class has __eq__, use it no matter the value of
   2165         #  eq=.
   2166 
   2167         @dataclass
   2168         class C:
   2169             x: int
   2170             def __eq__(self, other):
   2171                 return other == 3
   2172         self.assertEqual(C(1), 3)
   2173         self.assertNotEqual(C(1), 1)
   2174 
   2175         @dataclass(eq=True)
   2176         class C:
   2177             x: int
   2178             def __eq__(self, other):
   2179                 return other == 4
   2180         self.assertEqual(C(1), 4)
   2181         self.assertNotEqual(C(1), 1)
   2182 
   2183         @dataclass(eq=False)
   2184         class C:
   2185             x: int
   2186             def __eq__(self, other):
   2187                 return other == 5
   2188         self.assertEqual(C(1), 5)
   2189         self.assertNotEqual(C(1), 1)
   2190 
   2191 
   2192 class TestOrdering(unittest.TestCase):
   2193     def test_functools_total_ordering(self):
   2194         # Test that functools.total_ordering works with this class.
   2195         @total_ordering
   2196         @dataclass
   2197         class C:
   2198             x: int
   2199             def __lt__(self, other):
   2200                 # Perform the test "backward", just to make
   2201                 #  sure this is being called.
   2202                 return self.x >= other
   2203 
   2204         self.assertLess(C(0), -1)
   2205         self.assertLessEqual(C(0), -1)
   2206         self.assertGreater(C(0), 1)
   2207         self.assertGreaterEqual(C(0), 1)
   2208 
   2209     def test_no_order(self):
   2210         # Test that no ordering functions are added by default.
   2211         @dataclass(order=False)
   2212         class C:
   2213             x: int
   2214         # Make sure no order methods are added.
   2215         self.assertNotIn('__le__', C.__dict__)
   2216         self.assertNotIn('__lt__', C.__dict__)
   2217         self.assertNotIn('__ge__', C.__dict__)
   2218         self.assertNotIn('__gt__', C.__dict__)
   2219 
   2220         # Test that __lt__ is still called
   2221         @dataclass(order=False)
   2222         class C:
   2223             x: int
   2224             def __lt__(self, other):
   2225                 return False
   2226         # Make sure other methods aren't added.
   2227         self.assertNotIn('__le__', C.__dict__)
   2228         self.assertNotIn('__ge__', C.__dict__)
   2229         self.assertNotIn('__gt__', C.__dict__)
   2230 
   2231     def test_overwriting_order(self):
   2232         with self.assertRaisesRegex(TypeError,
   2233                                     'Cannot overwrite attribute __lt__'
   2234                                     '.*using functools.total_ordering'):
   2235             @dataclass(order=True)
   2236             class C:
   2237                 x: int
   2238                 def __lt__(self):
   2239                     pass
   2240 
   2241         with self.assertRaisesRegex(TypeError,
   2242                                     'Cannot overwrite attribute __le__'
   2243                                     '.*using functools.total_ordering'):
   2244             @dataclass(order=True)
   2245             class C:
   2246                 x: int
   2247                 def __le__(self):
   2248                     pass
   2249 
   2250         with self.assertRaisesRegex(TypeError,
   2251                                     'Cannot overwrite attribute __gt__'
   2252                                     '.*using functools.total_ordering'):
   2253             @dataclass(order=True)
   2254             class C:
   2255                 x: int
   2256                 def __gt__(self):
   2257                     pass
   2258 
   2259         with self.assertRaisesRegex(TypeError,
   2260                                     'Cannot overwrite attribute __ge__'
   2261                                     '.*using functools.total_ordering'):
   2262             @dataclass(order=True)
   2263             class C:
   2264                 x: int
   2265                 def __ge__(self):
   2266                     pass
   2267 
   2268 class TestHash(unittest.TestCase):
   2269     def test_unsafe_hash(self):
   2270         @dataclass(unsafe_hash=True)
   2271         class C:
   2272             x: int
   2273             y: str
   2274         self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
   2275 
   2276     def test_hash_rules(self):
   2277         def non_bool(value):
   2278             # Map to something else that's True, but not a bool.
   2279             if value is None:
   2280                 return None
   2281             if value:
   2282                 return (3,)
   2283             return 0
   2284 
   2285         def test(case, unsafe_hash, eq, frozen, with_hash, result):
   2286             with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
   2287                               frozen=frozen):
   2288                 if result != 'exception':
   2289                     if with_hash:
   2290                         @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
   2291                         class C:
   2292                             def __hash__(self):
   2293                                 return 0
   2294                     else:
   2295                         @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
   2296                         class C:
   2297                             pass
   2298 
   2299                 # See if the result matches what's expected.
   2300                 if result == 'fn':
   2301                     # __hash__ contains the function we generated.
   2302                     self.assertIn('__hash__', C.__dict__)
   2303                     self.assertIsNotNone(C.__dict__['__hash__'])
   2304 
   2305                 elif result == '':
   2306                     # __hash__ is not present in our class.
   2307                     if not with_hash:
   2308                         self.assertNotIn('__hash__', C.__dict__)
   2309 
   2310                 elif result == 'none':
   2311                     # __hash__ is set to None.
   2312                     self.assertIn('__hash__', C.__dict__)
   2313                     self.assertIsNone(C.__dict__['__hash__'])
   2314 
   2315                 elif result == 'exception':
   2316                     # Creating the class should cause an exception.
   2317                     #  This only happens with with_hash==True.
   2318                     assert(with_hash)
   2319                     with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
   2320                         @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
   2321                         class C:
   2322                             def __hash__(self):
   2323                                 return 0
   2324 
   2325                 else:
   2326                     assert False, f'unknown result {result!r}'
   2327 
   2328         # There are 8 cases of:
   2329         #  unsafe_hash=True/False
   2330         #  eq=True/False
   2331         #  frozen=True/False
   2332         # And for each of these, a different result if
   2333         #  __hash__ is defined or not.
   2334         for case, (unsafe_hash,  eq,    frozen, res_no_defined_hash, res_defined_hash) in enumerate([
   2335                   (False,        False, False,  '',                  ''),
   2336                   (False,        False, True,   '',                  ''),
   2337                   (False,        True,  False,  'none',              ''),
   2338                   (False,        True,  True,   'fn',                ''),
   2339                   (True,         False, False,  'fn',                'exception'),
   2340                   (True,         False, True,   'fn',                'exception'),
   2341                   (True,         True,  False,  'fn',                'exception'),
   2342                   (True,         True,  True,   'fn',                'exception'),
   2343                   ], 1):
   2344             test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
   2345             test(case, unsafe_hash, eq, frozen, True,  res_defined_hash)
   2346 
   2347             # Test non-bool truth values, too.  This is just to
   2348             #  make sure the data-driven table in the decorator
   2349             #  handles non-bool values.
   2350             test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
   2351             test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True,  res_defined_hash)
   2352 
   2353 
   2354     def test_eq_only(self):
   2355         # If a class defines __eq__, __hash__ is automatically added
   2356         #  and set to None.  This is normal Python behavior, not
   2357         #  related to dataclasses.  Make sure we don't interfere with
   2358         #  that (see bpo=32546).
   2359 
   2360         @dataclass
   2361         class C:
   2362             i: int
   2363             def __eq__(self, other):
   2364                 return self.i == other.i
   2365         self.assertEqual(C(1), C(1))
   2366         self.assertNotEqual(C(1), C(4))
   2367 
   2368         # And make sure things work in this case if we specify
   2369         #  unsafe_hash=True.
   2370         @dataclass(unsafe_hash=True)
   2371         class C:
   2372             i: int
   2373             def __eq__(self, other):
   2374                 return self.i == other.i
   2375         self.assertEqual(C(1), C(1.0))
   2376         self.assertEqual(hash(C(1)), hash(C(1.0)))
   2377 
   2378         # And check that the classes __eq__ is being used, despite
   2379         #  specifying eq=True.
   2380         @dataclass(unsafe_hash=True, eq=True)
   2381         class C:
   2382             i: int
   2383             def __eq__(self, other):
   2384                 return self.i == 3 and self.i == other.i
   2385         self.assertEqual(C(3), C(3))
   2386         self.assertNotEqual(C(1), C(1))
   2387         self.assertEqual(hash(C(1)), hash(C(1.0)))
   2388 
   2389     def test_0_field_hash(self):
   2390         @dataclass(frozen=True)
   2391         class C:
   2392             pass
   2393         self.assertEqual(hash(C()), hash(()))
   2394 
   2395         @dataclass(unsafe_hash=True)
   2396         class C:
   2397             pass
   2398         self.assertEqual(hash(C()), hash(()))
   2399 
   2400     def test_1_field_hash(self):
   2401         @dataclass(frozen=True)
   2402         class C:
   2403             x: int
   2404         self.assertEqual(hash(C(4)), hash((4,)))
   2405         self.assertEqual(hash(C(42)), hash((42,)))
   2406 
   2407         @dataclass(unsafe_hash=True)
   2408         class C:
   2409             x: int
   2410         self.assertEqual(hash(C(4)), hash((4,)))
   2411         self.assertEqual(hash(C(42)), hash((42,)))
   2412 
   2413     def test_hash_no_args(self):
   2414         # Test dataclasses with no hash= argument.  This exists to
   2415         #  make sure that if the @dataclass parameter name is changed
   2416         #  or the non-default hashing behavior changes, the default
   2417         #  hashability keeps working the same way.
   2418 
   2419         class Base:
   2420             def __hash__(self):
   2421                 return 301
   2422 
   2423         # If frozen or eq is None, then use the default value (do not
   2424         #  specify any value in the decorator).
   2425         for frozen, eq,    base,   expected       in [
   2426             (None,  None,  object, 'unhashable'),
   2427             (None,  None,  Base,   'unhashable'),
   2428             (None,  False, object, 'object'),
   2429             (None,  False, Base,   'base'),
   2430             (None,  True,  object, 'unhashable'),
   2431             (None,  True,  Base,   'unhashable'),
   2432             (False, None,  object, 'unhashable'),
   2433             (False, None,  Base,   'unhashable'),
   2434             (False, False, object, 'object'),
   2435             (False, False, Base,   'base'),
   2436             (False, True,  object, 'unhashable'),
   2437             (False, True,  Base,   'unhashable'),
   2438             (True,  None,  object, 'tuple'),
   2439             (True,  None,  Base,   'tuple'),
   2440             (True,  False, object, 'object'),
   2441             (True,  False, Base,   'base'),
   2442             (True,  True,  object, 'tuple'),
   2443             (True,  True,  Base,   'tuple'),
   2444             ]:
   2445 
   2446             with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
   2447                 # First, create the class.
   2448                 if frozen is None and eq is None:
   2449                     @dataclass
   2450                     class C(base):
   2451                         i: int
   2452                 elif frozen is None:
   2453                     @dataclass(eq=eq)
   2454                     class C(base):
   2455                         i: int
   2456                 elif eq is None:
   2457                     @dataclass(frozen=frozen)
   2458                     class C(base):
   2459                         i: int
   2460                 else:
   2461                     @dataclass(frozen=frozen, eq=eq)
   2462                     class C(base):
   2463                         i: int
   2464 
   2465                 # Now, make sure it hashes as expected.
   2466                 if expected == 'unhashable':
   2467                     c = C(10)
   2468                     with self.assertRaisesRegex(TypeError, 'unhashable type'):
   2469                         hash(c)
   2470 
   2471                 elif expected == 'base':
   2472                     self.assertEqual(hash(C(10)), 301)
   2473 
   2474                 elif expected == 'object':
   2475                     # I'm not sure what test to use here.  object's
   2476                     #  hash isn't based on id(), so calling hash()
   2477                     #  won't tell us much.  So, just check the
   2478                     #  function used is object's.
   2479                     self.assertIs(C.__hash__, object.__hash__)
   2480 
   2481                 elif expected == 'tuple':
   2482                     self.assertEqual(hash(C(42)), hash((42,)))
   2483 
   2484                 else:
   2485                     assert False, f'unknown value for expected={expected!r}'
   2486 
   2487 
   2488 class TestFrozen(unittest.TestCase):
   2489     def test_frozen(self):
   2490         @dataclass(frozen=True)
   2491         class C:
   2492             i: int
   2493 
   2494         c = C(10)
   2495         self.assertEqual(c.i, 10)
   2496         with self.assertRaises(FrozenInstanceError):
   2497             c.i = 5
   2498         self.assertEqual(c.i, 10)
   2499 
   2500     def test_inherit(self):
   2501         @dataclass(frozen=True)
   2502         class C:
   2503             i: int
   2504 
   2505         @dataclass(frozen=True)
   2506         class D(C):
   2507             j: int
   2508 
   2509         d = D(0, 10)
   2510         with self.assertRaises(FrozenInstanceError):
   2511             d.i = 5
   2512         with self.assertRaises(FrozenInstanceError):
   2513             d.j = 6
   2514         self.assertEqual(d.i, 0)
   2515         self.assertEqual(d.j, 10)
   2516 
   2517     # Test both ways: with an intermediate normal (non-dataclass)
   2518     #  class and without an intermediate class.
   2519     def test_inherit_nonfrozen_from_frozen(self):
   2520         for intermediate_class in [True, False]:
   2521             with self.subTest(intermediate_class=intermediate_class):
   2522                 @dataclass(frozen=True)
   2523                 class C:
   2524                     i: int
   2525 
   2526                 if intermediate_class:
   2527                     class I(C): pass
   2528                 else:
   2529                     I = C
   2530 
   2531                 with self.assertRaisesRegex(TypeError,
   2532                                             'cannot inherit non-frozen dataclass from a frozen one'):
   2533                     @dataclass
   2534                     class D(I):
   2535                         pass
   2536 
   2537     def test_inherit_frozen_from_nonfrozen(self):
   2538         for intermediate_class in [True, False]:
   2539             with self.subTest(intermediate_class=intermediate_class):
   2540                 @dataclass
   2541                 class C:
   2542                     i: int
   2543 
   2544                 if intermediate_class:
   2545                     class I(C): pass
   2546                 else:
   2547                     I = C
   2548 
   2549                 with self.assertRaisesRegex(TypeError,
   2550                                             'cannot inherit frozen dataclass from a non-frozen one'):
   2551                     @dataclass(frozen=True)
   2552                     class D(I):
   2553                         pass
   2554 
   2555     def test_inherit_from_normal_class(self):
   2556         for intermediate_class in [True, False]:
   2557             with self.subTest(intermediate_class=intermediate_class):
   2558                 class C:
   2559                     pass
   2560 
   2561                 if intermediate_class:
   2562                     class I(C): pass
   2563                 else:
   2564                     I = C
   2565 
   2566                 @dataclass(frozen=True)
   2567                 class D(I):
   2568                     i: int
   2569 
   2570             d = D(10)
   2571             with self.assertRaises(FrozenInstanceError):
   2572                 d.i = 5
   2573 
   2574     def test_non_frozen_normal_derived(self):
   2575         # See bpo-32953.
   2576 
   2577         @dataclass(frozen=True)
   2578         class D:
   2579             x: int
   2580             y: int = 10
   2581 
   2582         class S(D):
   2583             pass
   2584 
   2585         s = S(3)
   2586         self.assertEqual(s.x, 3)
   2587         self.assertEqual(s.y, 10)
   2588         s.cached = True
   2589 
   2590         # But can't change the frozen attributes.
   2591         with self.assertRaises(FrozenInstanceError):
   2592             s.x = 5
   2593         with self.assertRaises(FrozenInstanceError):
   2594             s.y = 5
   2595         self.assertEqual(s.x, 3)
   2596         self.assertEqual(s.y, 10)
   2597         self.assertEqual(s.cached, True)
   2598 
   2599     def test_overwriting_frozen(self):
   2600         # frozen uses __setattr__ and __delattr__.
   2601         with self.assertRaisesRegex(TypeError,
   2602                                     'Cannot overwrite attribute __setattr__'):
   2603             @dataclass(frozen=True)
   2604             class C:
   2605                 x: int
   2606                 def __setattr__(self):
   2607                     pass
   2608 
   2609         with self.assertRaisesRegex(TypeError,
   2610                                     'Cannot overwrite attribute __delattr__'):
   2611             @dataclass(frozen=True)
   2612             class C:
   2613                 x: int
   2614                 def __delattr__(self):
   2615                     pass
   2616 
   2617         @dataclass(frozen=False)
   2618         class C:
   2619             x: int
   2620             def __setattr__(self, name, value):
   2621                 self.__dict__['x'] = value * 2
   2622         self.assertEqual(C(10).x, 20)
   2623 
   2624     def test_frozen_hash(self):
   2625         @dataclass(frozen=True)
   2626         class C:
   2627             x: Any
   2628 
   2629         # If x is immutable, we can compute the hash.  No exception is
   2630         # raised.
   2631         hash(C(3))
   2632 
   2633         # If x is mutable, computing the hash is an error.
   2634         with self.assertRaisesRegex(TypeError, 'unhashable type'):
   2635             hash(C({}))
   2636 
   2637 
   2638 class TestSlots(unittest.TestCase):
   2639     def test_simple(self):
   2640         @dataclass
   2641         class C:
   2642             __slots__ = ('x',)
   2643             x: Any
   2644 
   2645         # There was a bug where a variable in a slot was assumed to
   2646         #  also have a default value (of type
   2647         #  types.MemberDescriptorType).
   2648         with self.assertRaisesRegex(TypeError,
   2649                                     r"__init__\(\) missing 1 required positional argument: 'x'"):
   2650             C()
   2651 
   2652         # We can create an instance, and assign to x.
   2653         c = C(10)
   2654         self.assertEqual(c.x, 10)
   2655         c.x = 5
   2656         self.assertEqual(c.x, 5)
   2657 
   2658         # We can't assign to anything else.
   2659         with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
   2660             c.y = 5
   2661 
   2662     def test_derived_added_field(self):
   2663         # See bpo-33100.
   2664         @dataclass
   2665         class Base:
   2666             __slots__ = ('x',)
   2667             x: Any
   2668 
   2669         @dataclass
   2670         class Derived(Base):
   2671             x: int
   2672             y: int
   2673 
   2674         d = Derived(1, 2)
   2675         self.assertEqual((d.x, d.y), (1, 2))
   2676 
   2677         # We can add a new field to the derived instance.
   2678         d.z = 10
   2679 
   2680 class TestDescriptors(unittest.TestCase):
   2681     def test_set_name(self):
   2682         # See bpo-33141.
   2683 
   2684         # Create a descriptor.
   2685         class D:
   2686             def __set_name__(self, owner, name):
   2687                 self.name = name + 'x'
   2688             def __get__(self, instance, owner):
   2689                 if instance is not None:
   2690                     return 1
   2691                 return self
   2692 
   2693         # This is the case of just normal descriptor behavior, no
   2694         #  dataclass code is involved in initializing the descriptor.
   2695         @dataclass
   2696         class C:
   2697             c: int=D()
   2698         self.assertEqual(C.c.name, 'cx')
   2699 
   2700         # Now test with a default value and init=False, which is the
   2701         #  only time this is really meaningful.  If not using
   2702         #  init=False, then the descriptor will be overwritten, anyway.
   2703         @dataclass
   2704         class C:
   2705             c: int=field(default=D(), init=False)
   2706         self.assertEqual(C.c.name, 'cx')
   2707         self.assertEqual(C().c, 1)
   2708 
   2709     def test_non_descriptor(self):
   2710         # PEP 487 says __set_name__ should work on non-descriptors.
   2711         # Create a descriptor.
   2712 
   2713         class D:
   2714             def __set_name__(self, owner, name):
   2715                 self.name = name + 'x'
   2716 
   2717         @dataclass
   2718         class C:
   2719             c: int=field(default=D(), init=False)
   2720         self.assertEqual(C.c.name, 'cx')
   2721 
   2722     def test_lookup_on_instance(self):
   2723         # See bpo-33175.
   2724         class D:
   2725             pass
   2726 
   2727         d = D()
   2728         # Create an attribute on the instance, not type.
   2729         d.__set_name__ = Mock()
   2730 
   2731         # Make sure d.__set_name__ is not called.
   2732         @dataclass
   2733         class C:
   2734             i: int=field(default=d, init=False)
   2735 
   2736         self.assertEqual(d.__set_name__.call_count, 0)
   2737 
   2738     def test_lookup_on_class(self):
   2739         # See bpo-33175.
   2740         class D:
   2741             pass
   2742         D.__set_name__ = Mock()
   2743 
   2744         # Make sure D.__set_name__ is called.
   2745         @dataclass
   2746         class C:
   2747             i: int=field(default=D(), init=False)
   2748 
   2749         self.assertEqual(D.__set_name__.call_count, 1)
   2750 
   2751 
   2752 class TestStringAnnotations(unittest.TestCase):
   2753     def test_classvar(self):
   2754         # Some expressions recognized as ClassVar really aren't.  But
   2755         #  if you're using string annotations, it's not an exact
   2756         #  science.
   2757         # These tests assume that both "import typing" and "from
   2758         # typing import *" have been run in this file.
   2759         for typestr in ('ClassVar[int]',
   2760                         'ClassVar [int]'
   2761                         ' ClassVar [int]',
   2762                         'ClassVar',
   2763                         ' ClassVar ',
   2764                         'typing.ClassVar[int]',
   2765                         'typing.ClassVar[str]',
   2766                         ' typing.ClassVar[str]',
   2767                         'typing .ClassVar[str]',
   2768                         'typing. ClassVar[str]',
   2769                         'typing.ClassVar [str]',
   2770                         'typing.ClassVar [ str]',
   2771 
   2772                         # Not syntactically valid, but these will
   2773                         #  be treated as ClassVars.
   2774                         'typing.ClassVar.[int]',
   2775                         'typing.ClassVar+',
   2776                         ):
   2777             with self.subTest(typestr=typestr):
   2778                 @dataclass
   2779                 class C:
   2780                     x: typestr
   2781 
   2782                 # x is a ClassVar, so C() takes no args.
   2783                 C()
   2784 
   2785                 # And it won't appear in the class's dict because it doesn't
   2786                 # have a default.
   2787                 self.assertNotIn('x', C.__dict__)
   2788 
   2789     def test_isnt_classvar(self):
   2790         for typestr in ('CV',
   2791                         't.ClassVar',
   2792                         't.ClassVar[int]',
   2793                         'typing..ClassVar[int]',
   2794                         'Classvar',
   2795                         'Classvar[int]',
   2796                         'typing.ClassVarx[int]',
   2797                         'typong.ClassVar[int]',
   2798                         'dataclasses.ClassVar[int]',
   2799                         'typingxClassVar[str]',
   2800                         ):
   2801             with self.subTest(typestr=typestr):
   2802                 @dataclass
   2803                 class C:
   2804                     x: typestr
   2805 
   2806                 # x is not a ClassVar, so C() takes one arg.
   2807                 self.assertEqual(C(10).x, 10)
   2808 
   2809     def test_initvar(self):
   2810         # These tests assume that both "import dataclasses" and "from
   2811         #  dataclasses import *" have been run in this file.
   2812         for typestr in ('InitVar[int]',
   2813                         'InitVar [int]'
   2814                         ' InitVar [int]',
   2815                         'InitVar',
   2816                         ' InitVar ',
   2817                         'dataclasses.InitVar[int]',
   2818                         'dataclasses.InitVar[str]',
   2819                         ' dataclasses.InitVar[str]',
   2820                         'dataclasses .InitVar[str]',
   2821                         'dataclasses. InitVar[str]',
   2822                         'dataclasses.InitVar [str]',
   2823                         'dataclasses.InitVar [ str]',
   2824 
   2825                         # Not syntactically valid, but these will
   2826                         #  be treated as InitVars.
   2827                         'dataclasses.InitVar.[int]',
   2828                         'dataclasses.InitVar+',
   2829                         ):
   2830             with self.subTest(typestr=typestr):
   2831                 @dataclass
   2832                 class C:
   2833                     x: typestr
   2834 
   2835                 # x is an InitVar, so doesn't create a member.
   2836                 with self.assertRaisesRegex(AttributeError,
   2837                                             "object has no attribute 'x'"):
   2838                     C(1).x
   2839 
   2840     def test_isnt_initvar(self):
   2841         for typestr in ('IV',
   2842                         'dc.InitVar',
   2843                         'xdataclasses.xInitVar',
   2844                         'typing.xInitVar[int]',
   2845                         ):
   2846             with self.subTest(typestr=typestr):
   2847                 @dataclass
   2848                 class C:
   2849                     x: typestr
   2850 
   2851                 # x is not an InitVar, so there will be a member x.
   2852                 self.assertEqual(C(10).x, 10)
   2853 
   2854     def test_classvar_module_level_import(self):
   2855         from test import dataclass_module_1
   2856         from test import dataclass_module_1_str
   2857         from test import dataclass_module_2
   2858         from test import dataclass_module_2_str
   2859 
   2860         for m in (dataclass_module_1, dataclass_module_1_str,
   2861                   dataclass_module_2, dataclass_module_2_str,
   2862                   ):
   2863             with self.subTest(m=m):
   2864                 # There's a difference in how the ClassVars are
   2865                 # interpreted when using string annotations or
   2866                 # not. See the imported modules for details.
   2867                 if m.USING_STRINGS:
   2868                     c = m.CV(10)
   2869                 else:
   2870                     c = m.CV()
   2871                 self.assertEqual(c.cv0, 20)
   2872 
   2873 
   2874                 # There's a difference in how the InitVars are
   2875                 # interpreted when using string annotations or
   2876                 # not. See the imported modules for details.
   2877                 c = m.IV(0, 1, 2, 3, 4)
   2878 
   2879                 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
   2880                     with self.subTest(field_name=field_name):
   2881                         with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
   2882                             # Since field_name is an InitVar, it's
   2883                             # not an instance field.
   2884                             getattr(c, field_name)
   2885 
   2886                 if m.USING_STRINGS:
   2887                     # iv4 is interpreted as a normal field.
   2888                     self.assertIn('not_iv4', c.__dict__)
   2889                     self.assertEqual(c.not_iv4, 4)
   2890                 else:
   2891                     # iv4 is interpreted as an InitVar, so it
   2892                     # won't exist on the instance.
   2893                     self.assertNotIn('not_iv4', c.__dict__)
   2894 
   2895 
   2896 class TestMakeDataclass(unittest.TestCase):
   2897     def test_simple(self):
   2898         C = make_dataclass('C',
   2899                            [('x', int),
   2900                             ('y', int, field(default=5))],
   2901                            namespace={'add_one': lambda self: self.x + 1})
   2902         c = C(10)
   2903         self.assertEqual((c.x, c.y), (10, 5))
   2904         self.assertEqual(c.add_one(), 11)
   2905 
   2906 
   2907     def test_no_mutate_namespace(self):
   2908         # Make sure a provided namespace isn't mutated.
   2909         ns = {}
   2910         C = make_dataclass('C',
   2911                            [('x', int),
   2912                             ('y', int, field(default=5))],
   2913                            namespace=ns)
   2914         self.assertEqual(ns, {})
   2915 
   2916     def test_base(self):
   2917         class Base1:
   2918             pass
   2919         class Base2:
   2920             pass
   2921         C = make_dataclass('C',
   2922                            [('x', int)],
   2923                            bases=(Base1, Base2))
   2924         c = C(2)
   2925         self.assertIsInstance(c, C)
   2926         self.assertIsInstance(c, Base1)
   2927         self.assertIsInstance(c, Base2)
   2928 
   2929     def test_base_dataclass(self):
   2930         @dataclass
   2931         class Base1:
   2932             x: int
   2933         class Base2:
   2934             pass
   2935         C = make_dataclass('C',
   2936                            [('y', int)],
   2937                            bases=(Base1, Base2))
   2938         with self.assertRaisesRegex(TypeError, 'required positional'):
   2939             c = C(2)
   2940         c = C(1, 2)
   2941         self.assertIsInstance(c, C)
   2942         self.assertIsInstance(c, Base1)
   2943         self.assertIsInstance(c, Base2)
   2944 
   2945         self.assertEqual((c.x, c.y), (1, 2))
   2946 
   2947     def test_init_var(self):
   2948         def post_init(self, y):
   2949             self.x *= y
   2950 
   2951         C = make_dataclass('C',
   2952                            [('x', int),
   2953                             ('y', InitVar[int]),
   2954                             ],
   2955                            namespace={'__post_init__': post_init},
   2956                            )
   2957         c = C(2, 3)
   2958         self.assertEqual(vars(c), {'x': 6})
   2959         self.assertEqual(len(fields(c)), 1)
   2960 
   2961     def test_class_var(self):
   2962         C = make_dataclass('C',
   2963                            [('x', int),
   2964                             ('y', ClassVar[int], 10),
   2965                             ('z', ClassVar[int], field(default=20)),
   2966                             ])
   2967         c = C(1)
   2968         self.assertEqual(vars(c), {'x': 1})
   2969         self.assertEqual(len(fields(c)), 1)
   2970         self.assertEqual(C.y, 10)
   2971         self.assertEqual(C.z, 20)
   2972 
   2973     def test_other_params(self):
   2974         C = make_dataclass('C',
   2975                            [('x', int),
   2976                             ('y', ClassVar[int], 10),
   2977                             ('z', ClassVar[int], field(default=20)),
   2978                             ],
   2979                            init=False)
   2980         # Make sure we have a repr, but no init.
   2981         self.assertNotIn('__init__', vars(C))
   2982         self.assertIn('__repr__', vars(C))
   2983 
   2984         # Make sure random other params don't work.
   2985         with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
   2986             C = make_dataclass('C',
   2987                                [],
   2988                                xxinit=False)
   2989 
   2990     def test_no_types(self):
   2991         C = make_dataclass('Point', ['x', 'y', 'z'])
   2992         c = C(1, 2, 3)
   2993         self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
   2994         self.assertEqual(C.__annotations__, {'x': 'typing.Any',
   2995                                              'y': 'typing.Any',
   2996                                              'z': 'typing.Any'})
   2997 
   2998         C = make_dataclass('Point', ['x', ('y', int), 'z'])
   2999         c = C(1, 2, 3)
   3000         self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
   3001         self.assertEqual(C.__annotations__, {'x': 'typing.Any',
   3002                                              'y': int,
   3003                                              'z': 'typing.Any'})
   3004 
   3005     def test_invalid_type_specification(self):
   3006         for bad_field in [(),
   3007                           (1, 2, 3, 4),
   3008                           ]:
   3009             with self.subTest(bad_field=bad_field):
   3010                 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
   3011                     make_dataclass('C', ['a', bad_field])
   3012 
   3013         # And test for things with no len().
   3014         for bad_field in [float,
   3015                           lambda x:x,
   3016                           ]:
   3017             with self.subTest(bad_field=bad_field):
   3018                 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
   3019                     make_dataclass('C', ['a', bad_field])
   3020 
   3021     def test_duplicate_field_names(self):
   3022         for field in ['a', 'ab']:
   3023             with self.subTest(field=field):
   3024                 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
   3025                     make_dataclass('C', [field, 'a', field])
   3026 
   3027     def test_keyword_field_names(self):
   3028         for field in ['for', 'async', 'await', 'as']:
   3029             with self.subTest(field=field):
   3030                 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
   3031                     make_dataclass('C', ['a', field])
   3032                 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
   3033                     make_dataclass('C', [field])
   3034                 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
   3035                     make_dataclass('C', [field, 'a'])
   3036 
   3037     def test_non_identifier_field_names(self):
   3038         for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
   3039             with self.subTest(field=field):
   3040                 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
   3041                     make_dataclass('C', ['a', field])
   3042                 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
   3043                     make_dataclass('C', [field])
   3044                 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
   3045                     make_dataclass('C', [field, 'a'])
   3046 
   3047     def test_underscore_field_names(self):
   3048         # Unlike namedtuple, it's okay if dataclass field names have
   3049         # an underscore.
   3050         make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
   3051 
   3052     def test_funny_class_names_names(self):
   3053         # No reason to prevent weird class names, since
   3054         # types.new_class allows them.
   3055         for classname in ['()', 'x,y', '*', '2@3', '']:
   3056             with self.subTest(classname=classname):
   3057                 C = make_dataclass(classname, ['a', 'b'])
   3058                 self.assertEqual(C.__name__, classname)
   3059 
   3060 class TestReplace(unittest.TestCase):
   3061     def test(self):
   3062         @dataclass(frozen=True)
   3063         class C:
   3064             x: int
   3065             y: int
   3066 
   3067         c = C(1, 2)
   3068         c1 = replace(c, x=3)
   3069         self.assertEqual(c1.x, 3)
   3070         self.assertEqual(c1.y, 2)
   3071 
   3072     def test_frozen(self):
   3073         @dataclass(frozen=True)
   3074         class C:
   3075             x: int
   3076             y: int
   3077             z: int = field(init=False, default=10)
   3078             t: int = field(init=False, default=100)
   3079 
   3080         c = C(1, 2)
   3081         c1 = replace(c, x=3)
   3082         self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
   3083         self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
   3084 
   3085 
   3086         with self.assertRaisesRegex(ValueError, 'init=False'):
   3087             replace(c, x=3, z=20, t=50)
   3088         with self.assertRaisesRegex(ValueError, 'init=False'):
   3089             replace(c, z=20)
   3090             replace(c, x=3, z=20, t=50)
   3091 
   3092         # Make sure the result is still frozen.
   3093         with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
   3094             c1.x = 3
   3095 
   3096         # Make sure we can't replace an attribute that doesn't exist,
   3097         #  if we're also replacing one that does exist.  Test this
   3098         #  here, because setting attributes on frozen instances is
   3099         #  handled slightly differently from non-frozen ones.
   3100         with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
   3101                                              "keyword argument 'a'"):
   3102             c1 = replace(c, x=20, a=5)
   3103 
   3104     def test_invalid_field_name(self):
   3105         @dataclass(frozen=True)
   3106         class C:
   3107             x: int
   3108             y: int
   3109 
   3110         c = C(1, 2)
   3111         with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
   3112                                     "keyword argument 'z'"):
   3113             c1 = replace(c, z=3)
   3114 
   3115     def test_invalid_object(self):
   3116         @dataclass(frozen=True)
   3117         class C:
   3118             x: int
   3119             y: int
   3120 
   3121         with self.assertRaisesRegex(TypeError, 'dataclass instance'):
   3122             replace(C, x=3)
   3123 
   3124         with self.assertRaisesRegex(TypeError, 'dataclass instance'):
   3125             replace(0, x=3)
   3126 
   3127     def test_no_init(self):
   3128         @dataclass
   3129         class C:
   3130             x: int
   3131             y: int = field(init=False, default=10)
   3132 
   3133         c = C(1)
   3134         c.y = 20
   3135 
   3136         # Make sure y gets the default value.
   3137         c1 = replace(c, x=5)
   3138         self.assertEqual((c1.x, c1.y), (5, 10))
   3139 
   3140         # Trying to replace y is an error.
   3141         with self.assertRaisesRegex(ValueError, 'init=False'):
   3142             replace(c, x=2, y=30)
   3143 
   3144         with self.assertRaisesRegex(ValueError, 'init=False'):
   3145             replace(c, y=30)
   3146 
   3147     def test_classvar(self):
   3148         @dataclass
   3149         class C:
   3150             x: int
   3151             y: ClassVar[int] = 1000
   3152 
   3153         c = C(1)
   3154         d = C(2)
   3155 
   3156         self.assertIs(c.y, d.y)
   3157         self.assertEqual(c.y, 1000)
   3158 
   3159         # Trying to replace y is an error: can't replace ClassVars.
   3160         with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
   3161                                     "unexpected keyword argument 'y'"):
   3162             replace(c, y=30)
   3163 
   3164         replace(c, x=5)
   3165 
   3166     def test_initvar_is_specified(self):
   3167         @dataclass
   3168         class C:
   3169             x: int
   3170             y: InitVar[int]
   3171 
   3172             def __post_init__(self, y):
   3173                 self.x *= y
   3174 
   3175         c = C(1, 10)
   3176         self.assertEqual(c.x, 10)
   3177         with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
   3178                                     "specified with replace()"):
   3179             replace(c, x=3)
   3180         c = replace(c, x=3, y=5)
   3181         self.assertEqual(c.x, 15)
   3182 
   3183     def test_recursive_repr(self):
   3184         @dataclass
   3185         class C:
   3186             f: "C"
   3187 
   3188         c = C(None)
   3189         c.f = c
   3190         self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
   3191 
   3192     def test_recursive_repr_two_attrs(self):
   3193         @dataclass
   3194         class C:
   3195             f: "C"
   3196             g: "C"
   3197 
   3198         c = C(None, None)
   3199         c.f = c
   3200         c.g = c
   3201         self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
   3202                                   ".<locals>.C(f=..., g=...)")
   3203 
   3204     def test_recursive_repr_indirection(self):
   3205         @dataclass
   3206         class C:
   3207             f: "D"
   3208 
   3209         @dataclass
   3210         class D:
   3211             f: "C"
   3212 
   3213         c = C(None)
   3214         d = D(None)
   3215         c.f = d
   3216         d.f = c
   3217         self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
   3218                                   ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
   3219                                   ".<locals>.D(f=...))")
   3220 
   3221     def test_recursive_repr_indirection_two(self):
   3222         @dataclass
   3223         class C:
   3224             f: "D"
   3225 
   3226         @dataclass
   3227         class D:
   3228             f: "E"
   3229 
   3230         @dataclass
   3231         class E:
   3232             f: "C"
   3233 
   3234         c = C(None)
   3235         d = D(None)
   3236         e = E(None)
   3237         c.f = d
   3238         d.f = e
   3239         e.f = c
   3240         self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
   3241                                   ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
   3242                                   ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
   3243                                   ".<locals>.E(f=...)))")
   3244 
   3245     def test_recursive_repr_two_attrs(self):
   3246         @dataclass
   3247         class C:
   3248             f: "C"
   3249             g: "C"
   3250 
   3251         c = C(None, None)
   3252         c.f = c
   3253         c.g = c
   3254         self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
   3255                                   ".<locals>.C(f=..., g=...)")
   3256 
   3257     def test_recursive_repr_misc_attrs(self):
   3258         @dataclass
   3259         class C:
   3260             f: "C"
   3261             g: int
   3262 
   3263         c = C(None, 1)
   3264         c.f = c
   3265         self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
   3266                                   ".<locals>.C(f=..., g=1)")
   3267 
   3268     ## def test_initvar(self):
   3269     ##     @dataclass
   3270     ##     class C:
   3271     ##         x: int
   3272     ##         y: InitVar[int]
   3273 
   3274     ##     c = C(1, 10)
   3275     ##     d = C(2, 20)
   3276 
   3277     ##     # In our case, replacing an InitVar is a no-op
   3278     ##     self.assertEqual(c, replace(c, y=5))
   3279 
   3280     ##     replace(c, x=5)
   3281 
   3282 
   3283 if __name__ == '__main__':
   3284     unittest.main()
   3285