Home | History | Annotate | Download | only in protorpc
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2010 Google Inc.
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #     http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 #
     17 
     18 """Tests for protorpc.messages."""
     19 import six
     20 
     21 __author__ = 'rafek (at] google.com (Rafe Kaplan)'
     22 
     23 
     24 import pickle
     25 import re
     26 import sys
     27 import types
     28 import unittest
     29 
     30 from protorpc import descriptor
     31 from protorpc import message_types
     32 from protorpc import messages
     33 from protorpc import test_util
     34 
     35 
     36 class ModuleInterfaceTest(test_util.ModuleInterfaceTest,
     37                           test_util.TestCase):
     38 
     39   MODULE = messages
     40 
     41 
     42 class ValidationErrorTest(test_util.TestCase):
     43 
     44   def testStr_NoFieldName(self):
     45     """Test string version of ValidationError when no name provided."""
     46     self.assertEquals('Validation error',
     47                       str(messages.ValidationError('Validation error')))
     48 
     49   def testStr_FieldName(self):
     50     """Test string version of ValidationError when no name provided."""
     51     validation_error = messages.ValidationError('Validation error')
     52     validation_error.field_name = 'a_field'
     53     self.assertEquals('Validation error', str(validation_error))
     54 
     55 
     56 class EnumTest(test_util.TestCase):
     57 
     58   def setUp(self):
     59     """Set up tests."""
     60     # Redefine Color class in case so that changes to it (an error) in one test
     61     # does not affect other tests.
     62     global Color
     63     class Color(messages.Enum):
     64       RED = 20
     65       ORANGE = 2
     66       YELLOW = 40
     67       GREEN = 4
     68       BLUE = 50
     69       INDIGO = 5
     70       VIOLET = 80
     71 
     72   def testNames(self):
     73     """Test that names iterates over enum names."""
     74     self.assertEquals(
     75         set(['BLUE', 'GREEN', 'INDIGO', 'ORANGE', 'RED', 'VIOLET', 'YELLOW']),
     76         set(Color.names()))
     77 
     78   def testNumbers(self):
     79     """Tests that numbers iterates of enum numbers."""
     80     self.assertEquals(set([2, 4, 5, 20, 40, 50, 80]), set(Color.numbers()))
     81 
     82   def testIterate(self):
     83     """Test that __iter__ iterates over all enum values."""
     84     self.assertEquals(set(Color),
     85                       set([Color.RED,
     86                            Color.ORANGE,
     87                            Color.YELLOW,
     88                            Color.GREEN,
     89                            Color.BLUE,
     90                            Color.INDIGO,
     91                            Color.VIOLET]))
     92 
     93   def testNaturalOrder(self):
     94     """Test that natural order enumeration is in numeric order."""
     95     self.assertEquals([Color.ORANGE,
     96                        Color.GREEN,
     97                        Color.INDIGO,
     98                        Color.RED,
     99                        Color.YELLOW,
    100                        Color.BLUE,
    101                        Color.VIOLET],
    102                       sorted(Color))
    103 
    104   def testByName(self):
    105     """Test look-up by name."""
    106     self.assertEquals(Color.RED, Color.lookup_by_name('RED'))
    107     self.assertRaises(KeyError, Color.lookup_by_name, 20)
    108     self.assertRaises(KeyError, Color.lookup_by_name, Color.RED)
    109 
    110   def testByNumber(self):
    111     """Test look-up by number."""
    112     self.assertRaises(KeyError, Color.lookup_by_number, 'RED')
    113     self.assertEquals(Color.RED, Color.lookup_by_number(20))
    114     self.assertRaises(KeyError, Color.lookup_by_number, Color.RED)
    115 
    116   def testConstructor(self):
    117     """Test that constructor look-up by name or number."""
    118     self.assertEquals(Color.RED, Color('RED'))
    119     self.assertEquals(Color.RED, Color(u'RED'))
    120     self.assertEquals(Color.RED, Color(20))
    121     if six.PY2:
    122         self.assertEquals(Color.RED, Color(long(20)))
    123     self.assertEquals(Color.RED, Color(Color.RED))
    124     self.assertRaises(TypeError, Color, 'Not exists')
    125     self.assertRaises(TypeError, Color, 'Red')
    126     self.assertRaises(TypeError, Color, 100)
    127     self.assertRaises(TypeError, Color, 10.0)
    128 
    129   def testLen(self):
    130     """Test that len function works to count enums."""
    131     self.assertEquals(7, len(Color))
    132 
    133   def testNoSubclasses(self):
    134     """Test that it is not possible to sub-class enum classes."""
    135     def declare_subclass():
    136       class MoreColor(Color):
    137         pass
    138     self.assertRaises(messages.EnumDefinitionError,
    139                       declare_subclass)
    140 
    141   def testClassNotMutable(self):
    142     """Test that enum classes themselves are not mutable."""
    143     self.assertRaises(AttributeError,
    144                       setattr,
    145                       Color,
    146                       'something_new',
    147                       10)
    148 
    149   def testInstancesMutable(self):
    150     """Test that enum instances are not mutable."""
    151     self.assertRaises(TypeError,
    152                       setattr,
    153                       Color.RED,
    154                       'something_new',
    155                       10)
    156 
    157   def testDefEnum(self):
    158     """Test def_enum works by building enum class from dict."""
    159     WeekDay = messages.Enum.def_enum({'Monday': 1,
    160                                       'Tuesday': 2,
    161                                       'Wednesday': 3,
    162                                       'Thursday': 4,
    163                                       'Friday': 6,
    164                                       'Saturday': 7,
    165                                       'Sunday': 8},
    166                                      'WeekDay')
    167     self.assertEquals('Wednesday', WeekDay(3).name)
    168     self.assertEquals(6, WeekDay('Friday').number)
    169     self.assertEquals(WeekDay.Sunday, WeekDay('Sunday'))
    170 
    171   def testNonInt(self):
    172     """Test that non-integer values rejection by enum def."""
    173     self.assertRaises(messages.EnumDefinitionError,
    174                       messages.Enum.def_enum,
    175                       {'Bad': '1'},
    176                       'BadEnum')
    177 
    178   def testNegativeInt(self):
    179     """Test that negative numbers rejection by enum def."""
    180     self.assertRaises(messages.EnumDefinitionError,
    181                       messages.Enum.def_enum,
    182                       {'Bad': -1},
    183                       'BadEnum')
    184 
    185   def testLowerBound(self):
    186     """Test that zero is accepted by enum def."""
    187     class NotImportant(messages.Enum):
    188       """Testing for value zero"""
    189       VALUE = 0
    190 
    191     self.assertEquals(0, int(NotImportant.VALUE))
    192 
    193   def testTooLargeInt(self):
    194     """Test that numbers too large are rejected."""
    195     self.assertRaises(messages.EnumDefinitionError,
    196                       messages.Enum.def_enum,
    197                       {'Bad': (2 ** 29)},
    198                       'BadEnum')
    199 
    200   def testRepeatedInt(self):
    201     """Test duplicated numbers are forbidden."""
    202     self.assertRaises(messages.EnumDefinitionError,
    203                       messages.Enum.def_enum,
    204                       {'Ok': 1, 'Repeated': 1},
    205                       'BadEnum')
    206 
    207   def testStr(self):
    208     """Test converting to string."""
    209     self.assertEquals('RED', str(Color.RED))
    210     self.assertEquals('ORANGE', str(Color.ORANGE))
    211 
    212   def testInt(self):
    213     """Test converting to int."""
    214     self.assertEquals(20, int(Color.RED))
    215     self.assertEquals(2, int(Color.ORANGE))
    216 
    217   def testRepr(self):
    218     """Test enum representation."""
    219     self.assertEquals('Color(RED, 20)', repr(Color.RED))
    220     self.assertEquals('Color(YELLOW, 40)', repr(Color.YELLOW))
    221 
    222   def testDocstring(self):
    223     """Test that docstring is supported ok."""
    224     class NotImportant(messages.Enum):
    225       """I have a docstring."""
    226 
    227       VALUE1 = 1
    228 
    229     self.assertEquals('I have a docstring.', NotImportant.__doc__)
    230 
    231   def testDeleteEnumValue(self):
    232     """Test that enum values cannot be deleted."""
    233     self.assertRaises(TypeError, delattr, Color, 'RED')
    234 
    235   def testEnumName(self):
    236     """Test enum name."""
    237     module_name = test_util.get_module_name(EnumTest)
    238     self.assertEquals('%s.Color' % module_name, Color.definition_name())
    239     self.assertEquals(module_name, Color.outer_definition_name())
    240     self.assertEquals(module_name, Color.definition_package())
    241 
    242   def testDefinitionName_OverrideModule(self):
    243     """Test enum module is overriden by module package name."""
    244     global package
    245     try:
    246       package = 'my.package'
    247       self.assertEquals('my.package.Color', Color.definition_name())
    248       self.assertEquals('my.package', Color.outer_definition_name())
    249       self.assertEquals('my.package', Color.definition_package())
    250     finally:
    251       del package
    252 
    253   def testDefinitionName_NoModule(self):
    254     """Test what happens when there is no module for enum."""
    255     class Enum1(messages.Enum):
    256       pass
    257 
    258     original_modules = sys.modules
    259     sys.modules = dict(sys.modules)
    260     try:
    261       del sys.modules[__name__]
    262       self.assertEquals('Enum1', Enum1.definition_name())
    263       self.assertEquals(None, Enum1.outer_definition_name())
    264       self.assertEquals(None, Enum1.definition_package())
    265       self.assertEquals(six.text_type, type(Enum1.definition_name()))
    266     finally:
    267       sys.modules = original_modules
    268 
    269   def testDefinitionName_Nested(self):
    270     """Test nested Enum names."""
    271     class MyMessage(messages.Message):
    272 
    273       class NestedEnum(messages.Enum):
    274 
    275         pass
    276 
    277       class NestedMessage(messages.Message):
    278 
    279         class NestedEnum(messages.Enum):
    280 
    281           pass
    282 
    283     module_name = test_util.get_module_name(EnumTest)
    284     self.assertEquals('%s.MyMessage.NestedEnum' % module_name,
    285                       MyMessage.NestedEnum.definition_name())
    286     self.assertEquals('%s.MyMessage' % module_name,
    287                       MyMessage.NestedEnum.outer_definition_name())
    288     self.assertEquals(module_name,
    289                       MyMessage.NestedEnum.definition_package())
    290 
    291     self.assertEquals('%s.MyMessage.NestedMessage.NestedEnum' % module_name,
    292                       MyMessage.NestedMessage.NestedEnum.definition_name())
    293     self.assertEquals(
    294       '%s.MyMessage.NestedMessage' % module_name,
    295       MyMessage.NestedMessage.NestedEnum.outer_definition_name())
    296     self.assertEquals(module_name,
    297                       MyMessage.NestedMessage.NestedEnum.definition_package())
    298 
    299   def testMessageDefinition(self):
    300     """Test that enumeration knows its enclosing message definition."""
    301     class OuterEnum(messages.Enum):
    302       pass
    303 
    304     self.assertEquals(None, OuterEnum.message_definition())
    305 
    306     class OuterMessage(messages.Message):
    307 
    308       class InnerEnum(messages.Enum):
    309         pass
    310 
    311     self.assertEquals(OuterMessage, OuterMessage.InnerEnum.message_definition())
    312 
    313   def testComparison(self):
    314     """Test comparing various enums to different types."""
    315     class Enum1(messages.Enum):
    316       VAL1 = 1
    317       VAL2 = 2
    318 
    319     class Enum2(messages.Enum):
    320       VAL1 = 1
    321 
    322     self.assertEquals(Enum1.VAL1, Enum1.VAL1)
    323     self.assertNotEquals(Enum1.VAL1, Enum1.VAL2)
    324     self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
    325     self.assertNotEquals(Enum1.VAL1, 'VAL1')
    326     self.assertNotEquals(Enum1.VAL1, 1)
    327     self.assertNotEquals(Enum1.VAL1, 2)
    328     self.assertNotEquals(Enum1.VAL1, None)
    329     self.assertNotEquals(Enum1.VAL1, Enum2.VAL1)
    330 
    331     self.assertTrue(Enum1.VAL1 < Enum1.VAL2)
    332     self.assertTrue(Enum1.VAL2 > Enum1.VAL1)
    333 
    334     self.assertNotEquals(1, Enum2.VAL1)
    335 
    336   def testPickle(self):
    337     """Testing pickling and unpickling of Enum instances."""
    338     colors = list(Color)
    339     unpickled = pickle.loads(pickle.dumps(colors))
    340     self.assertEquals(colors, unpickled)
    341     # Unpickling shouldn't create new enum instances.
    342     for i, color in enumerate(colors):
    343       self.assertTrue(color is unpickled[i])
    344 
    345 
    346 class FieldListTest(test_util.TestCase):
    347 
    348   def setUp(self):
    349     self.integer_field = messages.IntegerField(1, repeated=True)
    350 
    351   def testConstructor(self):
    352     self.assertEquals([1, 2, 3],
    353                       messages.FieldList(self.integer_field, [1, 2, 3]))
    354     self.assertEquals([1, 2, 3],
    355                       messages.FieldList(self.integer_field, (1, 2, 3)))
    356     self.assertEquals([], messages.FieldList(self.integer_field, []))
    357 
    358   def testNone(self):
    359     self.assertRaises(TypeError, messages.FieldList, self.integer_field, None)
    360 
    361   def testDoNotAutoConvertString(self):
    362     string_field = messages.StringField(1, repeated=True)
    363     self.assertRaises(messages.ValidationError,
    364                       messages.FieldList, string_field, 'abc')
    365 
    366   def testConstructorCopies(self):
    367     a_list = [1, 3, 6]
    368     field_list = messages.FieldList(self.integer_field, a_list)
    369     self.assertFalse(a_list is field_list)
    370     self.assertFalse(field_list is
    371                      messages.FieldList(self.integer_field, field_list))
    372 
    373   def testNonRepeatedField(self):
    374     self.assertRaisesWithRegexpMatch(
    375       messages.FieldDefinitionError,
    376       'FieldList may only accept repeated fields',
    377       messages.FieldList,
    378       messages.IntegerField(1),
    379       [])
    380 
    381   def testConstructor_InvalidValues(self):
    382     self.assertRaisesWithRegexpMatch(
    383       messages.ValidationError,
    384       re.escape("Expected type %r "
    385                 "for IntegerField, found 1 (type %r)"
    386                % (six.integer_types, str)),
    387       messages.FieldList, self.integer_field, ["1", "2", "3"])
    388 
    389   def testConstructor_Scalars(self):
    390     self.assertRaisesWithRegexpMatch(
    391       messages.ValidationError,
    392       "IntegerField is repeated. Found: 3",
    393       messages.FieldList, self.integer_field, 3)
    394 
    395     self.assertRaisesWithRegexpMatch(
    396       messages.ValidationError,
    397       "IntegerField is repeated. Found: <(list[_]?|sequence)iterator object",
    398       messages.FieldList, self.integer_field, iter([1, 2, 3]))
    399 
    400   def testSetSlice(self):
    401     field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
    402     field_list[1:3] = [10, 20]
    403     self.assertEquals([1, 10, 20, 4, 5], field_list)
    404 
    405   def testSetSlice_InvalidValues(self):
    406     field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
    407 
    408     def setslice():
    409       field_list[1:3] = ['10', '20']
    410 
    411     msg_re = re.escape("Expected type %r "
    412                        "for IntegerField, found 10 (type %r)"
    413                              % (six.integer_types, str))
    414     self.assertRaisesWithRegexpMatch(
    415       messages.ValidationError,
    416       msg_re,
    417       setslice)
    418 
    419   def testSetItem(self):
    420     field_list = messages.FieldList(self.integer_field, [2])
    421     field_list[0] = 10
    422     self.assertEquals([10], field_list)
    423 
    424   def testSetItem_InvalidValues(self):
    425     field_list = messages.FieldList(self.integer_field, [2])
    426 
    427     def setitem():
    428       field_list[0] = '10'
    429     self.assertRaisesWithRegexpMatch(
    430       messages.ValidationError,
    431       re.escape("Expected type %r "
    432                 "for IntegerField, found 10 (type %r)"
    433                % (six.integer_types, str)),
    434       setitem)
    435 
    436   def testAppend(self):
    437     field_list = messages.FieldList(self.integer_field, [2])
    438     field_list.append(10)
    439     self.assertEquals([2, 10], field_list)
    440 
    441   def testAppend_InvalidValues(self):
    442     field_list = messages.FieldList(self.integer_field, [2])
    443     field_list.name = 'a_field'
    444 
    445     def append():
    446       field_list.append('10')
    447     self.assertRaisesWithRegexpMatch(
    448       messages.ValidationError,
    449       re.escape("Expected type %r "
    450                 "for IntegerField, found 10 (type %r)"
    451                % (six.integer_types, str)),
    452       append)
    453 
    454   def testExtend(self):
    455     field_list = messages.FieldList(self.integer_field, [2])
    456     field_list.extend([10])
    457     self.assertEquals([2, 10], field_list)
    458 
    459   def testExtend_InvalidValues(self):
    460     field_list = messages.FieldList(self.integer_field, [2])
    461 
    462     def extend():
    463       field_list.extend(['10'])
    464     self.assertRaisesWithRegexpMatch(
    465       messages.ValidationError,
    466       re.escape("Expected type %r "
    467                 "for IntegerField, found 10 (type %r)"
    468                % (six.integer_types, str)),
    469       extend)
    470 
    471   def testInsert(self):
    472     field_list = messages.FieldList(self.integer_field, [2, 3])
    473     field_list.insert(1, 10)
    474     self.assertEquals([2, 10, 3], field_list)
    475 
    476   def testInsert_InvalidValues(self):
    477     field_list = messages.FieldList(self.integer_field, [2, 3])
    478 
    479     def insert():
    480       field_list.insert(1, '10')
    481     self.assertRaisesWithRegexpMatch(
    482       messages.ValidationError,
    483       re.escape("Expected type %r "
    484                 "for IntegerField, found 10 (type %r)"
    485                % (six.integer_types, str)),
    486       insert)
    487 
    488   def testPickle(self):
    489     """Testing pickling and unpickling of disconnected FieldList instances."""
    490     field_list = messages.FieldList(self.integer_field, [1, 2, 3, 4, 5])
    491     unpickled = pickle.loads(pickle.dumps(field_list))
    492     self.assertEquals(field_list, unpickled)
    493     self.assertIsInstance(unpickled.field, messages.IntegerField)
    494     self.assertEquals(1, unpickled.field.number)
    495     self.assertTrue(unpickled.field.repeated)
    496 
    497 
    498 class FieldTest(test_util.TestCase):
    499 
    500   def ActionOnAllFieldClasses(self, action):
    501     """Test all field classes except Message and Enum.
    502 
    503     Message and Enum require separate tests.
    504 
    505     Args:
    506       action: Callable that takes the field class as a parameter.
    507     """
    508     for field_class in (messages.IntegerField,
    509                         messages.FloatField,
    510                         messages.BooleanField,
    511                         messages.BytesField,
    512                         messages.StringField,
    513                        ):
    514       action(field_class)
    515 
    516   def testNumberAttribute(self):
    517     """Test setting the number attribute."""
    518     def action(field_class):
    519       # Check range.
    520       self.assertRaises(messages.InvalidNumberError,
    521                         field_class,
    522                         0)
    523       self.assertRaises(messages.InvalidNumberError,
    524                         field_class,
    525                         -1)
    526       self.assertRaises(messages.InvalidNumberError,
    527                         field_class,
    528                         messages.MAX_FIELD_NUMBER + 1)
    529 
    530       # Check reserved.
    531       self.assertRaises(messages.InvalidNumberError,
    532                         field_class,
    533                         messages.FIRST_RESERVED_FIELD_NUMBER)
    534       self.assertRaises(messages.InvalidNumberError,
    535                         field_class,
    536                         messages.LAST_RESERVED_FIELD_NUMBER)
    537       self.assertRaises(messages.InvalidNumberError,
    538                         field_class,
    539                         '1')
    540 
    541       # This one should work.
    542       field_class(number=1)
    543     self.ActionOnAllFieldClasses(action)
    544 
    545   def testRequiredAndRepeated(self):
    546     """Test setting the required and repeated fields."""
    547     def action(field_class):
    548       field_class(1, required=True)
    549       field_class(1, repeated=True)
    550       self.assertRaises(messages.FieldDefinitionError,
    551                         field_class,
    552                         1,
    553                         required=True,
    554                         repeated=True)
    555     self.ActionOnAllFieldClasses(action)
    556 
    557   def testInvalidVariant(self):
    558     """Test field with invalid variants."""
    559     def action(field_class):
    560       if field_class is not message_types.DateTimeField:
    561         self.assertRaises(messages.InvalidVariantError,
    562                           field_class,
    563                           1,
    564                           variant=messages.Variant.ENUM)
    565     self.ActionOnAllFieldClasses(action)
    566 
    567   def testDefaultVariant(self):
    568     """Test that default variant is used when not set."""
    569     def action(field_class):
    570       field = field_class(1)
    571       self.assertEquals(field_class.DEFAULT_VARIANT, field.variant)
    572 
    573     self.ActionOnAllFieldClasses(action)
    574 
    575   def testAlternateVariant(self):
    576     """Test that default variant is used when not set."""
    577     field = messages.IntegerField(1, variant=messages.Variant.UINT32)
    578     self.assertEquals(messages.Variant.UINT32, field.variant)
    579 
    580   def testDefaultFields_Single(self):
    581     """Test default field is correct type (single)."""
    582     defaults = {messages.IntegerField: 10,
    583                 messages.FloatField: 1.5,
    584                 messages.BooleanField: False,
    585                 messages.BytesField: b'abc',
    586                 messages.StringField: u'abc',
    587                }
    588 
    589     def action(field_class):
    590       field_class(1, default=defaults[field_class])
    591     self.ActionOnAllFieldClasses(action)
    592 
    593     # Run defaults test again checking for str/unicode compatiblity.
    594     defaults[messages.StringField] = 'abc'
    595     self.ActionOnAllFieldClasses(action)
    596 
    597   def testStringField_BadUnicodeInDefault(self):
    598     """Test binary values in string field."""
    599     self.assertRaisesWithRegexpMatch(
    600       messages.InvalidDefaultError,
    601       r"Invalid default value for StringField:.*: "
    602       r"Field encountered non-ASCII string .*: "
    603       r"'ascii' codec can't decode byte 0x89 in position 0: "
    604       r"ordinal not in range",
    605       messages.StringField, 1, default=b'\x89')
    606 
    607   def testDefaultFields_InvalidSingle(self):
    608     """Test default field is correct type (invalid single)."""
    609     def action(field_class):
    610       self.assertRaises(messages.InvalidDefaultError,
    611                         field_class,
    612                         1,
    613                         default=object())
    614     self.ActionOnAllFieldClasses(action)
    615 
    616   def testDefaultFields_InvalidRepeated(self):
    617     """Test default field does not accept defaults."""
    618     self.assertRaisesWithRegexpMatch(
    619       messages.FieldDefinitionError,
    620       'Repeated fields may not have defaults',
    621       messages.StringField, 1, repeated=True, default=[1, 2, 3])
    622 
    623   def testDefaultFields_None(self):
    624     """Test none is always acceptable."""
    625     def action(field_class):
    626       field_class(1, default=None)
    627       field_class(1, required=True, default=None)
    628       field_class(1, repeated=True, default=None)
    629     self.ActionOnAllFieldClasses(action)
    630 
    631   def testDefaultFields_Enum(self):
    632     """Test the default for enum fields."""
    633     class Symbol(messages.Enum):
    634 
    635       ALPHA = 1
    636       BETA = 2
    637       GAMMA = 3
    638 
    639     field = messages.EnumField(Symbol, 1, default=Symbol.ALPHA)
    640 
    641     self.assertEquals(Symbol.ALPHA, field.default)
    642 
    643   def testDefaultFields_EnumStringDelayedResolution(self):
    644     """Test that enum fields resolve default strings."""
    645     field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
    646                                1,
    647                                default='OPTIONAL')
    648 
    649     self.assertEquals(descriptor.FieldDescriptor.Label.OPTIONAL, field.default)
    650 
    651   def testDefaultFields_EnumIntDelayedResolution(self):
    652     """Test that enum fields resolve default integers."""
    653     field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
    654                                1,
    655                                default=2)
    656 
    657     self.assertEquals(descriptor.FieldDescriptor.Label.REQUIRED, field.default)
    658 
    659   def testDefaultFields_EnumOkIfTypeKnown(self):
    660     """Test that enum fields accept valid default values when type is known."""
    661     field = messages.EnumField(descriptor.FieldDescriptor.Label,
    662                                1,
    663                                default='REPEATED')
    664 
    665     self.assertEquals(descriptor.FieldDescriptor.Label.REPEATED, field.default)
    666 
    667   def testDefaultFields_EnumForceCheckIfTypeKnown(self):
    668     """Test that enum fields validate default values if type is known."""
    669     self.assertRaisesWithRegexpMatch(TypeError,
    670                                      'No such value for NOT_A_LABEL in '
    671                                      'Enum Label',
    672                                      messages.EnumField,
    673                                      descriptor.FieldDescriptor.Label,
    674                                      1,
    675                                      default='NOT_A_LABEL')
    676 
    677   def testDefaultFields_EnumInvalidDelayedResolution(self):
    678     """Test that enum fields raise errors upon delayed resolution error."""
    679     field = messages.EnumField('protorpc.descriptor.FieldDescriptor.Label',
    680                                1,
    681                                default=200)
    682 
    683     self.assertRaisesWithRegexpMatch(TypeError,
    684                                      'No such value for 200 in Enum Label',
    685                                      getattr,
    686                                      field,
    687                                      'default')
    688 
    689   def testValidate_Valid(self):
    690     """Test validation of valid values."""
    691     values = {messages.IntegerField: 10,
    692               messages.FloatField: 1.5,
    693               messages.BooleanField: False,
    694               messages.BytesField: b'abc',
    695               messages.StringField: u'abc',
    696              }
    697     def action(field_class):
    698       # Optional.
    699       field = field_class(1)
    700       field.validate(values[field_class])
    701 
    702       # Required.
    703       field = field_class(1, required=True)
    704       field.validate(values[field_class])
    705 
    706       # Repeated.
    707       field = field_class(1, repeated=True)
    708       field.validate([])
    709       field.validate(())
    710       field.validate([values[field_class]])
    711       field.validate((values[field_class],))
    712 
    713       # Right value, but not repeated.
    714       self.assertRaises(messages.ValidationError,
    715                         field.validate,
    716                         values[field_class])
    717       self.assertRaises(messages.ValidationError,
    718                         field.validate,
    719                         values[field_class])
    720 
    721     self.ActionOnAllFieldClasses(action)
    722 
    723   def testValidate_Invalid(self):
    724     """Test validation of valid values."""
    725     values = {messages.IntegerField: "10",
    726               messages.FloatField: 1,
    727               messages.BooleanField: 0,
    728               messages.BytesField: 10.20,
    729               messages.StringField: 42,
    730              }
    731     def action(field_class):
    732       # Optional.
    733       field = field_class(1)
    734       self.assertRaises(messages.ValidationError,
    735                         field.validate,
    736                         values[field_class])
    737 
    738       # Required.
    739       field = field_class(1, required=True)
    740       self.assertRaises(messages.ValidationError,
    741                         field.validate,
    742                         values[field_class])
    743 
    744       # Repeated.
    745       field = field_class(1, repeated=True)
    746       self.assertRaises(messages.ValidationError,
    747                         field.validate,
    748                         [values[field_class]])
    749       self.assertRaises(messages.ValidationError,
    750                         field.validate,
    751                         (values[field_class],))
    752     self.ActionOnAllFieldClasses(action)
    753 
    754   def testValidate_None(self):
    755     """Test that None is valid for non-required fields."""
    756     def action(field_class):
    757       # Optional.
    758       field = field_class(1)
    759       field.validate(None)
    760 
    761       # Required.
    762       field = field_class(1, required=True)
    763       self.assertRaisesWithRegexpMatch(messages.ValidationError,
    764                                        'Required field is missing',
    765                                        field.validate,
    766                                        None)
    767 
    768       # Repeated.
    769       field = field_class(1, repeated=True)
    770       field.validate(None)
    771       self.assertRaisesWithRegexpMatch(messages.ValidationError,
    772                                        'Repeated values for %s may '
    773                                        'not be None' % field_class.__name__,
    774                                        field.validate,
    775                                        [None])
    776       self.assertRaises(messages.ValidationError,
    777                         field.validate,
    778                         (None,))
    779     self.ActionOnAllFieldClasses(action)
    780 
    781   def testValidateElement(self):
    782     """Test validation of valid values."""
    783     values = {messages.IntegerField: 10,
    784               messages.FloatField: 1.5,
    785               messages.BooleanField: False,
    786               messages.BytesField: 'abc',
    787               messages.StringField: u'abc',
    788              }
    789     def action(field_class):
    790       # Optional.
    791       field = field_class(1)
    792       field.validate_element(values[field_class])
    793 
    794       # Required.
    795       field = field_class(1, required=True)
    796       field.validate_element(values[field_class])
    797 
    798       # Repeated.
    799       field = field_class(1, repeated=True)
    800       self.assertRaises(message.VAlidationError,
    801                         field.validate_element,
    802                         [])
    803       self.assertRaises(message.VAlidationError,
    804                         field.validate_element,
    805                         ())
    806       field.validate_element(values[field_class])
    807       field.validate_element(values[field_class])
    808 
    809       # Right value, but repeated.
    810       self.assertRaises(messages.ValidationError,
    811                         field.validate_element,
    812                         [values[field_class]])
    813       self.assertRaises(messages.ValidationError,
    814                         field.validate_element,
    815                         (values[field_class],))
    816 
    817   def testReadOnly(self):
    818     """Test that objects are all read-only."""
    819     def action(field_class):
    820       field = field_class(10)
    821       self.assertRaises(AttributeError,
    822                         setattr,
    823                         field,
    824                         'number',
    825                         20)
    826       self.assertRaises(AttributeError,
    827                         setattr,
    828                         field,
    829                         'anything_else',
    830                         'whatever')
    831     self.ActionOnAllFieldClasses(action)
    832 
    833   def testMessageField(self):
    834     """Test the construction of message fields."""
    835     self.assertRaises(messages.FieldDefinitionError,
    836                       messages.MessageField,
    837                       str,
    838                       10)
    839 
    840     self.assertRaises(messages.FieldDefinitionError,
    841                       messages.MessageField,
    842                       messages.Message,
    843                       10)
    844 
    845     class MyMessage(messages.Message):
    846       pass
    847 
    848     field = messages.MessageField(MyMessage, 10)
    849     self.assertEquals(MyMessage, field.type)
    850 
    851   def testMessageField_ForwardReference(self):
    852     """Test the construction of forward reference message fields."""
    853     global MyMessage
    854     global ForwardMessage
    855     try:
    856       class MyMessage(messages.Message):
    857 
    858         self_reference = messages.MessageField('MyMessage', 1)
    859         forward = messages.MessageField('ForwardMessage', 2)
    860         nested = messages.MessageField('ForwardMessage.NestedMessage', 3)
    861         inner = messages.MessageField('Inner', 4)
    862 
    863         class Inner(messages.Message):
    864 
    865           sibling = messages.MessageField('Sibling', 1)
    866 
    867         class Sibling(messages.Message):
    868 
    869           pass
    870 
    871       class ForwardMessage(messages.Message):
    872 
    873         class NestedMessage(messages.Message):
    874 
    875           pass
    876 
    877       self.assertEquals(MyMessage,
    878                         MyMessage.field_by_name('self_reference').type)
    879 
    880       self.assertEquals(ForwardMessage,
    881                         MyMessage.field_by_name('forward').type)
    882 
    883       self.assertEquals(ForwardMessage.NestedMessage,
    884                         MyMessage.field_by_name('nested').type)
    885 
    886       self.assertEquals(MyMessage.Inner,
    887                         MyMessage.field_by_name('inner').type)
    888 
    889       self.assertEquals(MyMessage.Sibling,
    890                         MyMessage.Inner.field_by_name('sibling').type)
    891     finally:
    892       try:
    893         del MyMessage
    894         del ForwardMessage
    895       except:
    896         pass
    897 
    898   def testMessageField_WrongType(self):
    899     """Test that forward referencing the wrong type raises an error."""
    900     global AnEnum
    901     try:
    902       class AnEnum(messages.Enum):
    903         pass
    904 
    905       class AnotherMessage(messages.Message):
    906 
    907         a_field = messages.MessageField('AnEnum', 1)
    908 
    909       self.assertRaises(messages.FieldDefinitionError,
    910                         getattr,
    911                         AnotherMessage.field_by_name('a_field'),
    912                         'type')
    913     finally:
    914       del AnEnum
    915 
    916   def testMessageFieldValidate(self):
    917     """Test validation on message field."""
    918     class MyMessage(messages.Message):
    919       pass
    920 
    921     class AnotherMessage(messages.Message):
    922       pass
    923 
    924     field = messages.MessageField(MyMessage, 10)
    925     field.validate(MyMessage())
    926 
    927     self.assertRaises(messages.ValidationError,
    928                       field.validate,
    929                       AnotherMessage())
    930 
    931   def testMessageFieldMessageType(self):
    932     """Test message_type property."""
    933     class MyMessage(messages.Message):
    934       pass
    935 
    936     class HasMessage(messages.Message):
    937       field = messages.MessageField(MyMessage, 1)
    938 
    939     self.assertEqual(HasMessage.field.type, HasMessage.field.message_type)
    940 
    941   def testMessageFieldValueFromMessage(self):
    942     class MyMessage(messages.Message):
    943       pass
    944 
    945     class HasMessage(messages.Message):
    946       field = messages.MessageField(MyMessage, 1)
    947 
    948     instance = MyMessage()
    949 
    950     self.assertTrue(instance is HasMessage.field.value_from_message(instance))
    951 
    952   def testMessageFieldValueFromMessageWrongType(self):
    953     class MyMessage(messages.Message):
    954       pass
    955 
    956     class HasMessage(messages.Message):
    957       field = messages.MessageField(MyMessage, 1)
    958 
    959     self.assertRaisesWithRegexpMatch(
    960         messages.DecodeError,
    961         'Expected type MyMessage, got int: 10',
    962         HasMessage.field.value_from_message, 10)
    963 
    964   def testMessageFieldValueToMessage(self):
    965     class MyMessage(messages.Message):
    966       pass
    967 
    968     class HasMessage(messages.Message):
    969       field = messages.MessageField(MyMessage, 1)
    970 
    971     instance = MyMessage()
    972 
    973     self.assertTrue(instance is HasMessage.field.value_to_message(instance))
    974 
    975   def testMessageFieldValueToMessageWrongType(self):
    976     class MyMessage(messages.Message):
    977       pass
    978 
    979     class MyOtherMessage(messages.Message):
    980       pass
    981 
    982     class HasMessage(messages.Message):
    983       field = messages.MessageField(MyMessage, 1)
    984 
    985     instance = MyOtherMessage()
    986 
    987     self.assertRaisesWithRegexpMatch(
    988         messages.EncodeError,
    989         'Expected type MyMessage, got MyOtherMessage: <MyOtherMessage>',
    990         HasMessage.field.value_to_message, instance)
    991 
    992   def testIntegerField_AllowLong(self):
    993     """Test that the integer field allows for longs."""
    994     if six.PY2:
    995         messages.IntegerField(10, default=long(10))
    996 
    997   def testMessageFieldValidate_Initialized(self):
    998     """Test validation on message field."""
    999     class MyMessage(messages.Message):
   1000       field1 = messages.IntegerField(1, required=True)
   1001 
   1002     field = messages.MessageField(MyMessage, 10)
   1003 
   1004     # Will validate messages where is_initialized() is False.
   1005     message = MyMessage()
   1006     field.validate(message)
   1007     message.field1 = 20
   1008     field.validate(message)
   1009 
   1010   def testEnumField(self):
   1011     """Test the construction of enum fields."""
   1012     self.assertRaises(messages.FieldDefinitionError,
   1013                       messages.EnumField,
   1014                       str,
   1015                       10)
   1016 
   1017     self.assertRaises(messages.FieldDefinitionError,
   1018                       messages.EnumField,
   1019                       messages.Enum,
   1020                       10)
   1021 
   1022     class Color(messages.Enum):
   1023       RED = 1
   1024       GREEN = 2
   1025       BLUE = 3
   1026 
   1027     field = messages.EnumField(Color, 10)
   1028     self.assertEquals(Color, field.type)
   1029 
   1030     class Another(messages.Enum):
   1031       VALUE = 1
   1032 
   1033     self.assertRaises(messages.InvalidDefaultError,
   1034                       messages.EnumField,
   1035                       Color,
   1036                       10,
   1037                       default=Another.VALUE)
   1038 
   1039   def testEnumField_ForwardReference(self):
   1040     """Test the construction of forward reference enum fields."""
   1041     global MyMessage
   1042     global ForwardEnum
   1043     global ForwardMessage
   1044     try:
   1045       class MyMessage(messages.Message):
   1046 
   1047         forward = messages.EnumField('ForwardEnum', 1)
   1048         nested = messages.EnumField('ForwardMessage.NestedEnum', 2)
   1049         inner = messages.EnumField('Inner', 3)
   1050 
   1051         class Inner(messages.Enum):
   1052           pass
   1053 
   1054       class ForwardEnum(messages.Enum):
   1055         pass
   1056 
   1057       class ForwardMessage(messages.Message):
   1058 
   1059         class NestedEnum(messages.Enum):
   1060           pass
   1061 
   1062       self.assertEquals(ForwardEnum,
   1063                         MyMessage.field_by_name('forward').type)
   1064 
   1065       self.assertEquals(ForwardMessage.NestedEnum,
   1066                         MyMessage.field_by_name('nested').type)
   1067 
   1068       self.assertEquals(MyMessage.Inner,
   1069                         MyMessage.field_by_name('inner').type)
   1070     finally:
   1071       try:
   1072         del MyMessage
   1073         del ForwardEnum
   1074         del ForwardMessage
   1075       except:
   1076         pass
   1077 
   1078   def testEnumField_WrongType(self):
   1079     """Test that forward referencing the wrong type raises an error."""
   1080     global AMessage
   1081     try:
   1082       class AMessage(messages.Message):
   1083         pass
   1084 
   1085       class AnotherMessage(messages.Message):
   1086 
   1087         a_field = messages.EnumField('AMessage', 1)
   1088 
   1089       self.assertRaises(messages.FieldDefinitionError,
   1090                         getattr,
   1091                         AnotherMessage.field_by_name('a_field'),
   1092                         'type')
   1093     finally:
   1094       del AMessage
   1095 
   1096   def testMessageDefinition(self):
   1097     """Test that message definition is set on fields."""
   1098     class MyMessage(messages.Message):
   1099 
   1100       my_field = messages.StringField(1)
   1101 
   1102     self.assertEquals(MyMessage,
   1103                       MyMessage.field_by_name('my_field').message_definition())
   1104 
   1105   def testNoneAssignment(self):
   1106     """Test that assigning None does not change comparison."""
   1107     class MyMessage(messages.Message):
   1108 
   1109       my_field = messages.StringField(1)
   1110 
   1111     m1 = MyMessage()
   1112     m2 = MyMessage()
   1113     m2.my_field = None
   1114     self.assertEquals(m1, m2)
   1115 
   1116   def testNonAsciiStr(self):
   1117     """Test validation fails for non-ascii StringField values."""
   1118     class Thing(messages.Message):
   1119       string_field = messages.StringField(2)
   1120 
   1121     thing = Thing()
   1122     self.assertRaisesWithRegexpMatch(
   1123       messages.ValidationError,
   1124       'Field string_field encountered non-ASCII string',
   1125       setattr, thing, 'string_field', test_util.BINARY)
   1126 
   1127 
   1128 class MessageTest(test_util.TestCase):
   1129   """Tests for message class."""
   1130 
   1131   def CreateMessageClass(self):
   1132     """Creates a simple message class with 3 fields.
   1133 
   1134     Fields are defined in alphabetical order but with conflicting numeric
   1135     order.
   1136     """
   1137     class ComplexMessage(messages.Message):
   1138       a3 = messages.IntegerField(3)
   1139       b1 = messages.StringField(1)
   1140       c2 = messages.StringField(2)
   1141 
   1142     return ComplexMessage
   1143 
   1144   def testSameNumbers(self):
   1145     """Test that cannot assign two fields with same numbers."""
   1146 
   1147     def action():
   1148       class BadMessage(messages.Message):
   1149         f1 = messages.IntegerField(1)
   1150         f2 = messages.IntegerField(1)
   1151     self.assertRaises(messages.DuplicateNumberError,
   1152                       action)
   1153 
   1154   def testStrictAssignment(self):
   1155     """Tests that cannot assign to unknown or non-reserved attributes."""
   1156     class SimpleMessage(messages.Message):
   1157       field = messages.IntegerField(1)
   1158 
   1159     simple_message = SimpleMessage()
   1160     self.assertRaises(AttributeError,
   1161                       setattr,
   1162                       simple_message,
   1163                       'does_not_exist',
   1164                       10)
   1165 
   1166   def testListAssignmentDoesNotCopy(self):
   1167     class SimpleMessage(messages.Message):
   1168       repeated = messages.IntegerField(1, repeated=True)
   1169 
   1170     message = SimpleMessage()
   1171     original = message.repeated
   1172     message.repeated = []
   1173     self.assertFalse(original is message.repeated)
   1174 
   1175   def testValidate_Optional(self):
   1176     """Tests validation of optional fields."""
   1177     class SimpleMessage(messages.Message):
   1178       non_required = messages.IntegerField(1)
   1179 
   1180     simple_message = SimpleMessage()
   1181     simple_message.check_initialized()
   1182     simple_message.non_required = 10
   1183     simple_message.check_initialized()
   1184 
   1185   def testValidate_Required(self):
   1186     """Tests validation of required fields."""
   1187     class SimpleMessage(messages.Message):
   1188       required = messages.IntegerField(1, required=True)
   1189 
   1190     simple_message = SimpleMessage()
   1191     self.assertRaises(messages.ValidationError,
   1192                       simple_message.check_initialized)
   1193     simple_message.required = 10
   1194     simple_message.check_initialized()
   1195 
   1196   def testValidate_Repeated(self):
   1197     """Tests validation of repeated fields."""
   1198     class SimpleMessage(messages.Message):
   1199       repeated = messages.IntegerField(1, repeated=True)
   1200 
   1201     simple_message = SimpleMessage()
   1202 
   1203     # Check valid values.
   1204     for valid_value in [], [10], [10, 20], (), (10,), (10, 20):
   1205       simple_message.repeated = valid_value
   1206       simple_message.check_initialized()
   1207 
   1208     # Check cleared.
   1209     simple_message.repeated = []
   1210     simple_message.check_initialized()
   1211 
   1212     # Check invalid values.
   1213     for invalid_value in 10, ['10', '20'], [None], (None,):
   1214       self.assertRaises(messages.ValidationError,
   1215                         setattr, simple_message, 'repeated', invalid_value)
   1216 
   1217   def testIsInitialized(self):
   1218     """Tests is_initialized."""
   1219     class SimpleMessage(messages.Message):
   1220       required = messages.IntegerField(1, required=True)
   1221 
   1222     simple_message = SimpleMessage()
   1223     self.assertFalse(simple_message.is_initialized())
   1224 
   1225     simple_message.required = 10
   1226 
   1227     self.assertTrue(simple_message.is_initialized())
   1228 
   1229   def testIsInitializedNestedField(self):
   1230     """Tests is_initialized for nested fields."""
   1231     class SimpleMessage(messages.Message):
   1232       required = messages.IntegerField(1, required=True)
   1233 
   1234     class NestedMessage(messages.Message):
   1235       simple = messages.MessageField(SimpleMessage, 1)
   1236 
   1237     simple_message = SimpleMessage()
   1238     self.assertFalse(simple_message.is_initialized())
   1239     nested_message = NestedMessage(simple=simple_message)
   1240     self.assertFalse(nested_message.is_initialized())
   1241 
   1242     simple_message.required = 10
   1243 
   1244     self.assertTrue(simple_message.is_initialized())
   1245     self.assertTrue(nested_message.is_initialized())
   1246 
   1247   def testInitializeNestedFieldFromDict(self):
   1248     """Tests initializing nested fields from dict."""
   1249     class SimpleMessage(messages.Message):
   1250       required = messages.IntegerField(1, required=True)
   1251 
   1252     class NestedMessage(messages.Message):
   1253       simple = messages.MessageField(SimpleMessage, 1)
   1254 
   1255     class RepeatedMessage(messages.Message):
   1256       simple = messages.MessageField(SimpleMessage, 1, repeated=True)
   1257 
   1258     nested_message1 = NestedMessage(simple={'required': 10})
   1259     self.assertTrue(nested_message1.is_initialized())
   1260     self.assertTrue(nested_message1.simple.is_initialized())
   1261 
   1262     nested_message2 = NestedMessage()
   1263     nested_message2.simple = {'required': 10}
   1264     self.assertTrue(nested_message2.is_initialized())
   1265     self.assertTrue(nested_message2.simple.is_initialized())
   1266 
   1267     repeated_values = [{}, {'required': 10}, SimpleMessage(required=20)]
   1268 
   1269     repeated_message1 = RepeatedMessage(simple=repeated_values)
   1270     self.assertEquals(3, len(repeated_message1.simple))
   1271     self.assertFalse(repeated_message1.is_initialized())
   1272 
   1273     repeated_message1.simple[0].required = 0
   1274     self.assertTrue(repeated_message1.is_initialized())
   1275 
   1276     repeated_message2 = RepeatedMessage()
   1277     repeated_message2.simple = repeated_values
   1278     self.assertEquals(3, len(repeated_message2.simple))
   1279     self.assertFalse(repeated_message2.is_initialized())
   1280 
   1281     repeated_message2.simple[0].required = 0
   1282     self.assertTrue(repeated_message2.is_initialized())
   1283 
   1284   def testNestedMethodsNotAllowed(self):
   1285     """Test that method definitions on Message classes are not allowed."""
   1286     def action():
   1287       class WithMethods(messages.Message):
   1288         def not_allowed(self):
   1289           pass
   1290 
   1291     self.assertRaises(messages.MessageDefinitionError,
   1292                       action)
   1293 
   1294   def testNestedAttributesNotAllowed(self):
   1295     """Test that attribute assignment on Message classes are not allowed."""
   1296     def int_attribute():
   1297       class WithMethods(messages.Message):
   1298         not_allowed = 1
   1299 
   1300     def string_attribute():
   1301       class WithMethods(messages.Message):
   1302         not_allowed = 'not allowed'
   1303 
   1304     def enum_attribute():
   1305       class WithMethods(messages.Message):
   1306         not_allowed = Color.RED
   1307 
   1308     for action in (int_attribute, string_attribute, enum_attribute):
   1309       self.assertRaises(messages.MessageDefinitionError,
   1310                         action)
   1311 
   1312   def testNameIsSetOnFields(self):
   1313     """Make sure name is set on fields after Message class init."""
   1314     class HasNamedFields(messages.Message):
   1315       field = messages.StringField(1)
   1316 
   1317     self.assertEquals('field', HasNamedFields.field_by_number(1).name)
   1318 
   1319   def testSubclassingMessageDisallowed(self):
   1320     """Not permitted to create sub-classes of message classes."""
   1321     class SuperClass(messages.Message):
   1322       pass
   1323 
   1324     def action():
   1325       class SubClass(SuperClass):
   1326         pass
   1327 
   1328     self.assertRaises(messages.MessageDefinitionError,
   1329                       action)
   1330 
   1331   def testAllFields(self):
   1332     """Test all_fields method."""
   1333     ComplexMessage = self.CreateMessageClass()
   1334     fields = list(ComplexMessage.all_fields())
   1335 
   1336     # Order does not matter, so sort now.
   1337     fields = sorted(fields, key=lambda f: f.name)
   1338 
   1339     self.assertEquals(3, len(fields))
   1340     self.assertEquals('a3', fields[0].name)
   1341     self.assertEquals('b1', fields[1].name)
   1342     self.assertEquals('c2', fields[2].name)
   1343 
   1344   def testFieldByName(self):
   1345     """Test getting field by name."""
   1346     ComplexMessage = self.CreateMessageClass()
   1347 
   1348     self.assertEquals(3, ComplexMessage.field_by_name('a3').number)
   1349     self.assertEquals(1, ComplexMessage.field_by_name('b1').number)
   1350     self.assertEquals(2, ComplexMessage.field_by_name('c2').number)
   1351 
   1352     self.assertRaises(KeyError,
   1353                       ComplexMessage.field_by_name,
   1354                       'unknown')
   1355 
   1356   def testFieldByNumber(self):
   1357     """Test getting field by number."""
   1358     ComplexMessage = self.CreateMessageClass()
   1359 
   1360     self.assertEquals('a3', ComplexMessage.field_by_number(3).name)
   1361     self.assertEquals('b1', ComplexMessage.field_by_number(1).name)
   1362     self.assertEquals('c2', ComplexMessage.field_by_number(2).name)
   1363 
   1364     self.assertRaises(KeyError,
   1365                       ComplexMessage.field_by_number,
   1366                       4)
   1367 
   1368   def testGetAssignedValue(self):
   1369     """Test getting the assigned value of a field."""
   1370     class SomeMessage(messages.Message):
   1371       a_value = messages.StringField(1, default=u'a default')
   1372 
   1373     message = SomeMessage()
   1374     self.assertEquals(None, message.get_assigned_value('a_value'))
   1375 
   1376     message.a_value = u'a string'
   1377     self.assertEquals(u'a string', message.get_assigned_value('a_value'))
   1378 
   1379     message.a_value = u'a default'
   1380     self.assertEquals(u'a default', message.get_assigned_value('a_value'))
   1381 
   1382     self.assertRaisesWithRegexpMatch(
   1383         AttributeError,
   1384         'Message SomeMessage has no field no_such_field',
   1385         message.get_assigned_value,
   1386         'no_such_field')
   1387 
   1388   def testReset(self):
   1389     """Test resetting a field value."""
   1390     class SomeMessage(messages.Message):
   1391       a_value = messages.StringField(1, default=u'a default')
   1392       repeated = messages.IntegerField(2, repeated=True)
   1393 
   1394     message = SomeMessage()
   1395 
   1396     self.assertRaises(AttributeError, message.reset, 'unknown')
   1397 
   1398     self.assertEquals(u'a default', message.a_value)
   1399     message.reset('a_value')
   1400     self.assertEquals(u'a default', message.a_value)
   1401 
   1402     message.a_value = u'a new value'
   1403     self.assertEquals(u'a new value', message.a_value)
   1404     message.reset('a_value')
   1405     self.assertEquals(u'a default', message.a_value)
   1406 
   1407     message.repeated = [1, 2, 3]
   1408     self.assertEquals([1, 2, 3], message.repeated)
   1409     saved = message.repeated
   1410     message.reset('repeated')
   1411     self.assertEquals([], message.repeated)
   1412     self.assertIsInstance(message.repeated, messages.FieldList)
   1413     self.assertEquals([1, 2, 3], saved)
   1414 
   1415   def testAllowNestedEnums(self):
   1416     """Test allowing nested enums in a message definition."""
   1417     class Trade(messages.Message):
   1418       class Duration(messages.Enum):
   1419         GTC = 1
   1420         DAY = 2
   1421 
   1422       class Currency(messages.Enum):
   1423         USD = 1
   1424         GBP = 2
   1425         INR = 3
   1426 
   1427     # Sorted by name order seems to be the only feasible option.
   1428     self.assertEquals(['Currency', 'Duration'], Trade.__enums__)
   1429 
   1430     # Message definition will now be set on Enumerated objects.
   1431     self.assertEquals(Trade, Trade.Duration.message_definition())
   1432 
   1433   def testAllowNestedMessages(self):
   1434     """Test allowing nested messages in a message definition."""
   1435     class Trade(messages.Message):
   1436       class Lot(messages.Message):
   1437         pass
   1438 
   1439       class Agent(messages.Message):
   1440         pass
   1441 
   1442     # Sorted by name order seems to be the only feasible option.
   1443     self.assertEquals(['Agent', 'Lot'], Trade.__messages__)
   1444     self.assertEquals(Trade, Trade.Agent.message_definition())
   1445     self.assertEquals(Trade, Trade.Lot.message_definition())
   1446 
   1447     # But not Message itself.
   1448     def action():
   1449       class Trade(messages.Message):
   1450         NiceTry = messages.Message
   1451     self.assertRaises(messages.MessageDefinitionError, action)
   1452 
   1453   def testDisallowClassAssignments(self):
   1454     """Test setting class attributes may not happen."""
   1455     class MyMessage(messages.Message):
   1456       pass
   1457 
   1458     self.assertRaises(AttributeError,
   1459                       setattr,
   1460                       MyMessage,
   1461                       'x',
   1462                       'do not assign')
   1463 
   1464   def testEquality(self):
   1465     """Test message class equality."""
   1466     # Comparison against enums must work.
   1467     class MyEnum(messages.Enum):
   1468       val1 = 1
   1469       val2 = 2
   1470 
   1471     # Comparisons against nested messages must work.
   1472     class AnotherMessage(messages.Message):
   1473       string = messages.StringField(1)
   1474 
   1475     class MyMessage(messages.Message):
   1476       field1 = messages.IntegerField(1)
   1477       field2 = messages.EnumField(MyEnum, 2)
   1478       field3 = messages.MessageField(AnotherMessage, 3)
   1479 
   1480     message1 = MyMessage()
   1481 
   1482     self.assertNotEquals('hi', message1)
   1483     self.assertNotEquals(AnotherMessage(), message1)
   1484     self.assertEquals(message1, message1)
   1485 
   1486     message2 = MyMessage()
   1487 
   1488     self.assertEquals(message1, message2)
   1489 
   1490     message1.field1 = 10
   1491     self.assertNotEquals(message1, message2)
   1492 
   1493     message2.field1 = 20
   1494     self.assertNotEquals(message1, message2)
   1495 
   1496     message2.field1 = 10
   1497     self.assertEquals(message1, message2)
   1498 
   1499     message1.field2 = MyEnum.val1
   1500     self.assertNotEquals(message1, message2)
   1501 
   1502     message2.field2 = MyEnum.val2
   1503     self.assertNotEquals(message1, message2)
   1504 
   1505     message2.field2 = MyEnum.val1
   1506     self.assertEquals(message1, message2)
   1507 
   1508     message1.field3 = AnotherMessage()
   1509     message1.field3.string = 'value1'
   1510     self.assertNotEquals(message1, message2)
   1511 
   1512     message2.field3 = AnotherMessage()
   1513     message2.field3.string = 'value2'
   1514     self.assertNotEquals(message1, message2)
   1515 
   1516     message2.field3.string = 'value1'
   1517     self.assertEquals(message1, message2)
   1518 
   1519   def testEqualityWithUnknowns(self):
   1520     """Test message class equality with unknown fields."""
   1521 
   1522     class MyMessage(messages.Message):
   1523       field1 = messages.IntegerField(1)
   1524 
   1525     message1 = MyMessage()
   1526     message2 = MyMessage()
   1527     self.assertEquals(message1, message2)
   1528     message1.set_unrecognized_field('unknown1', 'value1',
   1529                                     messages.Variant.STRING)
   1530     self.assertEquals(message1, message2)
   1531 
   1532     message1.set_unrecognized_field('unknown2', ['asdf', 3],
   1533                                     messages.Variant.STRING)
   1534     message1.set_unrecognized_field('unknown3', 4.7,
   1535                                     messages.Variant.DOUBLE)
   1536     self.assertEquals(message1, message2)
   1537 
   1538   def testUnrecognizedFieldInvalidVariant(self):
   1539     class MyMessage(messages.Message):
   1540       field1 = messages.IntegerField(1)
   1541 
   1542     message1 = MyMessage()
   1543     self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
   1544                       {'unhandled': 'type'}, None)
   1545     self.assertRaises(TypeError, message1.set_unrecognized_field, 'unknown4',
   1546                       {'unhandled': 'type'}, 123)
   1547 
   1548   def testRepr(self):
   1549     """Test represtation of Message object."""
   1550     class MyMessage(messages.Message):
   1551       integer_value = messages.IntegerField(1)
   1552       string_value = messages.StringField(2)
   1553       unassigned = messages.StringField(3)
   1554       unassigned_with_default = messages.StringField(4, default=u'a default')
   1555 
   1556     my_message = MyMessage()
   1557     my_message.integer_value = 42
   1558     my_message.string_value = u'A string'
   1559 
   1560     pat = re.compile(r"<MyMessage\n integer_value: 42\n"
   1561                       " string_value: [u]?'A string'>")
   1562     self.assertTrue(pat.match(repr(my_message)) is not None)
   1563 
   1564   def testValidation(self):
   1565     """Test validation of message values."""
   1566     # Test optional.
   1567     class SubMessage(messages.Message):
   1568       pass
   1569 
   1570     class Message(messages.Message):
   1571       val = messages.MessageField(SubMessage, 1)
   1572 
   1573     message = Message()
   1574 
   1575     message_field = messages.MessageField(Message, 1)
   1576     message_field.validate(message)
   1577     message.val = SubMessage()
   1578     message_field.validate(message)
   1579     self.assertRaises(messages.ValidationError,
   1580                       setattr, message, 'val', [SubMessage()])
   1581 
   1582     # Test required.
   1583     class Message(messages.Message):
   1584       val = messages.MessageField(SubMessage, 1, required=True)
   1585 
   1586     message = Message()
   1587 
   1588     message_field = messages.MessageField(Message, 1)
   1589     message_field.validate(message)
   1590     message.val = SubMessage()
   1591     message_field.validate(message)
   1592     self.assertRaises(messages.ValidationError,
   1593                       setattr, message, 'val', [SubMessage()])
   1594 
   1595     # Test repeated.
   1596     class Message(messages.Message):
   1597       val = messages.MessageField(SubMessage, 1, repeated=True)
   1598 
   1599     message = Message()
   1600 
   1601     message_field = messages.MessageField(Message, 1)
   1602     message_field.validate(message)
   1603     self.assertRaisesWithRegexpMatch(
   1604       messages.ValidationError,
   1605       "Field val is repeated. Found: <SubMessage>",
   1606       setattr, message, 'val', SubMessage())
   1607     message.val = [SubMessage()]
   1608     message_field.validate(message)
   1609 
   1610   def testDefinitionName(self):
   1611     """Test message name."""
   1612     class MyMessage(messages.Message):
   1613       pass
   1614 
   1615     module_name = test_util.get_module_name(FieldTest)
   1616     self.assertEquals('%s.MyMessage' % module_name,
   1617                       MyMessage.definition_name())
   1618     self.assertEquals(module_name, MyMessage.outer_definition_name())
   1619     self.assertEquals(module_name, MyMessage.definition_package())
   1620 
   1621     self.assertEquals(six.text_type, type(MyMessage.definition_name()))
   1622     self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
   1623     self.assertEquals(six.text_type, type(MyMessage.definition_package()))
   1624 
   1625   def testDefinitionName_OverrideModule(self):
   1626     """Test message module is overriden by module package name."""
   1627     class MyMessage(messages.Message):
   1628       pass
   1629 
   1630     global package
   1631     package = 'my.package'
   1632 
   1633     try:
   1634       self.assertEquals('my.package.MyMessage', MyMessage.definition_name())
   1635       self.assertEquals('my.package', MyMessage.outer_definition_name())
   1636       self.assertEquals('my.package', MyMessage.definition_package())
   1637 
   1638       self.assertEquals(six.text_type, type(MyMessage.definition_name()))
   1639       self.assertEquals(six.text_type, type(MyMessage.outer_definition_name()))
   1640       self.assertEquals(six.text_type, type(MyMessage.definition_package()))
   1641     finally:
   1642       del package
   1643 
   1644   def testDefinitionName_NoModule(self):
   1645     """Test what happens when there is no module for message."""
   1646     class MyMessage(messages.Message):
   1647       pass
   1648 
   1649     original_modules = sys.modules
   1650     sys.modules = dict(sys.modules)
   1651     try:
   1652       del sys.modules[__name__]
   1653       self.assertEquals('MyMessage', MyMessage.definition_name())
   1654       self.assertEquals(None, MyMessage.outer_definition_name())
   1655       self.assertEquals(None, MyMessage.definition_package())
   1656 
   1657       self.assertEquals(six.text_type, type(MyMessage.definition_name()))
   1658     finally:
   1659       sys.modules = original_modules
   1660 
   1661   def testDefinitionName_Nested(self):
   1662     """Test nested message names."""
   1663     class MyMessage(messages.Message):
   1664 
   1665       class NestedMessage(messages.Message):
   1666 
   1667         class NestedMessage(messages.Message):
   1668 
   1669           pass
   1670 
   1671     module_name = test_util.get_module_name(MessageTest)
   1672     self.assertEquals('%s.MyMessage.NestedMessage' % module_name,
   1673                       MyMessage.NestedMessage.definition_name())
   1674     self.assertEquals('%s.MyMessage' % module_name,
   1675                       MyMessage.NestedMessage.outer_definition_name())
   1676     self.assertEquals(module_name,
   1677                       MyMessage.NestedMessage.definition_package())
   1678 
   1679     self.assertEquals('%s.MyMessage.NestedMessage.NestedMessage' % module_name,
   1680                       MyMessage.NestedMessage.NestedMessage.definition_name())
   1681     self.assertEquals(
   1682       '%s.MyMessage.NestedMessage' % module_name,
   1683       MyMessage.NestedMessage.NestedMessage.outer_definition_name())
   1684     self.assertEquals(
   1685       module_name,
   1686       MyMessage.NestedMessage.NestedMessage.definition_package())
   1687 
   1688 
   1689   def testMessageDefinition(self):
   1690     """Test that enumeration knows its enclosing message definition."""
   1691     class OuterMessage(messages.Message):
   1692 
   1693       class InnerMessage(messages.Message):
   1694         pass
   1695 
   1696     self.assertEquals(None, OuterMessage.message_definition())
   1697     self.assertEquals(OuterMessage,
   1698                       OuterMessage.InnerMessage.message_definition())
   1699 
   1700   def testConstructorKwargs(self):
   1701     """Test kwargs via constructor."""
   1702     class SomeMessage(messages.Message):
   1703       name = messages.StringField(1)
   1704       number = messages.IntegerField(2)
   1705 
   1706     expected = SomeMessage()
   1707     expected.name = 'my name'
   1708     expected.number = 200
   1709     self.assertEquals(expected, SomeMessage(name='my name', number=200))
   1710 
   1711   def testConstructorNotAField(self):
   1712     """Test kwargs via constructor with wrong names."""
   1713     class SomeMessage(messages.Message):
   1714       pass
   1715 
   1716     self.assertRaisesWithRegexpMatch(
   1717       AttributeError,
   1718       'May not assign arbitrary value does_not_exist to message SomeMessage',
   1719       SomeMessage,
   1720       does_not_exist=10)
   1721 
   1722   def testGetUnsetRepeatedValue(self):
   1723     class SomeMessage(messages.Message):
   1724       repeated = messages.IntegerField(1, repeated=True)
   1725 
   1726     instance = SomeMessage()
   1727     self.assertEquals([], instance.repeated)
   1728     self.assertTrue(isinstance(instance.repeated, messages.FieldList))
   1729 
   1730   def testCompareAutoInitializedRepeatedFields(self):
   1731     class SomeMessage(messages.Message):
   1732       repeated = messages.IntegerField(1, repeated=True)
   1733 
   1734     message1 = SomeMessage(repeated=[])
   1735     message2 = SomeMessage()
   1736     self.assertEquals(message1, message2)
   1737 
   1738   def testUnknownValues(self):
   1739     """Test message class equality with unknown fields."""
   1740     class MyMessage(messages.Message):
   1741       field1 = messages.IntegerField(1)
   1742 
   1743     message = MyMessage()
   1744     self.assertEquals([], message.all_unrecognized_fields())
   1745     self.assertEquals((None, None),
   1746                       message.get_unrecognized_field_info('doesntexist'))
   1747     self.assertEquals((None, None),
   1748                       message.get_unrecognized_field_info(
   1749                           'doesntexist', None, None))
   1750     self.assertEquals(('defaultvalue', 'defaultwire'),
   1751                       message.get_unrecognized_field_info(
   1752                           'doesntexist', 'defaultvalue', 'defaultwire'))
   1753     self.assertEquals((3, None),
   1754                       message.get_unrecognized_field_info(
   1755                           'doesntexist', value_default=3))
   1756 
   1757     message.set_unrecognized_field('exists', 9.5, messages.Variant.DOUBLE)
   1758     self.assertEquals(1, len(message.all_unrecognized_fields()))
   1759     self.assertTrue('exists' in message.all_unrecognized_fields())
   1760     self.assertEquals((9.5, messages.Variant.DOUBLE),
   1761                       message.get_unrecognized_field_info('exists'))
   1762     self.assertEquals((9.5, messages.Variant.DOUBLE),
   1763                       message.get_unrecognized_field_info('exists', 'type',
   1764                                                           1234))
   1765     self.assertEquals((1234, None),
   1766                       message.get_unrecognized_field_info('doesntexist', 1234))
   1767 
   1768     message.set_unrecognized_field('another', 'value', messages.Variant.STRING)
   1769     self.assertEquals(2, len(message.all_unrecognized_fields()))
   1770     self.assertTrue('exists' in message.all_unrecognized_fields())
   1771     self.assertTrue('another' in message.all_unrecognized_fields())
   1772     self.assertEquals((9.5, messages.Variant.DOUBLE),
   1773                       message.get_unrecognized_field_info('exists'))
   1774     self.assertEquals(('value', messages.Variant.STRING),
   1775                       message.get_unrecognized_field_info('another'))
   1776 
   1777     message.set_unrecognized_field('typetest1', ['list', 0, ('test',)],
   1778                                    messages.Variant.STRING)
   1779     self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
   1780                       message.get_unrecognized_field_info('typetest1'))
   1781     message.set_unrecognized_field('typetest2', '', messages.Variant.STRING)
   1782     self.assertEquals(('', messages.Variant.STRING),
   1783                       message.get_unrecognized_field_info('typetest2'))
   1784 
   1785   def testPickle(self):
   1786     """Testing pickling and unpickling of Message instances."""
   1787     global MyEnum
   1788     global AnotherMessage
   1789     global MyMessage
   1790 
   1791     class MyEnum(messages.Enum):
   1792       val1 = 1
   1793       val2 = 2
   1794 
   1795     class AnotherMessage(messages.Message):
   1796       string = messages.StringField(1, repeated=True)
   1797 
   1798     class MyMessage(messages.Message):
   1799       field1 = messages.IntegerField(1)
   1800       field2 = messages.EnumField(MyEnum, 2)
   1801       field3 = messages.MessageField(AnotherMessage, 3)
   1802 
   1803     message = MyMessage(field1=1, field2=MyEnum.val2,
   1804                         field3=AnotherMessage(string=['a', 'b', 'c']))
   1805     message.set_unrecognized_field('exists', 'value', messages.Variant.STRING)
   1806     message.set_unrecognized_field('repeated', ['list', 0, ('test',)],
   1807                                    messages.Variant.STRING)
   1808     unpickled = pickle.loads(pickle.dumps(message))
   1809     self.assertEquals(message, unpickled)
   1810     self.assertTrue(AnotherMessage.string is unpickled.field3.string.field)
   1811     self.assertTrue('exists' in message.all_unrecognized_fields())
   1812     self.assertEquals(('value', messages.Variant.STRING),
   1813                       message.get_unrecognized_field_info('exists'))
   1814     self.assertEquals((['list', 0, ('test',)], messages.Variant.STRING),
   1815                       message.get_unrecognized_field_info('repeated'))
   1816 
   1817 
   1818 class FindDefinitionTest(test_util.TestCase):
   1819   """Test finding definitions relative to various definitions and modules."""
   1820 
   1821   def setUp(self):
   1822     """Set up module-space.  Starts off empty."""
   1823     self.modules = {}
   1824 
   1825   def DefineModule(self, name):
   1826     """Define a module and its parents in module space.
   1827 
   1828     Modules that are already defined in self.modules are not re-created.
   1829 
   1830     Args:
   1831       name: Fully qualified name of modules to create.
   1832 
   1833     Returns:
   1834       Deepest nested module.  For example:
   1835 
   1836         DefineModule('a.b.c')  # Returns c.
   1837     """
   1838     name_path = name.split('.')
   1839     full_path = []
   1840     for node in name_path:
   1841       full_path.append(node)
   1842       full_name = '.'.join(full_path)
   1843       self.modules.setdefault(full_name, types.ModuleType(full_name))
   1844     return self.modules[name]
   1845 
   1846   def DefineMessage(self, module, name, children={}, add_to_module=True):
   1847     """Define a new Message class in the context of a module.
   1848 
   1849     Used for easily describing complex Message hierarchy.  Message is defined
   1850     including all child definitions.
   1851 
   1852     Args:
   1853       module: Fully qualified name of module to place Message class in.
   1854       name: Name of Message to define within module.
   1855       children: Define any level of nesting of children definitions.  To define
   1856         a message, map the name to another dictionary.  The dictionary can
   1857         itself contain additional definitions, and so on.  To map to an Enum,
   1858         define the Enum class separately and map it by name.
   1859       add_to_module: If True, new Message class is added to module.  If False,
   1860         new Message is not added.
   1861     """
   1862     # Make sure module exists.
   1863     module_instance = self.DefineModule(module)
   1864 
   1865     # Recursively define all child messages.
   1866     for attribute, value in children.items():
   1867       if isinstance(value, dict):
   1868         children[attribute] = self.DefineMessage(
   1869             module, attribute, value, False)
   1870 
   1871     # Override default __module__ variable.
   1872     children['__module__'] = module
   1873 
   1874     # Instantiate and possibly add to module.
   1875     message_class = type(name, (messages.Message,), dict(children))
   1876     if add_to_module:
   1877       setattr(module_instance, name, message_class)
   1878     return message_class
   1879 
   1880   def Importer(self, module, globals='', locals='', fromlist=None):
   1881     """Importer function.
   1882 
   1883     Acts like __import__.  Only loads modules from self.modules.  Does not
   1884     try to load real modules defined elsewhere.  Does not try to handle relative
   1885     imports.
   1886 
   1887     Args:
   1888       module: Fully qualified name of module to load from self.modules.
   1889     """
   1890     if fromlist is None:
   1891       module = module.split('.')[0]
   1892     try:
   1893       return self.modules[module]
   1894     except KeyError:
   1895       raise ImportError()
   1896 
   1897   def testNoSuchModule(self):
   1898     """Test searching for definitions that do no exist."""
   1899     self.assertRaises(messages.DefinitionNotFoundError,
   1900                       messages.find_definition,
   1901                       'does.not.exist',
   1902                       importer=self.Importer)
   1903 
   1904   def testRefersToModule(self):
   1905     """Test that referring to a module does not return that module."""
   1906     self.DefineModule('i.am.a.module')
   1907     self.assertRaises(messages.DefinitionNotFoundError,
   1908                       messages.find_definition,
   1909                       'i.am.a.module',
   1910                       importer=self.Importer)
   1911 
   1912   def testNoDefinition(self):
   1913     """Test not finding a definition in an existing module."""
   1914     self.DefineModule('i.am.a.module')
   1915     self.assertRaises(messages.DefinitionNotFoundError,
   1916                       messages.find_definition,
   1917                       'i.am.a.module.MyMessage',
   1918                       importer=self.Importer)
   1919 
   1920   def testNotADefinition(self):
   1921     """Test trying to fetch something that is not a definition."""
   1922     module = self.DefineModule('i.am.a.module')
   1923     setattr(module, 'A', 'a string')
   1924     self.assertRaises(messages.DefinitionNotFoundError,
   1925                       messages.find_definition,
   1926                       'i.am.a.module.A',
   1927                       importer=self.Importer)
   1928 
   1929   def testGlobalFind(self):
   1930     """Test finding definitions from fully qualified module names."""
   1931     A = self.DefineMessage('a.b.c', 'A', {})
   1932     self.assertEquals(A, messages.find_definition('a.b.c.A',
   1933                                                   importer=self.Importer))
   1934     B = self.DefineMessage('a.b.c', 'B', {'C':{}})
   1935     self.assertEquals(B.C, messages.find_definition('a.b.c.B.C',
   1936                                                     importer=self.Importer))
   1937 
   1938   def testRelativeToModule(self):
   1939     """Test finding definitions relative to modules."""
   1940     # Define modules.
   1941     a = self.DefineModule('a')
   1942     b = self.DefineModule('a.b')
   1943     c = self.DefineModule('a.b.c')
   1944 
   1945     # Define messages.
   1946     A = self.DefineMessage('a', 'A')
   1947     B = self.DefineMessage('a.b', 'B')
   1948     C = self.DefineMessage('a.b.c', 'C')
   1949     D = self.DefineMessage('a.b.d', 'D')
   1950 
   1951     # Find A, B, C and D relative to a.
   1952     self.assertEquals(A, messages.find_definition(
   1953         'A', a, importer=self.Importer))
   1954     self.assertEquals(B, messages.find_definition(
   1955         'b.B', a, importer=self.Importer))
   1956     self.assertEquals(C, messages.find_definition(
   1957         'b.c.C', a, importer=self.Importer))
   1958     self.assertEquals(D, messages.find_definition(
   1959         'b.d.D', a, importer=self.Importer))
   1960 
   1961     # Find A, B, C and D relative to b.
   1962     self.assertEquals(A, messages.find_definition(
   1963         'A', b, importer=self.Importer))
   1964     self.assertEquals(B, messages.find_definition(
   1965         'B', b, importer=self.Importer))
   1966     self.assertEquals(C, messages.find_definition(
   1967         'c.C', b, importer=self.Importer))
   1968     self.assertEquals(D, messages.find_definition(
   1969         'd.D', b, importer=self.Importer))
   1970 
   1971     # Find A, B, C and D relative to c.  Module d is the same case as c.
   1972     self.assertEquals(A, messages.find_definition(
   1973         'A', c, importer=self.Importer))
   1974     self.assertEquals(B, messages.find_definition(
   1975         'B', c, importer=self.Importer))
   1976     self.assertEquals(C, messages.find_definition(
   1977         'C', c, importer=self.Importer))
   1978     self.assertEquals(D, messages.find_definition(
   1979         'd.D', c, importer=self.Importer))
   1980 
   1981   def testRelativeToMessages(self):
   1982     """Test finding definitions relative to Message definitions."""
   1983     A = self.DefineMessage('a.b', 'A', {'B': {'C': {}, 'D': {}}})
   1984     B = A.B
   1985     C = A.B.C
   1986     D = A.B.D
   1987 
   1988     # Find relative to A.
   1989     self.assertEquals(A, messages.find_definition(
   1990         'A', A, importer=self.Importer))
   1991     self.assertEquals(B, messages.find_definition(
   1992         'B', A, importer=self.Importer))
   1993     self.assertEquals(C, messages.find_definition(
   1994         'B.C', A, importer=self.Importer))
   1995     self.assertEquals(D, messages.find_definition(
   1996         'B.D', A, importer=self.Importer))
   1997 
   1998     # Find relative to B.
   1999     self.assertEquals(A, messages.find_definition(
   2000         'A', B, importer=self.Importer))
   2001     self.assertEquals(B, messages.find_definition(
   2002         'B', B, importer=self.Importer))
   2003     self.assertEquals(C, messages.find_definition(
   2004         'C', B, importer=self.Importer))
   2005     self.assertEquals(D, messages.find_definition(
   2006         'D', B, importer=self.Importer))
   2007 
   2008     # Find relative to C.
   2009     self.assertEquals(A, messages.find_definition(
   2010         'A', C, importer=self.Importer))
   2011     self.assertEquals(B, messages.find_definition(
   2012         'B', C, importer=self.Importer))
   2013     self.assertEquals(C, messages.find_definition(
   2014         'C', C, importer=self.Importer))
   2015     self.assertEquals(D, messages.find_definition(
   2016         'D', C, importer=self.Importer))
   2017 
   2018     # Find relative to C searching from c.
   2019     self.assertEquals(A, messages.find_definition(
   2020         'b.A', C, importer=self.Importer))
   2021     self.assertEquals(B, messages.find_definition(
   2022         'b.A.B', C, importer=self.Importer))
   2023     self.assertEquals(C, messages.find_definition(
   2024         'b.A.B.C', C, importer=self.Importer))
   2025     self.assertEquals(D, messages.find_definition(
   2026         'b.A.B.D', C, importer=self.Importer))
   2027 
   2028   def testAbsoluteReference(self):
   2029     """Test finding absolute definition names."""
   2030     # Define modules.
   2031     a = self.DefineModule('a')
   2032     b = self.DefineModule('a.a')
   2033 
   2034     # Define messages.
   2035     aA = self.DefineMessage('a', 'A')
   2036     aaA = self.DefineMessage('a.a', 'A')
   2037 
   2038     # Always find a.A.
   2039     self.assertEquals(aA, messages.find_definition('.a.A', None,
   2040                                                    importer=self.Importer))
   2041     self.assertEquals(aA, messages.find_definition('.a.A', a,
   2042                                                    importer=self.Importer))
   2043     self.assertEquals(aA, messages.find_definition('.a.A', aA,
   2044                                                    importer=self.Importer))
   2045     self.assertEquals(aA, messages.find_definition('.a.A', aaA,
   2046                                                    importer=self.Importer))
   2047 
   2048   def testFindEnum(self):
   2049     """Test that Enums are found."""
   2050     class Color(messages.Enum):
   2051       pass
   2052     A = self.DefineMessage('a', 'A', {'Color': Color})
   2053 
   2054     self.assertEquals(
   2055         Color,
   2056         messages.find_definition('Color', A, importer=self.Importer))
   2057 
   2058   def testFalseScope(self):
   2059     """Test that Message definitions nested in strange objects are hidden."""
   2060     global X
   2061     class X(object):
   2062       class A(messages.Message):
   2063         pass
   2064 
   2065     self.assertRaises(TypeError, messages.find_definition, 'A', X)
   2066     self.assertRaises(messages.DefinitionNotFoundError,
   2067                       messages.find_definition,
   2068                       'X.A', sys.modules[__name__])
   2069 
   2070   def testSearchAttributeFirst(self):
   2071     """Make sure not faked out by module, but continues searching."""
   2072     A = self.DefineMessage('a', 'A')
   2073     module_A = self.DefineModule('a.A')
   2074 
   2075     self.assertEquals(A, messages.find_definition(
   2076         'a.A', None, importer=self.Importer))
   2077 
   2078 
   2079 class FindDefinitionUnicodeTests(test_util.TestCase):
   2080 
   2081   # TODO(craigcitro): Fix this test and re-enable it.
   2082   def notatestUnicodeString(self):
   2083     """Test using unicode names."""
   2084     from protorpc import registry
   2085     self.assertEquals('ServiceMapping',
   2086                       messages.find_definition(
   2087                         u'protorpc.registry.ServiceMapping',
   2088                         None).__name__)
   2089 
   2090 
   2091 def main():
   2092   unittest.main()
   2093 
   2094 
   2095 if __name__ == '__main__':
   2096   main()
   2097