1 #!/usr/bin/env python 2 # 3 # Copyright 2010 Google Inc. 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for protorpc.messages.""" 19 import six 20 21 __author__ = 'rafek (at] google.com (Rafe Kaplan)' 22 23 24 import pickle 25 import re 26 import sys 27 import types 28 import unittest 29 30 from protorpc import descriptor 31 from protorpc import message_types 32 from protorpc import messages 33 from protorpc import test_util 34 35 36 class ModuleInterfaceTest(test_util.ModuleInterfaceTest, 37 test_util.TestCase): 38 39 MODULE = messages 40 41 42 class ValidationErrorTest(test_util.TestCase): 43 44 def testStr_NoFieldName(self): 45 """Test string version of ValidationError when no name provided.""" 46 self.assertEquals('Validation error', 47 str(messages.ValidationError('Validation error'))) 48 49 def testStr_FieldName(self): 50 """Test string version of ValidationError when no name provided.""" 51 validation_error = messages.ValidationError('Validation error') 52 validation_error.field_name = 'a_field' 53 self.assertEquals('Validation error', str(validation_error)) 54 55 56 class EnumTest(test_util.TestCase): 57 58 def setUp(self): 59 """Set up tests.""" 60 # Redefine Color class in case so that changes to it (an error) in one test 61 # does not affect other tests. 62 global Color 63 class Color(messages.Enum): 64 RED = 20 65 ORANGE = 2 66 YELLOW = 40 67 GREEN = 4 68 BLUE = 50 69 INDIGO = 5 70 VIOLET = 80 71 72 def testNames(self): 73 """Test that names iterates over enum names.""" 74 self.assertEquals( 75 set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']), 76 set(Color.names())) 77 78 def testNumbers(self): 79 """Tests that numbers iterates of enum numbers.""" 80 self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers())) 81 82 def testIterate(self): 83 """Test that __iter__ iterates over all enum values.""" 84 self.assertEquals(set(Color), 85 set([Color.RED, 86 Color.ORANGE, 87 Color.YELLOW, 88 Color.GREEN, 89 Color.BLUE, 90 Color.INDIGO, 91 Color.VIOLET])) 92 93 def testNaturalOrder(self): 94 """Test that natural order enumeration is in numeric order.""" 95 self.assertEquals([Color.ORANGE, 96 Color.GREEN, 97 Color.INDIGO, 98 Color.RED, 99 Color.YELLOW, 100 Color.BLUE, 101 Color.VIOLET], 102 sorted(Color)) 103 104 def testByName(self): 105 """Test look-up by name.""" 106 self.assertEquals(Color.RED, Color.lookup_by_name('RED')) 107 self.assertRaises(KeyError, Color.lookup_by_name, 20) 108 self.assertRaises(KeyError, Color.lookup_by_name, Color.RED) 109 110 def testByNumber(self): 111 """Test look-up by number.""" 112 self.assertRaises(KeyError, Color.lookup_by_number, 'RED') 113 self.assertEquals(Color.RED, Color.lookup_by_number(20)) 114 self.assertRaises(KeyError, Color.lookup_by_number, Color.RED) 115 116 def testConstructor(self): 117 """Test that constructor look-up by name or number.""" 118 self.assertEquals(Color.RED, Color('RED')) 119 self.assertEquals(Color.RED, Color(u'RED')) 120 self.assertEquals(Color.RED, Color(20)) 121 if six.PY2: 122 self.assertEquals(Color.RED, Color(long(20))) 123 self.assertEquals(Color.RED, Color(Color.RED)) 124 self.assertRaises(TypeError, Color, 'Not exists') 125 self.assertRaises(TypeError, Color, 'Red') 126 self.assertRaises(TypeError, Color, 100) 127 self.assertRaises(TypeError, Color, 10.0) 128 129 def testLen(self): 130 """Test that len function works to count enums.""" 131 self.assertEquals(7, len(Color)) 132 133 def testNoSubclasses(self): 134 """Test that it is not possible to sub-class enum classes.""" 135 def declare_subclass(): 136 class MoreColor(Color): 137 pass 138 self.assertRaises(messages.EnumDefinitionError, 139 declare_subclass) 140 141 def testClassNotMutable(self): 142 """Test that enum classes themselves are not mutable.""" 143 self.assertRaises(AttributeError, 144 setattr, 145 Color, 146 'something_new', 147 10) 148 149 def testInstancesMutable(self): 150 """Test that enum instances are not mutable.""" 151 self.assertRaises(TypeError, 152 setattr, 153 Color.RED, 154 'something_new', 155 10) 156 157 def testDefEnum(self): 158 """Test def_enum works by building enum class from dict.""" 159 WeekDay = messages.Enum.def_enum({'Monday': 1, 160 'Tuesday': 2, 161 'Wednesday': 3, 162 'Thursday': 4, 163 'Friday': 6, 164 'Saturday': 7, 165 'Sunday': 8}, 166 'WeekDay') 167 self.assertEquals('Wednesday', WeekDay(3).name) 168 self.assertEquals(6, WeekDay('Friday').number) 169 self.assertEquals(WeekDay.Sunday, WeekDay('Sunday')) 170 171 def testNonInt(self): 172 """Test that non-integer values rejection by enum def.""" 173 self.assertRaises(messages.EnumDefinitionError, 174 messages.Enum.def_enum, 175 {'Bad': '1'}, 176 'BadEnum') 177 178 def testNegativeInt(self): 179 """Test that negative numbers rejection by enum def.""" 180 self.assertRaises(messages.EnumDefinitionError, 181 messages.Enum.def_enum, 182 {'Bad': -1}, 183 'BadEnum') 184 185 def testLowerBound(self): 186 """Test that zero is accepted by enum def.""" 187 class NotImportant(messages.Enum): 188 """Testing for value zero""" 189 VALUE = 0 190 191 self.assertEquals(0, int(NotImportant.VALUE)) 192 193 def testTooLargeInt(self): 194 """Test that numbers too large are rejected.""" 195 self.assertRaises(messages.EnumDefinitionError, 196 messages.Enum.def_enum, 197 {'Bad': (2 ** 29)}, 198 'BadEnum') 199 200 def testRepeatedInt(self): 201 """Test duplicated numbers are forbidden.""" 202 self.assertRaises(messages.EnumDefinitionError, 203 messages.Enum.def_enum, 204 {'Ok': 1, 'Repeated': 1}, 205 'BadEnum') 206 207 def testStr(self): 208 """Test converting to string.""" 209 self.assertEquals('RED', str(Color.RED)) 210 self.assertEquals('ORANGE', str(Color.ORANGE)) 211 212 def testInt(self): 213 """Test converting to int.""" 214 self.assertEquals(20, int(Color.RED)) 215 self.assertEquals(2, int(Color.ORANGE)) 216 217 def testRepr(self): 218 """Test enum representation.""" 219 self.assertEquals('Color(RED, 20)', repr(Color.RED)) 220 self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW)) 221 222 def testDocstring(self): 223 """Test that docstring is supported ok.""" 224 class NotImportant(messages.Enum): 225 """I have a docstring.""" 226 227 VALUE1 = 1 228 229 self.assertEquals('I have a docstring.', NotImportant.__doc__) 230 231 def testDeleteEnumValue(self): 232 """Test that enum values cannot be deleted.""" 233 self.assertRaises(TypeError, delattr, Color, 'RED') 234 235 def testEnumName(self): 236 """Test enum name.""" 237 module_name = test_util.get_module_name(EnumTest) 238 self.assertEquals('%s.Color' % module_name, Color.definition_name()) 239 self.assertEquals(module_name, Color.outer_definition_name()) 240 self.assertEquals(module_name, Color.definition_package()) 241 242 def testDefinitionName_OverrideModule(self): 243 """Test enum module is overriden by module package name.""" 244 global package 245 try: 246 package = 'my.package' 247 self.assertEquals('my.package.Color', Color.definition_name()) 248 self.assertEquals('my.package', Color.outer_definition_name()) 249 self.assertEquals('my.package', Color.definition_package()) 250 finally: 251 del package 252 253 def testDefinitionName_NoModule(self): 254 """Test what happens when there is no module for enum.""" 255 class Enum1(messages.Enum): 256 pass 257 258 original_modules = sys.modules 259 sys.modules = dict(sys.modules) 260 try: 261 del sys.modules[__name__] 262 self.assertEquals('Enum1', Enum1.definition_name()) 263 self.assertEquals(None, Enum1.outer_definition_name()) 264 self.assertEquals(None, Enum1.definition_package()) 265 self.assertEquals(six.text_type, type(Enum1.definition_name())) 266 finally: 267 sys.modules = original_modules 268 269 def testDefinitionName_Nested(self): 270 """Test nested Enum names.""" 271 class MyMessage(messages.Message): 272 273 class NestedEnum(messages.Enum): 274 275 pass 276 277 class NestedMessage(messages.Message): 278 279 class NestedEnum(messages.Enum): 280 281 pass 282 283 module_name = test_util.get_module_name(EnumTest) 284 self.assertEquals('%s.MyMessage.NestedEnum' % module_name, 285 MyMessage.NestedEnum.definition_name()) 286 self.assertEquals('%s.MyMessage' % module_name, 287 MyMessage.NestedEnum.outer_definition_name()) 288 self.assertEquals(module_name, 289 MyMessage.NestedEnum.definition_package()) 290 291 self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name, 292 MyMessage.NestedMessage.NestedEnum.definition_name()) 293 self.assertEquals( 294 '%s.MyMessage.NestedMessage' % module_name, 295 MyMessage.NestedMessage.NestedEnum.outer_definition_name()) 296 self.assertEquals(module_name, 297 MyMessage.NestedMessage.NestedEnum.definition_package()) 298 299 def testMessageDefinition(self): 300 """Test that enumeration knows its enclosing message definition.""" 301 class OuterEnum(messages.Enum): 302 pass 303 304 self.assertEquals(None, OuterEnum.message_definition()) 305 306 class OuterMessage(messages.Message): 307 308 class InnerEnum(messages.Enum): 309 pass 310 311 self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition()) 312 313 def testComparison(self): 314 """Test comparing various enums to different types.""" 315 class Enum1(messages.Enum): 316 VAL1 = 1 317 VAL2 = 2 318 319 class Enum2(messages.Enum): 320 VAL1 = 1 321 322 self.assertEquals(Enum1.VAL1, Enum1.VAL1) 323 self.assertNotEquals(Enum1.VAL1, Enum1.VAL2) 324 self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) 325 self.assertNotEquals(Enum1.VAL1, 'VAL1') 326 self.assertNotEquals(Enum1.VAL1, 1) 327 self.assertNotEquals(Enum1.VAL1, 2) 328 self.assertNotEquals(Enum1.VAL1, None) 329 self.assertNotEquals(Enum1.VAL1, Enum2.VAL1) 330 331 self.assertTrue(Enum1.VAL1 < Enum1.VAL2) 332 self.assertTrue(Enum1.VAL2 > Enum1.VAL1) 333 334 self.assertNotEquals(1, Enum2.VAL1) 335 336 def testPickle(self): 337 """Testing pickling and unpickling of Enum instances.""" 338 colors = list(Color) 339 unpickled = pickle.loads(pickle.dumps(colors)) 340 self.assertEquals(colors, unpickled) 341 # Unpickling shouldn't create new enum instances. 342 for i, color in enumerate(colors): 343 self.assertTrue(color is unpickled[i]) 344 345 346 class FieldListTest(test_util.TestCase): 347 348 def setUp(self): 349 self.integer_field = messages.IntegerField(1, repeated=True) 350 351 def testConstructor(self): 352 self.assertEquals([1, 2, 3], 353 messages.FieldList(self.integer_field, [1, 2, 3])) 354 self.assertEquals([1, 2, 3], 355 messages.FieldList(self.integer_field, (1, 2, 3))) 356 self.assertEquals([], messages.FieldList(self.integer_field, [])) 357 358 def testNone(self): 359 self.assertRaises(TypeError, messages.FieldList, self.integer_field, None) 360 361 def testDoNotAutoConvertString(self): 362 string_field = messages.StringField(1, repeated=True) 363 self.assertRaises(messages.ValidationError, 364 messages.FieldList, string_field, 'abc') 365 366 def testConstructorCopies(self): 367 a_list = [1, 3, 6] 368 field_list = messages.FieldList(self.integer_field, a_list) 369 self.assertFalse(a_list is field_list) 370 self.assertFalse(field_list is 371 messages.FieldList(self.integer_field, field_list)) 372 373 def testNonRepeatedField(self): 374 self.assertRaisesWithRegexpMatch( 375 messages.FieldDefinitionError, 376 'FieldList may only accept repeated fields', 377 messages.FieldList, 378 messages.IntegerField(1), 379 []) 380 381 def testConstructor_InvalidValues(self): 382 self.assertRaisesWithRegexpMatch( 383 messages.ValidationError, 384 re.escape("Expected type %r " 385 "for IntegerField, found 1 (type %r)" 386 % (six.integer_types, str)), 387 messages.FieldList, self.integer_field, ["1", "2", "3"]) 388 389 def testConstructor_Scalars(self): 390 self.assertRaisesWithRegexpMatch( 391 messages.ValidationError, 392 "IntegerField is repeated. Found: 3", 393 messages.FieldList, self.integer_field, 3) 394 395 self.assertRaisesWithRegexpMatch( 396 messages.ValidationError, 397 "IntegerField is repeated. Found: <(list[_]?|sequence)iterator object", 398 messages.FieldList, self.integer_field, iter([1, 2, 3])) 399 400 def testSetSlice(self): 401 field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) 402 field_list[1:3] = [10, 20] 403 self.assertEquals([1, 10, 20, 4, 5], field_list) 404 405 def testSetSlice_InvalidValues(self): 406 field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) 407 408 def setslice(): 409 field_list[1:3] = ['10', '20'] 410 411 msg_re = re.escape("Expected type %r " 412 "for IntegerField, found 10 (type %r)" 413 % (six.integer_types, str)) 414 self.assertRaisesWithRegexpMatch( 415 messages.ValidationError, 416 msg_re, 417 setslice) 418 419 def testSetItem(self): 420 field_list = messages.FieldList(self.integer_field, [2]) 421 field_list[0] = 10 422 self.assertEquals([10], field_list) 423 424 def testSetItem_InvalidValues(self): 425 field_list = messages.FieldList(self.integer_field, [2]) 426 427 def setitem(): 428 field_list[0] = '10' 429 self.assertRaisesWithRegexpMatch( 430 messages.ValidationError, 431 re.escape("Expected type %r " 432 "for IntegerField, found 10 (type %r)" 433 % (six.integer_types, str)), 434 setitem) 435 436 def testAppend(self): 437 field_list = messages.FieldList(self.integer_field, [2]) 438 field_list.append(10) 439 self.assertEquals([2, 10], field_list) 440 441 def testAppend_InvalidValues(self): 442 field_list = messages.FieldList(self.integer_field, [2]) 443 field_list.name = 'a_field' 444 445 def append(): 446 field_list.append('10') 447 self.assertRaisesWithRegexpMatch( 448 messages.ValidationError, 449 re.escape("Expected type %r " 450 "for IntegerField, found 10 (type %r)" 451 % (six.integer_types, str)), 452 append) 453 454 def testExtend(self): 455 field_list = messages.FieldList(self.integer_field, [2]) 456 field_list.extend([10]) 457 self.assertEquals([2, 10], field_list) 458 459 def testExtend_InvalidValues(self): 460 field_list = messages.FieldList(self.integer_field, [2]) 461 462 def extend(): 463 field_list.extend(['10']) 464 self.assertRaisesWithRegexpMatch( 465 messages.ValidationError, 466 re.escape("Expected type %r " 467 "for IntegerField, found 10 (type %r)" 468 % (six.integer_types, str)), 469 extend) 470 471 def testInsert(self): 472 field_list = messages.FieldList(self.integer_field, [2, 3]) 473 field_list.insert(1, 10) 474 self.assertEquals([2, 10, 3], field_list) 475 476 def testInsert_InvalidValues(self): 477 field_list = messages.FieldList(self.integer_field, [2, 3]) 478 479 def insert(): 480 field_list.insert(1, '10') 481 self.assertRaisesWithRegexpMatch( 482 messages.ValidationError, 483 re.escape("Expected type %r " 484 "for IntegerField, found 10 (type %r)" 485 % (six.integer_types, str)), 486 insert) 487 488 def testPickle(self): 489 """Testing pickling and unpickling of disconnected FieldList instances.""" 490 field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5]) 491 unpickled = pickle.loads(pickle.dumps(field_list)) 492 self.assertEquals(field_list, unpickled) 493 self.assertIsInstance(unpickled.field, messages.IntegerField) 494 self.assertEquals(1, unpickled.field.number) 495 self.assertTrue(unpickled.field.repeated) 496 497 498 class FieldTest(test_util.TestCase): 499 500 def ActionOnAllFieldClasses(self, action): 501 """Test all field classes except Message and Enum. 502 503 Message and Enum require separate tests. 504 505 Args: 506 action: Callable that takes the field class as a parameter. 507 """ 508 for field_class in (messages.IntegerField, 509 messages.FloatField, 510 messages.BooleanField, 511 messages.BytesField, 512 messages.StringField, 513 ): 514 action(field_class) 515 516 def testNumberAttribute(self): 517 """Test setting the number attribute.""" 518 def action(field_class): 519 # Check range. 520 self.assertRaises(messages.InvalidNumberError, 521 field_class, 522 0) 523 self.assertRaises(messages.InvalidNumberError, 524 field_class, 525 -1) 526 self.assertRaises(messages.InvalidNumberError, 527 field_class, 528 messages.MAX_FIELD_NUMBER + 1) 529 530 # Check reserved. 531 self.assertRaises(messages.InvalidNumberError, 532 field_class, 533 messages.FIRST_RESERVED_FIELD_NUMBER) 534 self.assertRaises(messages.InvalidNumberError, 535 field_class, 536 messages.LAST_RESERVED_FIELD_NUMBER) 537 self.assertRaises(messages.InvalidNumberError, 538 field_class, 539 '1') 540 541 # This one should work. 542 field_class(number=1) 543 self.ActionOnAllFieldClasses(action) 544 545 def testRequiredAndRepeated(self): 546 """Test setting the required and repeated fields.""" 547 def action(field_class): 548 field_class(1, required=True) 549 field_class(1, repeated=True) 550 self.assertRaises(messages.FieldDefinitionError, 551 field_class, 552 1, 553 required=True, 554 repeated=True) 555 self.ActionOnAllFieldClasses(action) 556 557 def testInvalidVariant(self): 558 """Test field with invalid variants.""" 559 def action(field_class): 560 if field_class is not message_types.DateTimeField: 561 self.assertRaises(messages.InvalidVariantError, 562 field_class, 563 1, 564 variant=messages.Variant.ENUM) 565 self.ActionOnAllFieldClasses(action) 566 567 def testDefaultVariant(self): 568 """Test that default variant is used when not set.""" 569 def action(field_class): 570 field = field_class(1) 571 self.assertEquals(field_class.DEFAULT_VARIANT, field.variant) 572 573 self.ActionOnAllFieldClasses(action) 574 575 def testAlternateVariant(self): 576 """Test that default variant is used when not set.""" 577 field = messages.IntegerField(1, variant=messages.Variant.UINT32) 578 self.assertEquals(messages.Variant.UINT32, field.variant) 579 580 def testDefaultFields_Single(self): 581 """Test default field is correct type (single).""" 582 defaults = {messages.IntegerField: 10, 583 messages.FloatField: 1.5, 584 messages.BooleanField: False, 585 messages.BytesField: b'abc', 586 messages.StringField: u'abc', 587 } 588 589 def action(field_class): 590 field_class(1, default=defaults[field_class]) 591 self.ActionOnAllFieldClasses(action) 592 593 # Run defaults test again checking for str/unicode compatiblity. 594 defaults[messages.StringField] = 'abc' 595 self.ActionOnAllFieldClasses(action) 596 597 def testStringField_BadUnicodeInDefault(self): 598 """Test binary values in string field.""" 599 self.assertRaisesWithRegexpMatch( 600 messages.InvalidDefaultError, 601 r"Invalid default value for StringField:.*: " 602 r"Field encountered non-ASCII string .*: " 603 r"'ascii' codec can't decode byte 0x89 in position 0: " 604 r"ordinal not in range", 605 messages.StringField, 1, default=b'\x89') 606 607 def testDefaultFields_InvalidSingle(self): 608 """Test default field is correct type (invalid single).""" 609 def action(field_class): 610 self.assertRaises(messages.InvalidDefaultError, 611 field_class, 612 1, 613 default=object()) 614 self.ActionOnAllFieldClasses(action) 615 616 def testDefaultFields_InvalidRepeated(self): 617 """Test default field does not accept defaults.""" 618 self.assertRaisesWithRegexpMatch( 619 messages.FieldDefinitionError, 620 'Repeated fields may not have defaults', 621 messages.StringField, 1, repeated=True, default=[1, 2, 3]) 622 623 def testDefaultFields_None(self): 624 """Test none is always acceptable.""" 625 def action(field_class): 626 field_class(1, default=None) 627 field_class(1, required=True, default=None) 628 field_class(1, repeated=True, default=None) 629 self.ActionOnAllFieldClasses(action) 630 631 def testDefaultFields_Enum(self): 632 """Test the default for enum fields.""" 633 class Symbol(messages.Enum): 634 635 ALPHA = 1 636 BETA = 2 637 GAMMA = 3 638 639 field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA) 640 641 self.assertEquals(Symbol.ALPHA, field.default) 642 643 def testDefaultFields_EnumStringDelayedResolution(self): 644 """Test that enum fields resolve default strings.""" 645 field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', 646 1, 647 default='OPTIONAL') 648 649 self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default) 650 651 def testDefaultFields_EnumIntDelayedResolution(self): 652 """Test that enum fields resolve default integers.""" 653 field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', 654 1, 655 default=2) 656 657 self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default) 658 659 def testDefaultFields_EnumOkIfTypeKnown(self): 660 """Test that enum fields accept valid default values when type is known.""" 661 field = messages.EnumField(descriptor.FieldDescriptor.Label, 662 1, 663 default='REPEATED') 664 665 self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default) 666 667 def testDefaultFields_EnumForceCheckIfTypeKnown(self): 668 """Test that enum fields validate default values if type is known.""" 669 self.assertRaisesWithRegexpMatch(TypeError, 670 'No such value for NOT_A_LABEL in ' 671 'Enum Label', 672 messages.EnumField, 673 descriptor.FieldDescriptor.Label, 674 1, 675 default='NOT_A_LABEL') 676 677 def testDefaultFields_EnumInvalidDelayedResolution(self): 678 """Test that enum fields raise errors upon delayed resolution error.""" 679 field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label', 680 1, 681 default=200) 682 683 self.assertRaisesWithRegexpMatch(TypeError, 684 'No such value for 200 in Enum Label', 685 getattr, 686 field, 687 'default') 688 689 def testValidate_Valid(self): 690 """Test validation of valid values.""" 691 values = {messages.IntegerField: 10, 692 messages.FloatField: 1.5, 693 messages.BooleanField: False, 694 messages.BytesField: b'abc', 695 messages.StringField: u'abc', 696 } 697 def action(field_class): 698 # Optional. 699 field = field_class(1) 700 field.validate(values[field_class]) 701 702 # Required. 703 field = field_class(1, required=True) 704 field.validate(values[field_class]) 705 706 # Repeated. 707 field = field_class(1, repeated=True) 708 field.validate([]) 709 field.validate(()) 710 field.validate([values[field_class]]) 711 field.validate((values[field_class],)) 712 713 # Right value, but not repeated. 714 self.assertRaises(messages.ValidationError, 715 field.validate, 716 values[field_class]) 717 self.assertRaises(messages.ValidationError, 718 field.validate, 719 values[field_class]) 720 721 self.ActionOnAllFieldClasses(action) 722 723 def testValidate_Invalid(self): 724 """Test validation of valid values.""" 725 values = {messages.IntegerField: "10", 726 messages.FloatField: 1, 727 messages.BooleanField: 0, 728 messages.BytesField: 10.20, 729 messages.StringField: 42, 730 } 731 def action(field_class): 732 # Optional. 733 field = field_class(1) 734 self.assertRaises(messages.ValidationError, 735 field.validate, 736 values[field_class]) 737 738 # Required. 739 field = field_class(1, required=True) 740 self.assertRaises(messages.ValidationError, 741 field.validate, 742 values[field_class]) 743 744 # Repeated. 745 field = field_class(1, repeated=True) 746 self.assertRaises(messages.ValidationError, 747 field.validate, 748 [values[field_class]]) 749 self.assertRaises(messages.ValidationError, 750 field.validate, 751 (values[field_class],)) 752 self.ActionOnAllFieldClasses(action) 753 754 def testValidate_None(self): 755 """Test that None is valid for non-required fields.""" 756 def action(field_class): 757 # Optional. 758 field = field_class(1) 759 field.validate(None) 760 761 # Required. 762 field = field_class(1, required=True) 763 self.assertRaisesWithRegexpMatch(messages.ValidationError, 764 'Required field is missing', 765 field.validate, 766 None) 767 768 # Repeated. 769 field = field_class(1, repeated=True) 770 field.validate(None) 771 self.assertRaisesWithRegexpMatch(messages.ValidationError, 772 'Repeated values for %s may ' 773 'not be None' % field_class.__name__, 774 field.validate, 775 [None]) 776 self.assertRaises(messages.ValidationError, 777 field.validate, 778 (None,)) 779 self.ActionOnAllFieldClasses(action) 780 781 def testValidateElement(self): 782 """Test validation of valid values.""" 783 values = {messages.IntegerField: 10, 784 messages.FloatField: 1.5, 785 messages.BooleanField: False, 786 messages.BytesField: 'abc', 787 messages.StringField: u'abc', 788 } 789 def action(field_class): 790 # Optional. 791 field = field_class(1) 792 field.validate_element(values[field_class]) 793 794 # Required. 795 field = field_class(1, required=True) 796 field.validate_element(values[field_class]) 797 798 # Repeated. 799 field = field_class(1, repeated=True) 800 self.assertRaises(message.VAlidationError, 801 field.validate_element, 802 []) 803 self.assertRaises(message.VAlidationError, 804 field.validate_element, 805 ()) 806 field.validate_element(values[field_class]) 807 field.validate_element(values[field_class]) 808 809 # Right value, but repeated. 810 self.assertRaises(messages.ValidationError, 811 field.validate_element, 812 [values[field_class]]) 813 self.assertRaises(messages.ValidationError, 814 field.validate_element, 815 (values[field_class],)) 816 817 def testReadOnly(self): 818 """Test that objects are all read-only.""" 819 def action(field_class): 820 field = field_class(10) 821 self.assertRaises(AttributeError, 822 setattr, 823 field, 824 'number', 825 20) 826 self.assertRaises(AttributeError, 827 setattr, 828 field, 829 'anything_else', 830 'whatever') 831 self.ActionOnAllFieldClasses(action) 832 833 def testMessageField(self): 834 """Test the construction of message fields.""" 835 self.assertRaises(messages.FieldDefinitionError, 836 messages.MessageField, 837 str, 838 10) 839 840 self.assertRaises(messages.FieldDefinitionError, 841 messages.MessageField, 842 messages.Message, 843 10) 844 845 class MyMessage(messages.Message): 846 pass 847 848 field = messages.MessageField(MyMessage, 10) 849 self.assertEquals(MyMessage, field.type) 850 851 def testMessageField_ForwardReference(self): 852 """Test the construction of forward reference message fields.""" 853 global MyMessage 854 global ForwardMessage 855 try: 856 class MyMessage(messages.Message): 857 858 self_reference = messages.MessageField('MyMessage', 1) 859 forward = messages.MessageField('ForwardMessage', 2) 860 nested = messages.MessageField('ForwardMessage.NestedMessage', 3) 861 inner = messages.MessageField('Inner', 4) 862 863 class Inner(messages.Message): 864 865 sibling = messages.MessageField('Sibling', 1) 866 867 class Sibling(messages.Message): 868 869 pass 870 871 class ForwardMessage(messages.Message): 872 873 class NestedMessage(messages.Message): 874 875 pass 876 877 self.assertEquals(MyMessage, 878 MyMessage.field_by_name('self_reference').type) 879 880 self.assertEquals(ForwardMessage, 881 MyMessage.field_by_name('forward').type) 882 883 self.assertEquals(ForwardMessage.NestedMessage, 884 MyMessage.field_by_name('nested').type) 885 886 self.assertEquals(MyMessage.Inner, 887 MyMessage.field_by_name('inner').type) 888 889 self.assertEquals(MyMessage.Sibling, 890 MyMessage.Inner.field_by_name('sibling').type) 891 finally: 892 try: 893 del MyMessage 894 del ForwardMessage 895 except: 896 pass 897 898 def testMessageField_WrongType(self): 899 """Test that forward referencing the wrong type raises an error.""" 900 global AnEnum 901 try: 902 class AnEnum(messages.Enum): 903 pass 904 905 class AnotherMessage(messages.Message): 906 907 a_field = messages.MessageField('AnEnum', 1) 908 909 self.assertRaises(messages.FieldDefinitionError, 910 getattr, 911 AnotherMessage.field_by_name('a_field'), 912 'type') 913 finally: 914 del AnEnum 915 916 def testMessageFieldValidate(self): 917 """Test validation on message field.""" 918 class MyMessage(messages.Message): 919 pass 920 921 class AnotherMessage(messages.Message): 922 pass 923 924 field = messages.MessageField(MyMessage, 10) 925 field.validate(MyMessage()) 926 927 self.assertRaises(messages.ValidationError, 928 field.validate, 929 AnotherMessage()) 930 931 def testMessageFieldMessageType(self): 932 """Test message_type property.""" 933 class MyMessage(messages.Message): 934 pass 935 936 class HasMessage(messages.Message): 937 field = messages.MessageField(MyMessage, 1) 938 939 self.assertEqual(HasMessage.field.type, HasMessage.field.message_type) 940 941 def testMessageFieldValueFromMessage(self): 942 class MyMessage(messages.Message): 943 pass 944 945 class HasMessage(messages.Message): 946 field = messages.MessageField(MyMessage, 1) 947 948 instance = MyMessage() 949 950 self.assertTrue(instance is HasMessage.field.value_from_message(instance)) 951 952 def testMessageFieldValueFromMessageWrongType(self): 953 class MyMessage(messages.Message): 954 pass 955 956 class HasMessage(messages.Message): 957 field = messages.MessageField(MyMessage, 1) 958 959 self.assertRaisesWithRegexpMatch( 960 messages.DecodeError, 961 'Expected type MyMessage, got int: 10', 962 HasMessage.field.value_from_message, 10) 963 964 def testMessageFieldValueToMessage(self): 965 class MyMessage(messages.Message): 966 pass 967 968 class HasMessage(messages.Message): 969 field = messages.MessageField(MyMessage, 1) 970 971 instance = MyMessage() 972 973 self.assertTrue(instance is HasMessage.field.value_to_message(instance)) 974 975 def testMessageFieldValueToMessageWrongType(self): 976 class MyMessage(messages.Message): 977 pass 978 979 class MyOtherMessage(messages.Message): 980 pass 981 982 class HasMessage(messages.Message): 983 field = messages.MessageField(MyMessage, 1) 984 985 instance = MyOtherMessage() 986 987 self.assertRaisesWithRegexpMatch( 988 messages.EncodeError, 989 'Expected type MyMessage, got MyOtherMessage: <MyOtherMessage>', 990 HasMessage.field.value_to_message, instance) 991 992 def testIntegerField_AllowLong(self): 993 """Test that the integer field allows for longs.""" 994 if six.PY2: 995 messages.IntegerField(10, default=long(10)) 996 997 def testMessageFieldValidate_Initialized(self): 998 """Test validation on message field.""" 999 class MyMessage(messages.Message): 1000 field1 = messages.IntegerField(1, required=True) 1001 1002 field = messages.MessageField(MyMessage, 10) 1003 1004 # Will validate messages where is_initialized() is False. 1005 message = MyMessage() 1006 field.validate(message) 1007 message.field1 = 20 1008 field.validate(message) 1009 1010 def testEnumField(self): 1011 """Test the construction of enum fields.""" 1012 self.assertRaises(messages.FieldDefinitionError, 1013 messages.EnumField, 1014 str, 1015 10) 1016 1017 self.assertRaises(messages.FieldDefinitionError, 1018 messages.EnumField, 1019 messages.Enum, 1020 10) 1021 1022 class Color(messages.Enum): 1023 RED = 1 1024 GREEN = 2 1025 BLUE = 3 1026 1027 field = messages.EnumField(Color, 10) 1028 self.assertEquals(Color, field.type) 1029 1030 class Another(messages.Enum): 1031 VALUE = 1 1032 1033 self.assertRaises(messages.InvalidDefaultError, 1034 messages.EnumField, 1035 Color, 1036 10, 1037 default=Another.VALUE) 1038 1039 def testEnumField_ForwardReference(self): 1040 """Test the construction of forward reference enum fields.""" 1041 global MyMessage 1042 global ForwardEnum 1043 global ForwardMessage 1044 try: 1045 class MyMessage(messages.Message): 1046 1047 forward = messages.EnumField('ForwardEnum', 1) 1048 nested = messages.EnumField('ForwardMessage.NestedEnum', 2) 1049 inner = messages.EnumField('Inner', 3) 1050 1051 class Inner(messages.Enum): 1052 pass 1053 1054 class ForwardEnum(messages.Enum): 1055 pass 1056 1057 class ForwardMessage(messages.Message): 1058 1059 class NestedEnum(messages.Enum): 1060 pass 1061 1062 self.assertEquals(ForwardEnum, 1063 MyMessage.field_by_name('forward').type) 1064 1065 self.assertEquals(ForwardMessage.NestedEnum, 1066 MyMessage.field_by_name('nested').type) 1067 1068 self.assertEquals(MyMessage.Inner, 1069 MyMessage.field_by_name('inner').type) 1070 finally: 1071 try: 1072 del MyMessage 1073 del ForwardEnum 1074 del ForwardMessage 1075 except: 1076 pass 1077 1078 def testEnumField_WrongType(self): 1079 """Test that forward referencing the wrong type raises an error.""" 1080 global AMessage 1081 try: 1082 class AMessage(messages.Message): 1083 pass 1084 1085 class AnotherMessage(messages.Message): 1086 1087 a_field = messages.EnumField('AMessage', 1) 1088 1089 self.assertRaises(messages.FieldDefinitionError, 1090 getattr, 1091 AnotherMessage.field_by_name('a_field'), 1092 'type') 1093 finally: 1094 del AMessage 1095 1096 def testMessageDefinition(self): 1097 """Test that message definition is set on fields.""" 1098 class MyMessage(messages.Message): 1099 1100 my_field = messages.StringField(1) 1101 1102 self.assertEquals(MyMessage, 1103 MyMessage.field_by_name('my_field').message_definition()) 1104 1105 def testNoneAssignment(self): 1106 """Test that assigning None does not change comparison.""" 1107 class MyMessage(messages.Message): 1108 1109 my_field = messages.StringField(1) 1110 1111 m1 = MyMessage() 1112 m2 = MyMessage() 1113 m2.my_field = None 1114 self.assertEquals(m1, m2) 1115 1116 def testNonAsciiStr(self): 1117 """Test validation fails for non-ascii StringField values.""" 1118 class Thing(messages.Message): 1119 string_field = messages.StringField(2) 1120 1121 thing = Thing() 1122 self.assertRaisesWithRegexpMatch( 1123 messages.ValidationError, 1124 'Field string_field encountered non-ASCII string', 1125 setattr, thing, 'string_field', test_util.BINARY) 1126 1127 1128 class MessageTest(test_util.TestCase): 1129 """Tests for message class.""" 1130 1131 def CreateMessageClass(self): 1132 """Creates a simple message class with 3 fields. 1133 1134 Fields are defined in alphabetical order but with conflicting numeric 1135 order. 1136 """ 1137 class ComplexMessage(messages.Message): 1138 a3 = messages.IntegerField(3) 1139 b1 = messages.StringField(1) 1140 c2 = messages.StringField(2) 1141 1142 return ComplexMessage 1143 1144 def testSameNumbers(self): 1145 """Test that cannot assign two fields with same numbers.""" 1146 1147 def action(): 1148 class BadMessage(messages.Message): 1149 f1 = messages.IntegerField(1) 1150 f2 = messages.IntegerField(1) 1151 self.assertRaises(messages.DuplicateNumberError, 1152 action) 1153 1154 def testStrictAssignment(self): 1155 """Tests that cannot assign to unknown or non-reserved attributes.""" 1156 class SimpleMessage(messages.Message): 1157 field = messages.IntegerField(1) 1158 1159 simple_message = SimpleMessage() 1160 self.assertRaises(AttributeError, 1161 setattr, 1162 simple_message, 1163 'does_not_exist', 1164 10) 1165 1166 def testListAssignmentDoesNotCopy(self): 1167 class SimpleMessage(messages.Message): 1168 repeated = messages.IntegerField(1, repeated=True) 1169 1170 message = SimpleMessage() 1171 original = message.repeated 1172 message.repeated = [] 1173 self.assertFalse(original is message.repeated) 1174 1175 def testValidate_Optional(self): 1176 """Tests validation of optional fields.""" 1177 class SimpleMessage(messages.Message): 1178 non_required = messages.IntegerField(1) 1179 1180 simple_message = SimpleMessage() 1181 simple_message.check_initialized() 1182 simple_message.non_required = 10 1183 simple_message.check_initialized() 1184 1185 def testValidate_Required(self): 1186 """Tests validation of required fields.""" 1187 class SimpleMessage(messages.Message): 1188 required = messages.IntegerField(1, required=True) 1189 1190 simple_message = SimpleMessage() 1191 self.assertRaises(messages.ValidationError, 1192 simple_message.check_initialized) 1193 simple_message.required = 10 1194 simple_message.check_initialized() 1195 1196 def testValidate_Repeated(self): 1197 """Tests validation of repeated fields.""" 1198 class SimpleMessage(messages.Message): 1199 repeated = messages.IntegerField(1, repeated=True) 1200 1201 simple_message = SimpleMessage() 1202 1203 # Check valid values. 1204 for valid_value in [], [10], [10, 20], (), (10,), (10, 20): 1205 simple_message.repeated = valid_value 1206 simple_message.check_initialized() 1207 1208 # Check cleared. 1209 simple_message.repeated = [] 1210 simple_message.check_initialized() 1211 1212 # Check invalid values. 1213 for invalid_value in 10, ['10', '20'], [None], (None,): 1214 self.assertRaises(messages.ValidationError, 1215 setattr, simple_message, 'repeated', invalid_value) 1216 1217 def testIsInitialized(self): 1218 """Tests is_initialized.""" 1219 class SimpleMessage(messages.Message): 1220 required = messages.IntegerField(1, required=True) 1221 1222 simple_message = SimpleMessage() 1223 self.assertFalse(simple_message.is_initialized()) 1224 1225 simple_message.required = 10 1226 1227 self.assertTrue(simple_message.is_initialized()) 1228 1229 def testIsInitializedNestedField(self): 1230 """Tests is_initialized for nested fields.""" 1231 class SimpleMessage(messages.Message): 1232 required = messages.IntegerField(1, required=True) 1233 1234 class NestedMessage(messages.Message): 1235 simple = messages.MessageField(SimpleMessage, 1) 1236 1237 simple_message = SimpleMessage() 1238 self.assertFalse(simple_message.is_initialized()) 1239 nested_message = NestedMessage(simple=simple_message) 1240 self.assertFalse(nested_message.is_initialized()) 1241 1242 simple_message.required = 10 1243 1244 self.assertTrue(simple_message.is_initialized()) 1245 self.assertTrue(nested_message.is_initialized()) 1246 1247 def testInitializeNestedFieldFromDict(self): 1248 """Tests initializing nested fields from dict.""" 1249 class SimpleMessage(messages.Message): 1250 required = messages.IntegerField(1, required=True) 1251 1252 class NestedMessage(messages.Message): 1253 simple = messages.MessageField(SimpleMessage, 1) 1254 1255 class RepeatedMessage(messages.Message): 1256 simple = messages.MessageField(SimpleMessage, 1, repeated=True) 1257 1258 nested_message1 = NestedMessage(simple={'required': 10}) 1259 self.assertTrue(nested_message1.is_initialized()) 1260 self.assertTrue(nested_message1.simple.is_initialized()) 1261 1262 nested_message2 = NestedMessage() 1263 nested_message2.simple = {'required': 10} 1264 self.assertTrue(nested_message2.is_initialized()) 1265 self.assertTrue(nested_message2.simple.is_initialized()) 1266 1267 repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)] 1268 1269 repeated_message1 = RepeatedMessage(simple=repeated_values) 1270 self.assertEquals(3, len(repeated_message1.simple)) 1271 self.assertFalse(repeated_message1.is_initialized()) 1272 1273 repeated_message1.simple[0].required = 0 1274 self.assertTrue(repeated_message1.is_initialized()) 1275 1276 repeated_message2 = RepeatedMessage() 1277 repeated_message2.simple = repeated_values 1278 self.assertEquals(3, len(repeated_message2.simple)) 1279 self.assertFalse(repeated_message2.is_initialized()) 1280 1281 repeated_message2.simple[0].required = 0 1282 self.assertTrue(repeated_message2.is_initialized()) 1283 1284 def testNestedMethodsNotAllowed(self): 1285 """Test that method definitions on Message classes are not allowed.""" 1286 def action(): 1287 class WithMethods(messages.Message): 1288 def not_allowed(self): 1289 pass 1290 1291 self.assertRaises(messages.MessageDefinitionError, 1292 action) 1293 1294 def testNestedAttributesNotAllowed(self): 1295 """Test that attribute assignment on Message classes are not allowed.""" 1296 def int_attribute(): 1297 class WithMethods(messages.Message): 1298 not_allowed = 1 1299 1300 def string_attribute(): 1301 class WithMethods(messages.Message): 1302 not_allowed = 'not allowed' 1303 1304 def enum_attribute(): 1305 class WithMethods(messages.Message): 1306 not_allowed = Color.RED 1307 1308 for action in (int_attribute, string_attribute, enum_attribute): 1309 self.assertRaises(messages.MessageDefinitionError, 1310 action) 1311 1312 def testNameIsSetOnFields(self): 1313 """Make sure name is set on fields after Message class init.""" 1314 class HasNamedFields(messages.Message): 1315 field = messages.StringField(1) 1316 1317 self.assertEquals('field', HasNamedFields.field_by_number(1).name) 1318 1319 def testSubclassingMessageDisallowed(self): 1320 """Not permitted to create sub-classes of message classes.""" 1321 class SuperClass(messages.Message): 1322 pass 1323 1324 def action(): 1325 class SubClass(SuperClass): 1326 pass 1327 1328 self.assertRaises(messages.MessageDefinitionError, 1329 action) 1330 1331 def testAllFields(self): 1332 """Test all_fields method.""" 1333 ComplexMessage = self.CreateMessageClass() 1334 fields = list(ComplexMessage.all_fields()) 1335 1336 # Order does not matter, so sort now. 1337 fields = sorted(fields, key=lambda f: f.name) 1338 1339 self.assertEquals(3, len(fields)) 1340 self.assertEquals('a3', fields[0].name) 1341 self.assertEquals('b1', fields[1].name) 1342 self.assertEquals('c2', fields[2].name) 1343 1344 def testFieldByName(self): 1345 """Test getting field by name.""" 1346 ComplexMessage = self.CreateMessageClass() 1347 1348 self.assertEquals(3, ComplexMessage.field_by_name('a3').number) 1349 self.assertEquals(1, ComplexMessage.field_by_name('b1').number) 1350 self.assertEquals(2, ComplexMessage.field_by_name('c2').number) 1351 1352 self.assertRaises(KeyError, 1353 ComplexMessage.field_by_name, 1354 'unknown') 1355 1356 def testFieldByNumber(self): 1357 """Test getting field by number.""" 1358 ComplexMessage = self.CreateMessageClass() 1359 1360 self.assertEquals('a3', ComplexMessage.field_by_number(3).name) 1361 self.assertEquals('b1', ComplexMessage.field_by_number(1).name) 1362 self.assertEquals('c2', ComplexMessage.field_by_number(2).name) 1363 1364 self.assertRaises(KeyError, 1365 ComplexMessage.field_by_number, 1366 4) 1367 1368 def testGetAssignedValue(self): 1369 """Test getting the assigned value of a field.""" 1370 class SomeMessage(messages.Message): 1371 a_value = messages.StringField(1, default=u'a default') 1372 1373 message = SomeMessage() 1374 self.assertEquals(None, message.get_assigned_value('a_value')) 1375 1376 message.a_value = u'a string' 1377 self.assertEquals(u'a string', message.get_assigned_value('a_value')) 1378 1379 message.a_value = u'a default' 1380 self.assertEquals(u'a default', message.get_assigned_value('a_value')) 1381 1382 self.assertRaisesWithRegexpMatch( 1383 AttributeError, 1384 'Message SomeMessage has no field no_such_field', 1385 message.get_assigned_value, 1386 'no_such_field') 1387 1388 def testReset(self): 1389 """Test resetting a field value.""" 1390 class SomeMessage(messages.Message): 1391 a_value = messages.StringField(1, default=u'a default') 1392 repeated = messages.IntegerField(2, repeated=True) 1393 1394 message = SomeMessage() 1395 1396 self.assertRaises(AttributeError, message.reset, 'unknown') 1397 1398 self.assertEquals(u'a default', message.a_value) 1399 message.reset('a_value') 1400 self.assertEquals(u'a default', message.a_value) 1401 1402 message.a_value = u'a new value' 1403 self.assertEquals(u'a new value', message.a_value) 1404 message.reset('a_value') 1405 self.assertEquals(u'a default', message.a_value) 1406 1407 message.repeated = [1, 2, 3] 1408 self.assertEquals([1, 2, 3], message.repeated) 1409 saved = message.repeated 1410 message.reset('repeated') 1411 self.assertEquals([], message.repeated) 1412 self.assertIsInstance(message.repeated, messages.FieldList) 1413 self.assertEquals([1, 2, 3], saved) 1414 1415 def testAllowNestedEnums(self): 1416 """Test allowing nested enums in a message definition.""" 1417 class Trade(messages.Message): 1418 class Duration(messages.Enum): 1419 GTC = 1 1420 DAY = 2 1421 1422 class Currency(messages.Enum): 1423 USD = 1 1424 GBP = 2 1425 INR = 3 1426 1427 # Sorted by name order seems to be the only feasible option. 1428 self.assertEquals(['Currency', 'Duration'], Trade.__enums__) 1429 1430 # Message definition will now be set on Enumerated objects. 1431 self.assertEquals(Trade, Trade.Duration.message_definition()) 1432 1433 def testAllowNestedMessages(self): 1434 """Test allowing nested messages in a message definition.""" 1435 class Trade(messages.Message): 1436 class Lot(messages.Message): 1437 pass 1438 1439 class Agent(messages.Message): 1440 pass 1441 1442 # Sorted by name order seems to be the only feasible option. 1443 self.assertEquals(['Agent', 'Lot'], Trade.__messages__) 1444 self.assertEquals(Trade, Trade.Agent.message_definition()) 1445 self.assertEquals(Trade, Trade.Lot.message_definition()) 1446 1447 # But not Message itself. 1448 def action(): 1449 class Trade(messages.Message): 1450 NiceTry = messages.Message 1451 self.assertRaises(messages.MessageDefinitionError, action) 1452 1453 def testDisallowClassAssignments(self): 1454 """Test setting class attributes may not happen.""" 1455 class MyMessage(messages.Message): 1456 pass 1457 1458 self.assertRaises(AttributeError, 1459 setattr, 1460 MyMessage, 1461 'x', 1462 'do not assign') 1463 1464 def testEquality(self): 1465 """Test message class equality.""" 1466 # Comparison against enums must work. 1467 class MyEnum(messages.Enum): 1468 val1 = 1 1469 val2 = 2 1470 1471 # Comparisons against nested messages must work. 1472 class AnotherMessage(messages.Message): 1473 string = messages.StringField(1) 1474 1475 class MyMessage(messages.Message): 1476 field1 = messages.IntegerField(1) 1477 field2 = messages.EnumField(MyEnum, 2) 1478 field3 = messages.MessageField(AnotherMessage, 3) 1479 1480 message1 = MyMessage() 1481 1482 self.assertNotEquals('hi', message1) 1483 self.assertNotEquals(AnotherMessage(), message1) 1484 self.assertEquals(message1, message1) 1485 1486 message2 = MyMessage() 1487 1488 self.assertEquals(message1, message2) 1489 1490 message1.field1 = 10 1491 self.assertNotEquals(message1, message2) 1492 1493 message2.field1 = 20 1494 self.assertNotEquals(message1, message2) 1495 1496 message2.field1 = 10 1497 self.assertEquals(message1, message2) 1498 1499 message1.field2 = MyEnum.val1 1500 self.assertNotEquals(message1, message2) 1501 1502 message2.field2 = MyEnum.val2 1503 self.assertNotEquals(message1, message2) 1504 1505 message2.field2 = MyEnum.val1 1506 self.assertEquals(message1, message2) 1507 1508 message1.field3 = AnotherMessage() 1509 message1.field3.string = 'value1' 1510 self.assertNotEquals(message1, message2) 1511 1512 message2.field3 = AnotherMessage() 1513 message2.field3.string = 'value2' 1514 self.assertNotEquals(message1, message2) 1515 1516 message2.field3.string = 'value1' 1517 self.assertEquals(message1, message2) 1518 1519 def testEqualityWithUnknowns(self): 1520 """Test message class equality with unknown fields.""" 1521 1522 class MyMessage(messages.Message): 1523 field1 = messages.IntegerField(1) 1524 1525 message1 = MyMessage() 1526 message2 = MyMessage() 1527 self.assertEquals(message1, message2) 1528 message1.set_unrecognized_field('unknown1', 'value1', 1529 messages.Variant.STRING) 1530 self.assertEquals(message1, message2) 1531 1532 message1.set_unrecognized_field('unknown2', ['asdf', 3], 1533 messages.Variant.STRING) 1534 message1.set_unrecognized_field('unknown3', 4.7, 1535 messages.Variant.DOUBLE) 1536 self.assertEquals(message1, message2) 1537 1538 def testUnrecognizedFieldInvalidVariant(self): 1539 class MyMessage(messages.Message): 1540 field1 = messages.IntegerField(1) 1541 1542 message1 = MyMessage() 1543 self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', 1544 {'unhandled': 'type'}, None) 1545 self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4', 1546 {'unhandled': 'type'}, 123) 1547 1548 def testRepr(self): 1549 """Test represtation of Message object.""" 1550 class MyMessage(messages.Message): 1551 integer_value = messages.IntegerField(1) 1552 string_value = messages.StringField(2) 1553 unassigned = messages.StringField(3) 1554 unassigned_with_default = messages.StringField(4, default=u'a default') 1555 1556 my_message = MyMessage() 1557 my_message.integer_value = 42 1558 my_message.string_value = u'A string' 1559 1560 pat = re.compile(r"<MyMessage\n integer_value: 42\n" 1561 " string_value: [u]?'A string'>") 1562 self.assertTrue(pat.match(repr(my_message)) is not None) 1563 1564 def testValidation(self): 1565 """Test validation of message values.""" 1566 # Test optional. 1567 class SubMessage(messages.Message): 1568 pass 1569 1570 class Message(messages.Message): 1571 val = messages.MessageField(SubMessage, 1) 1572 1573 message = Message() 1574 1575 message_field = messages.MessageField(Message, 1) 1576 message_field.validate(message) 1577 message.val = SubMessage() 1578 message_field.validate(message) 1579 self.assertRaises(messages.ValidationError, 1580 setattr, message, 'val', [SubMessage()]) 1581 1582 # Test required. 1583 class Message(messages.Message): 1584 val = messages.MessageField(SubMessage, 1, required=True) 1585 1586 message = Message() 1587 1588 message_field = messages.MessageField(Message, 1) 1589 message_field.validate(message) 1590 message.val = SubMessage() 1591 message_field.validate(message) 1592 self.assertRaises(messages.ValidationError, 1593 setattr, message, 'val', [SubMessage()]) 1594 1595 # Test repeated. 1596 class Message(messages.Message): 1597 val = messages.MessageField(SubMessage, 1, repeated=True) 1598 1599 message = Message() 1600 1601 message_field = messages.MessageField(Message, 1) 1602 message_field.validate(message) 1603 self.assertRaisesWithRegexpMatch( 1604 messages.ValidationError, 1605 "Field val is repeated. Found: <SubMessage>", 1606 setattr, message, 'val', SubMessage()) 1607 message.val = [SubMessage()] 1608 message_field.validate(message) 1609 1610 def testDefinitionName(self): 1611 """Test message name.""" 1612 class MyMessage(messages.Message): 1613 pass 1614 1615 module_name = test_util.get_module_name(FieldTest) 1616 self.assertEquals('%s.MyMessage' % module_name, 1617 MyMessage.definition_name()) 1618 self.assertEquals(module_name, MyMessage.outer_definition_name()) 1619 self.assertEquals(module_name, MyMessage.definition_package()) 1620 1621 self.assertEquals(six.text_type, type(MyMessage.definition_name())) 1622 self.assertEquals(six.text_type, type(MyMessage.outer_definition_name())) 1623 self.assertEquals(six.text_type, type(MyMessage.definition_package())) 1624 1625 def testDefinitionName_OverrideModule(self): 1626 """Test message module is overriden by module package name.""" 1627 class MyMessage(messages.Message): 1628 pass 1629 1630 global package 1631 package = 'my.package' 1632 1633 try: 1634 self.assertEquals('my.package.MyMessage', MyMessage.definition_name()) 1635 self.assertEquals('my.package', MyMessage.outer_definition_name()) 1636 self.assertEquals('my.package', MyMessage.definition_package()) 1637 1638 self.assertEquals(six.text_type, type(MyMessage.definition_name())) 1639 self.assertEquals(six.text_type, type(MyMessage.outer_definition_name())) 1640 self.assertEquals(six.text_type, type(MyMessage.definition_package())) 1641 finally: 1642 del package 1643 1644 def testDefinitionName_NoModule(self): 1645 """Test what happens when there is no module for message.""" 1646 class MyMessage(messages.Message): 1647 pass 1648 1649 original_modules = sys.modules 1650 sys.modules = dict(sys.modules) 1651 try: 1652 del sys.modules[__name__] 1653 self.assertEquals('MyMessage', MyMessage.definition_name()) 1654 self.assertEquals(None, MyMessage.outer_definition_name()) 1655 self.assertEquals(None, MyMessage.definition_package()) 1656 1657 self.assertEquals(six.text_type, type(MyMessage.definition_name())) 1658 finally: 1659 sys.modules = original_modules 1660 1661 def testDefinitionName_Nested(self): 1662 """Test nested message names.""" 1663 class MyMessage(messages.Message): 1664 1665 class NestedMessage(messages.Message): 1666 1667 class NestedMessage(messages.Message): 1668 1669 pass 1670 1671 module_name = test_util.get_module_name(MessageTest) 1672 self.assertEquals('%s.MyMessage.NestedMessage' % module_name, 1673 MyMessage.NestedMessage.definition_name()) 1674 self.assertEquals('%s.MyMessage' % module_name, 1675 MyMessage.NestedMessage.outer_definition_name()) 1676 self.assertEquals(module_name, 1677 MyMessage.NestedMessage.definition_package()) 1678 1679 self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name, 1680 MyMessage.NestedMessage.NestedMessage.definition_name()) 1681 self.assertEquals( 1682 '%s.MyMessage.NestedMessage' % module_name, 1683 MyMessage.NestedMessage.NestedMessage.outer_definition_name()) 1684 self.assertEquals( 1685 module_name, 1686 MyMessage.NestedMessage.NestedMessage.definition_package()) 1687 1688 1689 def testMessageDefinition(self): 1690 """Test that enumeration knows its enclosing message definition.""" 1691 class OuterMessage(messages.Message): 1692 1693 class InnerMessage(messages.Message): 1694 pass 1695 1696 self.assertEquals(None, OuterMessage.message_definition()) 1697 self.assertEquals(OuterMessage, 1698 OuterMessage.InnerMessage.message_definition()) 1699 1700 def testConstructorKwargs(self): 1701 """Test kwargs via constructor.""" 1702 class SomeMessage(messages.Message): 1703 name = messages.StringField(1) 1704 number = messages.IntegerField(2) 1705 1706 expected = SomeMessage() 1707 expected.name = 'my name' 1708 expected.number = 200 1709 self.assertEquals(expected, SomeMessage(name='my name', number=200)) 1710 1711 def testConstructorNotAField(self): 1712 """Test kwargs via constructor with wrong names.""" 1713 class SomeMessage(messages.Message): 1714 pass 1715 1716 self.assertRaisesWithRegexpMatch( 1717 AttributeError, 1718 'May not assign arbitrary value does_not_exist to message SomeMessage', 1719 SomeMessage, 1720 does_not_exist=10) 1721 1722 def testGetUnsetRepeatedValue(self): 1723 class SomeMessage(messages.Message): 1724 repeated = messages.IntegerField(1, repeated=True) 1725 1726 instance = SomeMessage() 1727 self.assertEquals([], instance.repeated) 1728 self.assertTrue(isinstance(instance.repeated, messages.FieldList)) 1729 1730 def testCompareAutoInitializedRepeatedFields(self): 1731 class SomeMessage(messages.Message): 1732 repeated = messages.IntegerField(1, repeated=True) 1733 1734 message1 = SomeMessage(repeated=[]) 1735 message2 = SomeMessage() 1736 self.assertEquals(message1, message2) 1737 1738 def testUnknownValues(self): 1739 """Test message class equality with unknown fields.""" 1740 class MyMessage(messages.Message): 1741 field1 = messages.IntegerField(1) 1742 1743 message = MyMessage() 1744 self.assertEquals([], message.all_unrecognized_fields()) 1745 self.assertEquals((None, None), 1746 message.get_unrecognized_field_info('doesntexist')) 1747 self.assertEquals((None, None), 1748 message.get_unrecognized_field_info( 1749 'doesntexist', None, None)) 1750 self.assertEquals(('defaultvalue', 'defaultwire'), 1751 message.get_unrecognized_field_info( 1752 'doesntexist', 'defaultvalue', 'defaultwire')) 1753 self.assertEquals((3, None), 1754 message.get_unrecognized_field_info( 1755 'doesntexist', value_default=3)) 1756 1757 message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE) 1758 self.assertEquals(1, len(message.all_unrecognized_fields())) 1759 self.assertTrue('exists' in message.all_unrecognized_fields()) 1760 self.assertEquals((9.5, messages.Variant.DOUBLE), 1761 message.get_unrecognized_field_info('exists')) 1762 self.assertEquals((9.5, messages.Variant.DOUBLE), 1763 message.get_unrecognized_field_info('exists', 'type', 1764 1234)) 1765 self.assertEquals((1234, None), 1766 message.get_unrecognized_field_info('doesntexist', 1234)) 1767 1768 message.set_unrecognized_field('another', 'value', messages.Variant.STRING) 1769 self.assertEquals(2, len(message.all_unrecognized_fields())) 1770 self.assertTrue('exists' in message.all_unrecognized_fields()) 1771 self.assertTrue('another' in message.all_unrecognized_fields()) 1772 self.assertEquals((9.5, messages.Variant.DOUBLE), 1773 message.get_unrecognized_field_info('exists')) 1774 self.assertEquals(('value', messages.Variant.STRING), 1775 message.get_unrecognized_field_info('another')) 1776 1777 message.set_unrecognized_field('typetest1', ['list', 0, ('test',)], 1778 messages.Variant.STRING) 1779 self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING), 1780 message.get_unrecognized_field_info('typetest1')) 1781 message.set_unrecognized_field('typetest2', '', messages.Variant.STRING) 1782 self.assertEquals(('', messages.Variant.STRING), 1783 message.get_unrecognized_field_info('typetest2')) 1784 1785 def testPickle(self): 1786 """Testing pickling and unpickling of Message instances.""" 1787 global MyEnum 1788 global AnotherMessage 1789 global MyMessage 1790 1791 class MyEnum(messages.Enum): 1792 val1 = 1 1793 val2 = 2 1794 1795 class AnotherMessage(messages.Message): 1796 string = messages.StringField(1, repeated=True) 1797 1798 class MyMessage(messages.Message): 1799 field1 = messages.IntegerField(1) 1800 field2 = messages.EnumField(MyEnum, 2) 1801 field3 = messages.MessageField(AnotherMessage, 3) 1802 1803 message = MyMessage(field1=1, field2=MyEnum.val2, 1804 field3=AnotherMessage(string=['a', 'b', 'c'])) 1805 message.set_unrecognized_field('exists', 'value', messages.Variant.STRING) 1806 message.set_unrecognized_field('repeated', ['list', 0, ('test',)], 1807 messages.Variant.STRING) 1808 unpickled = pickle.loads(pickle.dumps(message)) 1809 self.assertEquals(message, unpickled) 1810 self.assertTrue(AnotherMessage.string is unpickled.field3.string.field) 1811 self.assertTrue('exists' in message.all_unrecognized_fields()) 1812 self.assertEquals(('value', messages.Variant.STRING), 1813 message.get_unrecognized_field_info('exists')) 1814 self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING), 1815 message.get_unrecognized_field_info('repeated')) 1816 1817 1818 class FindDefinitionTest(test_util.TestCase): 1819 """Test finding definitions relative to various definitions and modules.""" 1820 1821 def setUp(self): 1822 """Set up module-space. Starts off empty.""" 1823 self.modules = {} 1824 1825 def DefineModule(self, name): 1826 """Define a module and its parents in module space. 1827 1828 Modules that are already defined in self.modules are not re-created. 1829 1830 Args: 1831 name: Fully qualified name of modules to create. 1832 1833 Returns: 1834 Deepest nested module. For example: 1835 1836 DefineModule('a.b.c') # Returns c. 1837 """ 1838 name_path = name.split('.') 1839 full_path = [] 1840 for node in name_path: 1841 full_path.append(node) 1842 full_name = '.'.join(full_path) 1843 self.modules.setdefault(full_name, types.ModuleType(full_name)) 1844 return self.modules[name] 1845 1846 def DefineMessage(self, module, name, children={}, add_to_module=True): 1847 """Define a new Message class in the context of a module. 1848 1849 Used for easily describing complex Message hierarchy. Message is defined 1850 including all child definitions. 1851 1852 Args: 1853 module: Fully qualified name of module to place Message class in. 1854 name: Name of Message to define within module. 1855 children: Define any level of nesting of children definitions. To define 1856 a message, map the name to another dictionary. The dictionary can 1857 itself contain additional definitions, and so on. To map to an Enum, 1858 define the Enum class separately and map it by name. 1859 add_to_module: If True, new Message class is added to module. If False, 1860 new Message is not added. 1861 """ 1862 # Make sure module exists. 1863 module_instance = self.DefineModule(module) 1864 1865 # Recursively define all child messages. 1866 for attribute, value in children.items(): 1867 if isinstance(value, dict): 1868 children[attribute] = self.DefineMessage( 1869 module, attribute, value, False) 1870 1871 # Override default __module__ variable. 1872 children['__module__'] = module 1873 1874 # Instantiate and possibly add to module. 1875 message_class = type(name, (messages.Message,), dict(children)) 1876 if add_to_module: 1877 setattr(module_instance, name, message_class) 1878 return message_class 1879 1880 def Importer(self, module, globals='', locals='', fromlist=None): 1881 """Importer function. 1882 1883 Acts like __import__. Only loads modules from self.modules. Does not 1884 try to load real modules defined elsewhere. Does not try to handle relative 1885 imports. 1886 1887 Args: 1888 module: Fully qualified name of module to load from self.modules. 1889 """ 1890 if fromlist is None: 1891 module = module.split('.')[0] 1892 try: 1893 return self.modules[module] 1894 except KeyError: 1895 raise ImportError() 1896 1897 def testNoSuchModule(self): 1898 """Test searching for definitions that do no exist.""" 1899 self.assertRaises(messages.DefinitionNotFoundError, 1900 messages.find_definition, 1901 'does.not.exist', 1902 importer=self.Importer) 1903 1904 def testRefersToModule(self): 1905 """Test that referring to a module does not return that module.""" 1906 self.DefineModule('i.am.a.module') 1907 self.assertRaises(messages.DefinitionNotFoundError, 1908 messages.find_definition, 1909 'i.am.a.module', 1910 importer=self.Importer) 1911 1912 def testNoDefinition(self): 1913 """Test not finding a definition in an existing module.""" 1914 self.DefineModule('i.am.a.module') 1915 self.assertRaises(messages.DefinitionNotFoundError, 1916 messages.find_definition, 1917 'i.am.a.module.MyMessage', 1918 importer=self.Importer) 1919 1920 def testNotADefinition(self): 1921 """Test trying to fetch something that is not a definition.""" 1922 module = self.DefineModule('i.am.a.module') 1923 setattr(module, 'A', 'a string') 1924 self.assertRaises(messages.DefinitionNotFoundError, 1925 messages.find_definition, 1926 'i.am.a.module.A', 1927 importer=self.Importer) 1928 1929 def testGlobalFind(self): 1930 """Test finding definitions from fully qualified module names.""" 1931 A = self.DefineMessage('a.b.c', 'A', {}) 1932 self.assertEquals(A, messages.find_definition('a.b.c.A', 1933 importer=self.Importer)) 1934 B = self.DefineMessage('a.b.c', 'B', {'C':{}}) 1935 self.assertEquals(B.C, messages.find_definition('a.b.c.B.C', 1936 importer=self.Importer)) 1937 1938 def testRelativeToModule(self): 1939 """Test finding definitions relative to modules.""" 1940 # Define modules. 1941 a = self.DefineModule('a') 1942 b = self.DefineModule('a.b') 1943 c = self.DefineModule('a.b.c') 1944 1945 # Define messages. 1946 A = self.DefineMessage('a', 'A') 1947 B = self.DefineMessage('a.b', 'B') 1948 C = self.DefineMessage('a.b.c', 'C') 1949 D = self.DefineMessage('a.b.d', 'D') 1950 1951 # Find A, B, C and D relative to a. 1952 self.assertEquals(A, messages.find_definition( 1953 'A', a, importer=self.Importer)) 1954 self.assertEquals(B, messages.find_definition( 1955 'b.B', a, importer=self.Importer)) 1956 self.assertEquals(C, messages.find_definition( 1957 'b.c.C', a, importer=self.Importer)) 1958 self.assertEquals(D, messages.find_definition( 1959 'b.d.D', a, importer=self.Importer)) 1960 1961 # Find A, B, C and D relative to b. 1962 self.assertEquals(A, messages.find_definition( 1963 'A', b, importer=self.Importer)) 1964 self.assertEquals(B, messages.find_definition( 1965 'B', b, importer=self.Importer)) 1966 self.assertEquals(C, messages.find_definition( 1967 'c.C', b, importer=self.Importer)) 1968 self.assertEquals(D, messages.find_definition( 1969 'd.D', b, importer=self.Importer)) 1970 1971 # Find A, B, C and D relative to c. Module d is the same case as c. 1972 self.assertEquals(A, messages.find_definition( 1973 'A', c, importer=self.Importer)) 1974 self.assertEquals(B, messages.find_definition( 1975 'B', c, importer=self.Importer)) 1976 self.assertEquals(C, messages.find_definition( 1977 'C', c, importer=self.Importer)) 1978 self.assertEquals(D, messages.find_definition( 1979 'd.D', c, importer=self.Importer)) 1980 1981 def testRelativeToMessages(self): 1982 """Test finding definitions relative to Message definitions.""" 1983 A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}}) 1984 B = A.B 1985 C = A.B.C 1986 D = A.B.D 1987 1988 # Find relative to A. 1989 self.assertEquals(A, messages.find_definition( 1990 'A', A, importer=self.Importer)) 1991 self.assertEquals(B, messages.find_definition( 1992 'B', A, importer=self.Importer)) 1993 self.assertEquals(C, messages.find_definition( 1994 'B.C', A, importer=self.Importer)) 1995 self.assertEquals(D, messages.find_definition( 1996 'B.D', A, importer=self.Importer)) 1997 1998 # Find relative to B. 1999 self.assertEquals(A, messages.find_definition( 2000 'A', B, importer=self.Importer)) 2001 self.assertEquals(B, messages.find_definition( 2002 'B', B, importer=self.Importer)) 2003 self.assertEquals(C, messages.find_definition( 2004 'C', B, importer=self.Importer)) 2005 self.assertEquals(D, messages.find_definition( 2006 'D', B, importer=self.Importer)) 2007 2008 # Find relative to C. 2009 self.assertEquals(A, messages.find_definition( 2010 'A', C, importer=self.Importer)) 2011 self.assertEquals(B, messages.find_definition( 2012 'B', C, importer=self.Importer)) 2013 self.assertEquals(C, messages.find_definition( 2014 'C', C, importer=self.Importer)) 2015 self.assertEquals(D, messages.find_definition( 2016 'D', C, importer=self.Importer)) 2017 2018 # Find relative to C searching from c. 2019 self.assertEquals(A, messages.find_definition( 2020 'b.A', C, importer=self.Importer)) 2021 self.assertEquals(B, messages.find_definition( 2022 'b.A.B', C, importer=self.Importer)) 2023 self.assertEquals(C, messages.find_definition( 2024 'b.A.B.C', C, importer=self.Importer)) 2025 self.assertEquals(D, messages.find_definition( 2026 'b.A.B.D', C, importer=self.Importer)) 2027 2028 def testAbsoluteReference(self): 2029 """Test finding absolute definition names.""" 2030 # Define modules. 2031 a = self.DefineModule('a') 2032 b = self.DefineModule('a.a') 2033 2034 # Define messages. 2035 aA = self.DefineMessage('a', 'A') 2036 aaA = self.DefineMessage('a.a', 'A') 2037 2038 # Always find a.A. 2039 self.assertEquals(aA, messages.find_definition('.a.A', None, 2040 importer=self.Importer)) 2041 self.assertEquals(aA, messages.find_definition('.a.A', a, 2042 importer=self.Importer)) 2043 self.assertEquals(aA, messages.find_definition('.a.A', aA, 2044 importer=self.Importer)) 2045 self.assertEquals(aA, messages.find_definition('.a.A', aaA, 2046 importer=self.Importer)) 2047 2048 def testFindEnum(self): 2049 """Test that Enums are found.""" 2050 class Color(messages.Enum): 2051 pass 2052 A = self.DefineMessage('a', 'A', {'Color': Color}) 2053 2054 self.assertEquals( 2055 Color, 2056 messages.find_definition('Color', A, importer=self.Importer)) 2057 2058 def testFalseScope(self): 2059 """Test that Message definitions nested in strange objects are hidden.""" 2060 global X 2061 class X(object): 2062 class A(messages.Message): 2063 pass 2064 2065 self.assertRaises(TypeError, messages.find_definition, 'A', X) 2066 self.assertRaises(messages.DefinitionNotFoundError, 2067 messages.find_definition, 2068 'X.A', sys.modules[__name__]) 2069 2070 def testSearchAttributeFirst(self): 2071 """Make sure not faked out by module, but continues searching.""" 2072 A = self.DefineMessage('a', 'A') 2073 module_A = self.DefineModule('a.A') 2074 2075 self.assertEquals(A, messages.find_definition( 2076 'a.A', None, importer=self.Importer)) 2077 2078 2079 class FindDefinitionUnicodeTests(test_util.TestCase): 2080 2081 # TODO(craigcitro): Fix this test and re-enable it. 2082 def notatestUnicodeString(self): 2083 """Test using unicode names.""" 2084 from protorpc import registry 2085 self.assertEquals('ServiceMapping', 2086 messages.find_definition( 2087 u'protorpc.registry.ServiceMapping', 2088 None).__name__) 2089 2090 2091 def main(): 2092 unittest.main() 2093 2094 2095 if __name__ == '__main__': 2096 main() 2097