Home | History | Annotate | Download | only in internal
      1 # Protocol Buffers - Google's data interchange format
      2 # Copyright 2008 Google Inc.  All rights reserved.
      3 # https://developers.google.com/protocol-buffers/
      4 #
      5 # Redistribution and use in source and binary forms, with or without
      6 # modification, are permitted provided that the following conditions are
      7 # met:
      8 #
      9 #     * Redistributions of source code must retain the above copyright
     10 # notice, this list of conditions and the following disclaimer.
     11 #     * Redistributions in binary form must reproduce the above
     12 # copyright notice, this list of conditions and the following disclaimer
     13 # in the documentation and/or other materials provided with the
     14 # distribution.
     15 #     * Neither the name of Google Inc. nor the names of its
     16 # contributors may be used to endorse or promote products derived from
     17 # this software without specific prior written permission.
     18 #
     19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     30 
     31 """Code for decoding protocol buffer primitives.
     32 
     33 This code is very similar to encoder.py -- read the docs for that module first.
     34 
     35 A "decoder" is a function with the signature:
     36   Decode(buffer, pos, end, message, field_dict)
     37 The arguments are:
     38   buffer:     The string containing the encoded message.
     39   pos:        The current position in the string.
     40   end:        The position in the string where the current message ends.  May be
     41               less than len(buffer) if we're reading a sub-message.
     42   message:    The message object into which we're parsing.
     43   field_dict: message._fields (avoids a hashtable lookup).
     44 The decoder reads the field and stores it into field_dict, returning the new
     45 buffer position.  A decoder for a repeated field may proactively decode all of
     46 the elements of that field, if they appear consecutively.
     47 
     48 Note that decoders may throw any of the following:
     49   IndexError:  Indicates a truncated message.
     50   struct.error:  Unpacking of a fixed-width field failed.
     51   message.DecodeError:  Other errors.
     52 
     53 Decoders are expected to raise an exception if they are called with pos > end.
     54 This allows callers to be lax about bounds checking:  it's fineto read past
     55 "end" as long as you are sure that someone else will notice and throw an
     56 exception later on.
     57 
     58 Something up the call stack is expected to catch IndexError and struct.error
     59 and convert them to message.DecodeError.
     60 
     61 Decoders are constructed using decoder constructors with the signature:
     62   MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
     63 The arguments are:
     64   field_number:  The field number of the field we want to decode.
     65   is_repeated:   Is the field a repeated field? (bool)
     66   is_packed:     Is the field a packed field? (bool)
     67   key:           The key to use when looking up the field within field_dict.
     68                  (This is actually the FieldDescriptor but nothing in this
     69                  file should depend on that.)
     70   new_default:   A function which takes a message object as a parameter and
     71                  returns a new instance of the default value for this field.
     72                  (This is called for repeated fields and sub-messages, when an
     73                  instance does not already exist.)
     74 
     75 As with encoders, we define a decoder constructor for every type of field.
     76 Then, for every field of every message class we construct an actual decoder.
     77 That decoder goes into a dict indexed by tag, so when we decode a message
     78 we repeatedly read a tag, look up the corresponding decoder, and invoke it.
     79 """
     80 
     81 __author__ = 'kenton (at] google.com (Kenton Varda)'
     82 
     83 import struct
     84 
     85 import six
     86 
     87 if six.PY3:
     88   long = int
     89 
     90 from google.protobuf.internal import encoder
     91 from google.protobuf.internal import wire_format
     92 from google.protobuf import message
     93 
     94 
     95 # This will overflow and thus become IEEE-754 "infinity".  We would use
     96 # "float('inf')" but it doesn't work on Windows pre-Python-2.6.
     97 _POS_INF = 1e10000
     98 _NEG_INF = -_POS_INF
     99 _NAN = _POS_INF * 0
    100 
    101 
    102 # This is not for optimization, but rather to avoid conflicts with local
    103 # variables named "message".
    104 _DecodeError = message.DecodeError
    105 
    106 
    107 def _VarintDecoder(mask, result_type):
    108   """Return an encoder for a basic varint value (does not include tag).
    109 
    110   Decoded values will be bitwise-anded with the given mask before being
    111   returned, e.g. to limit them to 32 bits.  The returned decoder does not
    112   take the usual "end" parameter -- the caller is expected to do bounds checking
    113   after the fact (often the caller can defer such checking until later).  The
    114   decoder returns a (value, new_pos) pair.
    115   """
    116 
    117   def DecodeVarint(buffer, pos):
    118     result = 0
    119     shift = 0
    120     while 1:
    121       b = six.indexbytes(buffer, pos)
    122       result |= ((b & 0x7f) << shift)
    123       pos += 1
    124       if not (b & 0x80):
    125         result &= mask
    126         result = result_type(result)
    127         return (result, pos)
    128       shift += 7
    129       if shift >= 64:
    130         raise _DecodeError('Too many bytes when decoding varint.')
    131   return DecodeVarint
    132 
    133 
    134 def _SignedVarintDecoder(mask, result_type):
    135   """Like _VarintDecoder() but decodes signed values."""
    136 
    137   def DecodeVarint(buffer, pos):
    138     result = 0
    139     shift = 0
    140     while 1:
    141       b = six.indexbytes(buffer, pos)
    142       result |= ((b & 0x7f) << shift)
    143       pos += 1
    144       if not (b & 0x80):
    145         if result > 0x7fffffffffffffff:
    146           result -= (1 << 64)
    147           result |= ~mask
    148         else:
    149           result &= mask
    150         result = result_type(result)
    151         return (result, pos)
    152       shift += 7
    153       if shift >= 64:
    154         raise _DecodeError('Too many bytes when decoding varint.')
    155   return DecodeVarint
    156 
    157 # We force 32-bit values to int and 64-bit values to long to make
    158 # alternate implementations where the distinction is more significant
    159 # (e.g. the C++ implementation) simpler.
    160 
    161 _DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
    162 _DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
    163 
    164 # Use these versions for values which must be limited to 32 bits.
    165 _DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
    166 _DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
    167 
    168 
    169 def ReadTag(buffer, pos):
    170   """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
    171 
    172   We return the raw bytes of the tag rather than decoding them.  The raw
    173   bytes can then be used to look up the proper decoder.  This effectively allows
    174   us to trade some work that would be done in pure-python (decoding a varint)
    175   for work that is done in C (searching for a byte string in a hash table).
    176   In a low-level language it would be much cheaper to decode the varint and
    177   use that, but not in Python.
    178   """
    179 
    180   start = pos
    181   while six.indexbytes(buffer, pos) & 0x80:
    182     pos += 1
    183   pos += 1
    184   return (buffer[start:pos], pos)
    185 
    186 
    187 # --------------------------------------------------------------------
    188 
    189 
    190 def _SimpleDecoder(wire_type, decode_value):
    191   """Return a constructor for a decoder for fields of a particular type.
    192 
    193   Args:
    194       wire_type:  The field's wire type.
    195       decode_value:  A function which decodes an individual value, e.g.
    196         _DecodeVarint()
    197   """
    198 
    199   def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
    200     if is_packed:
    201       local_DecodeVarint = _DecodeVarint
    202       def DecodePackedField(buffer, pos, end, message, field_dict):
    203         value = field_dict.get(key)
    204         if value is None:
    205           value = field_dict.setdefault(key, new_default(message))
    206         (endpoint, pos) = local_DecodeVarint(buffer, pos)
    207         endpoint += pos
    208         if endpoint > end:
    209           raise _DecodeError('Truncated message.')
    210         while pos < endpoint:
    211           (element, pos) = decode_value(buffer, pos)
    212           value.append(element)
    213         if pos > endpoint:
    214           del value[-1]   # Discard corrupt value.
    215           raise _DecodeError('Packed element was truncated.')
    216         return pos
    217       return DecodePackedField
    218     elif is_repeated:
    219       tag_bytes = encoder.TagBytes(field_number, wire_type)
    220       tag_len = len(tag_bytes)
    221       def DecodeRepeatedField(buffer, pos, end, message, field_dict):
    222         value = field_dict.get(key)
    223         if value is None:
    224           value = field_dict.setdefault(key, new_default(message))
    225         while 1:
    226           (element, new_pos) = decode_value(buffer, pos)
    227           value.append(element)
    228           # Predict that the next tag is another copy of the same repeated
    229           # field.
    230           pos = new_pos + tag_len
    231           if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
    232             # Prediction failed.  Return.
    233             if new_pos > end:
    234               raise _DecodeError('Truncated message.')
    235             return new_pos
    236       return DecodeRepeatedField
    237     else:
    238       def DecodeField(buffer, pos, end, message, field_dict):
    239         (field_dict[key], pos) = decode_value(buffer, pos)
    240         if pos > end:
    241           del field_dict[key]  # Discard corrupt value.
    242           raise _DecodeError('Truncated message.')
    243         return pos
    244       return DecodeField
    245 
    246   return SpecificDecoder
    247 
    248 
    249 def _ModifiedDecoder(wire_type, decode_value, modify_value):
    250   """Like SimpleDecoder but additionally invokes modify_value on every value
    251   before storing it.  Usually modify_value is ZigZagDecode.
    252   """
    253 
    254   # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
    255   # not enough to make a significant difference.
    256 
    257   def InnerDecode(buffer, pos):
    258     (result, new_pos) = decode_value(buffer, pos)
    259     return (modify_value(result), new_pos)
    260   return _SimpleDecoder(wire_type, InnerDecode)
    261 
    262 
    263 def _StructPackDecoder(wire_type, format):
    264   """Return a constructor for a decoder for a fixed-width field.
    265 
    266   Args:
    267       wire_type:  The field's wire type.
    268       format:  The format string to pass to struct.unpack().
    269   """
    270 
    271   value_size = struct.calcsize(format)
    272   local_unpack = struct.unpack
    273 
    274   # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
    275   # not enough to make a significant difference.
    276 
    277   # Note that we expect someone up-stack to catch struct.error and convert
    278   # it to _DecodeError -- this way we don't have to set up exception-
    279   # handling blocks every time we parse one value.
    280 
    281   def InnerDecode(buffer, pos):
    282     new_pos = pos + value_size
    283     result = local_unpack(format, buffer[pos:new_pos])[0]
    284     return (result, new_pos)
    285   return _SimpleDecoder(wire_type, InnerDecode)
    286 
    287 
    288 def _FloatDecoder():
    289   """Returns a decoder for a float field.
    290 
    291   This code works around a bug in struct.unpack for non-finite 32-bit
    292   floating-point values.
    293   """
    294 
    295   local_unpack = struct.unpack
    296 
    297   def InnerDecode(buffer, pos):
    298     # We expect a 32-bit value in little-endian byte order.  Bit 1 is the sign
    299     # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
    300     new_pos = pos + 4
    301     float_bytes = buffer[pos:new_pos]
    302 
    303     # If this value has all its exponent bits set, then it's non-finite.
    304     # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
    305     # To avoid that, we parse it specially.
    306     if (float_bytes[3:4] in b'\x7F\xFF' and float_bytes[2:3] >= b'\x80'):
    307       # If at least one significand bit is set...
    308       if float_bytes[0:3] != b'\x00\x00\x80':
    309         return (_NAN, new_pos)
    310       # If sign bit is set...
    311       if float_bytes[3:4] == b'\xFF':
    312         return (_NEG_INF, new_pos)
    313       return (_POS_INF, new_pos)
    314 
    315     # Note that we expect someone up-stack to catch struct.error and convert
    316     # it to _DecodeError -- this way we don't have to set up exception-
    317     # handling blocks every time we parse one value.
    318     result = local_unpack('<f', float_bytes)[0]
    319     return (result, new_pos)
    320   return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
    321 
    322 
    323 def _DoubleDecoder():
    324   """Returns a decoder for a double field.
    325 
    326   This code works around a bug in struct.unpack for not-a-number.
    327   """
    328 
    329   local_unpack = struct.unpack
    330 
    331   def InnerDecode(buffer, pos):
    332     # We expect a 64-bit value in little-endian byte order.  Bit 1 is the sign
    333     # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
    334     new_pos = pos + 8
    335     double_bytes = buffer[pos:new_pos]
    336 
    337     # If this value has all its exponent bits set and at least one significand
    338     # bit set, it's not a number.  In Python 2.4, struct.unpack will treat it
    339     # as inf or -inf.  To avoid that, we treat it specially.
    340     if ((double_bytes[7:8] in b'\x7F\xFF')
    341         and (double_bytes[6:7] >= b'\xF0')
    342         and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
    343       return (_NAN, new_pos)
    344 
    345     # Note that we expect someone up-stack to catch struct.error and convert
    346     # it to _DecodeError -- this way we don't have to set up exception-
    347     # handling blocks every time we parse one value.
    348     result = local_unpack('<d', double_bytes)[0]
    349     return (result, new_pos)
    350   return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
    351 
    352 
    353 def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
    354   enum_type = key.enum_type
    355   if is_packed:
    356     local_DecodeVarint = _DecodeVarint
    357     def DecodePackedField(buffer, pos, end, message, field_dict):
    358       value = field_dict.get(key)
    359       if value is None:
    360         value = field_dict.setdefault(key, new_default(message))
    361       (endpoint, pos) = local_DecodeVarint(buffer, pos)
    362       endpoint += pos
    363       if endpoint > end:
    364         raise _DecodeError('Truncated message.')
    365       while pos < endpoint:
    366         value_start_pos = pos
    367         (element, pos) = _DecodeSignedVarint32(buffer, pos)
    368         if element in enum_type.values_by_number:
    369           value.append(element)
    370         else:
    371           if not message._unknown_fields:
    372             message._unknown_fields = []
    373           tag_bytes = encoder.TagBytes(field_number,
    374                                        wire_format.WIRETYPE_VARINT)
    375           message._unknown_fields.append(
    376               (tag_bytes, buffer[value_start_pos:pos]))
    377       if pos > endpoint:
    378         if element in enum_type.values_by_number:
    379           del value[-1]   # Discard corrupt value.
    380         else:
    381           del message._unknown_fields[-1]
    382         raise _DecodeError('Packed element was truncated.')
    383       return pos
    384     return DecodePackedField
    385   elif is_repeated:
    386     tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
    387     tag_len = len(tag_bytes)
    388     def DecodeRepeatedField(buffer, pos, end, message, field_dict):
    389       value = field_dict.get(key)
    390       if value is None:
    391         value = field_dict.setdefault(key, new_default(message))
    392       while 1:
    393         (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
    394         if element in enum_type.values_by_number:
    395           value.append(element)
    396         else:
    397           if not message._unknown_fields:
    398             message._unknown_fields = []
    399           message._unknown_fields.append(
    400               (tag_bytes, buffer[pos:new_pos]))
    401         # Predict that the next tag is another copy of the same repeated
    402         # field.
    403         pos = new_pos + tag_len
    404         if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
    405           # Prediction failed.  Return.
    406           if new_pos > end:
    407             raise _DecodeError('Truncated message.')
    408           return new_pos
    409     return DecodeRepeatedField
    410   else:
    411     def DecodeField(buffer, pos, end, message, field_dict):
    412       value_start_pos = pos
    413       (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
    414       if pos > end:
    415         raise _DecodeError('Truncated message.')
    416       if enum_value in enum_type.values_by_number:
    417         field_dict[key] = enum_value
    418       else:
    419         if not message._unknown_fields:
    420           message._unknown_fields = []
    421         tag_bytes = encoder.TagBytes(field_number,
    422                                      wire_format.WIRETYPE_VARINT)
    423         message._unknown_fields.append(
    424           (tag_bytes, buffer[value_start_pos:pos]))
    425       return pos
    426     return DecodeField
    427 
    428 
    429 # --------------------------------------------------------------------
    430 
    431 
    432 Int32Decoder = _SimpleDecoder(
    433     wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
    434 
    435 Int64Decoder = _SimpleDecoder(
    436     wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
    437 
    438 UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
    439 UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
    440 
    441 SInt32Decoder = _ModifiedDecoder(
    442     wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
    443 SInt64Decoder = _ModifiedDecoder(
    444     wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
    445 
    446 # Note that Python conveniently guarantees that when using the '<' prefix on
    447 # formats, they will also have the same size across all platforms (as opposed
    448 # to without the prefix, where their sizes depend on the C compiler's basic
    449 # type sizes).
    450 Fixed32Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
    451 Fixed64Decoder  = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
    452 SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
    453 SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
    454 FloatDecoder = _FloatDecoder()
    455 DoubleDecoder = _DoubleDecoder()
    456 
    457 BoolDecoder = _ModifiedDecoder(
    458     wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
    459 
    460 
    461 def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
    462   """Returns a decoder for a string field."""
    463 
    464   local_DecodeVarint = _DecodeVarint
    465   local_unicode = six.text_type
    466 
    467   def _ConvertToUnicode(byte_str):
    468     try:
    469       return local_unicode(byte_str, 'utf-8')
    470     except UnicodeDecodeError as e:
    471       # add more information to the error message and re-raise it.
    472       e.reason = '%s in field: %s' % (e, key.full_name)
    473       raise
    474 
    475   assert not is_packed
    476   if is_repeated:
    477     tag_bytes = encoder.TagBytes(field_number,
    478                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
    479     tag_len = len(tag_bytes)
    480     def DecodeRepeatedField(buffer, pos, end, message, field_dict):
    481       value = field_dict.get(key)
    482       if value is None:
    483         value = field_dict.setdefault(key, new_default(message))
    484       while 1:
    485         (size, pos) = local_DecodeVarint(buffer, pos)
    486         new_pos = pos + size
    487         if new_pos > end:
    488           raise _DecodeError('Truncated string.')
    489         value.append(_ConvertToUnicode(buffer[pos:new_pos]))
    490         # Predict that the next tag is another copy of the same repeated field.
    491         pos = new_pos + tag_len
    492         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
    493           # Prediction failed.  Return.
    494           return new_pos
    495     return DecodeRepeatedField
    496   else:
    497     def DecodeField(buffer, pos, end, message, field_dict):
    498       (size, pos) = local_DecodeVarint(buffer, pos)
    499       new_pos = pos + size
    500       if new_pos > end:
    501         raise _DecodeError('Truncated string.')
    502       field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
    503       return new_pos
    504     return DecodeField
    505 
    506 
    507 def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
    508   """Returns a decoder for a bytes field."""
    509 
    510   local_DecodeVarint = _DecodeVarint
    511 
    512   assert not is_packed
    513   if is_repeated:
    514     tag_bytes = encoder.TagBytes(field_number,
    515                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
    516     tag_len = len(tag_bytes)
    517     def DecodeRepeatedField(buffer, pos, end, message, field_dict):
    518       value = field_dict.get(key)
    519       if value is None:
    520         value = field_dict.setdefault(key, new_default(message))
    521       while 1:
    522         (size, pos) = local_DecodeVarint(buffer, pos)
    523         new_pos = pos + size
    524         if new_pos > end:
    525           raise _DecodeError('Truncated string.')
    526         value.append(buffer[pos:new_pos])
    527         # Predict that the next tag is another copy of the same repeated field.
    528         pos = new_pos + tag_len
    529         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
    530           # Prediction failed.  Return.
    531           return new_pos
    532     return DecodeRepeatedField
    533   else:
    534     def DecodeField(buffer, pos, end, message, field_dict):
    535       (size, pos) = local_DecodeVarint(buffer, pos)
    536       new_pos = pos + size
    537       if new_pos > end:
    538         raise _DecodeError('Truncated string.')
    539       field_dict[key] = buffer[pos:new_pos]
    540       return new_pos
    541     return DecodeField
    542 
    543 
    544 def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
    545   """Returns a decoder for a group field."""
    546 
    547   end_tag_bytes = encoder.TagBytes(field_number,
    548                                    wire_format.WIRETYPE_END_GROUP)
    549   end_tag_len = len(end_tag_bytes)
    550 
    551   assert not is_packed
    552   if is_repeated:
    553     tag_bytes = encoder.TagBytes(field_number,
    554                                  wire_format.WIRETYPE_START_GROUP)
    555     tag_len = len(tag_bytes)
    556     def DecodeRepeatedField(buffer, pos, end, message, field_dict):
    557       value = field_dict.get(key)
    558       if value is None:
    559         value = field_dict.setdefault(key, new_default(message))
    560       while 1:
    561         value = field_dict.get(key)
    562         if value is None:
    563           value = field_dict.setdefault(key, new_default(message))
    564         # Read sub-message.
    565         pos = value.add()._InternalParse(buffer, pos, end)
    566         # Read end tag.
    567         new_pos = pos+end_tag_len
    568         if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
    569           raise _DecodeError('Missing group end tag.')
    570         # Predict that the next tag is another copy of the same repeated field.
    571         pos = new_pos + tag_len
    572         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
    573           # Prediction failed.  Return.
    574           return new_pos
    575     return DecodeRepeatedField
    576   else:
    577     def DecodeField(buffer, pos, end, message, field_dict):
    578       value = field_dict.get(key)
    579       if value is None:
    580         value = field_dict.setdefault(key, new_default(message))
    581       # Read sub-message.
    582       pos = value._InternalParse(buffer, pos, end)
    583       # Read end tag.
    584       new_pos = pos+end_tag_len
    585       if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
    586         raise _DecodeError('Missing group end tag.')
    587       return new_pos
    588     return DecodeField
    589 
    590 
    591 def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
    592   """Returns a decoder for a message field."""
    593 
    594   local_DecodeVarint = _DecodeVarint
    595 
    596   assert not is_packed
    597   if is_repeated:
    598     tag_bytes = encoder.TagBytes(field_number,
    599                                  wire_format.WIRETYPE_LENGTH_DELIMITED)
    600     tag_len = len(tag_bytes)
    601     def DecodeRepeatedField(buffer, pos, end, message, field_dict):
    602       value = field_dict.get(key)
    603       if value is None:
    604         value = field_dict.setdefault(key, new_default(message))
    605       while 1:
    606         # Read length.
    607         (size, pos) = local_DecodeVarint(buffer, pos)
    608         new_pos = pos + size
    609         if new_pos > end:
    610           raise _DecodeError('Truncated message.')
    611         # Read sub-message.
    612         if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
    613           # The only reason _InternalParse would return early is if it
    614           # encountered an end-group tag.
    615           raise _DecodeError('Unexpected end-group tag.')
    616         # Predict that the next tag is another copy of the same repeated field.
    617         pos = new_pos + tag_len
    618         if buffer[new_pos:pos] != tag_bytes or new_pos == end:
    619           # Prediction failed.  Return.
    620           return new_pos
    621     return DecodeRepeatedField
    622   else:
    623     def DecodeField(buffer, pos, end, message, field_dict):
    624       value = field_dict.get(key)
    625       if value is None:
    626         value = field_dict.setdefault(key, new_default(message))
    627       # Read length.
    628       (size, pos) = local_DecodeVarint(buffer, pos)
    629       new_pos = pos + size
    630       if new_pos > end:
    631         raise _DecodeError('Truncated message.')
    632       # Read sub-message.
    633       if value._InternalParse(buffer, pos, new_pos) != new_pos:
    634         # The only reason _InternalParse would return early is if it encountered
    635         # an end-group tag.
    636         raise _DecodeError('Unexpected end-group tag.')
    637       return new_pos
    638     return DecodeField
    639 
    640 
    641 # --------------------------------------------------------------------
    642 
    643 MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
    644 
    645 def MessageSetItemDecoder(extensions_by_number):
    646   """Returns a decoder for a MessageSet item.
    647 
    648   The parameter is the _extensions_by_number map for the message class.
    649 
    650   The message set message looks like this:
    651     message MessageSet {
    652       repeated group Item = 1 {
    653         required int32 type_id = 2;
    654         required string message = 3;
    655       }
    656     }
    657   """
    658 
    659   type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
    660   message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
    661   item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
    662 
    663   local_ReadTag = ReadTag
    664   local_DecodeVarint = _DecodeVarint
    665   local_SkipField = SkipField
    666 
    667   def DecodeItem(buffer, pos, end, message, field_dict):
    668     message_set_item_start = pos
    669     type_id = -1
    670     message_start = -1
    671     message_end = -1
    672 
    673     # Technically, type_id and message can appear in any order, so we need
    674     # a little loop here.
    675     while 1:
    676       (tag_bytes, pos) = local_ReadTag(buffer, pos)
    677       if tag_bytes == type_id_tag_bytes:
    678         (type_id, pos) = local_DecodeVarint(buffer, pos)
    679       elif tag_bytes == message_tag_bytes:
    680         (size, message_start) = local_DecodeVarint(buffer, pos)
    681         pos = message_end = message_start + size
    682       elif tag_bytes == item_end_tag_bytes:
    683         break
    684       else:
    685         pos = SkipField(buffer, pos, end, tag_bytes)
    686         if pos == -1:
    687           raise _DecodeError('Missing group end tag.')
    688 
    689     if pos > end:
    690       raise _DecodeError('Truncated message.')
    691 
    692     if type_id == -1:
    693       raise _DecodeError('MessageSet item missing type_id.')
    694     if message_start == -1:
    695       raise _DecodeError('MessageSet item missing message.')
    696 
    697     extension = extensions_by_number.get(type_id)
    698     if extension is not None:
    699       value = field_dict.get(extension)
    700       if value is None:
    701         value = field_dict.setdefault(
    702             extension, extension.message_type._concrete_class())
    703       if value._InternalParse(buffer, message_start,message_end) != message_end:
    704         # The only reason _InternalParse would return early is if it encountered
    705         # an end-group tag.
    706         raise _DecodeError('Unexpected end-group tag.')
    707     else:
    708       if not message._unknown_fields:
    709         message._unknown_fields = []
    710       message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
    711                                       buffer[message_set_item_start:pos]))
    712 
    713     return pos
    714 
    715   return DecodeItem
    716 
    717 # --------------------------------------------------------------------
    718 
    719 def MapDecoder(field_descriptor, new_default, is_message_map):
    720   """Returns a decoder for a map field."""
    721 
    722   key = field_descriptor
    723   tag_bytes = encoder.TagBytes(field_descriptor.number,
    724                                wire_format.WIRETYPE_LENGTH_DELIMITED)
    725   tag_len = len(tag_bytes)
    726   local_DecodeVarint = _DecodeVarint
    727   # Can't read _concrete_class yet; might not be initialized.
    728   message_type = field_descriptor.message_type
    729 
    730   def DecodeMap(buffer, pos, end, message, field_dict):
    731     submsg = message_type._concrete_class()
    732     value = field_dict.get(key)
    733     if value is None:
    734       value = field_dict.setdefault(key, new_default(message))
    735     while 1:
    736       # Read length.
    737       (size, pos) = local_DecodeVarint(buffer, pos)
    738       new_pos = pos + size
    739       if new_pos > end:
    740         raise _DecodeError('Truncated message.')
    741       # Read sub-message.
    742       submsg.Clear()
    743       if submsg._InternalParse(buffer, pos, new_pos) != new_pos:
    744         # The only reason _InternalParse would return early is if it
    745         # encountered an end-group tag.
    746         raise _DecodeError('Unexpected end-group tag.')
    747 
    748       if is_message_map:
    749         value[submsg.key].MergeFrom(submsg.value)
    750       else:
    751         value[submsg.key] = submsg.value
    752 
    753       # Predict that the next tag is another copy of the same repeated field.
    754       pos = new_pos + tag_len
    755       if buffer[new_pos:pos] != tag_bytes or new_pos == end:
    756         # Prediction failed.  Return.
    757         return new_pos
    758 
    759   return DecodeMap
    760 
    761 # --------------------------------------------------------------------
    762 # Optimization is not as heavy here because calls to SkipField() are rare,
    763 # except for handling end-group tags.
    764 
    765 def _SkipVarint(buffer, pos, end):
    766   """Skip a varint value.  Returns the new position."""
    767   # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
    768   # With this code, ord(b'') raises TypeError.  Both are handled in
    769   # python_message.py to generate a 'Truncated message' error.
    770   while ord(buffer[pos:pos+1]) & 0x80:
    771     pos += 1
    772   pos += 1
    773   if pos > end:
    774     raise _DecodeError('Truncated message.')
    775   return pos
    776 
    777 def _SkipFixed64(buffer, pos, end):
    778   """Skip a fixed64 value.  Returns the new position."""
    779 
    780   pos += 8
    781   if pos > end:
    782     raise _DecodeError('Truncated message.')
    783   return pos
    784 
    785 def _SkipLengthDelimited(buffer, pos, end):
    786   """Skip a length-delimited value.  Returns the new position."""
    787 
    788   (size, pos) = _DecodeVarint(buffer, pos)
    789   pos += size
    790   if pos > end:
    791     raise _DecodeError('Truncated message.')
    792   return pos
    793 
    794 def _SkipGroup(buffer, pos, end):
    795   """Skip sub-group.  Returns the new position."""
    796 
    797   while 1:
    798     (tag_bytes, pos) = ReadTag(buffer, pos)
    799     new_pos = SkipField(buffer, pos, end, tag_bytes)
    800     if new_pos == -1:
    801       return pos
    802     pos = new_pos
    803 
    804 def _EndGroup(buffer, pos, end):
    805   """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
    806 
    807   return -1
    808 
    809 def _SkipFixed32(buffer, pos, end):
    810   """Skip a fixed32 value.  Returns the new position."""
    811 
    812   pos += 4
    813   if pos > end:
    814     raise _DecodeError('Truncated message.')
    815   return pos
    816 
    817 def _RaiseInvalidWireType(buffer, pos, end):
    818   """Skip function for unknown wire types.  Raises an exception."""
    819 
    820   raise _DecodeError('Tag had invalid wire type.')
    821 
    822 def _FieldSkipper():
    823   """Constructs the SkipField function."""
    824 
    825   WIRETYPE_TO_SKIPPER = [
    826       _SkipVarint,
    827       _SkipFixed64,
    828       _SkipLengthDelimited,
    829       _SkipGroup,
    830       _EndGroup,
    831       _SkipFixed32,
    832       _RaiseInvalidWireType,
    833       _RaiseInvalidWireType,
    834       ]
    835 
    836   wiretype_mask = wire_format.TAG_TYPE_MASK
    837 
    838   def SkipField(buffer, pos, end, tag_bytes):
    839     """Skips a field with the specified tag.
    840 
    841     |pos| should point to the byte immediately after the tag.
    842 
    843     Returns:
    844         The new position (after the tag value), or -1 if the tag is an end-group
    845         tag (in which case the calling loop should break).
    846     """
    847 
    848     # The wire type is always in the first byte since varints are little-endian.
    849     wire_type = ord(tag_bytes[0:1]) & wiretype_mask
    850     return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
    851 
    852   return SkipField
    853 
    854 SkipField = _FieldSkipper()
    855