1 # Deliberately use "from dataclasses import *". Every name in __all__ 2 # is tested, so they all must be present. This is a way to catch 3 # missing ones. 4 5 from dataclasses import * 6 7 import pickle 8 import inspect 9 import builtins 10 import unittest 11 from unittest.mock import Mock 12 from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional 13 from collections import deque, OrderedDict, namedtuple 14 from functools import total_ordering 15 16 import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation. 17 import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation. 18 19 # Just any custom exception we can catch. 20 class CustomError(Exception): pass 21 22 class TestCase(unittest.TestCase): 23 def test_no_fields(self): 24 @dataclass 25 class C: 26 pass 27 28 o = C() 29 self.assertEqual(len(fields(C)), 0) 30 31 def test_no_fields_but_member_variable(self): 32 @dataclass 33 class C: 34 i = 0 35 36 o = C() 37 self.assertEqual(len(fields(C)), 0) 38 39 def test_one_field_no_default(self): 40 @dataclass 41 class C: 42 x: int 43 44 o = C(42) 45 self.assertEqual(o.x, 42) 46 47 def test_named_init_params(self): 48 @dataclass 49 class C: 50 x: int 51 52 o = C(x=32) 53 self.assertEqual(o.x, 32) 54 55 def test_two_fields_one_default(self): 56 @dataclass 57 class C: 58 x: int 59 y: int = 0 60 61 o = C(3) 62 self.assertEqual((o.x, o.y), (3, 0)) 63 64 # Non-defaults following defaults. 65 with self.assertRaisesRegex(TypeError, 66 "non-default argument 'y' follows " 67 "default argument"): 68 @dataclass 69 class C: 70 x: int = 0 71 y: int 72 73 # A derived class adds a non-default field after a default one. 74 with self.assertRaisesRegex(TypeError, 75 "non-default argument 'y' follows " 76 "default argument"): 77 @dataclass 78 class B: 79 x: int = 0 80 81 @dataclass 82 class C(B): 83 y: int 84 85 # Override a base class field and add a default to 86 # a field which didn't use to have a default. 87 with self.assertRaisesRegex(TypeError, 88 "non-default argument 'y' follows " 89 "default argument"): 90 @dataclass 91 class B: 92 x: int 93 y: int 94 95 @dataclass 96 class C(B): 97 x: int = 0 98 99 def test_overwrite_hash(self): 100 # Test that declaring this class isn't an error. It should 101 # use the user-provided __hash__. 102 @dataclass(frozen=True) 103 class C: 104 x: int 105 def __hash__(self): 106 return 301 107 self.assertEqual(hash(C(100)), 301) 108 109 # Test that declaring this class isn't an error. It should 110 # use the generated __hash__. 111 @dataclass(frozen=True) 112 class C: 113 x: int 114 def __eq__(self, other): 115 return False 116 self.assertEqual(hash(C(100)), hash((100,))) 117 118 # But this one should generate an exception, because with 119 # unsafe_hash=True, it's an error to have a __hash__ defined. 120 with self.assertRaisesRegex(TypeError, 121 'Cannot overwrite attribute __hash__'): 122 @dataclass(unsafe_hash=True) 123 class C: 124 def __hash__(self): 125 pass 126 127 # Creating this class should not generate an exception, 128 # because even though __hash__ exists before @dataclass is 129 # called, (due to __eq__ being defined), since it's None 130 # that's okay. 131 @dataclass(unsafe_hash=True) 132 class C: 133 x: int 134 def __eq__(self): 135 pass 136 # The generated hash function works as we'd expect. 137 self.assertEqual(hash(C(10)), hash((10,))) 138 139 # Creating this class should generate an exception, because 140 # __hash__ exists and is not None, which it would be if it 141 # had been auto-generated due to __eq__ being defined. 142 with self.assertRaisesRegex(TypeError, 143 'Cannot overwrite attribute __hash__'): 144 @dataclass(unsafe_hash=True) 145 class C: 146 x: int 147 def __eq__(self): 148 pass 149 def __hash__(self): 150 pass 151 152 def test_overwrite_fields_in_derived_class(self): 153 # Note that x from C1 replaces x in Base, but the order remains 154 # the same as defined in Base. 155 @dataclass 156 class Base: 157 x: Any = 15.0 158 y: int = 0 159 160 @dataclass 161 class C1(Base): 162 z: int = 10 163 x: int = 15 164 165 o = Base() 166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)') 167 168 o = C1() 169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)') 170 171 o = C1(x=5) 172 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)') 173 174 def test_field_named_self(self): 175 @dataclass 176 class C: 177 self: str 178 c=C('foo') 179 self.assertEqual(c.self, 'foo') 180 181 # Make sure the first parameter is not named 'self'. 182 sig = inspect.signature(C.__init__) 183 first = next(iter(sig.parameters)) 184 self.assertNotEqual('self', first) 185 186 # But we do use 'self' if no field named self. 187 @dataclass 188 class C: 189 selfx: str 190 191 # Make sure the first parameter is named 'self'. 192 sig = inspect.signature(C.__init__) 193 first = next(iter(sig.parameters)) 194 self.assertEqual('self', first) 195 196 def test_field_named_object(self): 197 @dataclass 198 class C: 199 object: str 200 c = C('foo') 201 self.assertEqual(c.object, 'foo') 202 203 def test_field_named_object_frozen(self): 204 @dataclass(frozen=True) 205 class C: 206 object: str 207 c = C('foo') 208 self.assertEqual(c.object, 'foo') 209 210 def test_field_named_like_builtin(self): 211 # Attribute names can shadow built-in names 212 # since code generation is used. 213 # Ensure that this is not happening. 214 exclusions = {'None', 'True', 'False'} 215 builtins_names = sorted( 216 b for b in builtins.__dict__.keys() 217 if not b.startswith('__') and b not in exclusions 218 ) 219 attributes = [(name, str) for name in builtins_names] 220 C = make_dataclass('C', attributes) 221 222 c = C(*[name for name in builtins_names]) 223 224 for name in builtins_names: 225 self.assertEqual(getattr(c, name), name) 226 227 def test_field_named_like_builtin_frozen(self): 228 # Attribute names can shadow built-in names 229 # since code generation is used. 230 # Ensure that this is not happening 231 # for frozen data classes. 232 exclusions = {'None', 'True', 'False'} 233 builtins_names = sorted( 234 b for b in builtins.__dict__.keys() 235 if not b.startswith('__') and b not in exclusions 236 ) 237 attributes = [(name, str) for name in builtins_names] 238 C = make_dataclass('C', attributes, frozen=True) 239 240 c = C(*[name for name in builtins_names]) 241 242 for name in builtins_names: 243 self.assertEqual(getattr(c, name), name) 244 245 def test_0_field_compare(self): 246 # Ensure that order=False is the default. 247 @dataclass 248 class C0: 249 pass 250 251 @dataclass(order=False) 252 class C1: 253 pass 254 255 for cls in [C0, C1]: 256 with self.subTest(cls=cls): 257 self.assertEqual(cls(), cls()) 258 for idx, fn in enumerate([lambda a, b: a < b, 259 lambda a, b: a <= b, 260 lambda a, b: a > b, 261 lambda a, b: a >= b]): 262 with self.subTest(idx=idx): 263 with self.assertRaisesRegex(TypeError, 264 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 265 fn(cls(), cls()) 266 267 @dataclass(order=True) 268 class C: 269 pass 270 self.assertLessEqual(C(), C()) 271 self.assertGreaterEqual(C(), C()) 272 273 def test_1_field_compare(self): 274 # Ensure that order=False is the default. 275 @dataclass 276 class C0: 277 x: int 278 279 @dataclass(order=False) 280 class C1: 281 x: int 282 283 for cls in [C0, C1]: 284 with self.subTest(cls=cls): 285 self.assertEqual(cls(1), cls(1)) 286 self.assertNotEqual(cls(0), cls(1)) 287 for idx, fn in enumerate([lambda a, b: a < b, 288 lambda a, b: a <= b, 289 lambda a, b: a > b, 290 lambda a, b: a >= b]): 291 with self.subTest(idx=idx): 292 with self.assertRaisesRegex(TypeError, 293 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 294 fn(cls(0), cls(0)) 295 296 @dataclass(order=True) 297 class C: 298 x: int 299 self.assertLess(C(0), C(1)) 300 self.assertLessEqual(C(0), C(1)) 301 self.assertLessEqual(C(1), C(1)) 302 self.assertGreater(C(1), C(0)) 303 self.assertGreaterEqual(C(1), C(0)) 304 self.assertGreaterEqual(C(1), C(1)) 305 306 def test_simple_compare(self): 307 # Ensure that order=False is the default. 308 @dataclass 309 class C0: 310 x: int 311 y: int 312 313 @dataclass(order=False) 314 class C1: 315 x: int 316 y: int 317 318 for cls in [C0, C1]: 319 with self.subTest(cls=cls): 320 self.assertEqual(cls(0, 0), cls(0, 0)) 321 self.assertEqual(cls(1, 2), cls(1, 2)) 322 self.assertNotEqual(cls(1, 0), cls(0, 0)) 323 self.assertNotEqual(cls(1, 0), cls(1, 1)) 324 for idx, fn in enumerate([lambda a, b: a < b, 325 lambda a, b: a <= b, 326 lambda a, b: a > b, 327 lambda a, b: a >= b]): 328 with self.subTest(idx=idx): 329 with self.assertRaisesRegex(TypeError, 330 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"): 331 fn(cls(0, 0), cls(0, 0)) 332 333 @dataclass(order=True) 334 class C: 335 x: int 336 y: int 337 338 for idx, fn in enumerate([lambda a, b: a == b, 339 lambda a, b: a <= b, 340 lambda a, b: a >= b]): 341 with self.subTest(idx=idx): 342 self.assertTrue(fn(C(0, 0), C(0, 0))) 343 344 for idx, fn in enumerate([lambda a, b: a < b, 345 lambda a, b: a <= b, 346 lambda a, b: a != b]): 347 with self.subTest(idx=idx): 348 self.assertTrue(fn(C(0, 0), C(0, 1))) 349 self.assertTrue(fn(C(0, 1), C(1, 0))) 350 self.assertTrue(fn(C(1, 0), C(1, 1))) 351 352 for idx, fn in enumerate([lambda a, b: a > b, 353 lambda a, b: a >= b, 354 lambda a, b: a != b]): 355 with self.subTest(idx=idx): 356 self.assertTrue(fn(C(0, 1), C(0, 0))) 357 self.assertTrue(fn(C(1, 0), C(0, 1))) 358 self.assertTrue(fn(C(1, 1), C(1, 0))) 359 360 def test_compare_subclasses(self): 361 # Comparisons fail for subclasses, even if no fields 362 # are added. 363 @dataclass 364 class B: 365 i: int 366 367 @dataclass 368 class C(B): 369 pass 370 371 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False), 372 (lambda a, b: a != b, True)]): 373 with self.subTest(idx=idx): 374 self.assertEqual(fn(B(0), C(0)), expected) 375 376 for idx, fn in enumerate([lambda a, b: a < b, 377 lambda a, b: a <= b, 378 lambda a, b: a > b, 379 lambda a, b: a >= b]): 380 with self.subTest(idx=idx): 381 with self.assertRaisesRegex(TypeError, 382 "not supported between instances of 'B' and 'C'"): 383 fn(B(0), C(0)) 384 385 def test_eq_order(self): 386 # Test combining eq and order. 387 for (eq, order, result ) in [ 388 (False, False, 'neither'), 389 (False, True, 'exception'), 390 (True, False, 'eq_only'), 391 (True, True, 'both'), 392 ]: 393 with self.subTest(eq=eq, order=order): 394 if result == 'exception': 395 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'): 396 @dataclass(eq=eq, order=order) 397 class C: 398 pass 399 else: 400 @dataclass(eq=eq, order=order) 401 class C: 402 pass 403 404 if result == 'neither': 405 self.assertNotIn('__eq__', C.__dict__) 406 self.assertNotIn('__lt__', C.__dict__) 407 self.assertNotIn('__le__', C.__dict__) 408 self.assertNotIn('__gt__', C.__dict__) 409 self.assertNotIn('__ge__', C.__dict__) 410 elif result == 'both': 411 self.assertIn('__eq__', C.__dict__) 412 self.assertIn('__lt__', C.__dict__) 413 self.assertIn('__le__', C.__dict__) 414 self.assertIn('__gt__', C.__dict__) 415 self.assertIn('__ge__', C.__dict__) 416 elif result == 'eq_only': 417 self.assertIn('__eq__', C.__dict__) 418 self.assertNotIn('__lt__', C.__dict__) 419 self.assertNotIn('__le__', C.__dict__) 420 self.assertNotIn('__gt__', C.__dict__) 421 self.assertNotIn('__ge__', C.__dict__) 422 else: 423 assert False, f'unknown result {result!r}' 424 425 def test_field_no_default(self): 426 @dataclass 427 class C: 428 x: int = field() 429 430 self.assertEqual(C(5).x, 5) 431 432 with self.assertRaisesRegex(TypeError, 433 r"__init__\(\) missing 1 required " 434 "positional argument: 'x'"): 435 C() 436 437 def test_field_default(self): 438 default = object() 439 @dataclass 440 class C: 441 x: object = field(default=default) 442 443 self.assertIs(C.x, default) 444 c = C(10) 445 self.assertEqual(c.x, 10) 446 447 # If we delete the instance attribute, we should then see the 448 # class attribute. 449 del c.x 450 self.assertIs(c.x, default) 451 452 self.assertIs(C().x, default) 453 454 def test_not_in_repr(self): 455 @dataclass 456 class C: 457 x: int = field(repr=False) 458 with self.assertRaises(TypeError): 459 C() 460 c = C(10) 461 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()') 462 463 @dataclass 464 class C: 465 x: int = field(repr=False) 466 y: int 467 c = C(10, 20) 468 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)') 469 470 def test_not_in_compare(self): 471 @dataclass 472 class C: 473 x: int = 0 474 y: int = field(compare=False, default=4) 475 476 self.assertEqual(C(), C(0, 20)) 477 self.assertEqual(C(1, 10), C(1, 20)) 478 self.assertNotEqual(C(3), C(4, 10)) 479 self.assertNotEqual(C(3, 10), C(4, 10)) 480 481 def test_hash_field_rules(self): 482 # Test all 6 cases of: 483 # hash=True/False/None 484 # compare=True/False 485 for (hash_, compare, result ) in [ 486 (True, False, 'field' ), 487 (True, True, 'field' ), 488 (False, False, 'absent'), 489 (False, True, 'absent'), 490 (None, False, 'absent'), 491 (None, True, 'field' ), 492 ]: 493 with self.subTest(hash=hash_, compare=compare): 494 @dataclass(unsafe_hash=True) 495 class C: 496 x: int = field(compare=compare, hash=hash_, default=5) 497 498 if result == 'field': 499 # __hash__ contains the field. 500 self.assertEqual(hash(C(5)), hash((5,))) 501 elif result == 'absent': 502 # The field is not present in the hash. 503 self.assertEqual(hash(C(5)), hash(())) 504 else: 505 assert False, f'unknown result {result!r}' 506 507 def test_init_false_no_default(self): 508 # If init=False and no default value, then the field won't be 509 # present in the instance. 510 @dataclass 511 class C: 512 x: int = field(init=False) 513 514 self.assertNotIn('x', C().__dict__) 515 516 @dataclass 517 class C: 518 x: int 519 y: int = 0 520 z: int = field(init=False) 521 t: int = 10 522 523 self.assertNotIn('z', C(0).__dict__) 524 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0}) 525 526 def test_class_marker(self): 527 @dataclass 528 class C: 529 x: int 530 y: str = field(init=False, default=None) 531 z: str = field(repr=False) 532 533 the_fields = fields(C) 534 # the_fields is a tuple of 3 items, each value 535 # is in __annotations__. 536 self.assertIsInstance(the_fields, tuple) 537 for f in the_fields: 538 self.assertIs(type(f), Field) 539 self.assertIn(f.name, C.__annotations__) 540 541 self.assertEqual(len(the_fields), 3) 542 543 self.assertEqual(the_fields[0].name, 'x') 544 self.assertEqual(the_fields[0].type, int) 545 self.assertFalse(hasattr(C, 'x')) 546 self.assertTrue (the_fields[0].init) 547 self.assertTrue (the_fields[0].repr) 548 self.assertEqual(the_fields[1].name, 'y') 549 self.assertEqual(the_fields[1].type, str) 550 self.assertIsNone(getattr(C, 'y')) 551 self.assertFalse(the_fields[1].init) 552 self.assertTrue (the_fields[1].repr) 553 self.assertEqual(the_fields[2].name, 'z') 554 self.assertEqual(the_fields[2].type, str) 555 self.assertFalse(hasattr(C, 'z')) 556 self.assertTrue (the_fields[2].init) 557 self.assertFalse(the_fields[2].repr) 558 559 def test_field_order(self): 560 @dataclass 561 class B: 562 a: str = 'B:a' 563 b: str = 'B:b' 564 c: str = 'B:c' 565 566 @dataclass 567 class C(B): 568 b: str = 'C:b' 569 570 self.assertEqual([(f.name, f.default) for f in fields(C)], 571 [('a', 'B:a'), 572 ('b', 'C:b'), 573 ('c', 'B:c')]) 574 575 @dataclass 576 class D(B): 577 c: str = 'D:c' 578 579 self.assertEqual([(f.name, f.default) for f in fields(D)], 580 [('a', 'B:a'), 581 ('b', 'B:b'), 582 ('c', 'D:c')]) 583 584 @dataclass 585 class E(D): 586 a: str = 'E:a' 587 d: str = 'E:d' 588 589 self.assertEqual([(f.name, f.default) for f in fields(E)], 590 [('a', 'E:a'), 591 ('b', 'B:b'), 592 ('c', 'D:c'), 593 ('d', 'E:d')]) 594 595 def test_class_attrs(self): 596 # We only have a class attribute if a default value is 597 # specified, either directly or via a field with a default. 598 default = object() 599 @dataclass 600 class C: 601 x: int 602 y: int = field(repr=False) 603 z: object = default 604 t: int = field(default=100) 605 606 self.assertFalse(hasattr(C, 'x')) 607 self.assertFalse(hasattr(C, 'y')) 608 self.assertIs (C.z, default) 609 self.assertEqual(C.t, 100) 610 611 def test_disallowed_mutable_defaults(self): 612 # For the known types, don't allow mutable default values. 613 for typ, empty, non_empty in [(list, [], [1]), 614 (dict, {}, {0:1}), 615 (set, set(), set([1])), 616 ]: 617 with self.subTest(typ=typ): 618 # Can't use a zero-length value. 619 with self.assertRaisesRegex(ValueError, 620 f'mutable default {typ} for field ' 621 'x is not allowed'): 622 @dataclass 623 class Point: 624 x: typ = empty 625 626 627 # Nor a non-zero-length value 628 with self.assertRaisesRegex(ValueError, 629 f'mutable default {typ} for field ' 630 'y is not allowed'): 631 @dataclass 632 class Point: 633 y: typ = non_empty 634 635 # Check subtypes also fail. 636 class Subclass(typ): pass 637 638 with self.assertRaisesRegex(ValueError, 639 f"mutable default .*Subclass'>" 640 ' for field z is not allowed' 641 ): 642 @dataclass 643 class Point: 644 z: typ = Subclass() 645 646 # Because this is a ClassVar, it can be mutable. 647 @dataclass 648 class C: 649 z: ClassVar[typ] = typ() 650 651 # Because this is a ClassVar, it can be mutable. 652 @dataclass 653 class C: 654 x: ClassVar[typ] = Subclass() 655 656 def test_deliberately_mutable_defaults(self): 657 # If a mutable default isn't in the known list of 658 # (list, dict, set), then it's okay. 659 class Mutable: 660 def __init__(self): 661 self.l = [] 662 663 @dataclass 664 class C: 665 x: Mutable 666 667 # These 2 instances will share this value of x. 668 lst = Mutable() 669 o1 = C(lst) 670 o2 = C(lst) 671 self.assertEqual(o1, o2) 672 o1.x.l.extend([1, 2]) 673 self.assertEqual(o1, o2) 674 self.assertEqual(o1.x.l, [1, 2]) 675 self.assertIs(o1.x, o2.x) 676 677 def test_no_options(self): 678 # Call with dataclass(). 679 @dataclass() 680 class C: 681 x: int 682 683 self.assertEqual(C(42).x, 42) 684 685 def test_not_tuple(self): 686 # Make sure we can't be compared to a tuple. 687 @dataclass 688 class Point: 689 x: int 690 y: int 691 self.assertNotEqual(Point(1, 2), (1, 2)) 692 693 # And that we can't compare to another unrelated dataclass. 694 @dataclass 695 class C: 696 x: int 697 y: int 698 self.assertNotEqual(Point(1, 3), C(1, 3)) 699 700 def test_not_tuple(self): 701 # Test that some of the problems with namedtuple don't happen 702 # here. 703 @dataclass 704 class Point3D: 705 x: int 706 y: int 707 z: int 708 709 @dataclass 710 class Date: 711 year: int 712 month: int 713 day: int 714 715 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3)) 716 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3)) 717 718 # Make sure we can't unpack. 719 with self.assertRaisesRegex(TypeError, 'unpack'): 720 x, y, z = Point3D(4, 5, 6) 721 722 # Make sure another class with the same field names isn't 723 # equal. 724 @dataclass 725 class Point3Dv1: 726 x: int = 0 727 y: int = 0 728 z: int = 0 729 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1()) 730 731 def test_function_annotations(self): 732 # Some dummy class and instance to use as a default. 733 class F: 734 pass 735 f = F() 736 737 def validate_class(cls): 738 # First, check __annotations__, even though they're not 739 # function annotations. 740 self.assertEqual(cls.__annotations__['i'], int) 741 self.assertEqual(cls.__annotations__['j'], str) 742 self.assertEqual(cls.__annotations__['k'], F) 743 self.assertEqual(cls.__annotations__['l'], float) 744 self.assertEqual(cls.__annotations__['z'], complex) 745 746 # Verify __init__. 747 748 signature = inspect.signature(cls.__init__) 749 # Check the return type, should be None. 750 self.assertIs(signature.return_annotation, None) 751 752 # Check each parameter. 753 params = iter(signature.parameters.values()) 754 param = next(params) 755 # This is testing an internal name, and probably shouldn't be tested. 756 self.assertEqual(param.name, 'self') 757 param = next(params) 758 self.assertEqual(param.name, 'i') 759 self.assertIs (param.annotation, int) 760 self.assertEqual(param.default, inspect.Parameter.empty) 761 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 762 param = next(params) 763 self.assertEqual(param.name, 'j') 764 self.assertIs (param.annotation, str) 765 self.assertEqual(param.default, inspect.Parameter.empty) 766 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 767 param = next(params) 768 self.assertEqual(param.name, 'k') 769 self.assertIs (param.annotation, F) 770 # Don't test for the default, since it's set to MISSING. 771 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 772 param = next(params) 773 self.assertEqual(param.name, 'l') 774 self.assertIs (param.annotation, float) 775 # Don't test for the default, since it's set to MISSING. 776 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD) 777 self.assertRaises(StopIteration, next, params) 778 779 780 @dataclass 781 class C: 782 i: int 783 j: str 784 k: F = f 785 l: float=field(default=None) 786 z: complex=field(default=3+4j, init=False) 787 788 validate_class(C) 789 790 # Now repeat with __hash__. 791 @dataclass(frozen=True, unsafe_hash=True) 792 class C: 793 i: int 794 j: str 795 k: F = f 796 l: float=field(default=None) 797 z: complex=field(default=3+4j, init=False) 798 799 validate_class(C) 800 801 def test_missing_default(self): 802 # Test that MISSING works the same as a default not being 803 # specified. 804 @dataclass 805 class C: 806 x: int=field(default=MISSING) 807 with self.assertRaisesRegex(TypeError, 808 r'__init__\(\) missing 1 required ' 809 'positional argument'): 810 C() 811 self.assertNotIn('x', C.__dict__) 812 813 @dataclass 814 class D: 815 x: int 816 with self.assertRaisesRegex(TypeError, 817 r'__init__\(\) missing 1 required ' 818 'positional argument'): 819 D() 820 self.assertNotIn('x', D.__dict__) 821 822 def test_missing_default_factory(self): 823 # Test that MISSING works the same as a default factory not 824 # being specified (which is really the same as a default not 825 # being specified, too). 826 @dataclass 827 class C: 828 x: int=field(default_factory=MISSING) 829 with self.assertRaisesRegex(TypeError, 830 r'__init__\(\) missing 1 required ' 831 'positional argument'): 832 C() 833 self.assertNotIn('x', C.__dict__) 834 835 @dataclass 836 class D: 837 x: int=field(default=MISSING, default_factory=MISSING) 838 with self.assertRaisesRegex(TypeError, 839 r'__init__\(\) missing 1 required ' 840 'positional argument'): 841 D() 842 self.assertNotIn('x', D.__dict__) 843 844 def test_missing_repr(self): 845 self.assertIn('MISSING_TYPE object', repr(MISSING)) 846 847 def test_dont_include_other_annotations(self): 848 @dataclass 849 class C: 850 i: int 851 def foo(self) -> int: 852 return 4 853 @property 854 def bar(self) -> int: 855 return 5 856 self.assertEqual(list(C.__annotations__), ['i']) 857 self.assertEqual(C(10).foo(), 4) 858 self.assertEqual(C(10).bar, 5) 859 self.assertEqual(C(10).i, 10) 860 861 def test_post_init(self): 862 # Just make sure it gets called 863 @dataclass 864 class C: 865 def __post_init__(self): 866 raise CustomError() 867 with self.assertRaises(CustomError): 868 C() 869 870 @dataclass 871 class C: 872 i: int = 10 873 def __post_init__(self): 874 if self.i == 10: 875 raise CustomError() 876 with self.assertRaises(CustomError): 877 C() 878 # post-init gets called, but doesn't raise. This is just 879 # checking that self is used correctly. 880 C(5) 881 882 # If there's not an __init__, then post-init won't get called. 883 @dataclass(init=False) 884 class C: 885 def __post_init__(self): 886 raise CustomError() 887 # Creating the class won't raise 888 C() 889 890 @dataclass 891 class C: 892 x: int = 0 893 def __post_init__(self): 894 self.x *= 2 895 self.assertEqual(C().x, 0) 896 self.assertEqual(C(2).x, 4) 897 898 # Make sure that if we're frozen, post-init can't set 899 # attributes. 900 @dataclass(frozen=True) 901 class C: 902 x: int = 0 903 def __post_init__(self): 904 self.x *= 2 905 with self.assertRaises(FrozenInstanceError): 906 C() 907 908 def test_post_init_super(self): 909 # Make sure super() post-init isn't called by default. 910 class B: 911 def __post_init__(self): 912 raise CustomError() 913 914 @dataclass 915 class C(B): 916 def __post_init__(self): 917 self.x = 5 918 919 self.assertEqual(C().x, 5) 920 921 # Now call super(), and it will raise. 922 @dataclass 923 class C(B): 924 def __post_init__(self): 925 super().__post_init__() 926 927 with self.assertRaises(CustomError): 928 C() 929 930 # Make sure post-init is called, even if not defined in our 931 # class. 932 @dataclass 933 class C(B): 934 pass 935 936 with self.assertRaises(CustomError): 937 C() 938 939 def test_post_init_staticmethod(self): 940 flag = False 941 @dataclass 942 class C: 943 x: int 944 y: int 945 @staticmethod 946 def __post_init__(): 947 nonlocal flag 948 flag = True 949 950 self.assertFalse(flag) 951 c = C(3, 4) 952 self.assertEqual((c.x, c.y), (3, 4)) 953 self.assertTrue(flag) 954 955 def test_post_init_classmethod(self): 956 @dataclass 957 class C: 958 flag = False 959 x: int 960 y: int 961 @classmethod 962 def __post_init__(cls): 963 cls.flag = True 964 965 self.assertFalse(C.flag) 966 c = C(3, 4) 967 self.assertEqual((c.x, c.y), (3, 4)) 968 self.assertTrue(C.flag) 969 970 def test_class_var(self): 971 # Make sure ClassVars are ignored in __init__, __repr__, etc. 972 @dataclass 973 class C: 974 x: int 975 y: int = 10 976 z: ClassVar[int] = 1000 977 w: ClassVar[int] = 2000 978 t: ClassVar[int] = 3000 979 s: ClassVar = 4000 980 981 c = C(5) 982 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)') 983 self.assertEqual(len(fields(C)), 2) # We have 2 fields. 984 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars. 985 self.assertEqual(c.z, 1000) 986 self.assertEqual(c.w, 2000) 987 self.assertEqual(c.t, 3000) 988 self.assertEqual(c.s, 4000) 989 C.z += 1 990 self.assertEqual(c.z, 1001) 991 c = C(20) 992 self.assertEqual((c.x, c.y), (20, 10)) 993 self.assertEqual(c.z, 1001) 994 self.assertEqual(c.w, 2000) 995 self.assertEqual(c.t, 3000) 996 self.assertEqual(c.s, 4000) 997 998 def test_class_var_no_default(self): 999 # If a ClassVar has no default value, it should not be set on the class. 1000 @dataclass 1001 class C: 1002 x: ClassVar[int] 1003 1004 self.assertNotIn('x', C.__dict__) 1005 1006 def test_class_var_default_factory(self): 1007 # It makes no sense for a ClassVar to have a default factory. When 1008 # would it be called? Call it yourself, since it's class-wide. 1009 with self.assertRaisesRegex(TypeError, 1010 'cannot have a default factory'): 1011 @dataclass 1012 class C: 1013 x: ClassVar[int] = field(default_factory=int) 1014 1015 self.assertNotIn('x', C.__dict__) 1016 1017 def test_class_var_with_default(self): 1018 # If a ClassVar has a default value, it should be set on the class. 1019 @dataclass 1020 class C: 1021 x: ClassVar[int] = 10 1022 self.assertEqual(C.x, 10) 1023 1024 @dataclass 1025 class C: 1026 x: ClassVar[int] = field(default=10) 1027 self.assertEqual(C.x, 10) 1028 1029 def test_class_var_frozen(self): 1030 # Make sure ClassVars work even if we're frozen. 1031 @dataclass(frozen=True) 1032 class C: 1033 x: int 1034 y: int = 10 1035 z: ClassVar[int] = 1000 1036 w: ClassVar[int] = 2000 1037 t: ClassVar[int] = 3000 1038 1039 c = C(5) 1040 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)') 1041 self.assertEqual(len(fields(C)), 2) # We have 2 fields 1042 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars 1043 self.assertEqual(c.z, 1000) 1044 self.assertEqual(c.w, 2000) 1045 self.assertEqual(c.t, 3000) 1046 # We can still modify the ClassVar, it's only instances that are 1047 # frozen. 1048 C.z += 1 1049 self.assertEqual(c.z, 1001) 1050 c = C(20) 1051 self.assertEqual((c.x, c.y), (20, 10)) 1052 self.assertEqual(c.z, 1001) 1053 self.assertEqual(c.w, 2000) 1054 self.assertEqual(c.t, 3000) 1055 1056 def test_init_var_no_default(self): 1057 # If an InitVar has no default value, it should not be set on the class. 1058 @dataclass 1059 class C: 1060 x: InitVar[int] 1061 1062 self.assertNotIn('x', C.__dict__) 1063 1064 def test_init_var_default_factory(self): 1065 # It makes no sense for an InitVar to have a default factory. When 1066 # would it be called? Call it yourself, since it's class-wide. 1067 with self.assertRaisesRegex(TypeError, 1068 'cannot have a default factory'): 1069 @dataclass 1070 class C: 1071 x: InitVar[int] = field(default_factory=int) 1072 1073 self.assertNotIn('x', C.__dict__) 1074 1075 def test_init_var_with_default(self): 1076 # If an InitVar has a default value, it should be set on the class. 1077 @dataclass 1078 class C: 1079 x: InitVar[int] = 10 1080 self.assertEqual(C.x, 10) 1081 1082 @dataclass 1083 class C: 1084 x: InitVar[int] = field(default=10) 1085 self.assertEqual(C.x, 10) 1086 1087 def test_init_var(self): 1088 @dataclass 1089 class C: 1090 x: int = None 1091 init_param: InitVar[int] = None 1092 1093 def __post_init__(self, init_param): 1094 if self.x is None: 1095 self.x = init_param*2 1096 1097 c = C(init_param=10) 1098 self.assertEqual(c.x, 20) 1099 1100 def test_init_var_inheritance(self): 1101 # Note that this deliberately tests that a dataclass need not 1102 # have a __post_init__ function if it has an InitVar field. 1103 # It could just be used in a derived class, as shown here. 1104 @dataclass 1105 class Base: 1106 x: int 1107 init_base: InitVar[int] 1108 1109 # We can instantiate by passing the InitVar, even though 1110 # it's not used. 1111 b = Base(0, 10) 1112 self.assertEqual(vars(b), {'x': 0}) 1113 1114 @dataclass 1115 class C(Base): 1116 y: int 1117 init_derived: InitVar[int] 1118 1119 def __post_init__(self, init_base, init_derived): 1120 self.x = self.x + init_base 1121 self.y = self.y + init_derived 1122 1123 c = C(10, 11, 50, 51) 1124 self.assertEqual(vars(c), {'x': 21, 'y': 101}) 1125 1126 def test_default_factory(self): 1127 # Test a factory that returns a new list. 1128 @dataclass 1129 class C: 1130 x: int 1131 y: list = field(default_factory=list) 1132 1133 c0 = C(3) 1134 c1 = C(3) 1135 self.assertEqual(c0.x, 3) 1136 self.assertEqual(c0.y, []) 1137 self.assertEqual(c0, c1) 1138 self.assertIsNot(c0.y, c1.y) 1139 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1140 1141 # Test a factory that returns a shared list. 1142 l = [] 1143 @dataclass 1144 class C: 1145 x: int 1146 y: list = field(default_factory=lambda: l) 1147 1148 c0 = C(3) 1149 c1 = C(3) 1150 self.assertEqual(c0.x, 3) 1151 self.assertEqual(c0.y, []) 1152 self.assertEqual(c0, c1) 1153 self.assertIs(c0.y, c1.y) 1154 self.assertEqual(astuple(C(5, [1])), (5, [1])) 1155 1156 # Test various other field flags. 1157 # repr 1158 @dataclass 1159 class C: 1160 x: list = field(default_factory=list, repr=False) 1161 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()') 1162 self.assertEqual(C().x, []) 1163 1164 # hash 1165 @dataclass(unsafe_hash=True) 1166 class C: 1167 x: list = field(default_factory=list, hash=False) 1168 self.assertEqual(astuple(C()), ([],)) 1169 self.assertEqual(hash(C()), hash(())) 1170 1171 # init (see also test_default_factory_with_no_init) 1172 @dataclass 1173 class C: 1174 x: list = field(default_factory=list, init=False) 1175 self.assertEqual(astuple(C()), ([],)) 1176 1177 # compare 1178 @dataclass 1179 class C: 1180 x: list = field(default_factory=list, compare=False) 1181 self.assertEqual(C(), C([1])) 1182 1183 def test_default_factory_with_no_init(self): 1184 # We need a factory with a side effect. 1185 factory = Mock() 1186 1187 @dataclass 1188 class C: 1189 x: list = field(default_factory=factory, init=False) 1190 1191 # Make sure the default factory is called for each new instance. 1192 C().x 1193 self.assertEqual(factory.call_count, 1) 1194 C().x 1195 self.assertEqual(factory.call_count, 2) 1196 1197 def test_default_factory_not_called_if_value_given(self): 1198 # We need a factory that we can test if it's been called. 1199 factory = Mock() 1200 1201 @dataclass 1202 class C: 1203 x: int = field(default_factory=factory) 1204 1205 # Make sure that if a field has a default factory function, 1206 # it's not called if a value is specified. 1207 C().x 1208 self.assertEqual(factory.call_count, 1) 1209 self.assertEqual(C(10).x, 10) 1210 self.assertEqual(factory.call_count, 1) 1211 C().x 1212 self.assertEqual(factory.call_count, 2) 1213 1214 def test_default_factory_derived(self): 1215 # See bpo-32896. 1216 @dataclass 1217 class Foo: 1218 x: dict = field(default_factory=dict) 1219 1220 @dataclass 1221 class Bar(Foo): 1222 y: int = 1 1223 1224 self.assertEqual(Foo().x, {}) 1225 self.assertEqual(Bar().x, {}) 1226 self.assertEqual(Bar().y, 1) 1227 1228 @dataclass 1229 class Baz(Foo): 1230 pass 1231 self.assertEqual(Baz().x, {}) 1232 1233 def test_intermediate_non_dataclass(self): 1234 # Test that an intermediate class that defines 1235 # annotations does not define fields. 1236 1237 @dataclass 1238 class A: 1239 x: int 1240 1241 class B(A): 1242 y: int 1243 1244 @dataclass 1245 class C(B): 1246 z: int 1247 1248 c = C(1, 3) 1249 self.assertEqual((c.x, c.z), (1, 3)) 1250 1251 # .y was not initialized. 1252 with self.assertRaisesRegex(AttributeError, 1253 'object has no attribute'): 1254 c.y 1255 1256 # And if we again derive a non-dataclass, no fields are added. 1257 class D(C): 1258 t: int 1259 d = D(4, 5) 1260 self.assertEqual((d.x, d.z), (4, 5)) 1261 1262 def test_classvar_default_factory(self): 1263 # It's an error for a ClassVar to have a factory function. 1264 with self.assertRaisesRegex(TypeError, 1265 'cannot have a default factory'): 1266 @dataclass 1267 class C: 1268 x: ClassVar[int] = field(default_factory=int) 1269 1270 def test_is_dataclass(self): 1271 class NotDataClass: 1272 pass 1273 1274 self.assertFalse(is_dataclass(0)) 1275 self.assertFalse(is_dataclass(int)) 1276 self.assertFalse(is_dataclass(NotDataClass)) 1277 self.assertFalse(is_dataclass(NotDataClass())) 1278 1279 @dataclass 1280 class C: 1281 x: int 1282 1283 @dataclass 1284 class D: 1285 d: C 1286 e: int 1287 1288 c = C(10) 1289 d = D(c, 4) 1290 1291 self.assertTrue(is_dataclass(C)) 1292 self.assertTrue(is_dataclass(c)) 1293 self.assertFalse(is_dataclass(c.x)) 1294 self.assertTrue(is_dataclass(d.d)) 1295 self.assertFalse(is_dataclass(d.e)) 1296 1297 def test_helper_fields_with_class_instance(self): 1298 # Check that we can call fields() on either a class or instance, 1299 # and get back the same thing. 1300 @dataclass 1301 class C: 1302 x: int 1303 y: float 1304 1305 self.assertEqual(fields(C), fields(C(0, 0.0))) 1306 1307 def test_helper_fields_exception(self): 1308 # Check that TypeError is raised if not passed a dataclass or 1309 # instance. 1310 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1311 fields(0) 1312 1313 class C: pass 1314 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1315 fields(C) 1316 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'): 1317 fields(C()) 1318 1319 def test_helper_asdict(self): 1320 # Basic tests for asdict(), it should return a new dictionary. 1321 @dataclass 1322 class C: 1323 x: int 1324 y: int 1325 c = C(1, 2) 1326 1327 self.assertEqual(asdict(c), {'x': 1, 'y': 2}) 1328 self.assertEqual(asdict(c), asdict(c)) 1329 self.assertIsNot(asdict(c), asdict(c)) 1330 c.x = 42 1331 self.assertEqual(asdict(c), {'x': 42, 'y': 2}) 1332 self.assertIs(type(asdict(c)), dict) 1333 1334 def test_helper_asdict_raises_on_classes(self): 1335 # asdict() should raise on a class object. 1336 @dataclass 1337 class C: 1338 x: int 1339 y: int 1340 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1341 asdict(C) 1342 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1343 asdict(int) 1344 1345 def test_helper_asdict_copy_values(self): 1346 @dataclass 1347 class C: 1348 x: int 1349 y: List[int] = field(default_factory=list) 1350 initial = [] 1351 c = C(1, initial) 1352 d = asdict(c) 1353 self.assertEqual(d['y'], initial) 1354 self.assertIsNot(d['y'], initial) 1355 c = C(1) 1356 d = asdict(c) 1357 d['y'].append(1) 1358 self.assertEqual(c.y, []) 1359 1360 def test_helper_asdict_nested(self): 1361 @dataclass 1362 class UserId: 1363 token: int 1364 group: int 1365 @dataclass 1366 class User: 1367 name: str 1368 id: UserId 1369 u = User('Joe', UserId(123, 1)) 1370 d = asdict(u) 1371 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}}) 1372 self.assertIsNot(asdict(u), asdict(u)) 1373 u.id.group = 2 1374 self.assertEqual(asdict(u), {'name': 'Joe', 1375 'id': {'token': 123, 'group': 2}}) 1376 1377 def test_helper_asdict_builtin_containers(self): 1378 @dataclass 1379 class User: 1380 name: str 1381 id: int 1382 @dataclass 1383 class GroupList: 1384 id: int 1385 users: List[User] 1386 @dataclass 1387 class GroupTuple: 1388 id: int 1389 users: Tuple[User, ...] 1390 @dataclass 1391 class GroupDict: 1392 id: int 1393 users: Dict[str, User] 1394 a = User('Alice', 1) 1395 b = User('Bob', 2) 1396 gl = GroupList(0, [a, b]) 1397 gt = GroupTuple(0, (a, b)) 1398 gd = GroupDict(0, {'first': a, 'second': b}) 1399 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1}, 1400 {'name': 'Bob', 'id': 2}]}) 1401 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1}, 1402 {'name': 'Bob', 'id': 2})}) 1403 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1}, 1404 'second': {'name': 'Bob', 'id': 2}}}) 1405 1406 def test_helper_asdict_builtin_containers(self): 1407 @dataclass 1408 class Child: 1409 d: object 1410 1411 @dataclass 1412 class Parent: 1413 child: Child 1414 1415 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}}) 1416 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}}) 1417 1418 def test_helper_asdict_factory(self): 1419 @dataclass 1420 class C: 1421 x: int 1422 y: int 1423 c = C(1, 2) 1424 d = asdict(c, dict_factory=OrderedDict) 1425 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)])) 1426 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict)) 1427 c.x = 42 1428 d = asdict(c, dict_factory=OrderedDict) 1429 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)])) 1430 self.assertIs(type(d), OrderedDict) 1431 1432 def test_helper_asdict_namedtuple(self): 1433 T = namedtuple('T', 'a b c') 1434 @dataclass 1435 class C: 1436 x: str 1437 y: T 1438 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1439 1440 d = asdict(c) 1441 self.assertEqual(d, {'x': 'outer', 1442 'y': T(1, 1443 {'x': 'inner', 1444 'y': T(11, 12, 13)}, 1445 2), 1446 } 1447 ) 1448 1449 # Now with a dict_factory. OrderedDict is convenient, but 1450 # since it compares to dicts, we also need to have separate 1451 # assertIs tests. 1452 d = asdict(c, dict_factory=OrderedDict) 1453 self.assertEqual(d, {'x': 'outer', 1454 'y': T(1, 1455 {'x': 'inner', 1456 'y': T(11, 12, 13)}, 1457 2), 1458 } 1459 ) 1460 1461 # Make sure that the returned dicts are actuall OrderedDicts. 1462 self.assertIs(type(d), OrderedDict) 1463 self.assertIs(type(d['y'][1]), OrderedDict) 1464 1465 def test_helper_asdict_namedtuple_key(self): 1466 # Ensure that a field that contains a dict which has a 1467 # namedtuple as a key works with asdict(). 1468 1469 @dataclass 1470 class C: 1471 f: dict 1472 T = namedtuple('T', 'a') 1473 1474 c = C({T('an a'): 0}) 1475 1476 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}}) 1477 1478 def test_helper_asdict_namedtuple_derived(self): 1479 class T(namedtuple('Tbase', 'a')): 1480 def my_a(self): 1481 return self.a 1482 1483 @dataclass 1484 class C: 1485 f: T 1486 1487 t = T(6) 1488 c = C(t) 1489 1490 d = asdict(c) 1491 self.assertEqual(d, {'f': T(a=6)}) 1492 # Make sure that t has been copied, not used directly. 1493 self.assertIsNot(d['f'], t) 1494 self.assertEqual(d['f'].my_a(), 6) 1495 1496 def test_helper_astuple(self): 1497 # Basic tests for astuple(), it should return a new tuple. 1498 @dataclass 1499 class C: 1500 x: int 1501 y: int = 0 1502 c = C(1) 1503 1504 self.assertEqual(astuple(c), (1, 0)) 1505 self.assertEqual(astuple(c), astuple(c)) 1506 self.assertIsNot(astuple(c), astuple(c)) 1507 c.y = 42 1508 self.assertEqual(astuple(c), (1, 42)) 1509 self.assertIs(type(astuple(c)), tuple) 1510 1511 def test_helper_astuple_raises_on_classes(self): 1512 # astuple() should raise on a class object. 1513 @dataclass 1514 class C: 1515 x: int 1516 y: int 1517 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1518 astuple(C) 1519 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 1520 astuple(int) 1521 1522 def test_helper_astuple_copy_values(self): 1523 @dataclass 1524 class C: 1525 x: int 1526 y: List[int] = field(default_factory=list) 1527 initial = [] 1528 c = C(1, initial) 1529 t = astuple(c) 1530 self.assertEqual(t[1], initial) 1531 self.assertIsNot(t[1], initial) 1532 c = C(1) 1533 t = astuple(c) 1534 t[1].append(1) 1535 self.assertEqual(c.y, []) 1536 1537 def test_helper_astuple_nested(self): 1538 @dataclass 1539 class UserId: 1540 token: int 1541 group: int 1542 @dataclass 1543 class User: 1544 name: str 1545 id: UserId 1546 u = User('Joe', UserId(123, 1)) 1547 t = astuple(u) 1548 self.assertEqual(t, ('Joe', (123, 1))) 1549 self.assertIsNot(astuple(u), astuple(u)) 1550 u.id.group = 2 1551 self.assertEqual(astuple(u), ('Joe', (123, 2))) 1552 1553 def test_helper_astuple_builtin_containers(self): 1554 @dataclass 1555 class User: 1556 name: str 1557 id: int 1558 @dataclass 1559 class GroupList: 1560 id: int 1561 users: List[User] 1562 @dataclass 1563 class GroupTuple: 1564 id: int 1565 users: Tuple[User, ...] 1566 @dataclass 1567 class GroupDict: 1568 id: int 1569 users: Dict[str, User] 1570 a = User('Alice', 1) 1571 b = User('Bob', 2) 1572 gl = GroupList(0, [a, b]) 1573 gt = GroupTuple(0, (a, b)) 1574 gd = GroupDict(0, {'first': a, 'second': b}) 1575 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)])) 1576 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2)))) 1577 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)})) 1578 1579 def test_helper_astuple_builtin_containers(self): 1580 @dataclass 1581 class Child: 1582 d: object 1583 1584 @dataclass 1585 class Parent: 1586 child: Child 1587 1588 self.assertEqual(astuple(Parent(Child([1]))), (([1],),)) 1589 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),)) 1590 1591 def test_helper_astuple_factory(self): 1592 @dataclass 1593 class C: 1594 x: int 1595 y: int 1596 NT = namedtuple('NT', 'x y') 1597 def nt(lst): 1598 return NT(*lst) 1599 c = C(1, 2) 1600 t = astuple(c, tuple_factory=nt) 1601 self.assertEqual(t, NT(1, 2)) 1602 self.assertIsNot(t, astuple(c, tuple_factory=nt)) 1603 c.x = 42 1604 t = astuple(c, tuple_factory=nt) 1605 self.assertEqual(t, NT(42, 2)) 1606 self.assertIs(type(t), NT) 1607 1608 def test_helper_astuple_namedtuple(self): 1609 T = namedtuple('T', 'a b c') 1610 @dataclass 1611 class C: 1612 x: str 1613 y: T 1614 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2)) 1615 1616 t = astuple(c) 1617 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2))) 1618 1619 # Now, using a tuple_factory. list is convenient here. 1620 t = astuple(c, tuple_factory=list) 1621 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)]) 1622 1623 def test_dynamic_class_creation(self): 1624 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1625 } 1626 1627 # Create the class. 1628 cls = type('C', (), cls_dict) 1629 1630 # Make it a dataclass. 1631 cls1 = dataclass(cls) 1632 1633 self.assertEqual(cls1, cls) 1634 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2}) 1635 1636 def test_dynamic_class_creation_using_field(self): 1637 cls_dict = {'__annotations__': {'x': int, 'y': int}, 1638 'y': field(default=5), 1639 } 1640 1641 # Create the class. 1642 cls = type('C', (), cls_dict) 1643 1644 # Make it a dataclass. 1645 cls1 = dataclass(cls) 1646 1647 self.assertEqual(cls1, cls) 1648 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5}) 1649 1650 def test_init_in_order(self): 1651 @dataclass 1652 class C: 1653 a: int 1654 b: int = field() 1655 c: list = field(default_factory=list, init=False) 1656 d: list = field(default_factory=list) 1657 e: int = field(default=4, init=False) 1658 f: int = 4 1659 1660 calls = [] 1661 def setattr(self, name, value): 1662 calls.append((name, value)) 1663 1664 C.__setattr__ = setattr 1665 c = C(0, 1) 1666 self.assertEqual(('a', 0), calls[0]) 1667 self.assertEqual(('b', 1), calls[1]) 1668 self.assertEqual(('c', []), calls[2]) 1669 self.assertEqual(('d', []), calls[3]) 1670 self.assertNotIn(('e', 4), calls) 1671 self.assertEqual(('f', 4), calls[4]) 1672 1673 def test_items_in_dicts(self): 1674 @dataclass 1675 class C: 1676 a: int 1677 b: list = field(default_factory=list, init=False) 1678 c: list = field(default_factory=list) 1679 d: int = field(default=4, init=False) 1680 e: int = 0 1681 1682 c = C(0) 1683 # Class dict 1684 self.assertNotIn('a', C.__dict__) 1685 self.assertNotIn('b', C.__dict__) 1686 self.assertNotIn('c', C.__dict__) 1687 self.assertIn('d', C.__dict__) 1688 self.assertEqual(C.d, 4) 1689 self.assertIn('e', C.__dict__) 1690 self.assertEqual(C.e, 0) 1691 # Instance dict 1692 self.assertIn('a', c.__dict__) 1693 self.assertEqual(c.a, 0) 1694 self.assertIn('b', c.__dict__) 1695 self.assertEqual(c.b, []) 1696 self.assertIn('c', c.__dict__) 1697 self.assertEqual(c.c, []) 1698 self.assertNotIn('d', c.__dict__) 1699 self.assertIn('e', c.__dict__) 1700 self.assertEqual(c.e, 0) 1701 1702 def test_alternate_classmethod_constructor(self): 1703 # Since __post_init__ can't take params, use a classmethod 1704 # alternate constructor. This is mostly an example to show 1705 # how to use this technique. 1706 @dataclass 1707 class C: 1708 x: int 1709 @classmethod 1710 def from_file(cls, filename): 1711 # In a real example, create a new instance 1712 # and populate 'x' from contents of a file. 1713 value_in_file = 20 1714 return cls(value_in_file) 1715 1716 self.assertEqual(C.from_file('filename').x, 20) 1717 1718 def test_field_metadata_default(self): 1719 # Make sure the default metadata is read-only and of 1720 # zero length. 1721 @dataclass 1722 class C: 1723 i: int 1724 1725 self.assertFalse(fields(C)[0].metadata) 1726 self.assertEqual(len(fields(C)[0].metadata), 0) 1727 with self.assertRaisesRegex(TypeError, 1728 'does not support item assignment'): 1729 fields(C)[0].metadata['test'] = 3 1730 1731 def test_field_metadata_mapping(self): 1732 # Make sure only a mapping can be passed as metadata 1733 # zero length. 1734 with self.assertRaises(TypeError): 1735 @dataclass 1736 class C: 1737 i: int = field(metadata=0) 1738 1739 # Make sure an empty dict works. 1740 d = {} 1741 @dataclass 1742 class C: 1743 i: int = field(metadata=d) 1744 self.assertFalse(fields(C)[0].metadata) 1745 self.assertEqual(len(fields(C)[0].metadata), 0) 1746 # Update should work (see bpo-35960). 1747 d['foo'] = 1 1748 self.assertEqual(len(fields(C)[0].metadata), 1) 1749 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1750 with self.assertRaisesRegex(TypeError, 1751 'does not support item assignment'): 1752 fields(C)[0].metadata['test'] = 3 1753 1754 # Make sure a non-empty dict works. 1755 d = {'test': 10, 'bar': '42', 3: 'three'} 1756 @dataclass 1757 class C: 1758 i: int = field(metadata=d) 1759 self.assertEqual(len(fields(C)[0].metadata), 3) 1760 self.assertEqual(fields(C)[0].metadata['test'], 10) 1761 self.assertEqual(fields(C)[0].metadata['bar'], '42') 1762 self.assertEqual(fields(C)[0].metadata[3], 'three') 1763 # Update should work. 1764 d['foo'] = 1 1765 self.assertEqual(len(fields(C)[0].metadata), 4) 1766 self.assertEqual(fields(C)[0].metadata['foo'], 1) 1767 with self.assertRaises(KeyError): 1768 # Non-existent key. 1769 fields(C)[0].metadata['baz'] 1770 with self.assertRaisesRegex(TypeError, 1771 'does not support item assignment'): 1772 fields(C)[0].metadata['test'] = 3 1773 1774 def test_field_metadata_custom_mapping(self): 1775 # Try a custom mapping. 1776 class SimpleNameSpace: 1777 def __init__(self, **kw): 1778 self.__dict__.update(kw) 1779 1780 def __getitem__(self, item): 1781 if item == 'xyzzy': 1782 return 'plugh' 1783 return getattr(self, item) 1784 1785 def __len__(self): 1786 return self.__dict__.__len__() 1787 1788 @dataclass 1789 class C: 1790 i: int = field(metadata=SimpleNameSpace(a=10)) 1791 1792 self.assertEqual(len(fields(C)[0].metadata), 1) 1793 self.assertEqual(fields(C)[0].metadata['a'], 10) 1794 with self.assertRaises(AttributeError): 1795 fields(C)[0].metadata['b'] 1796 # Make sure we're still talking to our custom mapping. 1797 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh') 1798 1799 def test_generic_dataclasses(self): 1800 T = TypeVar('T') 1801 1802 @dataclass 1803 class LabeledBox(Generic[T]): 1804 content: T 1805 label: str = '<unknown>' 1806 1807 box = LabeledBox(42) 1808 self.assertEqual(box.content, 42) 1809 self.assertEqual(box.label, '<unknown>') 1810 1811 # Subscripting the resulting class should work, etc. 1812 Alias = List[LabeledBox[int]] 1813 1814 def test_generic_extending(self): 1815 S = TypeVar('S') 1816 T = TypeVar('T') 1817 1818 @dataclass 1819 class Base(Generic[T, S]): 1820 x: T 1821 y: S 1822 1823 @dataclass 1824 class DataDerived(Base[int, T]): 1825 new_field: str 1826 Alias = DataDerived[str] 1827 c = Alias(0, 'test1', 'test2') 1828 self.assertEqual(astuple(c), (0, 'test1', 'test2')) 1829 1830 class NonDataDerived(Base[int, T]): 1831 def new_method(self): 1832 return self.y 1833 Alias = NonDataDerived[float] 1834 c = Alias(10, 1.0) 1835 self.assertEqual(c.new_method(), 1.0) 1836 1837 def test_generic_dynamic(self): 1838 T = TypeVar('T') 1839 1840 @dataclass 1841 class Parent(Generic[T]): 1842 x: T 1843 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)], 1844 bases=(Parent[int], Generic[T]), namespace={'other': 42}) 1845 self.assertIs(Child[int](1, 2).z, None) 1846 self.assertEqual(Child[int](1, 2, 3).z, 3) 1847 self.assertEqual(Child[int](1, 2, 3).other, 42) 1848 # Check that type aliases work correctly. 1849 Alias = Child[T] 1850 self.assertEqual(Alias[int](1, 2).x, 1) 1851 # Check MRO resolution. 1852 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object)) 1853 1854 def test_dataclassses_pickleable(self): 1855 global P, Q, R 1856 @dataclass 1857 class P: 1858 x: int 1859 y: int = 0 1860 @dataclass 1861 class Q: 1862 x: int 1863 y: int = field(default=0, init=False) 1864 @dataclass 1865 class R: 1866 x: int 1867 y: List[int] = field(default_factory=list) 1868 q = Q(1) 1869 q.y = 2 1870 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])] 1871 for sample in samples: 1872 for proto in range(pickle.HIGHEST_PROTOCOL + 1): 1873 with self.subTest(sample=sample, proto=proto): 1874 new_sample = pickle.loads(pickle.dumps(sample, proto)) 1875 self.assertEqual(sample.x, new_sample.x) 1876 self.assertEqual(sample.y, new_sample.y) 1877 self.assertIsNot(sample, new_sample) 1878 new_sample.x = 42 1879 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto)) 1880 self.assertEqual(new_sample.x, another_new_sample.x) 1881 self.assertEqual(sample.y, another_new_sample.y) 1882 1883 1884 class TestFieldNoAnnotation(unittest.TestCase): 1885 def test_field_without_annotation(self): 1886 with self.assertRaisesRegex(TypeError, 1887 "'f' is a field but has no type annotation"): 1888 @dataclass 1889 class C: 1890 f = field() 1891 1892 def test_field_without_annotation_but_annotation_in_base(self): 1893 @dataclass 1894 class B: 1895 f: int 1896 1897 with self.assertRaisesRegex(TypeError, 1898 "'f' is a field but has no type annotation"): 1899 # This is still an error: make sure we don't pick up the 1900 # type annotation in the base class. 1901 @dataclass 1902 class C(B): 1903 f = field() 1904 1905 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self): 1906 # Same test, but with the base class not a dataclass. 1907 class B: 1908 f: int 1909 1910 with self.assertRaisesRegex(TypeError, 1911 "'f' is a field but has no type annotation"): 1912 # This is still an error: make sure we don't pick up the 1913 # type annotation in the base class. 1914 @dataclass 1915 class C(B): 1916 f = field() 1917 1918 1919 class TestDocString(unittest.TestCase): 1920 def assertDocStrEqual(self, a, b): 1921 # Because 3.6 and 3.7 differ in how inspect.signature work 1922 # (see bpo #32108), for the time being just compare them with 1923 # whitespace stripped. 1924 self.assertEqual(a.replace(' ', ''), b.replace(' ', '')) 1925 1926 def test_existing_docstring_not_overridden(self): 1927 @dataclass 1928 class C: 1929 """Lorem ipsum""" 1930 x: int 1931 1932 self.assertEqual(C.__doc__, "Lorem ipsum") 1933 1934 def test_docstring_no_fields(self): 1935 @dataclass 1936 class C: 1937 pass 1938 1939 self.assertDocStrEqual(C.__doc__, "C()") 1940 1941 def test_docstring_one_field(self): 1942 @dataclass 1943 class C: 1944 x: int 1945 1946 self.assertDocStrEqual(C.__doc__, "C(x:int)") 1947 1948 def test_docstring_two_fields(self): 1949 @dataclass 1950 class C: 1951 x: int 1952 y: int 1953 1954 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)") 1955 1956 def test_docstring_three_fields(self): 1957 @dataclass 1958 class C: 1959 x: int 1960 y: int 1961 z: str 1962 1963 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)") 1964 1965 def test_docstring_one_field_with_default(self): 1966 @dataclass 1967 class C: 1968 x: int = 3 1969 1970 self.assertDocStrEqual(C.__doc__, "C(x:int=3)") 1971 1972 def test_docstring_one_field_with_default_none(self): 1973 @dataclass 1974 class C: 1975 x: Union[int, type(None)] = None 1976 1977 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)") 1978 1979 def test_docstring_list_field(self): 1980 @dataclass 1981 class C: 1982 x: List[int] 1983 1984 self.assertDocStrEqual(C.__doc__, "C(x:List[int])") 1985 1986 def test_docstring_list_field_with_default_factory(self): 1987 @dataclass 1988 class C: 1989 x: List[int] = field(default_factory=list) 1990 1991 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)") 1992 1993 def test_docstring_deque_field(self): 1994 @dataclass 1995 class C: 1996 x: deque 1997 1998 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)") 1999 2000 def test_docstring_deque_field_with_default_factory(self): 2001 @dataclass 2002 class C: 2003 x: deque = field(default_factory=deque) 2004 2005 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)") 2006 2007 2008 class TestInit(unittest.TestCase): 2009 def test_base_has_init(self): 2010 class B: 2011 def __init__(self): 2012 self.z = 100 2013 pass 2014 2015 # Make sure that declaring this class doesn't raise an error. 2016 # The issue is that we can't override __init__ in our class, 2017 # but it should be okay to add __init__ to us if our base has 2018 # an __init__. 2019 @dataclass 2020 class C(B): 2021 x: int = 0 2022 c = C(10) 2023 self.assertEqual(c.x, 10) 2024 self.assertNotIn('z', vars(c)) 2025 2026 # Make sure that if we don't add an init, the base __init__ 2027 # gets called. 2028 @dataclass(init=False) 2029 class C(B): 2030 x: int = 10 2031 c = C() 2032 self.assertEqual(c.x, 10) 2033 self.assertEqual(c.z, 100) 2034 2035 def test_no_init(self): 2036 dataclass(init=False) 2037 class C: 2038 i: int = 0 2039 self.assertEqual(C().i, 0) 2040 2041 dataclass(init=False) 2042 class C: 2043 i: int = 2 2044 def __init__(self): 2045 self.i = 3 2046 self.assertEqual(C().i, 3) 2047 2048 def test_overwriting_init(self): 2049 # If the class has __init__, use it no matter the value of 2050 # init=. 2051 2052 @dataclass 2053 class C: 2054 x: int 2055 def __init__(self, x): 2056 self.x = 2 * x 2057 self.assertEqual(C(3).x, 6) 2058 2059 @dataclass(init=True) 2060 class C: 2061 x: int 2062 def __init__(self, x): 2063 self.x = 2 * x 2064 self.assertEqual(C(4).x, 8) 2065 2066 @dataclass(init=False) 2067 class C: 2068 x: int 2069 def __init__(self, x): 2070 self.x = 2 * x 2071 self.assertEqual(C(5).x, 10) 2072 2073 2074 class TestRepr(unittest.TestCase): 2075 def test_repr(self): 2076 @dataclass 2077 class B: 2078 x: int 2079 2080 @dataclass 2081 class C(B): 2082 y: int = 10 2083 2084 o = C(4) 2085 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)') 2086 2087 @dataclass 2088 class D(C): 2089 x: int = 20 2090 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)') 2091 2092 @dataclass 2093 class C: 2094 @dataclass 2095 class D: 2096 i: int 2097 @dataclass 2098 class E: 2099 pass 2100 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)') 2101 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()') 2102 2103 def test_no_repr(self): 2104 # Test a class with no __repr__ and repr=False. 2105 @dataclass(repr=False) 2106 class C: 2107 x: int 2108 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at', 2109 repr(C(3))) 2110 2111 # Test a class with a __repr__ and repr=False. 2112 @dataclass(repr=False) 2113 class C: 2114 x: int 2115 def __repr__(self): 2116 return 'C-class' 2117 self.assertEqual(repr(C(3)), 'C-class') 2118 2119 def test_overwriting_repr(self): 2120 # If the class has __repr__, use it no matter the value of 2121 # repr=. 2122 2123 @dataclass 2124 class C: 2125 x: int 2126 def __repr__(self): 2127 return 'x' 2128 self.assertEqual(repr(C(0)), 'x') 2129 2130 @dataclass(repr=True) 2131 class C: 2132 x: int 2133 def __repr__(self): 2134 return 'x' 2135 self.assertEqual(repr(C(0)), 'x') 2136 2137 @dataclass(repr=False) 2138 class C: 2139 x: int 2140 def __repr__(self): 2141 return 'x' 2142 self.assertEqual(repr(C(0)), 'x') 2143 2144 2145 class TestEq(unittest.TestCase): 2146 def test_no_eq(self): 2147 # Test a class with no __eq__ and eq=False. 2148 @dataclass(eq=False) 2149 class C: 2150 x: int 2151 self.assertNotEqual(C(0), C(0)) 2152 c = C(3) 2153 self.assertEqual(c, c) 2154 2155 # Test a class with an __eq__ and eq=False. 2156 @dataclass(eq=False) 2157 class C: 2158 x: int 2159 def __eq__(self, other): 2160 return other == 10 2161 self.assertEqual(C(3), 10) 2162 2163 def test_overwriting_eq(self): 2164 # If the class has __eq__, use it no matter the value of 2165 # eq=. 2166 2167 @dataclass 2168 class C: 2169 x: int 2170 def __eq__(self, other): 2171 return other == 3 2172 self.assertEqual(C(1), 3) 2173 self.assertNotEqual(C(1), 1) 2174 2175 @dataclass(eq=True) 2176 class C: 2177 x: int 2178 def __eq__(self, other): 2179 return other == 4 2180 self.assertEqual(C(1), 4) 2181 self.assertNotEqual(C(1), 1) 2182 2183 @dataclass(eq=False) 2184 class C: 2185 x: int 2186 def __eq__(self, other): 2187 return other == 5 2188 self.assertEqual(C(1), 5) 2189 self.assertNotEqual(C(1), 1) 2190 2191 2192 class TestOrdering(unittest.TestCase): 2193 def test_functools_total_ordering(self): 2194 # Test that functools.total_ordering works with this class. 2195 @total_ordering 2196 @dataclass 2197 class C: 2198 x: int 2199 def __lt__(self, other): 2200 # Perform the test "backward", just to make 2201 # sure this is being called. 2202 return self.x >= other 2203 2204 self.assertLess(C(0), -1) 2205 self.assertLessEqual(C(0), -1) 2206 self.assertGreater(C(0), 1) 2207 self.assertGreaterEqual(C(0), 1) 2208 2209 def test_no_order(self): 2210 # Test that no ordering functions are added by default. 2211 @dataclass(order=False) 2212 class C: 2213 x: int 2214 # Make sure no order methods are added. 2215 self.assertNotIn('__le__', C.__dict__) 2216 self.assertNotIn('__lt__', C.__dict__) 2217 self.assertNotIn('__ge__', C.__dict__) 2218 self.assertNotIn('__gt__', C.__dict__) 2219 2220 # Test that __lt__ is still called 2221 @dataclass(order=False) 2222 class C: 2223 x: int 2224 def __lt__(self, other): 2225 return False 2226 # Make sure other methods aren't added. 2227 self.assertNotIn('__le__', C.__dict__) 2228 self.assertNotIn('__ge__', C.__dict__) 2229 self.assertNotIn('__gt__', C.__dict__) 2230 2231 def test_overwriting_order(self): 2232 with self.assertRaisesRegex(TypeError, 2233 'Cannot overwrite attribute __lt__' 2234 '.*using functools.total_ordering'): 2235 @dataclass(order=True) 2236 class C: 2237 x: int 2238 def __lt__(self): 2239 pass 2240 2241 with self.assertRaisesRegex(TypeError, 2242 'Cannot overwrite attribute __le__' 2243 '.*using functools.total_ordering'): 2244 @dataclass(order=True) 2245 class C: 2246 x: int 2247 def __le__(self): 2248 pass 2249 2250 with self.assertRaisesRegex(TypeError, 2251 'Cannot overwrite attribute __gt__' 2252 '.*using functools.total_ordering'): 2253 @dataclass(order=True) 2254 class C: 2255 x: int 2256 def __gt__(self): 2257 pass 2258 2259 with self.assertRaisesRegex(TypeError, 2260 'Cannot overwrite attribute __ge__' 2261 '.*using functools.total_ordering'): 2262 @dataclass(order=True) 2263 class C: 2264 x: int 2265 def __ge__(self): 2266 pass 2267 2268 class TestHash(unittest.TestCase): 2269 def test_unsafe_hash(self): 2270 @dataclass(unsafe_hash=True) 2271 class C: 2272 x: int 2273 y: str 2274 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo'))) 2275 2276 def test_hash_rules(self): 2277 def non_bool(value): 2278 # Map to something else that's True, but not a bool. 2279 if value is None: 2280 return None 2281 if value: 2282 return (3,) 2283 return 0 2284 2285 def test(case, unsafe_hash, eq, frozen, with_hash, result): 2286 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq, 2287 frozen=frozen): 2288 if result != 'exception': 2289 if with_hash: 2290 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2291 class C: 2292 def __hash__(self): 2293 return 0 2294 else: 2295 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2296 class C: 2297 pass 2298 2299 # See if the result matches what's expected. 2300 if result == 'fn': 2301 # __hash__ contains the function we generated. 2302 self.assertIn('__hash__', C.__dict__) 2303 self.assertIsNotNone(C.__dict__['__hash__']) 2304 2305 elif result == '': 2306 # __hash__ is not present in our class. 2307 if not with_hash: 2308 self.assertNotIn('__hash__', C.__dict__) 2309 2310 elif result == 'none': 2311 # __hash__ is set to None. 2312 self.assertIn('__hash__', C.__dict__) 2313 self.assertIsNone(C.__dict__['__hash__']) 2314 2315 elif result == 'exception': 2316 # Creating the class should cause an exception. 2317 # This only happens with with_hash==True. 2318 assert(with_hash) 2319 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'): 2320 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen) 2321 class C: 2322 def __hash__(self): 2323 return 0 2324 2325 else: 2326 assert False, f'unknown result {result!r}' 2327 2328 # There are 8 cases of: 2329 # unsafe_hash=True/False 2330 # eq=True/False 2331 # frozen=True/False 2332 # And for each of these, a different result if 2333 # __hash__ is defined or not. 2334 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([ 2335 (False, False, False, '', ''), 2336 (False, False, True, '', ''), 2337 (False, True, False, 'none', ''), 2338 (False, True, True, 'fn', ''), 2339 (True, False, False, 'fn', 'exception'), 2340 (True, False, True, 'fn', 'exception'), 2341 (True, True, False, 'fn', 'exception'), 2342 (True, True, True, 'fn', 'exception'), 2343 ], 1): 2344 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash) 2345 test(case, unsafe_hash, eq, frozen, True, res_defined_hash) 2346 2347 # Test non-bool truth values, too. This is just to 2348 # make sure the data-driven table in the decorator 2349 # handles non-bool values. 2350 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash) 2351 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash) 2352 2353 2354 def test_eq_only(self): 2355 # If a class defines __eq__, __hash__ is automatically added 2356 # and set to None. This is normal Python behavior, not 2357 # related to dataclasses. Make sure we don't interfere with 2358 # that (see bpo=32546). 2359 2360 @dataclass 2361 class C: 2362 i: int 2363 def __eq__(self, other): 2364 return self.i == other.i 2365 self.assertEqual(C(1), C(1)) 2366 self.assertNotEqual(C(1), C(4)) 2367 2368 # And make sure things work in this case if we specify 2369 # unsafe_hash=True. 2370 @dataclass(unsafe_hash=True) 2371 class C: 2372 i: int 2373 def __eq__(self, other): 2374 return self.i == other.i 2375 self.assertEqual(C(1), C(1.0)) 2376 self.assertEqual(hash(C(1)), hash(C(1.0))) 2377 2378 # And check that the classes __eq__ is being used, despite 2379 # specifying eq=True. 2380 @dataclass(unsafe_hash=True, eq=True) 2381 class C: 2382 i: int 2383 def __eq__(self, other): 2384 return self.i == 3 and self.i == other.i 2385 self.assertEqual(C(3), C(3)) 2386 self.assertNotEqual(C(1), C(1)) 2387 self.assertEqual(hash(C(1)), hash(C(1.0))) 2388 2389 def test_0_field_hash(self): 2390 @dataclass(frozen=True) 2391 class C: 2392 pass 2393 self.assertEqual(hash(C()), hash(())) 2394 2395 @dataclass(unsafe_hash=True) 2396 class C: 2397 pass 2398 self.assertEqual(hash(C()), hash(())) 2399 2400 def test_1_field_hash(self): 2401 @dataclass(frozen=True) 2402 class C: 2403 x: int 2404 self.assertEqual(hash(C(4)), hash((4,))) 2405 self.assertEqual(hash(C(42)), hash((42,))) 2406 2407 @dataclass(unsafe_hash=True) 2408 class C: 2409 x: int 2410 self.assertEqual(hash(C(4)), hash((4,))) 2411 self.assertEqual(hash(C(42)), hash((42,))) 2412 2413 def test_hash_no_args(self): 2414 # Test dataclasses with no hash= argument. This exists to 2415 # make sure that if the @dataclass parameter name is changed 2416 # or the non-default hashing behavior changes, the default 2417 # hashability keeps working the same way. 2418 2419 class Base: 2420 def __hash__(self): 2421 return 301 2422 2423 # If frozen or eq is None, then use the default value (do not 2424 # specify any value in the decorator). 2425 for frozen, eq, base, expected in [ 2426 (None, None, object, 'unhashable'), 2427 (None, None, Base, 'unhashable'), 2428 (None, False, object, 'object'), 2429 (None, False, Base, 'base'), 2430 (None, True, object, 'unhashable'), 2431 (None, True, Base, 'unhashable'), 2432 (False, None, object, 'unhashable'), 2433 (False, None, Base, 'unhashable'), 2434 (False, False, object, 'object'), 2435 (False, False, Base, 'base'), 2436 (False, True, object, 'unhashable'), 2437 (False, True, Base, 'unhashable'), 2438 (True, None, object, 'tuple'), 2439 (True, None, Base, 'tuple'), 2440 (True, False, object, 'object'), 2441 (True, False, Base, 'base'), 2442 (True, True, object, 'tuple'), 2443 (True, True, Base, 'tuple'), 2444 ]: 2445 2446 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected): 2447 # First, create the class. 2448 if frozen is None and eq is None: 2449 @dataclass 2450 class C(base): 2451 i: int 2452 elif frozen is None: 2453 @dataclass(eq=eq) 2454 class C(base): 2455 i: int 2456 elif eq is None: 2457 @dataclass(frozen=frozen) 2458 class C(base): 2459 i: int 2460 else: 2461 @dataclass(frozen=frozen, eq=eq) 2462 class C(base): 2463 i: int 2464 2465 # Now, make sure it hashes as expected. 2466 if expected == 'unhashable': 2467 c = C(10) 2468 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2469 hash(c) 2470 2471 elif expected == 'base': 2472 self.assertEqual(hash(C(10)), 301) 2473 2474 elif expected == 'object': 2475 # I'm not sure what test to use here. object's 2476 # hash isn't based on id(), so calling hash() 2477 # won't tell us much. So, just check the 2478 # function used is object's. 2479 self.assertIs(C.__hash__, object.__hash__) 2480 2481 elif expected == 'tuple': 2482 self.assertEqual(hash(C(42)), hash((42,))) 2483 2484 else: 2485 assert False, f'unknown value for expected={expected!r}' 2486 2487 2488 class TestFrozen(unittest.TestCase): 2489 def test_frozen(self): 2490 @dataclass(frozen=True) 2491 class C: 2492 i: int 2493 2494 c = C(10) 2495 self.assertEqual(c.i, 10) 2496 with self.assertRaises(FrozenInstanceError): 2497 c.i = 5 2498 self.assertEqual(c.i, 10) 2499 2500 def test_inherit(self): 2501 @dataclass(frozen=True) 2502 class C: 2503 i: int 2504 2505 @dataclass(frozen=True) 2506 class D(C): 2507 j: int 2508 2509 d = D(0, 10) 2510 with self.assertRaises(FrozenInstanceError): 2511 d.i = 5 2512 with self.assertRaises(FrozenInstanceError): 2513 d.j = 6 2514 self.assertEqual(d.i, 0) 2515 self.assertEqual(d.j, 10) 2516 2517 # Test both ways: with an intermediate normal (non-dataclass) 2518 # class and without an intermediate class. 2519 def test_inherit_nonfrozen_from_frozen(self): 2520 for intermediate_class in [True, False]: 2521 with self.subTest(intermediate_class=intermediate_class): 2522 @dataclass(frozen=True) 2523 class C: 2524 i: int 2525 2526 if intermediate_class: 2527 class I(C): pass 2528 else: 2529 I = C 2530 2531 with self.assertRaisesRegex(TypeError, 2532 'cannot inherit non-frozen dataclass from a frozen one'): 2533 @dataclass 2534 class D(I): 2535 pass 2536 2537 def test_inherit_frozen_from_nonfrozen(self): 2538 for intermediate_class in [True, False]: 2539 with self.subTest(intermediate_class=intermediate_class): 2540 @dataclass 2541 class C: 2542 i: int 2543 2544 if intermediate_class: 2545 class I(C): pass 2546 else: 2547 I = C 2548 2549 with self.assertRaisesRegex(TypeError, 2550 'cannot inherit frozen dataclass from a non-frozen one'): 2551 @dataclass(frozen=True) 2552 class D(I): 2553 pass 2554 2555 def test_inherit_from_normal_class(self): 2556 for intermediate_class in [True, False]: 2557 with self.subTest(intermediate_class=intermediate_class): 2558 class C: 2559 pass 2560 2561 if intermediate_class: 2562 class I(C): pass 2563 else: 2564 I = C 2565 2566 @dataclass(frozen=True) 2567 class D(I): 2568 i: int 2569 2570 d = D(10) 2571 with self.assertRaises(FrozenInstanceError): 2572 d.i = 5 2573 2574 def test_non_frozen_normal_derived(self): 2575 # See bpo-32953. 2576 2577 @dataclass(frozen=True) 2578 class D: 2579 x: int 2580 y: int = 10 2581 2582 class S(D): 2583 pass 2584 2585 s = S(3) 2586 self.assertEqual(s.x, 3) 2587 self.assertEqual(s.y, 10) 2588 s.cached = True 2589 2590 # But can't change the frozen attributes. 2591 with self.assertRaises(FrozenInstanceError): 2592 s.x = 5 2593 with self.assertRaises(FrozenInstanceError): 2594 s.y = 5 2595 self.assertEqual(s.x, 3) 2596 self.assertEqual(s.y, 10) 2597 self.assertEqual(s.cached, True) 2598 2599 def test_overwriting_frozen(self): 2600 # frozen uses __setattr__ and __delattr__. 2601 with self.assertRaisesRegex(TypeError, 2602 'Cannot overwrite attribute __setattr__'): 2603 @dataclass(frozen=True) 2604 class C: 2605 x: int 2606 def __setattr__(self): 2607 pass 2608 2609 with self.assertRaisesRegex(TypeError, 2610 'Cannot overwrite attribute __delattr__'): 2611 @dataclass(frozen=True) 2612 class C: 2613 x: int 2614 def __delattr__(self): 2615 pass 2616 2617 @dataclass(frozen=False) 2618 class C: 2619 x: int 2620 def __setattr__(self, name, value): 2621 self.__dict__['x'] = value * 2 2622 self.assertEqual(C(10).x, 20) 2623 2624 def test_frozen_hash(self): 2625 @dataclass(frozen=True) 2626 class C: 2627 x: Any 2628 2629 # If x is immutable, we can compute the hash. No exception is 2630 # raised. 2631 hash(C(3)) 2632 2633 # If x is mutable, computing the hash is an error. 2634 with self.assertRaisesRegex(TypeError, 'unhashable type'): 2635 hash(C({})) 2636 2637 2638 class TestSlots(unittest.TestCase): 2639 def test_simple(self): 2640 @dataclass 2641 class C: 2642 __slots__ = ('x',) 2643 x: Any 2644 2645 # There was a bug where a variable in a slot was assumed to 2646 # also have a default value (of type 2647 # types.MemberDescriptorType). 2648 with self.assertRaisesRegex(TypeError, 2649 r"__init__\(\) missing 1 required positional argument: 'x'"): 2650 C() 2651 2652 # We can create an instance, and assign to x. 2653 c = C(10) 2654 self.assertEqual(c.x, 10) 2655 c.x = 5 2656 self.assertEqual(c.x, 5) 2657 2658 # We can't assign to anything else. 2659 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"): 2660 c.y = 5 2661 2662 def test_derived_added_field(self): 2663 # See bpo-33100. 2664 @dataclass 2665 class Base: 2666 __slots__ = ('x',) 2667 x: Any 2668 2669 @dataclass 2670 class Derived(Base): 2671 x: int 2672 y: int 2673 2674 d = Derived(1, 2) 2675 self.assertEqual((d.x, d.y), (1, 2)) 2676 2677 # We can add a new field to the derived instance. 2678 d.z = 10 2679 2680 class TestDescriptors(unittest.TestCase): 2681 def test_set_name(self): 2682 # See bpo-33141. 2683 2684 # Create a descriptor. 2685 class D: 2686 def __set_name__(self, owner, name): 2687 self.name = name + 'x' 2688 def __get__(self, instance, owner): 2689 if instance is not None: 2690 return 1 2691 return self 2692 2693 # This is the case of just normal descriptor behavior, no 2694 # dataclass code is involved in initializing the descriptor. 2695 @dataclass 2696 class C: 2697 c: int=D() 2698 self.assertEqual(C.c.name, 'cx') 2699 2700 # Now test with a default value and init=False, which is the 2701 # only time this is really meaningful. If not using 2702 # init=False, then the descriptor will be overwritten, anyway. 2703 @dataclass 2704 class C: 2705 c: int=field(default=D(), init=False) 2706 self.assertEqual(C.c.name, 'cx') 2707 self.assertEqual(C().c, 1) 2708 2709 def test_non_descriptor(self): 2710 # PEP 487 says __set_name__ should work on non-descriptors. 2711 # Create a descriptor. 2712 2713 class D: 2714 def __set_name__(self, owner, name): 2715 self.name = name + 'x' 2716 2717 @dataclass 2718 class C: 2719 c: int=field(default=D(), init=False) 2720 self.assertEqual(C.c.name, 'cx') 2721 2722 def test_lookup_on_instance(self): 2723 # See bpo-33175. 2724 class D: 2725 pass 2726 2727 d = D() 2728 # Create an attribute on the instance, not type. 2729 d.__set_name__ = Mock() 2730 2731 # Make sure d.__set_name__ is not called. 2732 @dataclass 2733 class C: 2734 i: int=field(default=d, init=False) 2735 2736 self.assertEqual(d.__set_name__.call_count, 0) 2737 2738 def test_lookup_on_class(self): 2739 # See bpo-33175. 2740 class D: 2741 pass 2742 D.__set_name__ = Mock() 2743 2744 # Make sure D.__set_name__ is called. 2745 @dataclass 2746 class C: 2747 i: int=field(default=D(), init=False) 2748 2749 self.assertEqual(D.__set_name__.call_count, 1) 2750 2751 2752 class TestStringAnnotations(unittest.TestCase): 2753 def test_classvar(self): 2754 # Some expressions recognized as ClassVar really aren't. But 2755 # if you're using string annotations, it's not an exact 2756 # science. 2757 # These tests assume that both "import typing" and "from 2758 # typing import *" have been run in this file. 2759 for typestr in ('ClassVar[int]', 2760 'ClassVar [int]' 2761 ' ClassVar [int]', 2762 'ClassVar', 2763 ' ClassVar ', 2764 'typing.ClassVar[int]', 2765 'typing.ClassVar[str]', 2766 ' typing.ClassVar[str]', 2767 'typing .ClassVar[str]', 2768 'typing. ClassVar[str]', 2769 'typing.ClassVar [str]', 2770 'typing.ClassVar [ str]', 2771 2772 # Not syntactically valid, but these will 2773 # be treated as ClassVars. 2774 'typing.ClassVar.[int]', 2775 'typing.ClassVar+', 2776 ): 2777 with self.subTest(typestr=typestr): 2778 @dataclass 2779 class C: 2780 x: typestr 2781 2782 # x is a ClassVar, so C() takes no args. 2783 C() 2784 2785 # And it won't appear in the class's dict because it doesn't 2786 # have a default. 2787 self.assertNotIn('x', C.__dict__) 2788 2789 def test_isnt_classvar(self): 2790 for typestr in ('CV', 2791 't.ClassVar', 2792 't.ClassVar[int]', 2793 'typing..ClassVar[int]', 2794 'Classvar', 2795 'Classvar[int]', 2796 'typing.ClassVarx[int]', 2797 'typong.ClassVar[int]', 2798 'dataclasses.ClassVar[int]', 2799 'typingxClassVar[str]', 2800 ): 2801 with self.subTest(typestr=typestr): 2802 @dataclass 2803 class C: 2804 x: typestr 2805 2806 # x is not a ClassVar, so C() takes one arg. 2807 self.assertEqual(C(10).x, 10) 2808 2809 def test_initvar(self): 2810 # These tests assume that both "import dataclasses" and "from 2811 # dataclasses import *" have been run in this file. 2812 for typestr in ('InitVar[int]', 2813 'InitVar [int]' 2814 ' InitVar [int]', 2815 'InitVar', 2816 ' InitVar ', 2817 'dataclasses.InitVar[int]', 2818 'dataclasses.InitVar[str]', 2819 ' dataclasses.InitVar[str]', 2820 'dataclasses .InitVar[str]', 2821 'dataclasses. InitVar[str]', 2822 'dataclasses.InitVar [str]', 2823 'dataclasses.InitVar [ str]', 2824 2825 # Not syntactically valid, but these will 2826 # be treated as InitVars. 2827 'dataclasses.InitVar.[int]', 2828 'dataclasses.InitVar+', 2829 ): 2830 with self.subTest(typestr=typestr): 2831 @dataclass 2832 class C: 2833 x: typestr 2834 2835 # x is an InitVar, so doesn't create a member. 2836 with self.assertRaisesRegex(AttributeError, 2837 "object has no attribute 'x'"): 2838 C(1).x 2839 2840 def test_isnt_initvar(self): 2841 for typestr in ('IV', 2842 'dc.InitVar', 2843 'xdataclasses.xInitVar', 2844 'typing.xInitVar[int]', 2845 ): 2846 with self.subTest(typestr=typestr): 2847 @dataclass 2848 class C: 2849 x: typestr 2850 2851 # x is not an InitVar, so there will be a member x. 2852 self.assertEqual(C(10).x, 10) 2853 2854 def test_classvar_module_level_import(self): 2855 from test import dataclass_module_1 2856 from test import dataclass_module_1_str 2857 from test import dataclass_module_2 2858 from test import dataclass_module_2_str 2859 2860 for m in (dataclass_module_1, dataclass_module_1_str, 2861 dataclass_module_2, dataclass_module_2_str, 2862 ): 2863 with self.subTest(m=m): 2864 # There's a difference in how the ClassVars are 2865 # interpreted when using string annotations or 2866 # not. See the imported modules for details. 2867 if m.USING_STRINGS: 2868 c = m.CV(10) 2869 else: 2870 c = m.CV() 2871 self.assertEqual(c.cv0, 20) 2872 2873 2874 # There's a difference in how the InitVars are 2875 # interpreted when using string annotations or 2876 # not. See the imported modules for details. 2877 c = m.IV(0, 1, 2, 3, 4) 2878 2879 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'): 2880 with self.subTest(field_name=field_name): 2881 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"): 2882 # Since field_name is an InitVar, it's 2883 # not an instance field. 2884 getattr(c, field_name) 2885 2886 if m.USING_STRINGS: 2887 # iv4 is interpreted as a normal field. 2888 self.assertIn('not_iv4', c.__dict__) 2889 self.assertEqual(c.not_iv4, 4) 2890 else: 2891 # iv4 is interpreted as an InitVar, so it 2892 # won't exist on the instance. 2893 self.assertNotIn('not_iv4', c.__dict__) 2894 2895 2896 class TestMakeDataclass(unittest.TestCase): 2897 def test_simple(self): 2898 C = make_dataclass('C', 2899 [('x', int), 2900 ('y', int, field(default=5))], 2901 namespace={'add_one': lambda self: self.x + 1}) 2902 c = C(10) 2903 self.assertEqual((c.x, c.y), (10, 5)) 2904 self.assertEqual(c.add_one(), 11) 2905 2906 2907 def test_no_mutate_namespace(self): 2908 # Make sure a provided namespace isn't mutated. 2909 ns = {} 2910 C = make_dataclass('C', 2911 [('x', int), 2912 ('y', int, field(default=5))], 2913 namespace=ns) 2914 self.assertEqual(ns, {}) 2915 2916 def test_base(self): 2917 class Base1: 2918 pass 2919 class Base2: 2920 pass 2921 C = make_dataclass('C', 2922 [('x', int)], 2923 bases=(Base1, Base2)) 2924 c = C(2) 2925 self.assertIsInstance(c, C) 2926 self.assertIsInstance(c, Base1) 2927 self.assertIsInstance(c, Base2) 2928 2929 def test_base_dataclass(self): 2930 @dataclass 2931 class Base1: 2932 x: int 2933 class Base2: 2934 pass 2935 C = make_dataclass('C', 2936 [('y', int)], 2937 bases=(Base1, Base2)) 2938 with self.assertRaisesRegex(TypeError, 'required positional'): 2939 c = C(2) 2940 c = C(1, 2) 2941 self.assertIsInstance(c, C) 2942 self.assertIsInstance(c, Base1) 2943 self.assertIsInstance(c, Base2) 2944 2945 self.assertEqual((c.x, c.y), (1, 2)) 2946 2947 def test_init_var(self): 2948 def post_init(self, y): 2949 self.x *= y 2950 2951 C = make_dataclass('C', 2952 [('x', int), 2953 ('y', InitVar[int]), 2954 ], 2955 namespace={'__post_init__': post_init}, 2956 ) 2957 c = C(2, 3) 2958 self.assertEqual(vars(c), {'x': 6}) 2959 self.assertEqual(len(fields(c)), 1) 2960 2961 def test_class_var(self): 2962 C = make_dataclass('C', 2963 [('x', int), 2964 ('y', ClassVar[int], 10), 2965 ('z', ClassVar[int], field(default=20)), 2966 ]) 2967 c = C(1) 2968 self.assertEqual(vars(c), {'x': 1}) 2969 self.assertEqual(len(fields(c)), 1) 2970 self.assertEqual(C.y, 10) 2971 self.assertEqual(C.z, 20) 2972 2973 def test_other_params(self): 2974 C = make_dataclass('C', 2975 [('x', int), 2976 ('y', ClassVar[int], 10), 2977 ('z', ClassVar[int], field(default=20)), 2978 ], 2979 init=False) 2980 # Make sure we have a repr, but no init. 2981 self.assertNotIn('__init__', vars(C)) 2982 self.assertIn('__repr__', vars(C)) 2983 2984 # Make sure random other params don't work. 2985 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'): 2986 C = make_dataclass('C', 2987 [], 2988 xxinit=False) 2989 2990 def test_no_types(self): 2991 C = make_dataclass('Point', ['x', 'y', 'z']) 2992 c = C(1, 2, 3) 2993 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 2994 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 2995 'y': 'typing.Any', 2996 'z': 'typing.Any'}) 2997 2998 C = make_dataclass('Point', ['x', ('y', int), 'z']) 2999 c = C(1, 2, 3) 3000 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3}) 3001 self.assertEqual(C.__annotations__, {'x': 'typing.Any', 3002 'y': int, 3003 'z': 'typing.Any'}) 3004 3005 def test_invalid_type_specification(self): 3006 for bad_field in [(), 3007 (1, 2, 3, 4), 3008 ]: 3009 with self.subTest(bad_field=bad_field): 3010 with self.assertRaisesRegex(TypeError, r'Invalid field: '): 3011 make_dataclass('C', ['a', bad_field]) 3012 3013 # And test for things with no len(). 3014 for bad_field in [float, 3015 lambda x:x, 3016 ]: 3017 with self.subTest(bad_field=bad_field): 3018 with self.assertRaisesRegex(TypeError, r'has no len\(\)'): 3019 make_dataclass('C', ['a', bad_field]) 3020 3021 def test_duplicate_field_names(self): 3022 for field in ['a', 'ab']: 3023 with self.subTest(field=field): 3024 with self.assertRaisesRegex(TypeError, 'Field name duplicated'): 3025 make_dataclass('C', [field, 'a', field]) 3026 3027 def test_keyword_field_names(self): 3028 for field in ['for', 'async', 'await', 'as']: 3029 with self.subTest(field=field): 3030 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3031 make_dataclass('C', ['a', field]) 3032 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3033 make_dataclass('C', [field]) 3034 with self.assertRaisesRegex(TypeError, 'must not be keywords'): 3035 make_dataclass('C', [field, 'a']) 3036 3037 def test_non_identifier_field_names(self): 3038 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']: 3039 with self.subTest(field=field): 3040 with self.assertRaisesRegex(TypeError, 'must be valid identifers'): 3041 make_dataclass('C', ['a', field]) 3042 with self.assertRaisesRegex(TypeError, 'must be valid identifers'): 3043 make_dataclass('C', [field]) 3044 with self.assertRaisesRegex(TypeError, 'must be valid identifers'): 3045 make_dataclass('C', [field, 'a']) 3046 3047 def test_underscore_field_names(self): 3048 # Unlike namedtuple, it's okay if dataclass field names have 3049 # an underscore. 3050 make_dataclass('C', ['_', '_a', 'a_a', 'a_']) 3051 3052 def test_funny_class_names_names(self): 3053 # No reason to prevent weird class names, since 3054 # types.new_class allows them. 3055 for classname in ['()', 'x,y', '*', '2@3', '']: 3056 with self.subTest(classname=classname): 3057 C = make_dataclass(classname, ['a', 'b']) 3058 self.assertEqual(C.__name__, classname) 3059 3060 class TestReplace(unittest.TestCase): 3061 def test(self): 3062 @dataclass(frozen=True) 3063 class C: 3064 x: int 3065 y: int 3066 3067 c = C(1, 2) 3068 c1 = replace(c, x=3) 3069 self.assertEqual(c1.x, 3) 3070 self.assertEqual(c1.y, 2) 3071 3072 def test_frozen(self): 3073 @dataclass(frozen=True) 3074 class C: 3075 x: int 3076 y: int 3077 z: int = field(init=False, default=10) 3078 t: int = field(init=False, default=100) 3079 3080 c = C(1, 2) 3081 c1 = replace(c, x=3) 3082 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100)) 3083 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100)) 3084 3085 3086 with self.assertRaisesRegex(ValueError, 'init=False'): 3087 replace(c, x=3, z=20, t=50) 3088 with self.assertRaisesRegex(ValueError, 'init=False'): 3089 replace(c, z=20) 3090 replace(c, x=3, z=20, t=50) 3091 3092 # Make sure the result is still frozen. 3093 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"): 3094 c1.x = 3 3095 3096 # Make sure we can't replace an attribute that doesn't exist, 3097 # if we're also replacing one that does exist. Test this 3098 # here, because setting attributes on frozen instances is 3099 # handled slightly differently from non-frozen ones. 3100 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3101 "keyword argument 'a'"): 3102 c1 = replace(c, x=20, a=5) 3103 3104 def test_invalid_field_name(self): 3105 @dataclass(frozen=True) 3106 class C: 3107 x: int 3108 y: int 3109 3110 c = C(1, 2) 3111 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected " 3112 "keyword argument 'z'"): 3113 c1 = replace(c, z=3) 3114 3115 def test_invalid_object(self): 3116 @dataclass(frozen=True) 3117 class C: 3118 x: int 3119 y: int 3120 3121 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3122 replace(C, x=3) 3123 3124 with self.assertRaisesRegex(TypeError, 'dataclass instance'): 3125 replace(0, x=3) 3126 3127 def test_no_init(self): 3128 @dataclass 3129 class C: 3130 x: int 3131 y: int = field(init=False, default=10) 3132 3133 c = C(1) 3134 c.y = 20 3135 3136 # Make sure y gets the default value. 3137 c1 = replace(c, x=5) 3138 self.assertEqual((c1.x, c1.y), (5, 10)) 3139 3140 # Trying to replace y is an error. 3141 with self.assertRaisesRegex(ValueError, 'init=False'): 3142 replace(c, x=2, y=30) 3143 3144 with self.assertRaisesRegex(ValueError, 'init=False'): 3145 replace(c, y=30) 3146 3147 def test_classvar(self): 3148 @dataclass 3149 class C: 3150 x: int 3151 y: ClassVar[int] = 1000 3152 3153 c = C(1) 3154 d = C(2) 3155 3156 self.assertIs(c.y, d.y) 3157 self.assertEqual(c.y, 1000) 3158 3159 # Trying to replace y is an error: can't replace ClassVars. 3160 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an " 3161 "unexpected keyword argument 'y'"): 3162 replace(c, y=30) 3163 3164 replace(c, x=5) 3165 3166 def test_initvar_is_specified(self): 3167 @dataclass 3168 class C: 3169 x: int 3170 y: InitVar[int] 3171 3172 def __post_init__(self, y): 3173 self.x *= y 3174 3175 c = C(1, 10) 3176 self.assertEqual(c.x, 10) 3177 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be " 3178 "specified with replace()"): 3179 replace(c, x=3) 3180 c = replace(c, x=3, y=5) 3181 self.assertEqual(c.x, 15) 3182 3183 def test_recursive_repr(self): 3184 @dataclass 3185 class C: 3186 f: "C" 3187 3188 c = C(None) 3189 c.f = c 3190 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)") 3191 3192 def test_recursive_repr_two_attrs(self): 3193 @dataclass 3194 class C: 3195 f: "C" 3196 g: "C" 3197 3198 c = C(None, None) 3199 c.f = c 3200 c.g = c 3201 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" 3202 ".<locals>.C(f=..., g=...)") 3203 3204 def test_recursive_repr_indirection(self): 3205 @dataclass 3206 class C: 3207 f: "D" 3208 3209 @dataclass 3210 class D: 3211 f: "C" 3212 3213 c = C(None) 3214 d = D(None) 3215 c.f = d 3216 d.f = c 3217 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection" 3218 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection" 3219 ".<locals>.D(f=...))") 3220 3221 def test_recursive_repr_indirection_two(self): 3222 @dataclass 3223 class C: 3224 f: "D" 3225 3226 @dataclass 3227 class D: 3228 f: "E" 3229 3230 @dataclass 3231 class E: 3232 f: "C" 3233 3234 c = C(None) 3235 d = D(None) 3236 e = E(None) 3237 c.f = d 3238 d.f = e 3239 e.f = c 3240 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two" 3241 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two" 3242 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two" 3243 ".<locals>.E(f=...)))") 3244 3245 def test_recursive_repr_two_attrs(self): 3246 @dataclass 3247 class C: 3248 f: "C" 3249 g: "C" 3250 3251 c = C(None, None) 3252 c.f = c 3253 c.g = c 3254 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs" 3255 ".<locals>.C(f=..., g=...)") 3256 3257 def test_recursive_repr_misc_attrs(self): 3258 @dataclass 3259 class C: 3260 f: "C" 3261 g: int 3262 3263 c = C(None, 1) 3264 c.f = c 3265 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs" 3266 ".<locals>.C(f=..., g=1)") 3267 3268 ## def test_initvar(self): 3269 ## @dataclass 3270 ## class C: 3271 ## x: int 3272 ## y: InitVar[int] 3273 3274 ## c = C(1, 10) 3275 ## d = C(2, 20) 3276 3277 ## # In our case, replacing an InitVar is a no-op 3278 ## self.assertEqual(c, replace(c, y=5)) 3279 3280 ## replace(c, x=5) 3281 3282 3283 if __name__ == '__main__': 3284 unittest.main() 3285