Home | History | Annotate | Download | only in protobuf
      1 # Protocol Buffers - Google's data interchange format
      2 # Copyright 2008 Google Inc.  All rights reserved.
      3 # http://code.google.com/p/protobuf/
      4 #
      5 # Redistribution and use in source and binary forms, with or without
      6 # modification, are permitted provided that the following conditions are
      7 # met:
      8 #
      9 #     * Redistributions of source code must retain the above copyright
     10 # notice, this list of conditions and the following disclaimer.
     11 #     * Redistributions in binary form must reproduce the above
     12 # copyright notice, this list of conditions and the following disclaimer
     13 # in the documentation and/or other materials provided with the
     14 # distribution.
     15 #     * Neither the name of Google Inc. nor the names of its
     16 # contributors may be used to endorse or promote products derived from
     17 # this software without specific prior written permission.
     18 #
     19 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     20 # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     21 # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     22 # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
     23 # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
     24 # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
     25 # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     26 # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     27 # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     28 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
     29 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     30 
     31 """Provides DescriptorPool to use as a container for proto2 descriptors.
     32 
     33 The DescriptorPool is used in conjection with a DescriptorDatabase to maintain
     34 a collection of protocol buffer descriptors for use when dynamically creating
     35 message types at runtime.
     36 
     37 For most applications protocol buffers should be used via modules generated by
     38 the protocol buffer compiler tool. This should only be used when the type of
     39 protocol buffers used in an application or library cannot be predetermined.
     40 
     41 Below is a straightforward example on how to use this class:
     42 
     43   pool = DescriptorPool()
     44   file_descriptor_protos = [ ... ]
     45   for file_descriptor_proto in file_descriptor_protos:
     46     pool.Add(file_descriptor_proto)
     47   my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType')
     48 
     49 The message descriptor can be used in conjunction with the message_factory
     50 module in order to create a protocol buffer class that can be encoded and
     51 decoded.
     52 """
     53 
     54 __author__ = 'matthewtoia (at] google.com (Matt Toia)'
     55 
     56 from google.protobuf import descriptor_pb2
     57 from google.protobuf import descriptor
     58 from google.protobuf import descriptor_database
     59 
     60 
     61 class DescriptorPool(object):
     62   """A collection of protobufs dynamically constructed by descriptor protos."""
     63 
     64   def __init__(self, descriptor_db=None):
     65     """Initializes a Pool of proto buffs.
     66 
     67     The descriptor_db argument to the constructor is provided to allow
     68     specialized file descriptor proto lookup code to be triggered on demand. An
     69     example would be an implementation which will read and compile a file
     70     specified in a call to FindFileByName() and not require the call to Add()
     71     at all. Results from this database will be cached internally here as well.
     72 
     73     Args:
     74       descriptor_db: A secondary source of file descriptors.
     75     """
     76 
     77     self._internal_db = descriptor_database.DescriptorDatabase()
     78     self._descriptor_db = descriptor_db
     79     self._descriptors = {}
     80     self._enum_descriptors = {}
     81     self._file_descriptors = {}
     82 
     83   def Add(self, file_desc_proto):
     84     """Adds the FileDescriptorProto and its types to this pool.
     85 
     86     Args:
     87       file_desc_proto: The FileDescriptorProto to add.
     88     """
     89 
     90     self._internal_db.Add(file_desc_proto)
     91 
     92   def FindFileByName(self, file_name):
     93     """Gets a FileDescriptor by file name.
     94 
     95     Args:
     96       file_name: The path to the file to get a descriptor for.
     97 
     98     Returns:
     99       A FileDescriptor for the named file.
    100 
    101     Raises:
    102       KeyError: if the file can not be found in the pool.
    103     """
    104 
    105     try:
    106       file_proto = self._internal_db.FindFileByName(file_name)
    107     except KeyError as error:
    108       if self._descriptor_db:
    109         file_proto = self._descriptor_db.FindFileByName(file_name)
    110       else:
    111         raise error
    112     if not file_proto:
    113       raise KeyError('Cannot find a file named %s' % file_name)
    114     return self._ConvertFileProtoToFileDescriptor(file_proto)
    115 
    116   def FindFileContainingSymbol(self, symbol):
    117     """Gets the FileDescriptor for the file containing the specified symbol.
    118 
    119     Args:
    120       symbol: The name of the symbol to search for.
    121 
    122     Returns:
    123       A FileDescriptor that contains the specified symbol.
    124 
    125     Raises:
    126       KeyError: if the file can not be found in the pool.
    127     """
    128 
    129     try:
    130       file_proto = self._internal_db.FindFileContainingSymbol(symbol)
    131     except KeyError as error:
    132       if self._descriptor_db:
    133         file_proto = self._descriptor_db.FindFileContainingSymbol(symbol)
    134       else:
    135         raise error
    136     if not file_proto:
    137       raise KeyError('Cannot find a file containing %s' % symbol)
    138     return self._ConvertFileProtoToFileDescriptor(file_proto)
    139 
    140   def FindMessageTypeByName(self, full_name):
    141     """Loads the named descriptor from the pool.
    142 
    143     Args:
    144       full_name: The full name of the descriptor to load.
    145 
    146     Returns:
    147       The descriptor for the named type.
    148     """
    149 
    150     full_name = full_name.lstrip('.')  # fix inconsistent qualified name formats
    151     if full_name not in self._descriptors:
    152       self.FindFileContainingSymbol(full_name)
    153     return self._descriptors[full_name]
    154 
    155   def FindEnumTypeByName(self, full_name):
    156     """Loads the named enum descriptor from the pool.
    157 
    158     Args:
    159       full_name: The full name of the enum descriptor to load.
    160 
    161     Returns:
    162       The enum descriptor for the named type.
    163     """
    164 
    165     full_name = full_name.lstrip('.')  # fix inconsistent qualified name formats
    166     if full_name not in self._enum_descriptors:
    167       self.FindFileContainingSymbol(full_name)
    168     return self._enum_descriptors[full_name]
    169 
    170   def _ConvertFileProtoToFileDescriptor(self, file_proto):
    171     """Creates a FileDescriptor from a proto or returns a cached copy.
    172 
    173     This method also has the side effect of loading all the symbols found in
    174     the file into the appropriate dictionaries in the pool.
    175 
    176     Args:
    177       file_proto: The proto to convert.
    178 
    179     Returns:
    180       A FileDescriptor matching the passed in proto.
    181     """
    182 
    183     if file_proto.name not in self._file_descriptors:
    184       file_descriptor = descriptor.FileDescriptor(
    185           name=file_proto.name,
    186           package=file_proto.package,
    187           options=file_proto.options,
    188           serialized_pb=file_proto.SerializeToString())
    189       scope = {}
    190       dependencies = list(self._GetDeps(file_proto))
    191 
    192       for dependency in dependencies:
    193         dep_desc = self.FindFileByName(dependency.name)
    194         dep_proto = descriptor_pb2.FileDescriptorProto.FromString(
    195             dep_desc.serialized_pb)
    196         package = '.' + dep_proto.package
    197         package_prefix = package + '.'
    198 
    199         def _strip_package(symbol):
    200           if symbol.startswith(package_prefix):
    201             return symbol[len(package_prefix):]
    202           return symbol
    203 
    204         symbols = list(self._ExtractSymbols(dep_proto.message_type, package))
    205         scope.update(symbols)
    206         scope.update((_strip_package(k), v) for k, v in symbols)
    207 
    208         symbols = list(self._ExtractEnums(dep_proto.enum_type, package))
    209         scope.update(symbols)
    210         scope.update((_strip_package(k), v) for k, v in symbols)
    211 
    212       for message_type in file_proto.message_type:
    213         message_desc = self._ConvertMessageDescriptor(
    214             message_type, file_proto.package, file_descriptor, scope)
    215         file_descriptor.message_types_by_name[message_desc.name] = message_desc
    216       for enum_type in file_proto.enum_type:
    217         self._ConvertEnumDescriptor(enum_type, file_proto.package,
    218                                     file_descriptor, None, scope)
    219       for desc_proto in self._ExtractMessages(file_proto.message_type):
    220         self._SetFieldTypes(desc_proto, scope)
    221 
    222       for desc_proto in file_proto.message_type:
    223         desc = scope[desc_proto.name]
    224         file_descriptor.message_types_by_name[desc_proto.name] = desc
    225       self.Add(file_proto)
    226       self._file_descriptors[file_proto.name] = file_descriptor
    227 
    228     return self._file_descriptors[file_proto.name]
    229 
    230   def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None,
    231                                 scope=None):
    232     """Adds the proto to the pool in the specified package.
    233 
    234     Args:
    235       desc_proto: The descriptor_pb2.DescriptorProto protobuf message.
    236       package: The package the proto should be located in.
    237       file_desc: The file containing this message.
    238       scope: Dict mapping short and full symbols to message and enum types.
    239 
    240     Returns:
    241       The added descriptor.
    242     """
    243 
    244     if package:
    245       desc_name = '.'.join((package, desc_proto.name))
    246     else:
    247       desc_name = desc_proto.name
    248 
    249     if file_desc is None:
    250       file_name = None
    251     else:
    252       file_name = file_desc.name
    253 
    254     if scope is None:
    255       scope = {}
    256 
    257     nested = [
    258         self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope)
    259         for nested in desc_proto.nested_type]
    260     enums = [
    261         self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
    262         for enum in desc_proto.enum_type]
    263     fields = [self._MakeFieldDescriptor(field, desc_name, index)
    264               for index, field in enumerate(desc_proto.field)]
    265     extensions = [self._MakeFieldDescriptor(extension, desc_name, True)
    266                   for index, extension in enumerate(desc_proto.extension)]
    267     extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range]
    268     if extension_ranges:
    269       is_extendable = True
    270     else:
    271       is_extendable = False
    272     desc = descriptor.Descriptor(
    273         name=desc_proto.name,
    274         full_name=desc_name,
    275         filename=file_name,
    276         containing_type=None,
    277         fields=fields,
    278         nested_types=nested,
    279         enum_types=enums,
    280         extensions=extensions,
    281         options=desc_proto.options,
    282         is_extendable=is_extendable,
    283         extension_ranges=extension_ranges,
    284         file=file_desc,
    285         serialized_start=None,
    286         serialized_end=None)
    287     for nested in desc.nested_types:
    288       nested.containing_type = desc
    289     for enum in desc.enum_types:
    290       enum.containing_type = desc
    291     scope[desc_proto.name] = desc
    292     scope['.' + desc_name] = desc
    293     self._descriptors[desc_name] = desc
    294     return desc
    295 
    296   def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
    297                              containing_type=None, scope=None):
    298     """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
    299 
    300     Args:
    301       enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message.
    302       package: Optional package name for the new message EnumDescriptor.
    303       file_desc: The file containing the enum descriptor.
    304       containing_type: The type containing this enum.
    305       scope: Scope containing available types.
    306 
    307     Returns:
    308       The added descriptor
    309     """
    310 
    311     if package:
    312       enum_name = '.'.join((package, enum_proto.name))
    313     else:
    314       enum_name = enum_proto.name
    315 
    316     if file_desc is None:
    317       file_name = None
    318     else:
    319       file_name = file_desc.name
    320 
    321     values = [self._MakeEnumValueDescriptor(value, index)
    322               for index, value in enumerate(enum_proto.value)]
    323     desc = descriptor.EnumDescriptor(name=enum_proto.name,
    324                                      full_name=enum_name,
    325                                      filename=file_name,
    326                                      file=file_desc,
    327                                      values=values,
    328                                      containing_type=containing_type,
    329                                      options=enum_proto.options)
    330     scope[enum_proto.name] = desc
    331     scope['.%s' % enum_name] = desc
    332     self._enum_descriptors[enum_name] = desc
    333     return desc
    334 
    335   def _MakeFieldDescriptor(self, field_proto, message_name, index,
    336                            is_extension=False):
    337     """Creates a field descriptor from a FieldDescriptorProto.
    338 
    339     For message and enum type fields, this method will do a look up
    340     in the pool for the appropriate descriptor for that type. If it
    341     is unavailable, it will fall back to the _source function to
    342     create it. If this type is still unavailable, construction will
    343     fail.
    344 
    345     Args:
    346       field_proto: The proto describing the field.
    347       message_name: The name of the containing message.
    348       index: Index of the field
    349       is_extension: Indication that this field is for an extension.
    350 
    351     Returns:
    352       An initialized FieldDescriptor object
    353     """
    354 
    355     if message_name:
    356       full_name = '.'.join((message_name, field_proto.name))
    357     else:
    358       full_name = field_proto.name
    359 
    360     return descriptor.FieldDescriptor(
    361         name=field_proto.name,
    362         full_name=full_name,
    363         index=index,
    364         number=field_proto.number,
    365         type=field_proto.type,
    366         cpp_type=None,
    367         message_type=None,
    368         enum_type=None,
    369         containing_type=None,
    370         label=field_proto.label,
    371         has_default_value=False,
    372         default_value=None,
    373         is_extension=is_extension,
    374         extension_scope=None,
    375         options=field_proto.options)
    376 
    377   def _SetFieldTypes(self, desc_proto, scope):
    378     """Sets the field's type, cpp_type, message_type and enum_type.
    379 
    380     Args:
    381       desc_proto: The message descriptor to update.
    382       scope: Enclosing scope of available types.
    383     """
    384 
    385     desc = scope[desc_proto.name]
    386     for field_proto, field_desc in zip(desc_proto.field, desc.fields):
    387       if field_proto.type_name:
    388         type_name = field_proto.type_name
    389         if type_name not in scope:
    390           type_name = '.' + type_name
    391         desc = scope[type_name]
    392       else:
    393         desc = None
    394 
    395       if not field_proto.HasField('type'):
    396         if isinstance(desc, descriptor.Descriptor):
    397           field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE
    398         else:
    399           field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM
    400 
    401       field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType(
    402           field_proto.type)
    403 
    404       if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE
    405           or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP):
    406         field_desc.message_type = desc
    407 
    408       if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
    409         field_desc.enum_type = desc
    410 
    411       if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED:
    412         field_desc.has_default = False
    413         field_desc.default_value = []
    414       elif field_proto.HasField('default_value'):
    415         field_desc.has_default = True
    416         if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or
    417             field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT):
    418           field_desc.default_value = float(field_proto.default_value)
    419         elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING:
    420           field_desc.default_value = field_proto.default_value
    421         elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL:
    422           field_desc.default_value = field_proto.default_value.lower() == 'true'
    423         elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM:
    424           field_desc.default_value = field_desc.enum_type.values_by_name[
    425               field_proto.default_value].index
    426         else:
    427           field_desc.default_value = int(field_proto.default_value)
    428       else:
    429         field_desc.has_default = False
    430         field_desc.default_value = None
    431 
    432       field_desc.type = field_proto.type
    433 
    434     for nested_type in desc_proto.nested_type:
    435       self._SetFieldTypes(nested_type, scope)
    436 
    437   def _MakeEnumValueDescriptor(self, value_proto, index):
    438     """Creates a enum value descriptor object from a enum value proto.
    439 
    440     Args:
    441       value_proto: The proto describing the enum value.
    442       index: The index of the enum value.
    443 
    444     Returns:
    445       An initialized EnumValueDescriptor object.
    446     """
    447 
    448     return descriptor.EnumValueDescriptor(
    449         name=value_proto.name,
    450         index=index,
    451         number=value_proto.number,
    452         options=value_proto.options,
    453         type=None)
    454 
    455   def _ExtractSymbols(self, desc_protos, package):
    456     """Pulls out all the symbols from descriptor protos.
    457 
    458     Args:
    459       desc_protos: The protos to extract symbols from.
    460       package: The package containing the descriptor type.
    461     Yields:
    462       A two element tuple of the type name and descriptor object.
    463     """
    464 
    465     for desc_proto in desc_protos:
    466       if package:
    467         message_name = '.'.join((package, desc_proto.name))
    468       else:
    469         message_name = desc_proto.name
    470       message_desc = self.FindMessageTypeByName(message_name)
    471       yield (message_name, message_desc)
    472       for symbol in self._ExtractSymbols(desc_proto.nested_type, message_name):
    473         yield symbol
    474       for symbol in self._ExtractEnums(desc_proto.enum_type, message_name):
    475         yield symbol
    476 
    477   def _ExtractEnums(self, enum_protos, package):
    478     """Pulls out all the symbols from enum protos.
    479 
    480     Args:
    481       enum_protos: The protos to extract symbols from.
    482       package: The package containing the enum type.
    483 
    484     Yields:
    485       A two element tuple of the type name and enum descriptor object.
    486     """
    487 
    488     for enum_proto in enum_protos:
    489       if package:
    490         enum_name = '.'.join((package, enum_proto.name))
    491       else:
    492         enum_name = enum_proto.name
    493       enum_desc = self.FindEnumTypeByName(enum_name)
    494       yield (enum_name, enum_desc)
    495 
    496   def _ExtractMessages(self, desc_protos):
    497     """Pulls out all the message protos from descriptos.
    498 
    499     Args:
    500       desc_protos: The protos to extract symbols from.
    501 
    502     Yields:
    503       Descriptor protos.
    504     """
    505 
    506     for desc_proto in desc_protos:
    507       yield desc_proto
    508       for message in self._ExtractMessages(desc_proto.nested_type):
    509         yield message
    510 
    511   def _GetDeps(self, file_proto):
    512     """Recursively finds dependencies for file protos.
    513 
    514     Args:
    515       file_proto: The proto to get dependencies from.
    516 
    517     Yields:
    518       Each direct and indirect dependency.
    519     """
    520 
    521     for dependency in file_proto.dependency:
    522       dep_desc = self.FindFileByName(dependency)
    523       dep_proto = descriptor_pb2.FileDescriptorProto.FromString(
    524           dep_desc.serialized_pb)
    525       yield dep_proto
    526       for parent_dep in self._GetDeps(dep_proto):
    527         yield parent_dep
    528