Home | History | Annotate | Download | only in internal
      1 #! /usr/bin/env python
      2 #
      3 # Protocol Buffers - Google's data interchange format
      4 # Copyright 2008 Google Inc.  All rights reserved.
      5 # https://developers.google.com/protocol-buffers/
      6 #
      7 # Redistribution and use in source and binary forms, with or without
      8 # modification, are permitted provided that the following conditions are
      9 # met:
     10 #
     11 #     * Redistributions of source code must retain the above copyright
     12 # notice, this list of conditions and the following disclaimer.
     13 #     * Redistributions in binary form must reproduce the above
     14 # copyright notice, this list of conditions and the following disclaimer
     15 # in the documentation and/or other materials provided with the
     16 # distribution.
     17 #     * Neither the name of Google Inc. nor the names of its
     18 # contributors may be used to endorse or promote products derived from
     19 # this software without specific prior written permission.
     20 #
     21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     22 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     23 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     24 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     25 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     26 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     27 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     28 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     29 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     30 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     31 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     32 
     33 """Tests python protocol buffers against the golden message.
     34 
     35 Note that the golden messages exercise every known field type, thus this
     36 test ends up exercising and verifying nearly all of the parsing and
     37 serialization code in the whole library.
     38 
     39 TODO(kenton):  Merge with wire_format_test?  It doesn't make a whole lot of
     40 sense to call this a test of the "message" module, which only declares an
     41 abstract interface.
     42 """
     43 
     44 __author__ = 'gps (at] google.com (Gregory P. Smith)'
     45 
     46 
     47 import collections
     48 import copy
     49 import math
     50 import operator
     51 import pickle
     52 import six
     53 import sys
     54 
     55 try:
     56   import unittest2 as unittest  #PY26
     57 except ImportError:
     58   import unittest
     59 
     60 from google.protobuf import map_unittest_pb2
     61 from google.protobuf import unittest_pb2
     62 from google.protobuf import unittest_proto3_arena_pb2
     63 from google.protobuf import descriptor_pb2
     64 from google.protobuf import descriptor_pool
     65 from google.protobuf import message_factory
     66 from google.protobuf import text_format
     67 from google.protobuf.internal import api_implementation
     68 from google.protobuf.internal import packed_field_test_pb2
     69 from google.protobuf.internal import test_util
     70 from google.protobuf import message
     71 from google.protobuf.internal import _parameterized
     72 
     73 if six.PY3:
     74   long = int
     75 
     76 
     77 # Python pre-2.6 does not have isinf() or isnan() functions, so we have
     78 # to provide our own.
     79 def isnan(val):
     80   # NaN is never equal to itself.
     81   return val != val
     82 def isinf(val):
     83   # Infinity times zero equals NaN.
     84   return not isnan(val) and isnan(val * 0)
     85 def IsPosInf(val):
     86   return isinf(val) and (val > 0)
     87 def IsNegInf(val):
     88   return isinf(val) and (val < 0)
     89 
     90 
     91 @_parameterized.Parameters(
     92     (unittest_pb2),
     93     (unittest_proto3_arena_pb2))
     94 class MessageTest(unittest.TestCase):
     95 
     96   def testBadUtf8String(self, message_module):
     97     if api_implementation.Type() != 'python':
     98       self.skipTest("Skipping testBadUtf8String, currently only the python "
     99                     "api implementation raises UnicodeDecodeError when a "
    100                     "string field contains bad utf-8.")
    101     bad_utf8_data = test_util.GoldenFileData('bad_utf8_string')
    102     with self.assertRaises(UnicodeDecodeError) as context:
    103       message_module.TestAllTypes.FromString(bad_utf8_data)
    104     self.assertIn('TestAllTypes.optional_string', str(context.exception))
    105 
    106   def testGoldenMessage(self, message_module):
    107     # Proto3 doesn't have the "default_foo" members or foreign enums,
    108     # and doesn't preserve unknown fields, so for proto3 we use a golden
    109     # message that doesn't have these fields set.
    110     if message_module is unittest_pb2:
    111       golden_data = test_util.GoldenFileData(
    112           'golden_message_oneof_implemented')
    113     else:
    114       golden_data = test_util.GoldenFileData('golden_message_proto3')
    115 
    116     golden_message = message_module.TestAllTypes()
    117     golden_message.ParseFromString(golden_data)
    118     if message_module is unittest_pb2:
    119       test_util.ExpectAllFieldsSet(self, golden_message)
    120     self.assertEqual(golden_data, golden_message.SerializeToString())
    121     golden_copy = copy.deepcopy(golden_message)
    122     self.assertEqual(golden_data, golden_copy.SerializeToString())
    123 
    124   def testGoldenPackedMessage(self, message_module):
    125     golden_data = test_util.GoldenFileData('golden_packed_fields_message')
    126     golden_message = message_module.TestPackedTypes()
    127     golden_message.ParseFromString(golden_data)
    128     all_set = message_module.TestPackedTypes()
    129     test_util.SetAllPackedFields(all_set)
    130     self.assertEqual(all_set, golden_message)
    131     self.assertEqual(golden_data, all_set.SerializeToString())
    132     golden_copy = copy.deepcopy(golden_message)
    133     self.assertEqual(golden_data, golden_copy.SerializeToString())
    134 
    135   def testPickleSupport(self, message_module):
    136     golden_data = test_util.GoldenFileData('golden_message')
    137     golden_message = message_module.TestAllTypes()
    138     golden_message.ParseFromString(golden_data)
    139     pickled_message = pickle.dumps(golden_message)
    140 
    141     unpickled_message = pickle.loads(pickled_message)
    142     self.assertEqual(unpickled_message, golden_message)
    143 
    144   def testPositiveInfinity(self, message_module):
    145     if message_module is unittest_pb2:
    146       golden_data = (b'\x5D\x00\x00\x80\x7F'
    147                      b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
    148                      b'\xCD\x02\x00\x00\x80\x7F'
    149                      b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F')
    150     else:
    151       golden_data = (b'\x5D\x00\x00\x80\x7F'
    152                      b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F'
    153                      b'\xCA\x02\x04\x00\x00\x80\x7F'
    154                      b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
    155 
    156     golden_message = message_module.TestAllTypes()
    157     golden_message.ParseFromString(golden_data)
    158     self.assertTrue(IsPosInf(golden_message.optional_float))
    159     self.assertTrue(IsPosInf(golden_message.optional_double))
    160     self.assertTrue(IsPosInf(golden_message.repeated_float[0]))
    161     self.assertTrue(IsPosInf(golden_message.repeated_double[0]))
    162     self.assertEqual(golden_data, golden_message.SerializeToString())
    163 
    164   def testNegativeInfinity(self, message_module):
    165     if message_module is unittest_pb2:
    166       golden_data = (b'\x5D\x00\x00\x80\xFF'
    167                      b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
    168                      b'\xCD\x02\x00\x00\x80\xFF'
    169                      b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF')
    170     else:
    171       golden_data = (b'\x5D\x00\x00\x80\xFF'
    172                      b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF'
    173                      b'\xCA\x02\x04\x00\x00\x80\xFF'
    174                      b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
    175 
    176     golden_message = message_module.TestAllTypes()
    177     golden_message.ParseFromString(golden_data)
    178     self.assertTrue(IsNegInf(golden_message.optional_float))
    179     self.assertTrue(IsNegInf(golden_message.optional_double))
    180     self.assertTrue(IsNegInf(golden_message.repeated_float[0]))
    181     self.assertTrue(IsNegInf(golden_message.repeated_double[0]))
    182     self.assertEqual(golden_data, golden_message.SerializeToString())
    183 
    184   def testNotANumber(self, message_module):
    185     golden_data = (b'\x5D\x00\x00\xC0\x7F'
    186                    b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F'
    187                    b'\xCD\x02\x00\x00\xC0\x7F'
    188                    b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F')
    189     golden_message = message_module.TestAllTypes()
    190     golden_message.ParseFromString(golden_data)
    191     self.assertTrue(isnan(golden_message.optional_float))
    192     self.assertTrue(isnan(golden_message.optional_double))
    193     self.assertTrue(isnan(golden_message.repeated_float[0]))
    194     self.assertTrue(isnan(golden_message.repeated_double[0]))
    195 
    196     # The protocol buffer may serialize to any one of multiple different
    197     # representations of a NaN.  Rather than verify a specific representation,
    198     # verify the serialized string can be converted into a correctly
    199     # behaving protocol buffer.
    200     serialized = golden_message.SerializeToString()
    201     message = message_module.TestAllTypes()
    202     message.ParseFromString(serialized)
    203     self.assertTrue(isnan(message.optional_float))
    204     self.assertTrue(isnan(message.optional_double))
    205     self.assertTrue(isnan(message.repeated_float[0]))
    206     self.assertTrue(isnan(message.repeated_double[0]))
    207 
    208   def testPositiveInfinityPacked(self, message_module):
    209     golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F'
    210                    b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F')
    211     golden_message = message_module.TestPackedTypes()
    212     golden_message.ParseFromString(golden_data)
    213     self.assertTrue(IsPosInf(golden_message.packed_float[0]))
    214     self.assertTrue(IsPosInf(golden_message.packed_double[0]))
    215     self.assertEqual(golden_data, golden_message.SerializeToString())
    216 
    217   def testNegativeInfinityPacked(self, message_module):
    218     golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF'
    219                    b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF')
    220     golden_message = message_module.TestPackedTypes()
    221     golden_message.ParseFromString(golden_data)
    222     self.assertTrue(IsNegInf(golden_message.packed_float[0]))
    223     self.assertTrue(IsNegInf(golden_message.packed_double[0]))
    224     self.assertEqual(golden_data, golden_message.SerializeToString())
    225 
    226   def testNotANumberPacked(self, message_module):
    227     golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F'
    228                    b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F')
    229     golden_message = message_module.TestPackedTypes()
    230     golden_message.ParseFromString(golden_data)
    231     self.assertTrue(isnan(golden_message.packed_float[0]))
    232     self.assertTrue(isnan(golden_message.packed_double[0]))
    233 
    234     serialized = golden_message.SerializeToString()
    235     message = message_module.TestPackedTypes()
    236     message.ParseFromString(serialized)
    237     self.assertTrue(isnan(message.packed_float[0]))
    238     self.assertTrue(isnan(message.packed_double[0]))
    239 
    240   def testExtremeFloatValues(self, message_module):
    241     message = message_module.TestAllTypes()
    242 
    243     # Most positive exponent, no significand bits set.
    244     kMostPosExponentNoSigBits = math.pow(2, 127)
    245     message.optional_float = kMostPosExponentNoSigBits
    246     message.ParseFromString(message.SerializeToString())
    247     self.assertTrue(message.optional_float == kMostPosExponentNoSigBits)
    248 
    249     # Most positive exponent, one significand bit set.
    250     kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127)
    251     message.optional_float = kMostPosExponentOneSigBit
    252     message.ParseFromString(message.SerializeToString())
    253     self.assertTrue(message.optional_float == kMostPosExponentOneSigBit)
    254 
    255     # Repeat last two cases with values of same magnitude, but negative.
    256     message.optional_float = -kMostPosExponentNoSigBits
    257     message.ParseFromString(message.SerializeToString())
    258     self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits)
    259 
    260     message.optional_float = -kMostPosExponentOneSigBit
    261     message.ParseFromString(message.SerializeToString())
    262     self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit)
    263 
    264     # Most negative exponent, no significand bits set.
    265     kMostNegExponentNoSigBits = math.pow(2, -127)
    266     message.optional_float = kMostNegExponentNoSigBits
    267     message.ParseFromString(message.SerializeToString())
    268     self.assertTrue(message.optional_float == kMostNegExponentNoSigBits)
    269 
    270     # Most negative exponent, one significand bit set.
    271     kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127)
    272     message.optional_float = kMostNegExponentOneSigBit
    273     message.ParseFromString(message.SerializeToString())
    274     self.assertTrue(message.optional_float == kMostNegExponentOneSigBit)
    275 
    276     # Repeat last two cases with values of the same magnitude, but negative.
    277     message.optional_float = -kMostNegExponentNoSigBits
    278     message.ParseFromString(message.SerializeToString())
    279     self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits)
    280 
    281     message.optional_float = -kMostNegExponentOneSigBit
    282     message.ParseFromString(message.SerializeToString())
    283     self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit)
    284 
    285   def testExtremeDoubleValues(self, message_module):
    286     message = message_module.TestAllTypes()
    287 
    288     # Most positive exponent, no significand bits set.
    289     kMostPosExponentNoSigBits = math.pow(2, 1023)
    290     message.optional_double = kMostPosExponentNoSigBits
    291     message.ParseFromString(message.SerializeToString())
    292     self.assertTrue(message.optional_double == kMostPosExponentNoSigBits)
    293 
    294     # Most positive exponent, one significand bit set.
    295     kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023)
    296     message.optional_double = kMostPosExponentOneSigBit
    297     message.ParseFromString(message.SerializeToString())
    298     self.assertTrue(message.optional_double == kMostPosExponentOneSigBit)
    299 
    300     # Repeat last two cases with values of same magnitude, but negative.
    301     message.optional_double = -kMostPosExponentNoSigBits
    302     message.ParseFromString(message.SerializeToString())
    303     self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits)
    304 
    305     message.optional_double = -kMostPosExponentOneSigBit
    306     message.ParseFromString(message.SerializeToString())
    307     self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit)
    308 
    309     # Most negative exponent, no significand bits set.
    310     kMostNegExponentNoSigBits = math.pow(2, -1023)
    311     message.optional_double = kMostNegExponentNoSigBits
    312     message.ParseFromString(message.SerializeToString())
    313     self.assertTrue(message.optional_double == kMostNegExponentNoSigBits)
    314 
    315     # Most negative exponent, one significand bit set.
    316     kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023)
    317     message.optional_double = kMostNegExponentOneSigBit
    318     message.ParseFromString(message.SerializeToString())
    319     self.assertTrue(message.optional_double == kMostNegExponentOneSigBit)
    320 
    321     # Repeat last two cases with values of the same magnitude, but negative.
    322     message.optional_double = -kMostNegExponentNoSigBits
    323     message.ParseFromString(message.SerializeToString())
    324     self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits)
    325 
    326     message.optional_double = -kMostNegExponentOneSigBit
    327     message.ParseFromString(message.SerializeToString())
    328     self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit)
    329 
    330   def testFloatPrinting(self, message_module):
    331     message = message_module.TestAllTypes()
    332     message.optional_float = 2.0
    333     self.assertEqual(str(message), 'optional_float: 2.0\n')
    334 
    335   def testHighPrecisionFloatPrinting(self, message_module):
    336     message = message_module.TestAllTypes()
    337     message.optional_double = 0.12345678912345678
    338     if sys.version_info >= (3,):
    339       self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n')
    340     else:
    341       self.assertEqual(str(message), 'optional_double: 0.123456789123\n')
    342 
    343   def testUnknownFieldPrinting(self, message_module):
    344     populated = message_module.TestAllTypes()
    345     test_util.SetAllNonLazyFields(populated)
    346     empty = message_module.TestEmptyMessage()
    347     empty.ParseFromString(populated.SerializeToString())
    348     self.assertEqual(str(empty), '')
    349 
    350   def testRepeatedNestedFieldIteration(self, message_module):
    351     msg = message_module.TestAllTypes()
    352     msg.repeated_nested_message.add(bb=1)
    353     msg.repeated_nested_message.add(bb=2)
    354     msg.repeated_nested_message.add(bb=3)
    355     msg.repeated_nested_message.add(bb=4)
    356 
    357     self.assertEqual([1, 2, 3, 4],
    358                      [m.bb for m in msg.repeated_nested_message])
    359     self.assertEqual([4, 3, 2, 1],
    360                      [m.bb for m in reversed(msg.repeated_nested_message)])
    361     self.assertEqual([4, 3, 2, 1],
    362                      [m.bb for m in msg.repeated_nested_message[::-1]])
    363 
    364   def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module):
    365     """Check some different types with the default comparator."""
    366     message = message_module.TestAllTypes()
    367 
    368     # TODO(mattp): would testing more scalar types strengthen test?
    369     message.repeated_int32.append(1)
    370     message.repeated_int32.append(3)
    371     message.repeated_int32.append(2)
    372     message.repeated_int32.sort()
    373     self.assertEqual(message.repeated_int32[0], 1)
    374     self.assertEqual(message.repeated_int32[1], 2)
    375     self.assertEqual(message.repeated_int32[2], 3)
    376 
    377     message.repeated_float.append(1.1)
    378     message.repeated_float.append(1.3)
    379     message.repeated_float.append(1.2)
    380     message.repeated_float.sort()
    381     self.assertAlmostEqual(message.repeated_float[0], 1.1)
    382     self.assertAlmostEqual(message.repeated_float[1], 1.2)
    383     self.assertAlmostEqual(message.repeated_float[2], 1.3)
    384 
    385     message.repeated_string.append('a')
    386     message.repeated_string.append('c')
    387     message.repeated_string.append('b')
    388     message.repeated_string.sort()
    389     self.assertEqual(message.repeated_string[0], 'a')
    390     self.assertEqual(message.repeated_string[1], 'b')
    391     self.assertEqual(message.repeated_string[2], 'c')
    392 
    393     message.repeated_bytes.append(b'a')
    394     message.repeated_bytes.append(b'c')
    395     message.repeated_bytes.append(b'b')
    396     message.repeated_bytes.sort()
    397     self.assertEqual(message.repeated_bytes[0], b'a')
    398     self.assertEqual(message.repeated_bytes[1], b'b')
    399     self.assertEqual(message.repeated_bytes[2], b'c')
    400 
    401   def testSortingRepeatedScalarFieldsCustomComparator(self, message_module):
    402     """Check some different types with custom comparator."""
    403     message = message_module.TestAllTypes()
    404 
    405     message.repeated_int32.append(-3)
    406     message.repeated_int32.append(-2)
    407     message.repeated_int32.append(-1)
    408     message.repeated_int32.sort(key=abs)
    409     self.assertEqual(message.repeated_int32[0], -1)
    410     self.assertEqual(message.repeated_int32[1], -2)
    411     self.assertEqual(message.repeated_int32[2], -3)
    412 
    413     message.repeated_string.append('aaa')
    414     message.repeated_string.append('bb')
    415     message.repeated_string.append('c')
    416     message.repeated_string.sort(key=len)
    417     self.assertEqual(message.repeated_string[0], 'c')
    418     self.assertEqual(message.repeated_string[1], 'bb')
    419     self.assertEqual(message.repeated_string[2], 'aaa')
    420 
    421   def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module):
    422     """Check passing a custom comparator to sort a repeated composite field."""
    423     message = message_module.TestAllTypes()
    424 
    425     message.repeated_nested_message.add().bb = 1
    426     message.repeated_nested_message.add().bb = 3
    427     message.repeated_nested_message.add().bb = 2
    428     message.repeated_nested_message.add().bb = 6
    429     message.repeated_nested_message.add().bb = 5
    430     message.repeated_nested_message.add().bb = 4
    431     message.repeated_nested_message.sort(key=operator.attrgetter('bb'))
    432     self.assertEqual(message.repeated_nested_message[0].bb, 1)
    433     self.assertEqual(message.repeated_nested_message[1].bb, 2)
    434     self.assertEqual(message.repeated_nested_message[2].bb, 3)
    435     self.assertEqual(message.repeated_nested_message[3].bb, 4)
    436     self.assertEqual(message.repeated_nested_message[4].bb, 5)
    437     self.assertEqual(message.repeated_nested_message[5].bb, 6)
    438 
    439   def testSortingRepeatedCompositeFieldsStable(self, message_module):
    440     """Check passing a custom comparator to sort a repeated composite field."""
    441     message = message_module.TestAllTypes()
    442 
    443     message.repeated_nested_message.add().bb = 21
    444     message.repeated_nested_message.add().bb = 20
    445     message.repeated_nested_message.add().bb = 13
    446     message.repeated_nested_message.add().bb = 33
    447     message.repeated_nested_message.add().bb = 11
    448     message.repeated_nested_message.add().bb = 24
    449     message.repeated_nested_message.add().bb = 10
    450     message.repeated_nested_message.sort(key=lambda z: z.bb // 10)
    451     self.assertEqual(
    452         [13, 11, 10, 21, 20, 24, 33],
    453         [n.bb for n in message.repeated_nested_message])
    454 
    455     # Make sure that for the C++ implementation, the underlying fields
    456     # are actually reordered.
    457     pb = message.SerializeToString()
    458     message.Clear()
    459     message.MergeFromString(pb)
    460     self.assertEqual(
    461         [13, 11, 10, 21, 20, 24, 33],
    462         [n.bb for n in message.repeated_nested_message])
    463 
    464   def testRepeatedCompositeFieldSortArguments(self, message_module):
    465     """Check sorting a repeated composite field using list.sort() arguments."""
    466     message = message_module.TestAllTypes()
    467 
    468     get_bb = operator.attrgetter('bb')
    469     cmp_bb = lambda a, b: cmp(a.bb, b.bb)
    470     message.repeated_nested_message.add().bb = 1
    471     message.repeated_nested_message.add().bb = 3
    472     message.repeated_nested_message.add().bb = 2
    473     message.repeated_nested_message.add().bb = 6
    474     message.repeated_nested_message.add().bb = 5
    475     message.repeated_nested_message.add().bb = 4
    476     message.repeated_nested_message.sort(key=get_bb)
    477     self.assertEqual([k.bb for k in message.repeated_nested_message],
    478                      [1, 2, 3, 4, 5, 6])
    479     message.repeated_nested_message.sort(key=get_bb, reverse=True)
    480     self.assertEqual([k.bb for k in message.repeated_nested_message],
    481                      [6, 5, 4, 3, 2, 1])
    482     if sys.version_info >= (3,): return  # No cmp sorting in PY3.
    483     message.repeated_nested_message.sort(sort_function=cmp_bb)
    484     self.assertEqual([k.bb for k in message.repeated_nested_message],
    485                      [1, 2, 3, 4, 5, 6])
    486     message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True)
    487     self.assertEqual([k.bb for k in message.repeated_nested_message],
    488                      [6, 5, 4, 3, 2, 1])
    489 
    490   def testRepeatedScalarFieldSortArguments(self, message_module):
    491     """Check sorting a scalar field using list.sort() arguments."""
    492     message = message_module.TestAllTypes()
    493 
    494     message.repeated_int32.append(-3)
    495     message.repeated_int32.append(-2)
    496     message.repeated_int32.append(-1)
    497     message.repeated_int32.sort(key=abs)
    498     self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
    499     message.repeated_int32.sort(key=abs, reverse=True)
    500     self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
    501     if sys.version_info < (3,):  # No cmp sorting in PY3.
    502       abs_cmp = lambda a, b: cmp(abs(a), abs(b))
    503       message.repeated_int32.sort(sort_function=abs_cmp)
    504       self.assertEqual(list(message.repeated_int32), [-1, -2, -3])
    505       message.repeated_int32.sort(cmp=abs_cmp, reverse=True)
    506       self.assertEqual(list(message.repeated_int32), [-3, -2, -1])
    507 
    508     message.repeated_string.append('aaa')
    509     message.repeated_string.append('bb')
    510     message.repeated_string.append('c')
    511     message.repeated_string.sort(key=len)
    512     self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
    513     message.repeated_string.sort(key=len, reverse=True)
    514     self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
    515     if sys.version_info < (3,):  # No cmp sorting in PY3.
    516       len_cmp = lambda a, b: cmp(len(a), len(b))
    517       message.repeated_string.sort(sort_function=len_cmp)
    518       self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa'])
    519       message.repeated_string.sort(cmp=len_cmp, reverse=True)
    520       self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c'])
    521 
    522   def testRepeatedFieldsComparable(self, message_module):
    523     m1 = message_module.TestAllTypes()
    524     m2 = message_module.TestAllTypes()
    525     m1.repeated_int32.append(0)
    526     m1.repeated_int32.append(1)
    527     m1.repeated_int32.append(2)
    528     m2.repeated_int32.append(0)
    529     m2.repeated_int32.append(1)
    530     m2.repeated_int32.append(2)
    531     m1.repeated_nested_message.add().bb = 1
    532     m1.repeated_nested_message.add().bb = 2
    533     m1.repeated_nested_message.add().bb = 3
    534     m2.repeated_nested_message.add().bb = 1
    535     m2.repeated_nested_message.add().bb = 2
    536     m2.repeated_nested_message.add().bb = 3
    537 
    538     if sys.version_info >= (3,): return  # No cmp() in PY3.
    539 
    540     # These comparisons should not raise errors.
    541     _ = m1 < m2
    542     _ = m1.repeated_nested_message < m2.repeated_nested_message
    543 
    544     # Make sure cmp always works. If it wasn't defined, these would be
    545     # id() comparisons and would all fail.
    546     self.assertEqual(cmp(m1, m2), 0)
    547     self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0)
    548     self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0)
    549     self.assertEqual(cmp(m1.repeated_nested_message,
    550                          m2.repeated_nested_message), 0)
    551     with self.assertRaises(TypeError):
    552       # Can't compare repeated composite containers to lists.
    553       cmp(m1.repeated_nested_message, m2.repeated_nested_message[:])
    554 
    555     # TODO(anuraag): Implement extensiondict comparison in C++ and then add test
    556 
    557   def testRepeatedFieldsAreSequences(self, message_module):
    558     m = message_module.TestAllTypes()
    559     self.assertIsInstance(m.repeated_int32, collections.MutableSequence)
    560     self.assertIsInstance(m.repeated_nested_message,
    561                           collections.MutableSequence)
    562 
    563   def ensureNestedMessageExists(self, msg, attribute):
    564     """Make sure that a nested message object exists.
    565 
    566     As soon as a nested message attribute is accessed, it will be present in the
    567     _fields dict, without being marked as actually being set.
    568     """
    569     getattr(msg, attribute)
    570     self.assertFalse(msg.HasField(attribute))
    571 
    572   def testOneofGetCaseNonexistingField(self, message_module):
    573     m = message_module.TestAllTypes()
    574     self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field')
    575 
    576   def testOneofDefaultValues(self, message_module):
    577     m = message_module.TestAllTypes()
    578     self.assertIs(None, m.WhichOneof('oneof_field'))
    579     self.assertFalse(m.HasField('oneof_uint32'))
    580 
    581     # Oneof is set even when setting it to a default value.
    582     m.oneof_uint32 = 0
    583     self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
    584     self.assertTrue(m.HasField('oneof_uint32'))
    585     self.assertFalse(m.HasField('oneof_string'))
    586 
    587     m.oneof_string = ""
    588     self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
    589     self.assertTrue(m.HasField('oneof_string'))
    590     self.assertFalse(m.HasField('oneof_uint32'))
    591 
    592   def testOneofSemantics(self, message_module):
    593     m = message_module.TestAllTypes()
    594     self.assertIs(None, m.WhichOneof('oneof_field'))
    595 
    596     m.oneof_uint32 = 11
    597     self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
    598     self.assertTrue(m.HasField('oneof_uint32'))
    599 
    600     m.oneof_string = u'foo'
    601     self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
    602     self.assertFalse(m.HasField('oneof_uint32'))
    603     self.assertTrue(m.HasField('oneof_string'))
    604 
    605     # Read nested message accessor without accessing submessage.
    606     m.oneof_nested_message
    607     self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
    608     self.assertTrue(m.HasField('oneof_string'))
    609     self.assertFalse(m.HasField('oneof_nested_message'))
    610 
    611     # Read accessor of nested message without accessing submessage.
    612     m.oneof_nested_message.bb
    613     self.assertEqual('oneof_string', m.WhichOneof('oneof_field'))
    614     self.assertTrue(m.HasField('oneof_string'))
    615     self.assertFalse(m.HasField('oneof_nested_message'))
    616 
    617     m.oneof_nested_message.bb = 11
    618     self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
    619     self.assertFalse(m.HasField('oneof_string'))
    620     self.assertTrue(m.HasField('oneof_nested_message'))
    621 
    622     m.oneof_bytes = b'bb'
    623     self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
    624     self.assertFalse(m.HasField('oneof_nested_message'))
    625     self.assertTrue(m.HasField('oneof_bytes'))
    626 
    627   def testOneofCompositeFieldReadAccess(self, message_module):
    628     m = message_module.TestAllTypes()
    629     m.oneof_uint32 = 11
    630 
    631     self.ensureNestedMessageExists(m, 'oneof_nested_message')
    632     self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
    633     self.assertEqual(11, m.oneof_uint32)
    634 
    635   def testOneofWhichOneof(self, message_module):
    636     m = message_module.TestAllTypes()
    637     self.assertIs(None, m.WhichOneof('oneof_field'))
    638     if message_module is unittest_pb2:
    639       self.assertFalse(m.HasField('oneof_field'))
    640 
    641     m.oneof_uint32 = 11
    642     self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
    643     if message_module is unittest_pb2:
    644       self.assertTrue(m.HasField('oneof_field'))
    645 
    646     m.oneof_bytes = b'bb'
    647     self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
    648 
    649     m.ClearField('oneof_bytes')
    650     self.assertIs(None, m.WhichOneof('oneof_field'))
    651     if message_module is unittest_pb2:
    652       self.assertFalse(m.HasField('oneof_field'))
    653 
    654   def testOneofClearField(self, message_module):
    655     m = message_module.TestAllTypes()
    656     m.oneof_uint32 = 11
    657     m.ClearField('oneof_field')
    658     if message_module is unittest_pb2:
    659       self.assertFalse(m.HasField('oneof_field'))
    660     self.assertFalse(m.HasField('oneof_uint32'))
    661     self.assertIs(None, m.WhichOneof('oneof_field'))
    662 
    663   def testOneofClearSetField(self, message_module):
    664     m = message_module.TestAllTypes()
    665     m.oneof_uint32 = 11
    666     m.ClearField('oneof_uint32')
    667     if message_module is unittest_pb2:
    668       self.assertFalse(m.HasField('oneof_field'))
    669     self.assertFalse(m.HasField('oneof_uint32'))
    670     self.assertIs(None, m.WhichOneof('oneof_field'))
    671 
    672   def testOneofClearUnsetField(self, message_module):
    673     m = message_module.TestAllTypes()
    674     m.oneof_uint32 = 11
    675     self.ensureNestedMessageExists(m, 'oneof_nested_message')
    676     m.ClearField('oneof_nested_message')
    677     self.assertEqual(11, m.oneof_uint32)
    678     if message_module is unittest_pb2:
    679       self.assertTrue(m.HasField('oneof_field'))
    680     self.assertTrue(m.HasField('oneof_uint32'))
    681     self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field'))
    682 
    683   def testOneofDeserialize(self, message_module):
    684     m = message_module.TestAllTypes()
    685     m.oneof_uint32 = 11
    686     m2 = message_module.TestAllTypes()
    687     m2.ParseFromString(m.SerializeToString())
    688     self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
    689 
    690   def testOneofCopyFrom(self, message_module):
    691     m = message_module.TestAllTypes()
    692     m.oneof_uint32 = 11
    693     m2 = message_module.TestAllTypes()
    694     m2.CopyFrom(m)
    695     self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
    696 
    697   def testOneofNestedMergeFrom(self, message_module):
    698     m = message_module.NestedTestAllTypes()
    699     m.payload.oneof_uint32 = 11
    700     m2 = message_module.NestedTestAllTypes()
    701     m2.payload.oneof_bytes = b'bb'
    702     m2.child.payload.oneof_bytes = b'bb'
    703     m2.MergeFrom(m)
    704     self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field'))
    705     self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field'))
    706 
    707   def testOneofMessageMergeFrom(self, message_module):
    708     m = message_module.NestedTestAllTypes()
    709     m.payload.oneof_nested_message.bb = 11
    710     m.child.payload.oneof_nested_message.bb = 12
    711     m2 = message_module.NestedTestAllTypes()
    712     m2.payload.oneof_uint32 = 13
    713     m2.MergeFrom(m)
    714     self.assertEqual('oneof_nested_message',
    715                      m2.payload.WhichOneof('oneof_field'))
    716     self.assertEqual('oneof_nested_message',
    717                      m2.child.payload.WhichOneof('oneof_field'))
    718 
    719   def testOneofNestedMessageInit(self, message_module):
    720     m = message_module.TestAllTypes(
    721         oneof_nested_message=message_module.TestAllTypes.NestedMessage())
    722     self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field'))
    723 
    724   def testOneofClear(self, message_module):
    725     m = message_module.TestAllTypes()
    726     m.oneof_uint32 = 11
    727     m.Clear()
    728     self.assertIsNone(m.WhichOneof('oneof_field'))
    729     m.oneof_bytes = b'bb'
    730     self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field'))
    731 
    732   def testAssignByteStringToUnicodeField(self, message_module):
    733     """Assigning a byte string to a string field should result
    734     in the value being converted to a Unicode string."""
    735     m = message_module.TestAllTypes()
    736     m.optional_string = str('')
    737     self.assertIsInstance(m.optional_string, six.text_type)
    738 
    739   def testLongValuedSlice(self, message_module):
    740     """It should be possible to use long-valued indicies in slices
    741 
    742     This didn't used to work in the v2 C++ implementation.
    743     """
    744     m = message_module.TestAllTypes()
    745 
    746     # Repeated scalar
    747     m.repeated_int32.append(1)
    748     sl = m.repeated_int32[long(0):long(len(m.repeated_int32))]
    749     self.assertEqual(len(m.repeated_int32), len(sl))
    750 
    751     # Repeated composite
    752     m.repeated_nested_message.add().bb = 3
    753     sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))]
    754     self.assertEqual(len(m.repeated_nested_message), len(sl))
    755 
    756   def testExtendShouldNotSwallowExceptions(self, message_module):
    757     """This didn't use to work in the v2 C++ implementation."""
    758     m = message_module.TestAllTypes()
    759     with self.assertRaises(NameError) as _:
    760       m.repeated_int32.extend(a for i in range(10))  # pylint: disable=undefined-variable
    761     with self.assertRaises(NameError) as _:
    762       m.repeated_nested_enum.extend(
    763           a for i in range(10))  # pylint: disable=undefined-variable
    764 
    765   FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()]
    766 
    767   def testExtendInt32WithNothing(self, message_module):
    768     """Test no-ops extending repeated int32 fields."""
    769     m = message_module.TestAllTypes()
    770     self.assertSequenceEqual([], m.repeated_int32)
    771 
    772     # TODO(ptucker): Deprecate this behavior. b/18413862
    773     for falsy_value in MessageTest.FALSY_VALUES:
    774       m.repeated_int32.extend(falsy_value)
    775       self.assertSequenceEqual([], m.repeated_int32)
    776 
    777     m.repeated_int32.extend([])
    778     self.assertSequenceEqual([], m.repeated_int32)
    779 
    780   def testExtendFloatWithNothing(self, message_module):
    781     """Test no-ops extending repeated float fields."""
    782     m = message_module.TestAllTypes()
    783     self.assertSequenceEqual([], m.repeated_float)
    784 
    785     # TODO(ptucker): Deprecate this behavior. b/18413862
    786     for falsy_value in MessageTest.FALSY_VALUES:
    787       m.repeated_float.extend(falsy_value)
    788       self.assertSequenceEqual([], m.repeated_float)
    789 
    790     m.repeated_float.extend([])
    791     self.assertSequenceEqual([], m.repeated_float)
    792 
    793   def testExtendStringWithNothing(self, message_module):
    794     """Test no-ops extending repeated string fields."""
    795     m = message_module.TestAllTypes()
    796     self.assertSequenceEqual([], m.repeated_string)
    797 
    798     # TODO(ptucker): Deprecate this behavior. b/18413862
    799     for falsy_value in MessageTest.FALSY_VALUES:
    800       m.repeated_string.extend(falsy_value)
    801       self.assertSequenceEqual([], m.repeated_string)
    802 
    803     m.repeated_string.extend([])
    804     self.assertSequenceEqual([], m.repeated_string)
    805 
    806   def testExtendInt32WithPythonList(self, message_module):
    807     """Test extending repeated int32 fields with python lists."""
    808     m = message_module.TestAllTypes()
    809     self.assertSequenceEqual([], m.repeated_int32)
    810     m.repeated_int32.extend([0])
    811     self.assertSequenceEqual([0], m.repeated_int32)
    812     m.repeated_int32.extend([1, 2])
    813     self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
    814     m.repeated_int32.extend([3, 4])
    815     self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
    816 
    817   def testExtendFloatWithPythonList(self, message_module):
    818     """Test extending repeated float fields with python lists."""
    819     m = message_module.TestAllTypes()
    820     self.assertSequenceEqual([], m.repeated_float)
    821     m.repeated_float.extend([0.0])
    822     self.assertSequenceEqual([0.0], m.repeated_float)
    823     m.repeated_float.extend([1.0, 2.0])
    824     self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
    825     m.repeated_float.extend([3.0, 4.0])
    826     self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
    827 
    828   def testExtendStringWithPythonList(self, message_module):
    829     """Test extending repeated string fields with python lists."""
    830     m = message_module.TestAllTypes()
    831     self.assertSequenceEqual([], m.repeated_string)
    832     m.repeated_string.extend([''])
    833     self.assertSequenceEqual([''], m.repeated_string)
    834     m.repeated_string.extend(['11', '22'])
    835     self.assertSequenceEqual(['', '11', '22'], m.repeated_string)
    836     m.repeated_string.extend(['33', '44'])
    837     self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string)
    838 
    839   def testExtendStringWithString(self, message_module):
    840     """Test extending repeated string fields with characters from a string."""
    841     m = message_module.TestAllTypes()
    842     self.assertSequenceEqual([], m.repeated_string)
    843     m.repeated_string.extend('abc')
    844     self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string)
    845 
    846   class TestIterable(object):
    847     """This iterable object mimics the behavior of numpy.array.
    848 
    849     __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1.
    850 
    851     """
    852 
    853     def __init__(self, values=None):
    854       self._list = values or []
    855 
    856     def __nonzero__(self):
    857       size = len(self._list)
    858       if size == 0:
    859         return False
    860       if size == 1:
    861         return bool(self._list[0])
    862       raise ValueError('Truth value is ambiguous.')
    863 
    864     def __len__(self):
    865       return len(self._list)
    866 
    867     def __iter__(self):
    868       return self._list.__iter__()
    869 
    870   def testExtendInt32WithIterable(self, message_module):
    871     """Test extending repeated int32 fields with iterable."""
    872     m = message_module.TestAllTypes()
    873     self.assertSequenceEqual([], m.repeated_int32)
    874     m.repeated_int32.extend(MessageTest.TestIterable([]))
    875     self.assertSequenceEqual([], m.repeated_int32)
    876     m.repeated_int32.extend(MessageTest.TestIterable([0]))
    877     self.assertSequenceEqual([0], m.repeated_int32)
    878     m.repeated_int32.extend(MessageTest.TestIterable([1, 2]))
    879     self.assertSequenceEqual([0, 1, 2], m.repeated_int32)
    880     m.repeated_int32.extend(MessageTest.TestIterable([3, 4]))
    881     self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32)
    882 
    883   def testExtendFloatWithIterable(self, message_module):
    884     """Test extending repeated float fields with iterable."""
    885     m = message_module.TestAllTypes()
    886     self.assertSequenceEqual([], m.repeated_float)
    887     m.repeated_float.extend(MessageTest.TestIterable([]))
    888     self.assertSequenceEqual([], m.repeated_float)
    889     m.repeated_float.extend(MessageTest.TestIterable([0.0]))
    890     self.assertSequenceEqual([0.0], m.repeated_float)
    891     m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0]))
    892     self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float)
    893     m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0]))
    894     self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float)
    895 
    896   def testExtendStringWithIterable(self, message_module):
    897     """Test extending repeated string fields with iterable."""
    898     m = message_module.TestAllTypes()
    899     self.assertSequenceEqual([], m.repeated_string)
    900     m.repeated_string.extend(MessageTest.TestIterable([]))
    901     self.assertSequenceEqual([], m.repeated_string)
    902     m.repeated_string.extend(MessageTest.TestIterable(['']))
    903     self.assertSequenceEqual([''], m.repeated_string)
    904     m.repeated_string.extend(MessageTest.TestIterable(['1', '2']))
    905     self.assertSequenceEqual(['', '1', '2'], m.repeated_string)
    906     m.repeated_string.extend(MessageTest.TestIterable(['3', '4']))
    907     self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string)
    908 
    909   def testPickleRepeatedScalarContainer(self, message_module):
    910     # TODO(tibell): The pure-Python implementation support pickling of
    911     #   scalar containers in *some* cases. For now the cpp2 version
    912     #   throws an exception to avoid a segfault. Investigate if we
    913     #   want to support pickling of these fields.
    914     #
    915     # For more information see: https://b2.corp.google.com/u/0/issues/18677897
    916     if (api_implementation.Type() != 'cpp' or
    917         api_implementation.Version() == 2):
    918       return
    919     m = message_module.TestAllTypes()
    920     with self.assertRaises(pickle.PickleError) as _:
    921       pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL)
    922 
    923   def testSortEmptyRepeatedCompositeContainer(self, message_module):
    924     """Exercise a scenario that has led to segfaults in the past.
    925     """
    926     m = message_module.TestAllTypes()
    927     m.repeated_nested_message.sort()
    928 
    929   def testHasFieldOnRepeatedField(self, message_module):
    930     """Using HasField on a repeated field should raise an exception.
    931     """
    932     m = message_module.TestAllTypes()
    933     with self.assertRaises(ValueError) as _:
    934       m.HasField('repeated_int32')
    935 
    936   def testRepeatedScalarFieldPop(self, message_module):
    937     m = message_module.TestAllTypes()
    938     with self.assertRaises(IndexError) as _:
    939       m.repeated_int32.pop()
    940     m.repeated_int32.extend(range(5))
    941     self.assertEqual(4, m.repeated_int32.pop())
    942     self.assertEqual(0, m.repeated_int32.pop(0))
    943     self.assertEqual(2, m.repeated_int32.pop(1))
    944     self.assertEqual([1, 3], m.repeated_int32)
    945 
    946   def testRepeatedCompositeFieldPop(self, message_module):
    947     m = message_module.TestAllTypes()
    948     with self.assertRaises(IndexError) as _:
    949       m.repeated_nested_message.pop()
    950     for i in range(5):
    951       n = m.repeated_nested_message.add()
    952       n.bb = i
    953     self.assertEqual(4, m.repeated_nested_message.pop().bb)
    954     self.assertEqual(0, m.repeated_nested_message.pop(0).bb)
    955     self.assertEqual(2, m.repeated_nested_message.pop(1).bb)
    956     self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message])
    957 
    958 
    959 # Class to test proto2-only features (required, extensions, etc.)
    960 class Proto2Test(unittest.TestCase):
    961 
    962   def testFieldPresence(self):
    963     message = unittest_pb2.TestAllTypes()
    964 
    965     self.assertFalse(message.HasField("optional_int32"))
    966     self.assertFalse(message.HasField("optional_bool"))
    967     self.assertFalse(message.HasField("optional_nested_message"))
    968 
    969     with self.assertRaises(ValueError):
    970       message.HasField("field_doesnt_exist")
    971 
    972     with self.assertRaises(ValueError):
    973       message.HasField("repeated_int32")
    974     with self.assertRaises(ValueError):
    975       message.HasField("repeated_nested_message")
    976 
    977     self.assertEqual(0, message.optional_int32)
    978     self.assertEqual(False, message.optional_bool)
    979     self.assertEqual(0, message.optional_nested_message.bb)
    980 
    981     # Fields are set even when setting the values to default values.
    982     message.optional_int32 = 0
    983     message.optional_bool = False
    984     message.optional_nested_message.bb = 0
    985     self.assertTrue(message.HasField("optional_int32"))
    986     self.assertTrue(message.HasField("optional_bool"))
    987     self.assertTrue(message.HasField("optional_nested_message"))
    988 
    989     # Set the fields to non-default values.
    990     message.optional_int32 = 5
    991     message.optional_bool = True
    992     message.optional_nested_message.bb = 15
    993 
    994     self.assertTrue(message.HasField("optional_int32"))
    995     self.assertTrue(message.HasField("optional_bool"))
    996     self.assertTrue(message.HasField("optional_nested_message"))
    997 
    998     # Clearing the fields unsets them and resets their value to default.
    999     message.ClearField("optional_int32")
   1000     message.ClearField("optional_bool")
   1001     message.ClearField("optional_nested_message")
   1002 
   1003     self.assertFalse(message.HasField("optional_int32"))
   1004     self.assertFalse(message.HasField("optional_bool"))
   1005     self.assertFalse(message.HasField("optional_nested_message"))
   1006     self.assertEqual(0, message.optional_int32)
   1007     self.assertEqual(False, message.optional_bool)
   1008     self.assertEqual(0, message.optional_nested_message.bb)
   1009 
   1010   # TODO(tibell): The C++ implementations actually allows assignment
   1011   # of unknown enum values to *scalar* fields (but not repeated
   1012   # fields). Once checked enum fields becomes the default in the
   1013   # Python implementation, the C++ implementation should follow suit.
   1014   def testAssignInvalidEnum(self):
   1015     """It should not be possible to assign an invalid enum number to an
   1016     enum field."""
   1017     m = unittest_pb2.TestAllTypes()
   1018 
   1019     with self.assertRaises(ValueError) as _:
   1020       m.optional_nested_enum = 1234567
   1021     self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567)
   1022 
   1023   def testGoldenExtensions(self):
   1024     golden_data = test_util.GoldenFileData('golden_message')
   1025     golden_message = unittest_pb2.TestAllExtensions()
   1026     golden_message.ParseFromString(golden_data)
   1027     all_set = unittest_pb2.TestAllExtensions()
   1028     test_util.SetAllExtensions(all_set)
   1029     self.assertEqual(all_set, golden_message)
   1030     self.assertEqual(golden_data, golden_message.SerializeToString())
   1031     golden_copy = copy.deepcopy(golden_message)
   1032     self.assertEqual(golden_data, golden_copy.SerializeToString())
   1033 
   1034   def testGoldenPackedExtensions(self):
   1035     golden_data = test_util.GoldenFileData('golden_packed_fields_message')
   1036     golden_message = unittest_pb2.TestPackedExtensions()
   1037     golden_message.ParseFromString(golden_data)
   1038     all_set = unittest_pb2.TestPackedExtensions()
   1039     test_util.SetAllPackedExtensions(all_set)
   1040     self.assertEqual(all_set, golden_message)
   1041     self.assertEqual(golden_data, all_set.SerializeToString())
   1042     golden_copy = copy.deepcopy(golden_message)
   1043     self.assertEqual(golden_data, golden_copy.SerializeToString())
   1044 
   1045   def testPickleIncompleteProto(self):
   1046     golden_message = unittest_pb2.TestRequired(a=1)
   1047     pickled_message = pickle.dumps(golden_message)
   1048 
   1049     unpickled_message = pickle.loads(pickled_message)
   1050     self.assertEqual(unpickled_message, golden_message)
   1051     self.assertEqual(unpickled_message.a, 1)
   1052     # This is still an incomplete proto - so serializing should fail
   1053     self.assertRaises(message.EncodeError, unpickled_message.SerializeToString)
   1054 
   1055 
   1056   # TODO(haberman): this isn't really a proto2-specific test except that this
   1057   # message has a required field in it.  Should probably be factored out so
   1058   # that we can test the other parts with proto3.
   1059   def testParsingMerge(self):
   1060     """Check the merge behavior when a required or optional field appears
   1061     multiple times in the input."""
   1062     messages = [
   1063         unittest_pb2.TestAllTypes(),
   1064         unittest_pb2.TestAllTypes(),
   1065         unittest_pb2.TestAllTypes() ]
   1066     messages[0].optional_int32 = 1
   1067     messages[1].optional_int64 = 2
   1068     messages[2].optional_int32 = 3
   1069     messages[2].optional_string = 'hello'
   1070 
   1071     merged_message = unittest_pb2.TestAllTypes()
   1072     merged_message.optional_int32 = 3
   1073     merged_message.optional_int64 = 2
   1074     merged_message.optional_string = 'hello'
   1075 
   1076     generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator()
   1077     generator.field1.extend(messages)
   1078     generator.field2.extend(messages)
   1079     generator.field3.extend(messages)
   1080     generator.ext1.extend(messages)
   1081     generator.ext2.extend(messages)
   1082     generator.group1.add().field1.MergeFrom(messages[0])
   1083     generator.group1.add().field1.MergeFrom(messages[1])
   1084     generator.group1.add().field1.MergeFrom(messages[2])
   1085     generator.group2.add().field1.MergeFrom(messages[0])
   1086     generator.group2.add().field1.MergeFrom(messages[1])
   1087     generator.group2.add().field1.MergeFrom(messages[2])
   1088 
   1089     data = generator.SerializeToString()
   1090     parsing_merge = unittest_pb2.TestParsingMerge()
   1091     parsing_merge.ParseFromString(data)
   1092 
   1093     # Required and optional fields should be merged.
   1094     self.assertEqual(parsing_merge.required_all_types, merged_message)
   1095     self.assertEqual(parsing_merge.optional_all_types, merged_message)
   1096     self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types,
   1097                      merged_message)
   1098     self.assertEqual(parsing_merge.Extensions[
   1099                      unittest_pb2.TestParsingMerge.optional_ext],
   1100                      merged_message)
   1101 
   1102     # Repeated fields should not be merged.
   1103     self.assertEqual(len(parsing_merge.repeated_all_types), 3)
   1104     self.assertEqual(len(parsing_merge.repeatedgroup), 3)
   1105     self.assertEqual(len(parsing_merge.Extensions[
   1106         unittest_pb2.TestParsingMerge.repeated_ext]), 3)
   1107 
   1108   def testPythonicInit(self):
   1109     message = unittest_pb2.TestAllTypes(
   1110         optional_int32=100,
   1111         optional_fixed32=200,
   1112         optional_float=300.5,
   1113         optional_bytes=b'x',
   1114         optionalgroup={'a': 400},
   1115         optional_nested_message={'bb': 500},
   1116         optional_nested_enum='BAZ',
   1117         repeatedgroup=[{'a': 600},
   1118                        {'a': 700}],
   1119         repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR],
   1120         default_int32=800,
   1121         oneof_string='y')
   1122     self.assertIsInstance(message, unittest_pb2.TestAllTypes)
   1123     self.assertEqual(100, message.optional_int32)
   1124     self.assertEqual(200, message.optional_fixed32)
   1125     self.assertEqual(300.5, message.optional_float)
   1126     self.assertEqual(b'x', message.optional_bytes)
   1127     self.assertEqual(400, message.optionalgroup.a)
   1128     self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage)
   1129     self.assertEqual(500, message.optional_nested_message.bb)
   1130     self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
   1131                      message.optional_nested_enum)
   1132     self.assertEqual(2, len(message.repeatedgroup))
   1133     self.assertEqual(600, message.repeatedgroup[0].a)
   1134     self.assertEqual(700, message.repeatedgroup[1].a)
   1135     self.assertEqual(2, len(message.repeated_nested_enum))
   1136     self.assertEqual(unittest_pb2.TestAllTypes.FOO,
   1137                      message.repeated_nested_enum[0])
   1138     self.assertEqual(unittest_pb2.TestAllTypes.BAR,
   1139                      message.repeated_nested_enum[1])
   1140     self.assertEqual(800, message.default_int32)
   1141     self.assertEqual('y', message.oneof_string)
   1142     self.assertFalse(message.HasField('optional_int64'))
   1143     self.assertEqual(0, len(message.repeated_float))
   1144     self.assertEqual(42, message.default_int64)
   1145 
   1146     message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ')
   1147     self.assertEqual(unittest_pb2.TestAllTypes.BAZ,
   1148                      message.optional_nested_enum)
   1149 
   1150     with self.assertRaises(ValueError):
   1151       unittest_pb2.TestAllTypes(
   1152           optional_nested_message={'INVALID_NESTED_FIELD': 17})
   1153 
   1154     with self.assertRaises(TypeError):
   1155       unittest_pb2.TestAllTypes(
   1156           optional_nested_message={'bb': 'INVALID_VALUE_TYPE'})
   1157 
   1158     with self.assertRaises(ValueError):
   1159       unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL')
   1160 
   1161     with self.assertRaises(ValueError):
   1162       unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
   1163 
   1164 
   1165 
   1166 # Class to test proto3-only features/behavior (updated field presence & enums)
   1167 class Proto3Test(unittest.TestCase):
   1168 
   1169   # Utility method for comparing equality with a map.
   1170   def assertMapIterEquals(self, map_iter, dict_value):
   1171     # Avoid mutating caller's copy.
   1172     dict_value = dict(dict_value)
   1173 
   1174     for k, v in map_iter:
   1175       self.assertEqual(v, dict_value[k])
   1176       del dict_value[k]
   1177 
   1178     self.assertEqual({}, dict_value)
   1179 
   1180   def testFieldPresence(self):
   1181     message = unittest_proto3_arena_pb2.TestAllTypes()
   1182 
   1183     # We can't test presence of non-repeated, non-submessage fields.
   1184     with self.assertRaises(ValueError):
   1185       message.HasField('optional_int32')
   1186     with self.assertRaises(ValueError):
   1187       message.HasField('optional_float')
   1188     with self.assertRaises(ValueError):
   1189       message.HasField('optional_string')
   1190     with self.assertRaises(ValueError):
   1191       message.HasField('optional_bool')
   1192 
   1193     # But we can still test presence of submessage fields.
   1194     self.assertFalse(message.HasField('optional_nested_message'))
   1195 
   1196     # As with proto2, we can't test presence of fields that don't exist, or
   1197     # repeated fields.
   1198     with self.assertRaises(ValueError):
   1199       message.HasField('field_doesnt_exist')
   1200 
   1201     with self.assertRaises(ValueError):
   1202       message.HasField('repeated_int32')
   1203     with self.assertRaises(ValueError):
   1204       message.HasField('repeated_nested_message')
   1205 
   1206     # Fields should default to their type-specific default.
   1207     self.assertEqual(0, message.optional_int32)
   1208     self.assertEqual(0, message.optional_float)
   1209     self.assertEqual('', message.optional_string)
   1210     self.assertEqual(False, message.optional_bool)
   1211     self.assertEqual(0, message.optional_nested_message.bb)
   1212 
   1213     # Setting a submessage should still return proper presence information.
   1214     message.optional_nested_message.bb = 0
   1215     self.assertTrue(message.HasField('optional_nested_message'))
   1216 
   1217     # Set the fields to non-default values.
   1218     message.optional_int32 = 5
   1219     message.optional_float = 1.1
   1220     message.optional_string = 'abc'
   1221     message.optional_bool = True
   1222     message.optional_nested_message.bb = 15
   1223 
   1224     # Clearing the fields unsets them and resets their value to default.
   1225     message.ClearField('optional_int32')
   1226     message.ClearField('optional_float')
   1227     message.ClearField('optional_string')
   1228     message.ClearField('optional_bool')
   1229     message.ClearField('optional_nested_message')
   1230 
   1231     self.assertEqual(0, message.optional_int32)
   1232     self.assertEqual(0, message.optional_float)
   1233     self.assertEqual('', message.optional_string)
   1234     self.assertEqual(False, message.optional_bool)
   1235     self.assertEqual(0, message.optional_nested_message.bb)
   1236 
   1237   def testAssignUnknownEnum(self):
   1238     """Assigning an unknown enum value is allowed and preserves the value."""
   1239     m = unittest_proto3_arena_pb2.TestAllTypes()
   1240 
   1241     m.optional_nested_enum = 1234567
   1242     self.assertEqual(1234567, m.optional_nested_enum)
   1243     m.repeated_nested_enum.append(22334455)
   1244     self.assertEqual(22334455, m.repeated_nested_enum[0])
   1245     # Assignment is a different code path than append for the C++ impl.
   1246     m.repeated_nested_enum[0] = 7654321
   1247     self.assertEqual(7654321, m.repeated_nested_enum[0])
   1248     serialized = m.SerializeToString()
   1249 
   1250     m2 = unittest_proto3_arena_pb2.TestAllTypes()
   1251     m2.ParseFromString(serialized)
   1252     self.assertEqual(1234567, m2.optional_nested_enum)
   1253     self.assertEqual(7654321, m2.repeated_nested_enum[0])
   1254 
   1255   # Map isn't really a proto3-only feature. But there is no proto2 equivalent
   1256   # of google/protobuf/map_unittest.proto right now, so it's not easy to
   1257   # test both with the same test like we do for the other proto2/proto3 tests.
   1258   # (google/protobuf/map_protobuf_unittest.proto is very different in the set
   1259   # of messages and fields it contains).
   1260   def testScalarMapDefaults(self):
   1261     msg = map_unittest_pb2.TestMap()
   1262 
   1263     # Scalars start out unset.
   1264     self.assertFalse(-123 in msg.map_int32_int32)
   1265     self.assertFalse(-2**33 in msg.map_int64_int64)
   1266     self.assertFalse(123 in msg.map_uint32_uint32)
   1267     self.assertFalse(2**33 in msg.map_uint64_uint64)
   1268     self.assertFalse(123 in msg.map_int32_double)
   1269     self.assertFalse(False in msg.map_bool_bool)
   1270     self.assertFalse('abc' in msg.map_string_string)
   1271     self.assertFalse(111 in msg.map_int32_bytes)
   1272     self.assertFalse(888 in msg.map_int32_enum)
   1273 
   1274     # Accessing an unset key returns the default.
   1275     self.assertEqual(0, msg.map_int32_int32[-123])
   1276     self.assertEqual(0, msg.map_int64_int64[-2**33])
   1277     self.assertEqual(0, msg.map_uint32_uint32[123])
   1278     self.assertEqual(0, msg.map_uint64_uint64[2**33])
   1279     self.assertEqual(0.0, msg.map_int32_double[123])
   1280     self.assertTrue(isinstance(msg.map_int32_double[123], float))
   1281     self.assertEqual(False, msg.map_bool_bool[False])
   1282     self.assertTrue(isinstance(msg.map_bool_bool[False], bool))
   1283     self.assertEqual('', msg.map_string_string['abc'])
   1284     self.assertEqual(b'', msg.map_int32_bytes[111])
   1285     self.assertEqual(0, msg.map_int32_enum[888])
   1286 
   1287     # It also sets the value in the map
   1288     self.assertTrue(-123 in msg.map_int32_int32)
   1289     self.assertTrue(-2**33 in msg.map_int64_int64)
   1290     self.assertTrue(123 in msg.map_uint32_uint32)
   1291     self.assertTrue(2**33 in msg.map_uint64_uint64)
   1292     self.assertTrue(123 in msg.map_int32_double)
   1293     self.assertTrue(False in msg.map_bool_bool)
   1294     self.assertTrue('abc' in msg.map_string_string)
   1295     self.assertTrue(111 in msg.map_int32_bytes)
   1296     self.assertTrue(888 in msg.map_int32_enum)
   1297 
   1298     self.assertIsInstance(msg.map_string_string['abc'], six.text_type)
   1299 
   1300     # Accessing an unset key still throws TypeError if the type of the key
   1301     # is incorrect.
   1302     with self.assertRaises(TypeError):
   1303       msg.map_string_string[123]
   1304 
   1305     with self.assertRaises(TypeError):
   1306       123 in msg.map_string_string
   1307 
   1308   def testMapGet(self):
   1309     # Need to test that get() properly returns the default, even though the dict
   1310     # has defaultdict-like semantics.
   1311     msg = map_unittest_pb2.TestMap()
   1312 
   1313     self.assertIsNone(msg.map_int32_int32.get(5))
   1314     self.assertEqual(10, msg.map_int32_int32.get(5, 10))
   1315     self.assertIsNone(msg.map_int32_int32.get(5))
   1316 
   1317     msg.map_int32_int32[5] = 15
   1318     self.assertEqual(15, msg.map_int32_int32.get(5))
   1319 
   1320     self.assertIsNone(msg.map_int32_foreign_message.get(5))
   1321     self.assertEqual(10, msg.map_int32_foreign_message.get(5, 10))
   1322 
   1323     submsg = msg.map_int32_foreign_message[5]
   1324     self.assertIs(submsg, msg.map_int32_foreign_message.get(5))
   1325 
   1326   def testScalarMap(self):
   1327     msg = map_unittest_pb2.TestMap()
   1328 
   1329     self.assertEqual(0, len(msg.map_int32_int32))
   1330     self.assertFalse(5 in msg.map_int32_int32)
   1331 
   1332     msg.map_int32_int32[-123] = -456
   1333     msg.map_int64_int64[-2**33] = -2**34
   1334     msg.map_uint32_uint32[123] = 456
   1335     msg.map_uint64_uint64[2**33] = 2**34
   1336     msg.map_string_string['abc'] = '123'
   1337     msg.map_int32_enum[888] = 2
   1338 
   1339     self.assertEqual([], msg.FindInitializationErrors())
   1340 
   1341     self.assertEqual(1, len(msg.map_string_string))
   1342 
   1343     # Bad key.
   1344     with self.assertRaises(TypeError):
   1345       msg.map_string_string[123] = '123'
   1346 
   1347     # Verify that trying to assign a bad key doesn't actually add a member to
   1348     # the map.
   1349     self.assertEqual(1, len(msg.map_string_string))
   1350 
   1351     # Bad value.
   1352     with self.assertRaises(TypeError):
   1353       msg.map_string_string['123'] = 123
   1354 
   1355     serialized = msg.SerializeToString()
   1356     msg2 = map_unittest_pb2.TestMap()
   1357     msg2.ParseFromString(serialized)
   1358 
   1359     # Bad key.
   1360     with self.assertRaises(TypeError):
   1361       msg2.map_string_string[123] = '123'
   1362 
   1363     # Bad value.
   1364     with self.assertRaises(TypeError):
   1365       msg2.map_string_string['123'] = 123
   1366 
   1367     self.assertEqual(-456, msg2.map_int32_int32[-123])
   1368     self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
   1369     self.assertEqual(456, msg2.map_uint32_uint32[123])
   1370     self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
   1371     self.assertEqual('123', msg2.map_string_string['abc'])
   1372     self.assertEqual(2, msg2.map_int32_enum[888])
   1373 
   1374   def testStringUnicodeConversionInMap(self):
   1375     msg = map_unittest_pb2.TestMap()
   1376 
   1377     unicode_obj = u'\u1234'
   1378     bytes_obj = unicode_obj.encode('utf8')
   1379 
   1380     msg.map_string_string[bytes_obj] = bytes_obj
   1381 
   1382     (key, value) = list(msg.map_string_string.items())[0]
   1383 
   1384     self.assertEqual(key, unicode_obj)
   1385     self.assertEqual(value, unicode_obj)
   1386 
   1387     self.assertIsInstance(key, six.text_type)
   1388     self.assertIsInstance(value, six.text_type)
   1389 
   1390   def testMessageMap(self):
   1391     msg = map_unittest_pb2.TestMap()
   1392 
   1393     self.assertEqual(0, len(msg.map_int32_foreign_message))
   1394     self.assertFalse(5 in msg.map_int32_foreign_message)
   1395 
   1396     msg.map_int32_foreign_message[123]
   1397     # get_or_create() is an alias for getitem.
   1398     msg.map_int32_foreign_message.get_or_create(-456)
   1399 
   1400     self.assertEqual(2, len(msg.map_int32_foreign_message))
   1401     self.assertIn(123, msg.map_int32_foreign_message)
   1402     self.assertIn(-456, msg.map_int32_foreign_message)
   1403     self.assertEqual(2, len(msg.map_int32_foreign_message))
   1404 
   1405     # Bad key.
   1406     with self.assertRaises(TypeError):
   1407       msg.map_int32_foreign_message['123']
   1408 
   1409     # Can't assign directly to submessage.
   1410     with self.assertRaises(ValueError):
   1411       msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123]
   1412 
   1413     # Verify that trying to assign a bad key doesn't actually add a member to
   1414     # the map.
   1415     self.assertEqual(2, len(msg.map_int32_foreign_message))
   1416 
   1417     serialized = msg.SerializeToString()
   1418     msg2 = map_unittest_pb2.TestMap()
   1419     msg2.ParseFromString(serialized)
   1420 
   1421     self.assertEqual(2, len(msg2.map_int32_foreign_message))
   1422     self.assertIn(123, msg2.map_int32_foreign_message)
   1423     self.assertIn(-456, msg2.map_int32_foreign_message)
   1424     self.assertEqual(2, len(msg2.map_int32_foreign_message))
   1425 
   1426   def testMergeFrom(self):
   1427     msg = map_unittest_pb2.TestMap()
   1428     msg.map_int32_int32[12] = 34
   1429     msg.map_int32_int32[56] = 78
   1430     msg.map_int64_int64[22] = 33
   1431     msg.map_int32_foreign_message[111].c = 5
   1432     msg.map_int32_foreign_message[222].c = 10
   1433 
   1434     msg2 = map_unittest_pb2.TestMap()
   1435     msg2.map_int32_int32[12] = 55
   1436     msg2.map_int64_int64[88] = 99
   1437     msg2.map_int32_foreign_message[222].c = 15
   1438 
   1439     msg2.MergeFrom(msg)
   1440 
   1441     self.assertEqual(34, msg2.map_int32_int32[12])
   1442     self.assertEqual(78, msg2.map_int32_int32[56])
   1443     self.assertEqual(33, msg2.map_int64_int64[22])
   1444     self.assertEqual(99, msg2.map_int64_int64[88])
   1445     self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
   1446     self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
   1447 
   1448     # Verify that there is only one entry per key, even though the MergeFrom
   1449     # may have internally created multiple entries for a single key in the
   1450     # list representation.
   1451     as_dict = {}
   1452     for key in msg2.map_int32_foreign_message:
   1453       self.assertFalse(key in as_dict)
   1454       as_dict[key] = msg2.map_int32_foreign_message[key].c
   1455 
   1456     self.assertEqual({111: 5, 222: 10}, as_dict)
   1457 
   1458     # Special case: test that delete of item really removes the item, even if
   1459     # there might have physically been duplicate keys due to the previous merge.
   1460     # This is only a special case for the C++ implementation which stores the
   1461     # map as an array.
   1462     del msg2.map_int32_int32[12]
   1463     self.assertFalse(12 in msg2.map_int32_int32)
   1464 
   1465     del msg2.map_int32_foreign_message[222]
   1466     self.assertFalse(222 in msg2.map_int32_foreign_message)
   1467 
   1468   def testMergeFromBadType(self):
   1469     msg = map_unittest_pb2.TestMap()
   1470     with self.assertRaisesRegexp(
   1471         TypeError,
   1472         r'Parameter to MergeFrom\(\) must be instance of same class: expected '
   1473         r'.*TestMap got int\.'):
   1474       msg.MergeFrom(1)
   1475 
   1476   def testCopyFromBadType(self):
   1477     msg = map_unittest_pb2.TestMap()
   1478     with self.assertRaisesRegexp(
   1479         TypeError,
   1480         r'Parameter to [A-Za-z]*From\(\) must be instance of same class: '
   1481         r'expected .*TestMap got int\.'):
   1482       msg.CopyFrom(1)
   1483 
   1484   def testIntegerMapWithLongs(self):
   1485     msg = map_unittest_pb2.TestMap()
   1486     msg.map_int32_int32[long(-123)] = long(-456)
   1487     msg.map_int64_int64[long(-2**33)] = long(-2**34)
   1488     msg.map_uint32_uint32[long(123)] = long(456)
   1489     msg.map_uint64_uint64[long(2**33)] = long(2**34)
   1490 
   1491     serialized = msg.SerializeToString()
   1492     msg2 = map_unittest_pb2.TestMap()
   1493     msg2.ParseFromString(serialized)
   1494 
   1495     self.assertEqual(-456, msg2.map_int32_int32[-123])
   1496     self.assertEqual(-2**34, msg2.map_int64_int64[-2**33])
   1497     self.assertEqual(456, msg2.map_uint32_uint32[123])
   1498     self.assertEqual(2**34, msg2.map_uint64_uint64[2**33])
   1499 
   1500   def testMapAssignmentCausesPresence(self):
   1501     msg = map_unittest_pb2.TestMapSubmessage()
   1502     msg.test_map.map_int32_int32[123] = 456
   1503 
   1504     serialized = msg.SerializeToString()
   1505     msg2 = map_unittest_pb2.TestMapSubmessage()
   1506     msg2.ParseFromString(serialized)
   1507 
   1508     self.assertEqual(msg, msg2)
   1509 
   1510     # Now test that various mutations of the map properly invalidate the
   1511     # cached size of the submessage.
   1512     msg.test_map.map_int32_int32[888] = 999
   1513     serialized = msg.SerializeToString()
   1514     msg2.ParseFromString(serialized)
   1515     self.assertEqual(msg, msg2)
   1516 
   1517     msg.test_map.map_int32_int32.clear()
   1518     serialized = msg.SerializeToString()
   1519     msg2.ParseFromString(serialized)
   1520     self.assertEqual(msg, msg2)
   1521 
   1522   def testMapAssignmentCausesPresenceForSubmessages(self):
   1523     msg = map_unittest_pb2.TestMapSubmessage()
   1524     msg.test_map.map_int32_foreign_message[123].c = 5
   1525 
   1526     serialized = msg.SerializeToString()
   1527     msg2 = map_unittest_pb2.TestMapSubmessage()
   1528     msg2.ParseFromString(serialized)
   1529 
   1530     self.assertEqual(msg, msg2)
   1531 
   1532     # Now test that various mutations of the map properly invalidate the
   1533     # cached size of the submessage.
   1534     msg.test_map.map_int32_foreign_message[888].c = 7
   1535     serialized = msg.SerializeToString()
   1536     msg2.ParseFromString(serialized)
   1537     self.assertEqual(msg, msg2)
   1538 
   1539     msg.test_map.map_int32_foreign_message[888].MergeFrom(
   1540         msg.test_map.map_int32_foreign_message[123])
   1541     serialized = msg.SerializeToString()
   1542     msg2.ParseFromString(serialized)
   1543     self.assertEqual(msg, msg2)
   1544 
   1545     msg.test_map.map_int32_foreign_message.clear()
   1546     serialized = msg.SerializeToString()
   1547     msg2.ParseFromString(serialized)
   1548     self.assertEqual(msg, msg2)
   1549 
   1550   def testModifyMapWhileIterating(self):
   1551     msg = map_unittest_pb2.TestMap()
   1552 
   1553     string_string_iter = iter(msg.map_string_string)
   1554     int32_foreign_iter = iter(msg.map_int32_foreign_message)
   1555 
   1556     msg.map_string_string['abc'] = '123'
   1557     msg.map_int32_foreign_message[5].c = 5
   1558 
   1559     with self.assertRaises(RuntimeError):
   1560       for key in string_string_iter:
   1561         pass
   1562 
   1563     with self.assertRaises(RuntimeError):
   1564       for key in int32_foreign_iter:
   1565         pass
   1566 
   1567   def testSubmessageMap(self):
   1568     msg = map_unittest_pb2.TestMap()
   1569 
   1570     submsg = msg.map_int32_foreign_message[111]
   1571     self.assertIs(submsg, msg.map_int32_foreign_message[111])
   1572     self.assertIsInstance(submsg, unittest_pb2.ForeignMessage)
   1573 
   1574     submsg.c = 5
   1575 
   1576     serialized = msg.SerializeToString()
   1577     msg2 = map_unittest_pb2.TestMap()
   1578     msg2.ParseFromString(serialized)
   1579 
   1580     self.assertEqual(5, msg2.map_int32_foreign_message[111].c)
   1581 
   1582     # Doesn't allow direct submessage assignment.
   1583     with self.assertRaises(ValueError):
   1584       msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage()
   1585 
   1586   def testMapIteration(self):
   1587     msg = map_unittest_pb2.TestMap()
   1588 
   1589     for k, v in msg.map_int32_int32.items():
   1590       # Should not be reached.
   1591       self.assertTrue(False)
   1592 
   1593     msg.map_int32_int32[2] = 4
   1594     msg.map_int32_int32[3] = 6
   1595     msg.map_int32_int32[4] = 8
   1596     self.assertEqual(3, len(msg.map_int32_int32))
   1597 
   1598     matching_dict = {2: 4, 3: 6, 4: 8}
   1599     self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict)
   1600 
   1601   def testMapItems(self):
   1602     # Map items used to have strange behaviors when use c extension. Because
   1603     # [] may reorder the map and invalidate any exsting iterators.
   1604     # TODO(jieluo): Check if [] reordering the map is a bug or intended
   1605     # behavior.
   1606     msg = map_unittest_pb2.TestMap()
   1607     msg.map_string_string['local_init_op'] = ''
   1608     msg.map_string_string['trainable_variables'] = ''
   1609     msg.map_string_string['variables'] = ''
   1610     msg.map_string_string['init_op'] = ''
   1611     msg.map_string_string['summaries'] = ''
   1612     items1 = msg.map_string_string.items()
   1613     items2 = msg.map_string_string.items()
   1614     self.assertEqual(items1, items2)
   1615 
   1616   def testMapIterationClearMessage(self):
   1617     # Iterator needs to work even if message and map are deleted.
   1618     msg = map_unittest_pb2.TestMap()
   1619 
   1620     msg.map_int32_int32[2] = 4
   1621     msg.map_int32_int32[3] = 6
   1622     msg.map_int32_int32[4] = 8
   1623 
   1624     it = msg.map_int32_int32.items()
   1625     del msg
   1626 
   1627     matching_dict = {2: 4, 3: 6, 4: 8}
   1628     self.assertMapIterEquals(it, matching_dict)
   1629 
   1630   def testMapConstruction(self):
   1631     msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4})
   1632     self.assertEqual(2, msg.map_int32_int32[1])
   1633     self.assertEqual(4, msg.map_int32_int32[3])
   1634 
   1635     msg = map_unittest_pb2.TestMap(
   1636         map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)})
   1637     self.assertEqual(5, msg.map_int32_foreign_message[3].c)
   1638 
   1639   def testMapValidAfterFieldCleared(self):
   1640     # Map needs to work even if field is cleared.
   1641     # For the C++ implementation this tests the correctness of
   1642     # ScalarMapContainer::Release()
   1643     msg = map_unittest_pb2.TestMap()
   1644     int32_map = msg.map_int32_int32
   1645 
   1646     int32_map[2] = 4
   1647     int32_map[3] = 6
   1648     int32_map[4] = 8
   1649 
   1650     msg.ClearField('map_int32_int32')
   1651     self.assertEqual(b'', msg.SerializeToString())
   1652     matching_dict = {2: 4, 3: 6, 4: 8}
   1653     self.assertMapIterEquals(int32_map.items(), matching_dict)
   1654 
   1655   def testMessageMapValidAfterFieldCleared(self):
   1656     # Map needs to work even if field is cleared.
   1657     # For the C++ implementation this tests the correctness of
   1658     # ScalarMapContainer::Release()
   1659     msg = map_unittest_pb2.TestMap()
   1660     int32_foreign_message = msg.map_int32_foreign_message
   1661 
   1662     int32_foreign_message[2].c = 5
   1663 
   1664     msg.ClearField('map_int32_foreign_message')
   1665     self.assertEqual(b'', msg.SerializeToString())
   1666     self.assertTrue(2 in int32_foreign_message.keys())
   1667 
   1668   def testMapIterInvalidatedByClearField(self):
   1669     # Map iterator is invalidated when field is cleared.
   1670     # But this case does need to not crash the interpreter.
   1671     # For the C++ implementation this tests the correctness of
   1672     # ScalarMapContainer::Release()
   1673     msg = map_unittest_pb2.TestMap()
   1674 
   1675     it = iter(msg.map_int32_int32)
   1676 
   1677     msg.ClearField('map_int32_int32')
   1678     with self.assertRaises(RuntimeError):
   1679       for _ in it:
   1680         pass
   1681 
   1682     it = iter(msg.map_int32_foreign_message)
   1683     msg.ClearField('map_int32_foreign_message')
   1684     with self.assertRaises(RuntimeError):
   1685       for _ in it:
   1686         pass
   1687 
   1688   def testMapDelete(self):
   1689     msg = map_unittest_pb2.TestMap()
   1690 
   1691     self.assertEqual(0, len(msg.map_int32_int32))
   1692 
   1693     msg.map_int32_int32[4] = 6
   1694     self.assertEqual(1, len(msg.map_int32_int32))
   1695 
   1696     with self.assertRaises(KeyError):
   1697       del msg.map_int32_int32[88]
   1698 
   1699     del msg.map_int32_int32[4]
   1700     self.assertEqual(0, len(msg.map_int32_int32))
   1701 
   1702   def testMapsAreMapping(self):
   1703     msg = map_unittest_pb2.TestMap()
   1704     self.assertIsInstance(msg.map_int32_int32, collections.Mapping)
   1705     self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping)
   1706     self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping)
   1707     self.assertIsInstance(msg.map_int32_foreign_message,
   1708                           collections.MutableMapping)
   1709 
   1710   def testMapFindInitializationErrorsSmokeTest(self):
   1711     msg = map_unittest_pb2.TestMap()
   1712     msg.map_string_string['abc'] = '123'
   1713     msg.map_int32_int32[35] = 64
   1714     msg.map_string_foreign_message['foo'].c = 5
   1715     self.assertEqual(0, len(msg.FindInitializationErrors()))
   1716 
   1717 
   1718 
   1719 class ValidTypeNamesTest(unittest.TestCase):
   1720 
   1721   def assertImportFromName(self, msg, base_name):
   1722     # Parse <type 'module.class_name'> to extra 'some.name' as a string.
   1723     tp_name = str(type(msg)).split("'")[1]
   1724     valid_names = ('Repeated%sContainer' % base_name,
   1725                    'Repeated%sFieldContainer' % base_name)
   1726     self.assertTrue(any(tp_name.endswith(v) for v in valid_names),
   1727                     '%r does end with any of %r' % (tp_name, valid_names))
   1728 
   1729     parts = tp_name.split('.')
   1730     class_name = parts[-1]
   1731     module_name = '.'.join(parts[:-1])
   1732     __import__(module_name, fromlist=[class_name])
   1733 
   1734   def testTypeNamesCanBeImported(self):
   1735     # If import doesn't work, pickling won't work either.
   1736     pb = unittest_pb2.TestAllTypes()
   1737     self.assertImportFromName(pb.repeated_int32, 'Scalar')
   1738     self.assertImportFromName(pb.repeated_nested_message, 'Composite')
   1739 
   1740 class PackedFieldTest(unittest.TestCase):
   1741 
   1742   def setMessage(self, message):
   1743     message.repeated_int32.append(1)
   1744     message.repeated_int64.append(1)
   1745     message.repeated_uint32.append(1)
   1746     message.repeated_uint64.append(1)
   1747     message.repeated_sint32.append(1)
   1748     message.repeated_sint64.append(1)
   1749     message.repeated_fixed32.append(1)
   1750     message.repeated_fixed64.append(1)
   1751     message.repeated_sfixed32.append(1)
   1752     message.repeated_sfixed64.append(1)
   1753     message.repeated_float.append(1.0)
   1754     message.repeated_double.append(1.0)
   1755     message.repeated_bool.append(True)
   1756     message.repeated_nested_enum.append(1)
   1757 
   1758   def testPackedFields(self):
   1759     message = packed_field_test_pb2.TestPackedTypes()
   1760     self.setMessage(message)
   1761     golden_data = (b'\x0A\x01\x01'
   1762                    b'\x12\x01\x01'
   1763                    b'\x1A\x01\x01'
   1764                    b'\x22\x01\x01'
   1765                    b'\x2A\x01\x02'
   1766                    b'\x32\x01\x02'
   1767                    b'\x3A\x04\x01\x00\x00\x00'
   1768                    b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00'
   1769                    b'\x4A\x04\x01\x00\x00\x00'
   1770                    b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00'
   1771                    b'\x5A\x04\x00\x00\x80\x3f'
   1772                    b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f'
   1773                    b'\x6A\x01\x01'
   1774                    b'\x72\x01\x01')
   1775     self.assertEqual(golden_data, message.SerializeToString())
   1776 
   1777   def testUnpackedFields(self):
   1778     message = packed_field_test_pb2.TestUnpackedTypes()
   1779     self.setMessage(message)
   1780     golden_data = (b'\x08\x01'
   1781                    b'\x10\x01'
   1782                    b'\x18\x01'
   1783                    b'\x20\x01'
   1784                    b'\x28\x02'
   1785                    b'\x30\x02'
   1786                    b'\x3D\x01\x00\x00\x00'
   1787                    b'\x41\x01\x00\x00\x00\x00\x00\x00\x00'
   1788                    b'\x4D\x01\x00\x00\x00'
   1789                    b'\x51\x01\x00\x00\x00\x00\x00\x00\x00'
   1790                    b'\x5D\x00\x00\x80\x3f'
   1791                    b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f'
   1792                    b'\x68\x01'
   1793                    b'\x70\x01')
   1794     self.assertEqual(golden_data, message.SerializeToString())
   1795 
   1796 
   1797 @unittest.skipIf(api_implementation.Type() != 'cpp',
   1798                  'explicit tests of the C++ implementation')
   1799 class OversizeProtosTest(unittest.TestCase):
   1800 
   1801   def setUp(self):
   1802     self.file_desc = """
   1803       name: "f/f.msg2"
   1804       package: "f"
   1805       message_type {
   1806         name: "msg1"
   1807         field {
   1808           name: "payload"
   1809           number: 1
   1810           label: LABEL_OPTIONAL
   1811           type: TYPE_STRING
   1812         }
   1813       }
   1814       message_type {
   1815         name: "msg2"
   1816         field {
   1817           name: "field"
   1818           number: 1
   1819           label: LABEL_OPTIONAL
   1820           type: TYPE_MESSAGE
   1821           type_name: "msg1"
   1822         }
   1823       }
   1824     """
   1825     pool = descriptor_pool.DescriptorPool()
   1826     desc = descriptor_pb2.FileDescriptorProto()
   1827     text_format.Parse(self.file_desc, desc)
   1828     pool.Add(desc)
   1829     self.proto_cls = message_factory.MessageFactory(pool).GetPrototype(
   1830         pool.FindMessageTypeByName('f.msg2'))
   1831     self.p = self.proto_cls()
   1832     self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1)
   1833     self.p_serialized = self.p.SerializeToString()
   1834 
   1835   def testAssertOversizeProto(self):
   1836     from google.protobuf.pyext._message import SetAllowOversizeProtos
   1837     SetAllowOversizeProtos(False)
   1838     q = self.proto_cls()
   1839     try:
   1840       q.ParseFromString(self.p_serialized)
   1841     except message.DecodeError as e:
   1842       self.assertEqual(str(e), 'Error parsing message')
   1843 
   1844   def testSucceedOversizeProto(self):
   1845     from google.protobuf.pyext._message import SetAllowOversizeProtos
   1846     SetAllowOversizeProtos(True)
   1847     q = self.proto_cls()
   1848     q.ParseFromString(self.p_serialized)
   1849     self.assertEqual(self.p.field.payload, q.field.payload)
   1850 
   1851 if __name__ == '__main__':
   1852   unittest.main()
   1853