Home | History | Annotate | Download | only in internal
      1 # Protocol Buffers - Google's data interchange format
      2 # Copyright 2008 Google Inc.  All rights reserved.
      3 # http://code.google.com/p/protobuf/
      4 #
      5 # Redistribution and use in source and binary forms, with or without
      6 # modification, are permitted provided that the following conditions are
      7 # met:
      8 #
      9 #     * Redistributions of source code must retain the above copyright
     10 # notice, this list of conditions and the following disclaimer.
     11 #     * Redistributions in binary form must reproduce the above
     12 # copyright notice, this list of conditions and the following disclaimer
     13 # in the documentation and/or other materials provided with the
     14 # distribution.
     15 #     * Neither the name of Google Inc. nor the names of its
     16 # contributors may be used to endorse or promote products derived from
     17 # this software without specific prior written permission.
     18 #
     19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     30 
     31 # This code is meant to work on Python 2.4 and above only.
     32 #
     33 # TODO(robinson): Helpers for verbose, common checks like seeing if a
     34 # descriptor's cpp_type is CPPTYPE_MESSAGE.
     35 
     36 """Contains a metaclass and helper functions used to create
     37 protocol message classes from Descriptor objects at runtime.
     38 
     39 Recall that a metaclass is the "type" of a class.
     40 (A class is to a metaclass what an instance is to a class.)
     41 
     42 In this case, we use the GeneratedProtocolMessageType metaclass
     43 to inject all the useful functionality into the classes
     44 output by the protocol compiler at compile-time.
     45 
     46 The upshot of all this is that the real implementation
     47 details for ALL pure-Python protocol buffers are *here in
     48 this file*.
     49 """
     50 
     51 __author__ = 'robinson (at] google.com (Will Robinson)'
     52 
     53 try:
     54   from cStringIO import StringIO
     55 except ImportError:
     56   from StringIO import StringIO
     57 import copy_reg
     58 import struct
     59 import weakref
     60 
     61 # We use "as" to avoid name collisions with variables.
     62 from google.protobuf.internal import containers
     63 from google.protobuf.internal import decoder
     64 from google.protobuf.internal import encoder
     65 from google.protobuf.internal import enum_type_wrapper
     66 from google.protobuf.internal import message_listener as message_listener_mod
     67 from google.protobuf.internal import type_checkers
     68 from google.protobuf.internal import wire_format
     69 from google.protobuf import descriptor as descriptor_mod
     70 from google.protobuf import message as message_mod
     71 from google.protobuf import text_format
     72 
     73 _FieldDescriptor = descriptor_mod.FieldDescriptor
     74 
     75 
     76 def NewMessage(bases, descriptor, dictionary):
     77   _AddClassAttributesForNestedExtensions(descriptor, dictionary)
     78   _AddSlots(descriptor, dictionary)
     79   return bases
     80 
     81 
     82 def InitMessage(descriptor, cls):
     83   cls._decoders_by_tag = {}
     84   cls._extensions_by_name = {}
     85   cls._extensions_by_number = {}
     86   if (descriptor.has_options and
     87       descriptor.GetOptions().message_set_wire_format):
     88     cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
     89         decoder.MessageSetItemDecoder(cls._extensions_by_number))
     90 
     91   # Attach stuff to each FieldDescriptor for quick lookup later on.
     92   for field in descriptor.fields:
     93     _AttachFieldHelpers(cls, field)
     94 
     95   _AddEnumValues(descriptor, cls)
     96   _AddInitMethod(descriptor, cls)
     97   _AddPropertiesForFields(descriptor, cls)
     98   _AddPropertiesForExtensions(descriptor, cls)
     99   _AddStaticMethods(cls)
    100   _AddMessageMethods(descriptor, cls)
    101   _AddPrivateHelperMethods(cls)
    102   copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__()))
    103 
    104 
    105 # Stateless helpers for GeneratedProtocolMessageType below.
    106 # Outside clients should not access these directly.
    107 #
    108 # I opted not to make any of these methods on the metaclass, to make it more
    109 # clear that I'm not really using any state there and to keep clients from
    110 # thinking that they have direct access to these construction helpers.
    111 
    112 
    113 def _PropertyName(proto_field_name):
    114   """Returns the name of the public property attribute which
    115   clients can use to get and (in some cases) set the value
    116   of a protocol message field.
    117 
    118   Args:
    119     proto_field_name: The protocol message field name, exactly
    120       as it appears (or would appear) in a .proto file.
    121   """
    122   # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
    123   # nnorwitz makes my day by writing:
    124   # """
    125   # FYI.  See the keyword module in the stdlib. This could be as simple as:
    126   #
    127   # if keyword.iskeyword(proto_field_name):
    128   #   return proto_field_name + "_"
    129   # return proto_field_name
    130   # """
    131   # Kenton says:  The above is a BAD IDEA.  People rely on being able to use
    132   #   getattr() and setattr() to reflectively manipulate field values.  If we
    133   #   rename the properties, then every such user has to also make sure to apply
    134   #   the same transformation.  Note that currently if you name a field "yield",
    135   #   you can still access it just fine using getattr/setattr -- it's not even
    136   #   that cumbersome to do so.
    137   # TODO(kenton):  Remove this method entirely if/when everyone agrees with my
    138   #   position.
    139   return proto_field_name
    140 
    141 
    142 def _VerifyExtensionHandle(message, extension_handle):
    143   """Verify that the given extension handle is valid."""
    144 
    145   if not isinstance(extension_handle, _FieldDescriptor):
    146     raise KeyError('HasExtension() expects an extension handle, got: %s' %
    147                    extension_handle)
    148 
    149   if not extension_handle.is_extension:
    150     raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
    151 
    152   if not extension_handle.containing_type:
    153     raise KeyError('"%s" is missing a containing_type.'
    154                    % extension_handle.full_name)
    155 
    156   if extension_handle.containing_type is not message.DESCRIPTOR:
    157     raise KeyError('Extension "%s" extends message type "%s", but this '
    158                    'message is of type "%s".' %
    159                    (extension_handle.full_name,
    160                     extension_handle.containing_type.full_name,
    161                     message.DESCRIPTOR.full_name))
    162 
    163 
    164 def _AddSlots(message_descriptor, dictionary):
    165   """Adds a __slots__ entry to dictionary, containing the names of all valid
    166   attributes for this message type.
    167 
    168   Args:
    169     message_descriptor: A Descriptor instance describing this message type.
    170     dictionary: Class dictionary to which we'll add a '__slots__' entry.
    171   """
    172   dictionary['__slots__'] = ['_cached_byte_size',
    173                              '_cached_byte_size_dirty',
    174                              '_fields',
    175                              '_unknown_fields',
    176                              '_is_present_in_parent',
    177                              '_listener',
    178                              '_listener_for_children',
    179                              '__weakref__']
    180 
    181 
    182 def _IsMessageSetExtension(field):
    183   return (field.is_extension and
    184           field.containing_type.has_options and
    185           field.containing_type.GetOptions().message_set_wire_format and
    186           field.type == _FieldDescriptor.TYPE_MESSAGE and
    187           field.message_type == field.extension_scope and
    188           field.label == _FieldDescriptor.LABEL_OPTIONAL)
    189 
    190 
    191 def _AttachFieldHelpers(cls, field_descriptor):
    192   is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
    193   is_packed = (field_descriptor.has_options and
    194                field_descriptor.GetOptions().packed)
    195 
    196   if _IsMessageSetExtension(field_descriptor):
    197     field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
    198     sizer = encoder.MessageSetItemSizer(field_descriptor.number)
    199   else:
    200     field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
    201         field_descriptor.number, is_repeated, is_packed)
    202     sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
    203         field_descriptor.number, is_repeated, is_packed)
    204 
    205   field_descriptor._encoder = field_encoder
    206   field_descriptor._sizer = sizer
    207   field_descriptor._default_constructor = _DefaultValueConstructorForField(
    208       field_descriptor)
    209 
    210   def AddDecoder(wiretype, is_packed):
    211     tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
    212     cls._decoders_by_tag[tag_bytes] = (
    213         type_checkers.TYPE_TO_DECODER[field_descriptor.type](
    214             field_descriptor.number, is_repeated, is_packed,
    215             field_descriptor, field_descriptor._default_constructor))
    216 
    217   AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
    218              False)
    219 
    220   if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
    221     # To support wire compatibility of adding packed = true, add a decoder for
    222     # packed values regardless of the field's options.
    223     AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
    224 
    225 
    226 def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
    227   extension_dict = descriptor.extensions_by_name
    228   for extension_name, extension_field in extension_dict.iteritems():
    229     assert extension_name not in dictionary
    230     dictionary[extension_name] = extension_field
    231 
    232 
    233 def _AddEnumValues(descriptor, cls):
    234   """Sets class-level attributes for all enum fields defined in this message.
    235 
    236   Also exporting a class-level object that can name enum values.
    237 
    238   Args:
    239     descriptor: Descriptor object for this message type.
    240     cls: Class we're constructing for this message type.
    241   """
    242   for enum_type in descriptor.enum_types:
    243     setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type))
    244     for enum_value in enum_type.values:
    245       setattr(cls, enum_value.name, enum_value.number)
    246 
    247 
    248 def _DefaultValueConstructorForField(field):
    249   """Returns a function which returns a default value for a field.
    250 
    251   Args:
    252     field: FieldDescriptor object for this field.
    253 
    254   The returned function has one argument:
    255     message: Message instance containing this field, or a weakref proxy
    256       of same.
    257 
    258   That function in turn returns a default value for this field.  The default
    259     value may refer back to |message| via a weak reference.
    260   """
    261 
    262   if field.label == _FieldDescriptor.LABEL_REPEATED:
    263     if field.has_default_value and field.default_value != []:
    264       raise ValueError('Repeated field default value not empty list: %s' % (
    265           field.default_value))
    266     if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    267       # We can't look at _concrete_class yet since it might not have
    268       # been set.  (Depends on order in which we initialize the classes).
    269       message_type = field.message_type
    270       def MakeRepeatedMessageDefault(message):
    271         return containers.RepeatedCompositeFieldContainer(
    272             message._listener_for_children, field.message_type)
    273       return MakeRepeatedMessageDefault
    274     else:
    275       type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
    276       def MakeRepeatedScalarDefault(message):
    277         return containers.RepeatedScalarFieldContainer(
    278             message._listener_for_children, type_checker)
    279       return MakeRepeatedScalarDefault
    280 
    281   if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    282     # _concrete_class may not yet be initialized.
    283     message_type = field.message_type
    284     def MakeSubMessageDefault(message):
    285       result = message_type._concrete_class()
    286       result._SetListener(message._listener_for_children)
    287       return result
    288     return MakeSubMessageDefault
    289 
    290   def MakeScalarDefault(message):
    291     # TODO(protobuf-team): This may be broken since there may not be
    292     # default_value.  Combine with has_default_value somehow.
    293     return field.default_value
    294   return MakeScalarDefault
    295 
    296 
    297 def _AddInitMethod(message_descriptor, cls):
    298   """Adds an __init__ method to cls."""
    299   fields = message_descriptor.fields
    300   def init(self, **kwargs):
    301     self._cached_byte_size = 0
    302     self._cached_byte_size_dirty = len(kwargs) > 0
    303     self._fields = {}
    304     # _unknown_fields is () when empty for efficiency, and will be turned into
    305     # a list if fields are added.
    306     self._unknown_fields = ()
    307     self._is_present_in_parent = False
    308     self._listener = message_listener_mod.NullMessageListener()
    309     self._listener_for_children = _Listener(self)
    310     for field_name, field_value in kwargs.iteritems():
    311       field = _GetFieldByName(message_descriptor, field_name)
    312       if field is None:
    313         raise TypeError("%s() got an unexpected keyword argument '%s'" %
    314                         (message_descriptor.name, field_name))
    315       if field.label == _FieldDescriptor.LABEL_REPEATED:
    316         copy = field._default_constructor(self)
    317         if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:  # Composite
    318           for val in field_value:
    319             copy.add().MergeFrom(val)
    320         else:  # Scalar
    321           copy.extend(field_value)
    322         self._fields[field] = copy
    323       elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    324         copy = field._default_constructor(self)
    325         copy.MergeFrom(field_value)
    326         self._fields[field] = copy
    327       else:
    328         setattr(self, field_name, field_value)
    329 
    330   init.__module__ = None
    331   init.__doc__ = None
    332   cls.__init__ = init
    333 
    334 
    335 def _GetFieldByName(message_descriptor, field_name):
    336   """Returns a field descriptor by field name.
    337 
    338   Args:
    339     message_descriptor: A Descriptor describing all fields in message.
    340     field_name: The name of the field to retrieve.
    341   Returns:
    342     The field descriptor associated with the field name.
    343   """
    344   try:
    345     return message_descriptor.fields_by_name[field_name]
    346   except KeyError:
    347     raise ValueError('Protocol message has no "%s" field.' % field_name)
    348 
    349 
    350 def _AddPropertiesForFields(descriptor, cls):
    351   """Adds properties for all fields in this protocol message type."""
    352   for field in descriptor.fields:
    353     _AddPropertiesForField(field, cls)
    354 
    355   if descriptor.is_extendable:
    356     # _ExtensionDict is just an adaptor with no state so we allocate a new one
    357     # every time it is accessed.
    358     cls.Extensions = property(lambda self: _ExtensionDict(self))
    359 
    360 
    361 def _AddPropertiesForField(field, cls):
    362   """Adds a public property for a protocol message field.
    363   Clients can use this property to get and (in the case
    364   of non-repeated scalar fields) directly set the value
    365   of a protocol message field.
    366 
    367   Args:
    368     field: A FieldDescriptor for this field.
    369     cls: The class we're constructing.
    370   """
    371   # Catch it if we add other types that we should
    372   # handle specially here.
    373   assert _FieldDescriptor.MAX_CPPTYPE == 10
    374 
    375   constant_name = field.name.upper() + "_FIELD_NUMBER"
    376   setattr(cls, constant_name, field.number)
    377 
    378   if field.label == _FieldDescriptor.LABEL_REPEATED:
    379     _AddPropertiesForRepeatedField(field, cls)
    380   elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    381     _AddPropertiesForNonRepeatedCompositeField(field, cls)
    382   else:
    383     _AddPropertiesForNonRepeatedScalarField(field, cls)
    384 
    385 
    386 def _AddPropertiesForRepeatedField(field, cls):
    387   """Adds a public property for a "repeated" protocol message field.  Clients
    388   can use this property to get the value of the field, which will be either a
    389   _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
    390   below).
    391 
    392   Note that when clients add values to these containers, we perform
    393   type-checking in the case of repeated scalar fields, and we also set any
    394   necessary "has" bits as a side-effect.
    395 
    396   Args:
    397     field: A FieldDescriptor for this field.
    398     cls: The class we're constructing.
    399   """
    400   proto_field_name = field.name
    401   property_name = _PropertyName(proto_field_name)
    402 
    403   def getter(self):
    404     field_value = self._fields.get(field)
    405     if field_value is None:
    406       # Construct a new object to represent this field.
    407       field_value = field._default_constructor(self)
    408 
    409       # Atomically check if another thread has preempted us and, if not, swap
    410       # in the new object we just created.  If someone has preempted us, we
    411       # take that object and discard ours.
    412       # WARNING:  We are relying on setdefault() being atomic.  This is true
    413       #   in CPython but we haven't investigated others.  This warning appears
    414       #   in several other locations in this file.
    415       field_value = self._fields.setdefault(field, field_value)
    416     return field_value
    417   getter.__module__ = None
    418   getter.__doc__ = 'Getter for %s.' % proto_field_name
    419 
    420   # We define a setter just so we can throw an exception with a more
    421   # helpful error message.
    422   def setter(self, new_value):
    423     raise AttributeError('Assignment not allowed to repeated field '
    424                          '"%s" in protocol message object.' % proto_field_name)
    425 
    426   doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
    427   setattr(cls, property_name, property(getter, setter, doc=doc))
    428 
    429 
    430 def _AddPropertiesForNonRepeatedScalarField(field, cls):
    431   """Adds a public property for a nonrepeated, scalar protocol message field.
    432   Clients can use this property to get and directly set the value of the field.
    433   Note that when the client sets the value of a field by using this property,
    434   all necessary "has" bits are set as a side-effect, and we also perform
    435   type-checking.
    436 
    437   Args:
    438     field: A FieldDescriptor for this field.
    439     cls: The class we're constructing.
    440   """
    441   proto_field_name = field.name
    442   property_name = _PropertyName(proto_field_name)
    443   type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
    444   default_value = field.default_value
    445   valid_values = set()
    446 
    447   def getter(self):
    448     # TODO(protobuf-team): This may be broken since there may not be
    449     # default_value.  Combine with has_default_value somehow.
    450     return self._fields.get(field, default_value)
    451   getter.__module__ = None
    452   getter.__doc__ = 'Getter for %s.' % proto_field_name
    453   def setter(self, new_value):
    454     type_checker.CheckValue(new_value)
    455     self._fields[field] = new_value
    456     # Check _cached_byte_size_dirty inline to improve performance, since scalar
    457     # setters are called frequently.
    458     if not self._cached_byte_size_dirty:
    459       self._Modified()
    460 
    461   setter.__module__ = None
    462   setter.__doc__ = 'Setter for %s.' % proto_field_name
    463 
    464   # Add a property to encapsulate the getter/setter.
    465   doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
    466   setattr(cls, property_name, property(getter, setter, doc=doc))
    467 
    468 
    469 def _AddPropertiesForNonRepeatedCompositeField(field, cls):
    470   """Adds a public property for a nonrepeated, composite protocol message field.
    471   A composite field is a "group" or "message" field.
    472 
    473   Clients can use this property to get the value of the field, but cannot
    474   assign to the property directly.
    475 
    476   Args:
    477     field: A FieldDescriptor for this field.
    478     cls: The class we're constructing.
    479   """
    480   # TODO(robinson): Remove duplication with similar method
    481   # for non-repeated scalars.
    482   proto_field_name = field.name
    483   property_name = _PropertyName(proto_field_name)
    484 
    485   # TODO(komarek): Can anyone explain to me why we cache the message_type this
    486   # way, instead of referring to field.message_type inside of getter(self)?
    487   # What if someone sets message_type later on (which makes for simpler
    488   # dyanmic proto descriptor and class creation code).
    489   message_type = field.message_type
    490 
    491   def getter(self):
    492     field_value = self._fields.get(field)
    493     if field_value is None:
    494       # Construct a new object to represent this field.
    495       field_value = message_type._concrete_class()  # use field.message_type?
    496       field_value._SetListener(self._listener_for_children)
    497 
    498       # Atomically check if another thread has preempted us and, if not, swap
    499       # in the new object we just created.  If someone has preempted us, we
    500       # take that object and discard ours.
    501       # WARNING:  We are relying on setdefault() being atomic.  This is true
    502       #   in CPython but we haven't investigated others.  This warning appears
    503       #   in several other locations in this file.
    504       field_value = self._fields.setdefault(field, field_value)
    505     return field_value
    506   getter.__module__ = None
    507   getter.__doc__ = 'Getter for %s.' % proto_field_name
    508 
    509   # We define a setter just so we can throw an exception with a more
    510   # helpful error message.
    511   def setter(self, new_value):
    512     raise AttributeError('Assignment not allowed to composite field '
    513                          '"%s" in protocol message object.' % proto_field_name)
    514 
    515   # Add a property to encapsulate the getter.
    516   doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
    517   setattr(cls, property_name, property(getter, setter, doc=doc))
    518 
    519 
    520 def _AddPropertiesForExtensions(descriptor, cls):
    521   """Adds properties for all fields in this protocol message type."""
    522   extension_dict = descriptor.extensions_by_name
    523   for extension_name, extension_field in extension_dict.iteritems():
    524     constant_name = extension_name.upper() + "_FIELD_NUMBER"
    525     setattr(cls, constant_name, extension_field.number)
    526 
    527 
    528 def _AddStaticMethods(cls):
    529   # TODO(robinson): This probably needs to be thread-safe(?)
    530   def RegisterExtension(extension_handle):
    531     extension_handle.containing_type = cls.DESCRIPTOR
    532     _AttachFieldHelpers(cls, extension_handle)
    533 
    534     # Try to insert our extension, failing if an extension with the same number
    535     # already exists.
    536     actual_handle = cls._extensions_by_number.setdefault(
    537         extension_handle.number, extension_handle)
    538     if actual_handle is not extension_handle:
    539       raise AssertionError(
    540           'Extensions "%s" and "%s" both try to extend message type "%s" with '
    541           'field number %d.' %
    542           (extension_handle.full_name, actual_handle.full_name,
    543            cls.DESCRIPTOR.full_name, extension_handle.number))
    544 
    545     cls._extensions_by_name[extension_handle.full_name] = extension_handle
    546 
    547     handle = extension_handle  # avoid line wrapping
    548     if _IsMessageSetExtension(handle):
    549       # MessageSet extension.  Also register under type name.
    550       cls._extensions_by_name[
    551           extension_handle.message_type.full_name] = extension_handle
    552 
    553   cls.RegisterExtension = staticmethod(RegisterExtension)
    554 
    555   def FromString(s):
    556     message = cls()
    557     message.MergeFromString(s)
    558     return message
    559   cls.FromString = staticmethod(FromString)
    560 
    561 
    562 def _IsPresent(item):
    563   """Given a (FieldDescriptor, value) tuple from _fields, return true if the
    564   value should be included in the list returned by ListFields()."""
    565 
    566   if item[0].label == _FieldDescriptor.LABEL_REPEATED:
    567     return bool(item[1])
    568   elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    569     return item[1]._is_present_in_parent
    570   else:
    571     return True
    572 
    573 
    574 def _AddListFieldsMethod(message_descriptor, cls):
    575   """Helper for _AddMessageMethods()."""
    576 
    577   def ListFields(self):
    578     all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
    579     all_fields.sort(key = lambda item: item[0].number)
    580     return all_fields
    581 
    582   cls.ListFields = ListFields
    583 
    584 
    585 def _AddHasFieldMethod(message_descriptor, cls):
    586   """Helper for _AddMessageMethods()."""
    587 
    588   singular_fields = {}
    589   for field in message_descriptor.fields:
    590     if field.label != _FieldDescriptor.LABEL_REPEATED:
    591       singular_fields[field.name] = field
    592 
    593   def HasField(self, field_name):
    594     try:
    595       field = singular_fields[field_name]
    596     except KeyError:
    597       raise ValueError(
    598           'Protocol message has no singular "%s" field.' % field_name)
    599 
    600     if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    601       value = self._fields.get(field)
    602       return value is not None and value._is_present_in_parent
    603     else:
    604       return field in self._fields
    605   cls.HasField = HasField
    606 
    607 
    608 def _AddClearFieldMethod(message_descriptor, cls):
    609   """Helper for _AddMessageMethods()."""
    610   def ClearField(self, field_name):
    611     try:
    612       field = message_descriptor.fields_by_name[field_name]
    613     except KeyError:
    614       raise ValueError('Protocol message has no "%s" field.' % field_name)
    615 
    616     if field in self._fields:
    617       # Note:  If the field is a sub-message, its listener will still point
    618       #   at us.  That's fine, because the worst than can happen is that it
    619       #   will call _Modified() and invalidate our byte size.  Big deal.
    620       del self._fields[field]
    621 
    622     # Always call _Modified() -- even if nothing was changed, this is
    623     # a mutating method, and thus calling it should cause the field to become
    624     # present in the parent message.
    625     self._Modified()
    626 
    627   cls.ClearField = ClearField
    628 
    629 
    630 def _AddClearExtensionMethod(cls):
    631   """Helper for _AddMessageMethods()."""
    632   def ClearExtension(self, extension_handle):
    633     _VerifyExtensionHandle(self, extension_handle)
    634 
    635     # Similar to ClearField(), above.
    636     if extension_handle in self._fields:
    637       del self._fields[extension_handle]
    638     self._Modified()
    639   cls.ClearExtension = ClearExtension
    640 
    641 
    642 def _AddClearMethod(message_descriptor, cls):
    643   """Helper for _AddMessageMethods()."""
    644   def Clear(self):
    645     # Clear fields.
    646     self._fields = {}
    647     self._unknown_fields = ()
    648     self._Modified()
    649   cls.Clear = Clear
    650 
    651 
    652 def _AddHasExtensionMethod(cls):
    653   """Helper for _AddMessageMethods()."""
    654   def HasExtension(self, extension_handle):
    655     _VerifyExtensionHandle(self, extension_handle)
    656     if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
    657       raise KeyError('"%s" is repeated.' % extension_handle.full_name)
    658 
    659     if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    660       value = self._fields.get(extension_handle)
    661       return value is not None and value._is_present_in_parent
    662     else:
    663       return extension_handle in self._fields
    664   cls.HasExtension = HasExtension
    665 
    666 
    667 def _AddEqualsMethod(message_descriptor, cls):
    668   """Helper for _AddMessageMethods()."""
    669   def __eq__(self, other):
    670     if (not isinstance(other, message_mod.Message) or
    671         other.DESCRIPTOR != self.DESCRIPTOR):
    672       return False
    673 
    674     if self is other:
    675       return True
    676 
    677     if not self.ListFields() == other.ListFields():
    678       return False
    679 
    680     # Sort unknown fields because their order shouldn't affect equality test.
    681     unknown_fields = list(self._unknown_fields)
    682     unknown_fields.sort()
    683     other_unknown_fields = list(other._unknown_fields)
    684     other_unknown_fields.sort()
    685 
    686     return unknown_fields == other_unknown_fields
    687 
    688   cls.__eq__ = __eq__
    689 
    690 
    691 def _AddStrMethod(message_descriptor, cls):
    692   """Helper for _AddMessageMethods()."""
    693   def __str__(self):
    694     return text_format.MessageToString(self)
    695   cls.__str__ = __str__
    696 
    697 
    698 def _AddUnicodeMethod(unused_message_descriptor, cls):
    699   """Helper for _AddMessageMethods()."""
    700 
    701   def __unicode__(self):
    702     return text_format.MessageToString(self, as_utf8=True).decode('utf-8')
    703   cls.__unicode__ = __unicode__
    704 
    705 
    706 def _AddSetListenerMethod(cls):
    707   """Helper for _AddMessageMethods()."""
    708   def SetListener(self, listener):
    709     if listener is None:
    710       self._listener = message_listener_mod.NullMessageListener()
    711     else:
    712       self._listener = listener
    713   cls._SetListener = SetListener
    714 
    715 
    716 def _BytesForNonRepeatedElement(value, field_number, field_type):
    717   """Returns the number of bytes needed to serialize a non-repeated element.
    718   The returned byte count includes space for tag information and any
    719   other additional space associated with serializing value.
    720 
    721   Args:
    722     value: Value we're serializing.
    723     field_number: Field number of this value.  (Since the field number
    724       is stored as part of a varint-encoded tag, this has an impact
    725       on the total bytes required to serialize the value).
    726     field_type: The type of the field.  One of the TYPE_* constants
    727       within FieldDescriptor.
    728   """
    729   try:
    730     fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
    731     return fn(field_number, value)
    732   except KeyError:
    733     raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
    734 
    735 
    736 def _AddByteSizeMethod(message_descriptor, cls):
    737   """Helper for _AddMessageMethods()."""
    738 
    739   def ByteSize(self):
    740     if not self._cached_byte_size_dirty:
    741       return self._cached_byte_size
    742 
    743     size = 0
    744     for field_descriptor, field_value in self.ListFields():
    745       size += field_descriptor._sizer(field_value)
    746 
    747     for tag_bytes, value_bytes in self._unknown_fields:
    748       size += len(tag_bytes) + len(value_bytes)
    749 
    750     self._cached_byte_size = size
    751     self._cached_byte_size_dirty = False
    752     self._listener_for_children.dirty = False
    753     return size
    754 
    755   cls.ByteSize = ByteSize
    756 
    757 
    758 def _AddSerializeToStringMethod(message_descriptor, cls):
    759   """Helper for _AddMessageMethods()."""
    760 
    761   def SerializeToString(self):
    762     # Check if the message has all of its required fields set.
    763     errors = []
    764     if not self.IsInitialized():
    765       raise message_mod.EncodeError(
    766           'Message %s is missing required fields: %s' % (
    767           self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
    768     return self.SerializePartialToString()
    769   cls.SerializeToString = SerializeToString
    770 
    771 
    772 def _AddSerializePartialToStringMethod(message_descriptor, cls):
    773   """Helper for _AddMessageMethods()."""
    774 
    775   def SerializePartialToString(self):
    776     out = StringIO()
    777     self._InternalSerialize(out.write)
    778     return out.getvalue()
    779   cls.SerializePartialToString = SerializePartialToString
    780 
    781   def InternalSerialize(self, write_bytes):
    782     for field_descriptor, field_value in self.ListFields():
    783       field_descriptor._encoder(write_bytes, field_value)
    784     for tag_bytes, value_bytes in self._unknown_fields:
    785       write_bytes(tag_bytes)
    786       write_bytes(value_bytes)
    787   cls._InternalSerialize = InternalSerialize
    788 
    789 
    790 def _AddMergeFromStringMethod(message_descriptor, cls):
    791   """Helper for _AddMessageMethods()."""
    792   def MergeFromString(self, serialized):
    793     length = len(serialized)
    794     try:
    795       if self._InternalParse(serialized, 0, length) != length:
    796         # The only reason _InternalParse would return early is if it
    797         # encountered an end-group tag.
    798         raise message_mod.DecodeError('Unexpected end-group tag.')
    799     except IndexError:
    800       raise message_mod.DecodeError('Truncated message.')
    801     except struct.error, e:
    802       raise message_mod.DecodeError(e)
    803     return length   # Return this for legacy reasons.
    804   cls.MergeFromString = MergeFromString
    805 
    806   local_ReadTag = decoder.ReadTag
    807   local_SkipField = decoder.SkipField
    808   decoders_by_tag = cls._decoders_by_tag
    809 
    810   def InternalParse(self, buffer, pos, end):
    811     self._Modified()
    812     field_dict = self._fields
    813     unknown_field_list = self._unknown_fields
    814     while pos != end:
    815       (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
    816       field_decoder = decoders_by_tag.get(tag_bytes)
    817       if field_decoder is None:
    818         value_start_pos = new_pos
    819         new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
    820         if new_pos == -1:
    821           return pos
    822         if not unknown_field_list:
    823           unknown_field_list = self._unknown_fields = []
    824         unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos]))
    825         pos = new_pos
    826       else:
    827         pos = field_decoder(buffer, new_pos, end, self, field_dict)
    828     return pos
    829   cls._InternalParse = InternalParse
    830 
    831 
    832 def _AddIsInitializedMethod(message_descriptor, cls):
    833   """Adds the IsInitialized and FindInitializationError methods to the
    834   protocol message class."""
    835 
    836   required_fields = [field for field in message_descriptor.fields
    837                            if field.label == _FieldDescriptor.LABEL_REQUIRED]
    838 
    839   def IsInitialized(self, errors=None):
    840     """Checks if all required fields of a message are set.
    841 
    842     Args:
    843       errors:  A list which, if provided, will be populated with the field
    844                paths of all missing required fields.
    845 
    846     Returns:
    847       True iff the specified message has all required fields set.
    848     """
    849 
    850     # Performance is critical so we avoid HasField() and ListFields().
    851 
    852     for field in required_fields:
    853       if (field not in self._fields or
    854           (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
    855            not self._fields[field]._is_present_in_parent)):
    856         if errors is not None:
    857           errors.extend(self.FindInitializationErrors())
    858         return False
    859 
    860     for field, value in self._fields.iteritems():
    861       if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    862         if field.label == _FieldDescriptor.LABEL_REPEATED:
    863           for element in value:
    864             if not element.IsInitialized():
    865               if errors is not None:
    866                 errors.extend(self.FindInitializationErrors())
    867               return False
    868         elif value._is_present_in_parent and not value.IsInitialized():
    869           if errors is not None:
    870             errors.extend(self.FindInitializationErrors())
    871           return False
    872 
    873     return True
    874 
    875   cls.IsInitialized = IsInitialized
    876 
    877   def FindInitializationErrors(self):
    878     """Finds required fields which are not initialized.
    879 
    880     Returns:
    881       A list of strings.  Each string is a path to an uninitialized field from
    882       the top-level message, e.g. "foo.bar[5].baz".
    883     """
    884 
    885     errors = []  # simplify things
    886 
    887     for field in required_fields:
    888       if not self.HasField(field.name):
    889         errors.append(field.name)
    890 
    891     for field, value in self.ListFields():
    892       if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    893         if field.is_extension:
    894           name = "(%s)" % field.full_name
    895         else:
    896           name = field.name
    897 
    898         if field.label == _FieldDescriptor.LABEL_REPEATED:
    899           for i in xrange(len(value)):
    900             element = value[i]
    901             prefix = "%s[%d]." % (name, i)
    902             sub_errors = element.FindInitializationErrors()
    903             errors += [ prefix + error for error in sub_errors ]
    904         else:
    905           prefix = name + "."
    906           sub_errors = value.FindInitializationErrors()
    907           errors += [ prefix + error for error in sub_errors ]
    908 
    909     return errors
    910 
    911   cls.FindInitializationErrors = FindInitializationErrors
    912 
    913 
    914 def _AddMergeFromMethod(cls):
    915   LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
    916   CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
    917 
    918   def MergeFrom(self, msg):
    919     if not isinstance(msg, cls):
    920       raise TypeError(
    921           "Parameter to MergeFrom() must be instance of same class: "
    922           "expected %s got %s." % (cls.__name__, type(msg).__name__))
    923 
    924     assert msg is not self
    925     self._Modified()
    926 
    927     fields = self._fields
    928 
    929     for field, value in msg._fields.iteritems():
    930       if field.label == LABEL_REPEATED:
    931         field_value = fields.get(field)
    932         if field_value is None:
    933           # Construct a new object to represent this field.
    934           field_value = field._default_constructor(self)
    935           fields[field] = field_value
    936         field_value.MergeFrom(value)
    937       elif field.cpp_type == CPPTYPE_MESSAGE:
    938         if value._is_present_in_parent:
    939           field_value = fields.get(field)
    940           if field_value is None:
    941             # Construct a new object to represent this field.
    942             field_value = field._default_constructor(self)
    943             fields[field] = field_value
    944           field_value.MergeFrom(value)
    945       else:
    946         self._fields[field] = value
    947 
    948     if msg._unknown_fields:
    949       if not self._unknown_fields:
    950         self._unknown_fields = []
    951       self._unknown_fields.extend(msg._unknown_fields)
    952 
    953   cls.MergeFrom = MergeFrom
    954 
    955 
    956 def _AddMessageMethods(message_descriptor, cls):
    957   """Adds implementations of all Message methods to cls."""
    958   _AddListFieldsMethod(message_descriptor, cls)
    959   _AddHasFieldMethod(message_descriptor, cls)
    960   _AddClearFieldMethod(message_descriptor, cls)
    961   if message_descriptor.is_extendable:
    962     _AddClearExtensionMethod(cls)
    963     _AddHasExtensionMethod(cls)
    964   _AddClearMethod(message_descriptor, cls)
    965   _AddEqualsMethod(message_descriptor, cls)
    966   _AddStrMethod(message_descriptor, cls)
    967   _AddUnicodeMethod(message_descriptor, cls)
    968   _AddSetListenerMethod(cls)
    969   _AddByteSizeMethod(message_descriptor, cls)
    970   _AddSerializeToStringMethod(message_descriptor, cls)
    971   _AddSerializePartialToStringMethod(message_descriptor, cls)
    972   _AddMergeFromStringMethod(message_descriptor, cls)
    973   _AddIsInitializedMethod(message_descriptor, cls)
    974   _AddMergeFromMethod(cls)
    975 
    976 
    977 def _AddPrivateHelperMethods(cls):
    978   """Adds implementation of private helper methods to cls."""
    979 
    980   def Modified(self):
    981     """Sets the _cached_byte_size_dirty bit to true,
    982     and propagates this to our listener iff this was a state change.
    983     """
    984 
    985     # Note:  Some callers check _cached_byte_size_dirty before calling
    986     #   _Modified() as an extra optimization.  So, if this method is ever
    987     #   changed such that it does stuff even when _cached_byte_size_dirty is
    988     #   already true, the callers need to be updated.
    989     if not self._cached_byte_size_dirty:
    990       self._cached_byte_size_dirty = True
    991       self._listener_for_children.dirty = True
    992       self._is_present_in_parent = True
    993       self._listener.Modified()
    994 
    995   cls._Modified = Modified
    996   cls.SetInParent = Modified
    997 
    998 
    999 class _Listener(object):
   1000 
   1001   """MessageListener implementation that a parent message registers with its
   1002   child message.
   1003 
   1004   In order to support semantics like:
   1005 
   1006     foo.bar.baz.qux = 23
   1007     assert foo.HasField('bar')
   1008 
   1009   ...child objects must have back references to their parents.
   1010   This helper class is at the heart of this support.
   1011   """
   1012 
   1013   def __init__(self, parent_message):
   1014     """Args:
   1015       parent_message: The message whose _Modified() method we should call when
   1016         we receive Modified() messages.
   1017     """
   1018     # This listener establishes a back reference from a child (contained) object
   1019     # to its parent (containing) object.  We make this a weak reference to avoid
   1020     # creating cyclic garbage when the client finishes with the 'parent' object
   1021     # in the tree.
   1022     if isinstance(parent_message, weakref.ProxyType):
   1023       self._parent_message_weakref = parent_message
   1024     else:
   1025       self._parent_message_weakref = weakref.proxy(parent_message)
   1026 
   1027     # As an optimization, we also indicate directly on the listener whether
   1028     # or not the parent message is dirty.  This way we can avoid traversing
   1029     # up the tree in the common case.
   1030     self.dirty = False
   1031 
   1032   def Modified(self):
   1033     if self.dirty:
   1034       return
   1035     try:
   1036       # Propagate the signal to our parents iff this is the first field set.
   1037       self._parent_message_weakref._Modified()
   1038     except ReferenceError:
   1039       # We can get here if a client has kept a reference to a child object,
   1040       # and is now setting a field on it, but the child's parent has been
   1041       # garbage-collected.  This is not an error.
   1042       pass
   1043 
   1044 
   1045 # TODO(robinson): Move elsewhere?  This file is getting pretty ridiculous...
   1046 # TODO(robinson): Unify error handling of "unknown extension" crap.
   1047 # TODO(robinson): Support iteritems()-style iteration over all
   1048 # extensions with the "has" bits turned on?
   1049 class _ExtensionDict(object):
   1050 
   1051   """Dict-like container for supporting an indexable "Extensions"
   1052   field on proto instances.
   1053 
   1054   Note that in all cases we expect extension handles to be
   1055   FieldDescriptors.
   1056   """
   1057 
   1058   def __init__(self, extended_message):
   1059     """extended_message: Message instance for which we are the Extensions dict.
   1060     """
   1061 
   1062     self._extended_message = extended_message
   1063 
   1064   def __getitem__(self, extension_handle):
   1065     """Returns the current value of the given extension handle."""
   1066 
   1067     _VerifyExtensionHandle(self._extended_message, extension_handle)
   1068 
   1069     result = self._extended_message._fields.get(extension_handle)
   1070     if result is not None:
   1071       return result
   1072 
   1073     if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
   1074       result = extension_handle._default_constructor(self._extended_message)
   1075     elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
   1076       result = extension_handle.message_type._concrete_class()
   1077       try:
   1078         result._SetListener(self._extended_message._listener_for_children)
   1079       except ReferenceError:
   1080         pass
   1081     else:
   1082       # Singular scalar -- just return the default without inserting into the
   1083       # dict.
   1084       return extension_handle.default_value
   1085 
   1086     # Atomically check if another thread has preempted us and, if not, swap
   1087     # in the new object we just created.  If someone has preempted us, we
   1088     # take that object and discard ours.
   1089     # WARNING:  We are relying on setdefault() being atomic.  This is true
   1090     #   in CPython but we haven't investigated others.  This warning appears
   1091     #   in several other locations in this file.
   1092     result = self._extended_message._fields.setdefault(
   1093         extension_handle, result)
   1094 
   1095     return result
   1096 
   1097   def __eq__(self, other):
   1098     if not isinstance(other, self.__class__):
   1099       return False
   1100 
   1101     my_fields = self._extended_message.ListFields()
   1102     other_fields = other._extended_message.ListFields()
   1103 
   1104     # Get rid of non-extension fields.
   1105     my_fields    = [ field for field in my_fields    if field.is_extension ]
   1106     other_fields = [ field for field in other_fields if field.is_extension ]
   1107 
   1108     return my_fields == other_fields
   1109 
   1110   def __ne__(self, other):
   1111     return not self == other
   1112 
   1113   def __hash__(self):
   1114     raise TypeError('unhashable object')
   1115 
   1116   # Note that this is only meaningful for non-repeated, scalar extension
   1117   # fields.  Note also that we may have to call _Modified() when we do
   1118   # successfully set a field this way, to set any necssary "has" bits in the
   1119   # ancestors of the extended message.
   1120   def __setitem__(self, extension_handle, value):
   1121     """If extension_handle specifies a non-repeated, scalar extension
   1122     field, sets the value of that field.
   1123     """
   1124 
   1125     _VerifyExtensionHandle(self._extended_message, extension_handle)
   1126 
   1127     if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
   1128         extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
   1129       raise TypeError(
   1130           'Cannot assign to extension "%s" because it is a repeated or '
   1131           'composite type.' % extension_handle.full_name)
   1132 
   1133     # It's slightly wasteful to lookup the type checker each time,
   1134     # but we expect this to be a vanishingly uncommon case anyway.
   1135     type_checker = type_checkers.GetTypeChecker(
   1136         extension_handle.cpp_type, extension_handle.type)
   1137     type_checker.CheckValue(value)
   1138     self._extended_message._fields[extension_handle] = value
   1139     self._extended_message._Modified()
   1140 
   1141   def _FindExtensionByName(self, name):
   1142     """Tries to find a known extension with the specified name.
   1143 
   1144     Args:
   1145       name: Extension full name.
   1146 
   1147     Returns:
   1148       Extension field descriptor.
   1149     """
   1150     return self._extended_message._extensions_by_name.get(name, None)
   1151