Home | History | Annotate | Download | only in protorpc
      1 #!/usr/bin/env python
      2 #
      3 # Copyright 2010 Google Inc.
      4 #
      5 # Licensed under the Apache License, Version 2.0 (the "License");
      6 # you may not use this file except in compliance with the License.
      7 # You may obtain a copy of the License at
      8 #
      9 #     http://www.apache.org/licenses/LICENSE-2.0
     10 #
     11 # Unless required by applicable law or agreed to in writing, software
     12 # distributed under the License is distributed on an "AS IS" BASIS,
     13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 # See the License for the specific language governing permissions and
     15 # limitations under the License.
     16 #
     17 
     18 """Protocol buffer support for message types.
     19 
     20 For more details about protocol buffer encoding and decoding please see:
     21 
     22   http://code.google.com/apis/protocolbuffers/docs/encoding.html
     23 
     24 Public Exceptions:
     25   DecodeError: Raised when a decode error occurs from incorrect protobuf format.
     26 
     27 Public Functions:
     28   encode_message: Encodes a message in to a protocol buffer string.
     29   decode_message: Decode from a protocol buffer string to a message.
     30 """
     31 import six
     32 
     33 __author__ = 'rafek (at] google.com (Rafe Kaplan)'
     34 
     35 
     36 import array
     37 
     38 from . import message_types
     39 from . import messages
     40 from . import util
     41 from .google_imports import ProtocolBuffer
     42 
     43 
     44 __all__ = ['ALTERNATIVE_CONTENT_TYPES',
     45            'CONTENT_TYPE',
     46            'encode_message',
     47            'decode_message',
     48           ]
     49 
     50 CONTENT_TYPE = 'application/octet-stream'
     51 
     52 ALTERNATIVE_CONTENT_TYPES = ['application/x-google-protobuf']
     53 
     54 
     55 class _Encoder(ProtocolBuffer.Encoder):
     56   """Extension of protocol buffer encoder.
     57 
     58   Original protocol buffer encoder does not have complete set of methods
     59   for handling required encoding.  This class adds them.
     60   """
     61 
     62   # TODO(rafek): Implement the missing encoding types.
     63   def no_encoding(self, value):
     64     """No encoding available for type.
     65 
     66     Args:
     67       value: Value to encode.
     68 
     69     Raises:
     70       NotImplementedError at all times.
     71     """
     72     raise NotImplementedError()
     73 
     74   def encode_enum(self, value):
     75     """Encode an enum value.
     76 
     77     Args:
     78       value: Enum to encode.
     79     """
     80     self.putVarInt32(value.number)
     81 
     82   def encode_message(self, value):
     83     """Encode a Message in to an embedded message.
     84 
     85     Args:
     86       value: Message instance to encode.
     87     """
     88     self.putPrefixedString(encode_message(value))
     89 
     90 
     91   def encode_unicode_string(self, value):
     92     """Helper to properly pb encode unicode strings to UTF-8.
     93 
     94     Args:
     95       value: String value to encode.
     96     """
     97     if isinstance(value, six.text_type):
     98       value = value.encode('utf-8')
     99     self.putPrefixedString(value)
    100 
    101 
    102 class _Decoder(ProtocolBuffer.Decoder):
    103   """Extension of protocol buffer decoder.
    104 
    105   Original protocol buffer decoder does not have complete set of methods
    106   for handling required decoding.  This class adds them.
    107   """
    108 
    109   # TODO(rafek): Implement the missing encoding types.
    110   def no_decoding(self):
    111     """No decoding available for type.
    112 
    113     Raises:
    114       NotImplementedError at all times.
    115     """
    116     raise NotImplementedError()
    117 
    118   def decode_string(self):
    119     """Decode a unicode string.
    120 
    121     Returns:
    122       Next value in stream as a unicode string.
    123     """
    124     return self.getPrefixedString().decode('UTF-8')
    125 
    126   def decode_boolean(self):
    127     """Decode a boolean value.
    128 
    129     Returns:
    130       Next value in stream as a boolean.
    131     """
    132     return bool(self.getBoolean())
    133 
    134 
    135 # Number of bits used to describe a protocol buffer bits used for the variant.
    136 _WIRE_TYPE_BITS = 3
    137 _WIRE_TYPE_MASK = 7
    138 
    139 
    140 # Maps variant to underlying wire type.  Many variants map to same type.
    141 _VARIANT_TO_WIRE_TYPE = {
    142     messages.Variant.DOUBLE: _Encoder.DOUBLE,
    143     messages.Variant.FLOAT: _Encoder.FLOAT,
    144     messages.Variant.INT64: _Encoder.NUMERIC,
    145     messages.Variant.UINT64: _Encoder.NUMERIC,
    146     messages.Variant.INT32:  _Encoder.NUMERIC,
    147     messages.Variant.BOOL: _Encoder.NUMERIC,
    148     messages.Variant.STRING: _Encoder.STRING,
    149     messages.Variant.MESSAGE: _Encoder.STRING,
    150     messages.Variant.BYTES: _Encoder.STRING,
    151     messages.Variant.UINT32: _Encoder.NUMERIC,
    152     messages.Variant.ENUM:  _Encoder.NUMERIC,
    153     messages.Variant.SINT32: _Encoder.NUMERIC,
    154     messages.Variant.SINT64: _Encoder.NUMERIC,
    155 }
    156 
    157 
    158 # Maps variant to encoder method.
    159 _VARIANT_TO_ENCODER_MAP = {
    160     messages.Variant.DOUBLE: _Encoder.putDouble,
    161     messages.Variant.FLOAT: _Encoder.putFloat,
    162     messages.Variant.INT64: _Encoder.putVarInt64,
    163     messages.Variant.UINT64: _Encoder.putVarUint64,
    164     messages.Variant.INT32: _Encoder.putVarInt32,
    165     messages.Variant.BOOL: _Encoder.putBoolean,
    166     messages.Variant.STRING: _Encoder.encode_unicode_string,
    167     messages.Variant.MESSAGE: _Encoder.encode_message,
    168     messages.Variant.BYTES: _Encoder.encode_unicode_string,
    169     messages.Variant.UINT32: _Encoder.no_encoding,
    170     messages.Variant.ENUM: _Encoder.encode_enum,
    171     messages.Variant.SINT32: _Encoder.no_encoding,
    172     messages.Variant.SINT64: _Encoder.no_encoding,
    173 }
    174 
    175 
    176 # Basic wire format decoders.  Used for reading unknown values.
    177 _WIRE_TYPE_TO_DECODER_MAP = {
    178   _Encoder.NUMERIC: _Decoder.getVarInt64,
    179   _Encoder.DOUBLE: _Decoder.getDouble,
    180   _Encoder.STRING: _Decoder.getPrefixedString,
    181   _Encoder.FLOAT: _Decoder.getFloat,
    182 }
    183 
    184 
    185 # Map wire type to variant.  Used to find a variant for unknown values.
    186 _WIRE_TYPE_TO_VARIANT_MAP = {
    187   _Encoder.NUMERIC: messages.Variant.INT64,
    188   _Encoder.DOUBLE: messages.Variant.DOUBLE,
    189   _Encoder.STRING: messages.Variant.STRING,
    190   _Encoder.FLOAT: messages.Variant.FLOAT,
    191 }
    192 
    193 
    194 # Wire type to name mapping for error messages.
    195 _WIRE_TYPE_NAME = {
    196   _Encoder.NUMERIC: 'NUMERIC',
    197   _Encoder.DOUBLE: 'DOUBLE',
    198   _Encoder.STRING: 'STRING',
    199   _Encoder.FLOAT: 'FLOAT',
    200 }
    201 
    202 
    203 # Maps variant to decoder method.
    204 _VARIANT_TO_DECODER_MAP = {
    205     messages.Variant.DOUBLE: _Decoder.getDouble,
    206     messages.Variant.FLOAT: _Decoder.getFloat,
    207     messages.Variant.INT64: _Decoder.getVarInt64,
    208     messages.Variant.UINT64: _Decoder.getVarUint64,
    209     messages.Variant.INT32:  _Decoder.getVarInt32,
    210     messages.Variant.BOOL: _Decoder.decode_boolean,
    211     messages.Variant.STRING: _Decoder.decode_string,
    212     messages.Variant.MESSAGE: _Decoder.getPrefixedString,
    213     messages.Variant.BYTES: _Decoder.getPrefixedString,
    214     messages.Variant.UINT32: _Decoder.no_decoding,
    215     messages.Variant.ENUM:  _Decoder.getVarInt32,
    216     messages.Variant.SINT32: _Decoder.no_decoding,
    217     messages.Variant.SINT64: _Decoder.no_decoding,
    218 }
    219 
    220 
    221 def encode_message(message):
    222   """Encode Message instance to protocol buffer.
    223 
    224   Args:
    225     Message instance to encode in to protocol buffer.
    226 
    227   Returns:
    228     String encoding of Message instance in protocol buffer format.
    229 
    230   Raises:
    231     messages.ValidationError if message is not initialized.
    232   """
    233   message.check_initialized()
    234   encoder = _Encoder()
    235 
    236   # Get all fields, from the known fields we parsed and the unknown fields
    237   # we saved.  Note which ones were known, so we can process them differently.
    238   all_fields = [(field.number, field) for field in message.all_fields()]
    239   all_fields.extend((key, None)
    240                     for key in message.all_unrecognized_fields()
    241                     if isinstance(key, six.integer_types))
    242   all_fields.sort()
    243   for field_num, field in all_fields:
    244     if field:
    245       # Known field.
    246       value = message.get_assigned_value(field.name)
    247       if value is None:
    248         continue
    249       variant = field.variant
    250       repeated = field.repeated
    251     else:
    252       # Unrecognized field.
    253       value, variant = message.get_unrecognized_field_info(field_num)
    254       if not isinstance(variant, messages.Variant):
    255         continue
    256       repeated = isinstance(value, (list, tuple))
    257 
    258     tag = ((field_num << _WIRE_TYPE_BITS) | _VARIANT_TO_WIRE_TYPE[variant])
    259 
    260     # Write value to wire.
    261     if repeated:
    262       values = value
    263     else:
    264       values = [value]
    265     for next in values:
    266       encoder.putVarInt32(tag)
    267       if isinstance(field, messages.MessageField):
    268         next = field.value_to_message(next)
    269       field_encoder = _VARIANT_TO_ENCODER_MAP[variant]
    270       field_encoder(encoder, next)
    271 
    272   return encoder.buffer().tostring()
    273 
    274 
    275 def decode_message(message_type, encoded_message):
    276   """Decode protocol buffer to Message instance.
    277 
    278   Args:
    279     message_type: Message type to decode data to.
    280     encoded_message: Encoded version of message as string.
    281 
    282   Returns:
    283     Decoded instance of message_type.
    284 
    285   Raises:
    286     DecodeError if an error occurs during decoding, such as incompatible
    287       wire format for a field.
    288     messages.ValidationError if merged message is not initialized.
    289   """
    290   message = message_type()
    291   message_array = array.array('B')
    292   message_array.fromstring(encoded_message)
    293   try:
    294     decoder = _Decoder(message_array, 0, len(message_array))
    295 
    296     while decoder.avail() > 0:
    297       # Decode tag and variant information.
    298       encoded_tag = decoder.getVarInt32()
    299       tag = encoded_tag >> _WIRE_TYPE_BITS
    300       wire_type = encoded_tag & _WIRE_TYPE_MASK
    301       try:
    302         found_wire_type_decoder = _WIRE_TYPE_TO_DECODER_MAP[wire_type]
    303       except:
    304         raise messages.DecodeError('No such wire type %d' % wire_type)
    305 
    306       if tag < 1:
    307         raise messages.DecodeError('Invalid tag value %d' % tag)
    308 
    309       try:
    310         field = message.field_by_number(tag)
    311       except KeyError:
    312         # Unexpected tags are ok.
    313         field = None
    314         wire_type_decoder = found_wire_type_decoder
    315       else:
    316         expected_wire_type = _VARIANT_TO_WIRE_TYPE[field.variant]
    317         if expected_wire_type != wire_type:
    318           raise messages.DecodeError('Expected wire type %s but found %s' % (
    319               _WIRE_TYPE_NAME[expected_wire_type],
    320               _WIRE_TYPE_NAME[wire_type]))
    321 
    322         wire_type_decoder = _VARIANT_TO_DECODER_MAP[field.variant]
    323 
    324       value = wire_type_decoder(decoder)
    325 
    326       # Save unknown fields and skip additional processing.
    327       if not field:
    328         # When saving this, save it under the tag number (which should
    329         # be unique), and set the variant and value so we know how to
    330         # interpret the value later.
    331         variant = _WIRE_TYPE_TO_VARIANT_MAP.get(wire_type)
    332         if variant:
    333           message.set_unrecognized_field(tag, value, variant)
    334         continue
    335 
    336       # Special case Enum and Message types.
    337       if isinstance(field, messages.EnumField):
    338         try:
    339           value = field.type(value)
    340         except TypeError:
    341           raise messages.DecodeError('Invalid enum value %s' % value)
    342       elif isinstance(field, messages.MessageField):
    343         value = decode_message(field.message_type, value)
    344         value = field.value_from_message(value)
    345 
    346       # Merge value in to message.
    347       if field.repeated:
    348         values = getattr(message, field.name)
    349         if values is None:
    350           setattr(message, field.name, [value])
    351         else:
    352           values.append(value)
    353       else:
    354         setattr(message, field.name, value)
    355   except ProtocolBuffer.ProtocolBufferDecodeError as err:
    356     raise messages.DecodeError('Decoding error: %s' % str(err))
    357 
    358   message.check_initialized()
    359   return message
    360