Home | History | Annotate | Download | only in internal
      1 #! /usr/bin/env python
      2 # -*- coding: utf-8 -*-
      3 #
      4 # Protocol Buffers - Google's data interchange format
      5 # Copyright 2008 Google Inc.  All rights reserved.
      6 # https://developers.google.com/protocol-buffers/
      7 #
      8 # Redistribution and use in source and binary forms, with or without
      9 # modification, are permitted provided that the following conditions are
     10 # met:
     11 #
     12 #     * Redistributions of source code must retain the above copyright
     13 # notice, this list of conditions and the following disclaimer.
     14 #     * Redistributions in binary form must reproduce the above
     15 # copyright notice, this list of conditions and the following disclaimer
     16 # in the documentation and/or other materials provided with the
     17 # distribution.
     18 #     * Neither the name of Google Inc. nor the names of its
     19 # contributors may be used to endorse or promote products derived from
     20 # this software without specific prior written permission.
     21 #
     22 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     23 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     24 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     25 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     26 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     27 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     28 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     29 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     30 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     31 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     32 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     33 
     34 """Test for preservation of unknown fields in the pure Python implementation."""
     35 
     36 __author__ = 'bohdank (at] google.com (Bohdan Koval)'
     37 
     38 try:
     39   import unittest2 as unittest  #PY26
     40 except ImportError:
     41   import unittest
     42 from google.protobuf import unittest_mset_pb2
     43 from google.protobuf import unittest_pb2
     44 from google.protobuf import unittest_proto3_arena_pb2
     45 from google.protobuf.internal import api_implementation
     46 from google.protobuf.internal import encoder
     47 from google.protobuf.internal import message_set_extensions_pb2
     48 from google.protobuf.internal import missing_enum_values_pb2
     49 from google.protobuf.internal import test_util
     50 from google.protobuf.internal import type_checkers
     51 
     52 
     53 def SkipIfCppImplementation(func):
     54   return unittest.skipIf(
     55       api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
     56       'C++ implementation does not expose unknown fields to Python')(func)
     57 
     58 
     59 class UnknownFieldsTest(unittest.TestCase):
     60 
     61   def setUp(self):
     62     self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
     63     self.all_fields = unittest_pb2.TestAllTypes()
     64     test_util.SetAllFields(self.all_fields)
     65     self.all_fields_data = self.all_fields.SerializeToString()
     66     self.empty_message = unittest_pb2.TestEmptyMessage()
     67     self.empty_message.ParseFromString(self.all_fields_data)
     68 
     69   def testSerialize(self):
     70     data = self.empty_message.SerializeToString()
     71 
     72     # Don't use assertEqual because we don't want to dump raw binary data to
     73     # stdout.
     74     self.assertTrue(data == self.all_fields_data)
     75 
     76   def testSerializeProto3(self):
     77     # Verify that proto3 doesn't preserve unknown fields.
     78     message = unittest_proto3_arena_pb2.TestEmptyMessage()
     79     message.ParseFromString(self.all_fields_data)
     80     self.assertEqual(0, len(message.SerializeToString()))
     81 
     82   def testByteSize(self):
     83     self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
     84 
     85   def testListFields(self):
     86     # Make sure ListFields doesn't return unknown fields.
     87     self.assertEqual(0, len(self.empty_message.ListFields()))
     88 
     89   def testSerializeMessageSetWireFormatUnknownExtension(self):
     90     # Create a message using the message set wire format with an unknown
     91     # message.
     92     raw = unittest_mset_pb2.RawMessageSet()
     93 
     94     # Add an unknown extension.
     95     item = raw.item.add()
     96     item.type_id = 98418603
     97     message1 = message_set_extensions_pb2.TestMessageSetExtension1()
     98     message1.i = 12345
     99     item.message = message1.SerializeToString()
    100 
    101     serialized = raw.SerializeToString()
    102 
    103     # Parse message using the message set wire format.
    104     proto = message_set_extensions_pb2.TestMessageSet()
    105     proto.MergeFromString(serialized)
    106 
    107     # Verify that the unknown extension is serialized unchanged
    108     reserialized = proto.SerializeToString()
    109     new_raw = unittest_mset_pb2.RawMessageSet()
    110     new_raw.MergeFromString(reserialized)
    111     self.assertEqual(raw, new_raw)
    112 
    113   def testEquals(self):
    114     message = unittest_pb2.TestEmptyMessage()
    115     message.ParseFromString(self.all_fields_data)
    116     self.assertEqual(self.empty_message, message)
    117 
    118     self.all_fields.ClearField('optional_string')
    119     message.ParseFromString(self.all_fields.SerializeToString())
    120     self.assertNotEqual(self.empty_message, message)
    121 
    122   def testDiscardUnknownFields(self):
    123     self.empty_message.DiscardUnknownFields()
    124     self.assertEqual(b'', self.empty_message.SerializeToString())
    125     # Test message field and repeated message field.
    126     message = unittest_pb2.TestAllTypes()
    127     other_message = unittest_pb2.TestAllTypes()
    128     other_message.optional_string = 'discard'
    129     message.optional_nested_message.ParseFromString(
    130         other_message.SerializeToString())
    131     message.repeated_nested_message.add().ParseFromString(
    132         other_message.SerializeToString())
    133     self.assertNotEqual(
    134         b'', message.optional_nested_message.SerializeToString())
    135     self.assertNotEqual(
    136         b'', message.repeated_nested_message[0].SerializeToString())
    137     message.DiscardUnknownFields()
    138     self.assertEqual(b'', message.optional_nested_message.SerializeToString())
    139     self.assertEqual(
    140         b'', message.repeated_nested_message[0].SerializeToString())
    141 
    142 
    143 class UnknownFieldsAccessorsTest(unittest.TestCase):
    144 
    145   def setUp(self):
    146     self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
    147     self.all_fields = unittest_pb2.TestAllTypes()
    148     test_util.SetAllFields(self.all_fields)
    149     self.all_fields_data = self.all_fields.SerializeToString()
    150     self.empty_message = unittest_pb2.TestEmptyMessage()
    151     self.empty_message.ParseFromString(self.all_fields_data)
    152     if api_implementation.Type() != 'cpp':
    153       # _unknown_fields is an implementation detail.
    154       self.unknown_fields = self.empty_message._unknown_fields
    155 
    156   # All the tests that use GetField() check an implementation detail of the
    157   # Python implementation, which stores unknown fields as serialized strings.
    158   # These tests are skipped by the C++ implementation: it's enough to check that
    159   # the message is correctly serialized.
    160 
    161   def GetField(self, name):
    162     field_descriptor = self.descriptor.fields_by_name[name]
    163     wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
    164     field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
    165     result_dict = {}
    166     for tag_bytes, value in self.unknown_fields:
    167       if tag_bytes == field_tag:
    168         decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
    169         decoder(value, 0, len(value), self.all_fields, result_dict)
    170     return result_dict[field_descriptor]
    171 
    172   @SkipIfCppImplementation
    173   def testEnum(self):
    174     value = self.GetField('optional_nested_enum')
    175     self.assertEqual(self.all_fields.optional_nested_enum, value)
    176 
    177   @SkipIfCppImplementation
    178   def testRepeatedEnum(self):
    179     value = self.GetField('repeated_nested_enum')
    180     self.assertEqual(self.all_fields.repeated_nested_enum, value)
    181 
    182   @SkipIfCppImplementation
    183   def testVarint(self):
    184     value = self.GetField('optional_int32')
    185     self.assertEqual(self.all_fields.optional_int32, value)
    186 
    187   @SkipIfCppImplementation
    188   def testFixed32(self):
    189     value = self.GetField('optional_fixed32')
    190     self.assertEqual(self.all_fields.optional_fixed32, value)
    191 
    192   @SkipIfCppImplementation
    193   def testFixed64(self):
    194     value = self.GetField('optional_fixed64')
    195     self.assertEqual(self.all_fields.optional_fixed64, value)
    196 
    197   @SkipIfCppImplementation
    198   def testLengthDelimited(self):
    199     value = self.GetField('optional_string')
    200     self.assertEqual(self.all_fields.optional_string, value)
    201 
    202   @SkipIfCppImplementation
    203   def testGroup(self):
    204     value = self.GetField('optionalgroup')
    205     self.assertEqual(self.all_fields.optionalgroup, value)
    206 
    207   def testCopyFrom(self):
    208     message = unittest_pb2.TestEmptyMessage()
    209     message.CopyFrom(self.empty_message)
    210     self.assertEqual(message.SerializeToString(), self.all_fields_data)
    211 
    212   def testMergeFrom(self):
    213     message = unittest_pb2.TestAllTypes()
    214     message.optional_int32 = 1
    215     message.optional_uint32 = 2
    216     source = unittest_pb2.TestEmptyMessage()
    217     source.ParseFromString(message.SerializeToString())
    218 
    219     message.ClearField('optional_int32')
    220     message.optional_int64 = 3
    221     message.optional_uint32 = 4
    222     destination = unittest_pb2.TestEmptyMessage()
    223     destination.ParseFromString(message.SerializeToString())
    224 
    225     destination.MergeFrom(source)
    226     # Check that the fields where correctly merged, even stored in the unknown
    227     # fields set.
    228     message.ParseFromString(destination.SerializeToString())
    229     self.assertEqual(message.optional_int32, 1)
    230     self.assertEqual(message.optional_uint32, 2)
    231     self.assertEqual(message.optional_int64, 3)
    232 
    233   def testClear(self):
    234     self.empty_message.Clear()
    235     # All cleared, even unknown fields.
    236     self.assertEqual(self.empty_message.SerializeToString(), b'')
    237 
    238   def testUnknownExtensions(self):
    239     message = unittest_pb2.TestEmptyMessageWithExtensions()
    240     message.ParseFromString(self.all_fields_data)
    241     self.assertEqual(message.SerializeToString(), self.all_fields_data)
    242 
    243 
    244 class UnknownEnumValuesTest(unittest.TestCase):
    245 
    246   def setUp(self):
    247     self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR
    248 
    249     self.message = missing_enum_values_pb2.TestEnumValues()
    250     self.message.optional_nested_enum = (
    251       missing_enum_values_pb2.TestEnumValues.ZERO)
    252     self.message.repeated_nested_enum.extend([
    253       missing_enum_values_pb2.TestEnumValues.ZERO,
    254       missing_enum_values_pb2.TestEnumValues.ONE,
    255       ])
    256     self.message.packed_nested_enum.extend([
    257       missing_enum_values_pb2.TestEnumValues.ZERO,
    258       missing_enum_values_pb2.TestEnumValues.ONE,
    259       ])
    260     self.message_data = self.message.SerializeToString()
    261     self.missing_message = missing_enum_values_pb2.TestMissingEnumValues()
    262     self.missing_message.ParseFromString(self.message_data)
    263     if api_implementation.Type() != 'cpp':
    264       # _unknown_fields is an implementation detail.
    265       self.unknown_fields = self.missing_message._unknown_fields
    266 
    267   # All the tests that use GetField() check an implementation detail of the
    268   # Python implementation, which stores unknown fields as serialized strings.
    269   # These tests are skipped by the C++ implementation: it's enough to check that
    270   # the message is correctly serialized.
    271 
    272   def GetField(self, name):
    273     field_descriptor = self.descriptor.fields_by_name[name]
    274     wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
    275     field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
    276     result_dict = {}
    277     for tag_bytes, value in self.unknown_fields:
    278       if tag_bytes == field_tag:
    279         decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
    280           tag_bytes][0]
    281         decoder(value, 0, len(value), self.message, result_dict)
    282     return result_dict[field_descriptor]
    283 
    284   def testUnknownParseMismatchEnumValue(self):
    285     just_string = missing_enum_values_pb2.JustString()
    286     just_string.dummy = 'blah'
    287 
    288     missing = missing_enum_values_pb2.TestEnumValues()
    289     # The parse is invalid, storing the string proto into the set of
    290     # unknown fields.
    291     missing.ParseFromString(just_string.SerializeToString())
    292 
    293     # Fetching the enum field shouldn't crash, instead returning the
    294     # default value.
    295     self.assertEqual(missing.optional_nested_enum, 0)
    296 
    297   @SkipIfCppImplementation
    298   def testUnknownEnumValue(self):
    299     self.assertFalse(self.missing_message.HasField('optional_nested_enum'))
    300     value = self.GetField('optional_nested_enum')
    301     self.assertEqual(self.message.optional_nested_enum, value)
    302 
    303   @SkipIfCppImplementation
    304   def testUnknownRepeatedEnumValue(self):
    305     value = self.GetField('repeated_nested_enum')
    306     self.assertEqual(self.message.repeated_nested_enum, value)
    307 
    308   @SkipIfCppImplementation
    309   def testUnknownPackedEnumValue(self):
    310     value = self.GetField('packed_nested_enum')
    311     self.assertEqual(self.message.packed_nested_enum, value)
    312 
    313   def testRoundTrip(self):
    314     new_message = missing_enum_values_pb2.TestEnumValues()
    315     new_message.ParseFromString(self.missing_message.SerializeToString())
    316     self.assertEqual(self.message, new_message)
    317 
    318 
    319 if __name__ == '__main__':
    320   unittest.main()
    321