Home | History | Annotate | Download | only in mbim_compliance
      1 # Copyright 2015 The Chromium OS Authors. All rights reserved.
      2 # Use of this source code is governed by a BSD-style license that can be
      3 # found in the LICENSE file.
      4 """
      5 All of the MBIM messages are created using the MBIMControlMessageMeta metaclass.
      6 The metaclass supports a hierarchy of message definitions so that each message
      7 definition extends the structure of the base class it inherits.
      8 
      9 (mbim_message.py)
     10 MBIMControlMessage|         (mbim_message_request.py)
     11                   |>MBIMControlMessageRequest |
     12                   |                           |>MBIMOpen
     13                   |                           |>MBIMClose
     14                   |                           |>MBIMCommand    |
     15                   |                           |                |>MBIMSetConnect
     16                   |                           |                |>...
     17                   |                           |
     18                   |                           |>MBIMHostError
     19                   |
     20                   |         (mbim_message_response.py)
     21                   |>MBIMControlMessageResponse|
     22                                               |>MBIMOpenDone
     23                                               |>MBIMCloseDone
     24                                               |>MBIMCommandDone|
     25                                               |                |>MBIMConnectInfo
     26                                               |                |>...
     27                                               |
     28                                               |>MBIMHostError
     29 """
     30 import array
     31 import logging
     32 import struct
     33 import sys
     34 from collections import namedtuple
     35 
     36 from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors
     37 
     38 
     39 # Type of message classes. The values of each field in the message is stored
     40 # as an attribute of the object created.
     41 # Request message classes accepts values for the attributes of the object.
     42 MESSAGE_TYPE_REQUEST = 1
     43 # Response message classes accepts raw_data which is parsed into attributes of
     44 # the object.
     45 MESSAGE_TYPE_RESPONSE = 2
     46 
     47 # Message field types.
     48 # Just a normal field type. No special properties.
     49 FIELD_TYPE_NORMAL = 1
     50 # Identify the payload ID for a message. This is used in  parsing of
     51 # response messages to help in identifying the child message class.
     52 FIELD_TYPE_PAYLOAD_ID = 2
     53 # Total length of the message including any payload_buffer it may contain.
     54 FIELD_TYPE_TOTAL_LEN = 3
     55 # Length of the payload contained in the payload_buffer.
     56 FIELD_TYPE_PAYLOAD_LEN = 4
     57 # Number of fragments of this message.
     58 FIELD_TYPE_NUM_FRAGMENTS = 5
     59 # Transaction ID of this message
     60 FIELD_TYPE_TRANSACTION_ID = 6
     61 
     62 
     63 def message_class_new(cls, **kwargs):
     64     """
     65     Creates a message instance with either the given field name/value
     66     pairs or raw data buffer.
     67 
     68     The total_length and transaction_id fields are automatically calculated
     69     if not explicitly provided in the message args.
     70 
     71     @param kwargs: Dictionary of (field_name, field_value) pairs or
     72                     raw_data=Packed binary array.
     73     @returns New message object created.
     74 
     75     """
     76     if 'raw_data' in kwargs and kwargs['raw_data']:
     77         # We unpack the raw data received into the appropriate fields
     78         # for this class. If there is some additional data present in
     79         # |raw_data| that does not fit the format of the structure,
     80         # they're stored in the variable sized |payload_buffer| field.
     81         raw_data = kwargs['raw_data']
     82         data_format = cls.get_field_format_string(get_all=True)
     83         unpack_length = cls.get_struct_len(get_all=True)
     84         data_length = len(raw_data)
     85         if data_length < unpack_length:
     86             mbim_errors.log_and_raise(
     87                     mbim_errors.MBIMComplianceControlMessageError,
     88                     'Length of Data (%d) to be parsed less than message'
     89                     ' structure length (%d)' %
     90                     (data_length, unpack_length))
     91         obj = super(cls, cls).__new__(cls, *struct.unpack_from(data_format,
     92                                                                raw_data))
     93         if data_length > unpack_length:
     94             setattr(obj, 'payload_buffer', raw_data[unpack_length:])
     95         else:
     96             setattr(obj, 'payload_buffer', None)
     97         return obj
     98     else:
     99         # Check if all the fields have been populated for this message
    100         # except for transaction ID and message length since these
    101         # are generated during init.
    102         field_values = []
    103         fields = cls.get_fields(get_all=True)
    104         defaults = cls.get_defaults(get_all=True)
    105         for _, field_name, field_type in fields:
    106             if field_name not in kwargs:
    107                 if field_type == FIELD_TYPE_TOTAL_LEN:
    108                     field_value = cls.get_struct_len(get_all=True)
    109                     if 'payload_buffer' in kwargs:
    110                         field_value += len(kwargs.get('payload_buffer'))
    111                 elif field_type == FIELD_TYPE_TRANSACTION_ID:
    112                     field_value = cls.get_next_transaction_id()
    113                 else:
    114                     field_value = defaults.get(field_name, None)
    115                 if field_value is None:
    116                     mbim_errors.log_and_raise(
    117                             mbim_errors.MBIMComplianceControlMessageError,
    118                             'Missing field value (%s) in %s' % (
    119                                     field_name, cls.__name__))
    120                 field_values.append(field_value)
    121             else:
    122                 field_values.append(kwargs.pop(field_name))
    123         obj = super(cls, cls).__new__(cls, *field_values)
    124         # We need to account for optional variable sized payload_buffer
    125         # in some messages which are not explicitly mentioned in the
    126         # |cls._FIELDS| attribute.
    127         if 'payload_buffer' in kwargs:
    128             setattr(obj, 'payload_buffer', kwargs.pop('payload_buffer'))
    129         else:
    130             setattr(obj, 'payload_buffer', None)
    131         if kwargs:
    132             mbim_errors.log_and_raise(
    133                     mbim_errors.MBIMComplianceControlMessageError,
    134                     'Unexpected fields (%s) in %s' % (
    135                             kwargs.keys(), cls.__name__))
    136         return obj
    137 
    138 
    139 class MBIMControlMessageMeta(type):
    140     """
    141     Metaclass for all the control message parsing/generation.
    142 
    143     The metaclass creates each class by concatenating all the message fields
    144     from it's base classes to create a hierarchy of messages.
    145     Thus the payload class of each message class becomes the subclass of that
    146     message.
    147 
    148     Message definition attributes->
    149     _FIELDS(optional): Used to define structure elements. The fields of a
    150                        message is the concatenation of the _FIELDS attribute
    151                        along with all the _FIELDS attribute from it's parent
    152                        classes.
    153     _DEFAULTS(optional): Field name/value pairs to be assigned to some
    154                          of the fields if they are fixed for a message type.
    155                          These are generally used to assign values to fields in
    156                          the parent class.
    157     _IDENTIFIERS(optional): Field name/value pairs to be used to idenitfy this
    158                             message during parsing from raw_data.
    159     _SECONDARY_FRAGMENTS(optional): Used to identify if this class can be
    160                                     fragmented and name of secondary class
    161                                     definition.
    162     MESSAGE_TYPE: Used to identify request/repsonse classes.
    163 
    164     Message internal attributes->
    165     _CONSOLIDATED_FIELDS: Consolidated list of all the fields defining this
    166                           message.
    167     _CONSOLIDATED_DEFAULTS: Consolidated list of all the default field
    168                             name/value pairs for this  message.
    169 
    170     """
    171     def __new__(mcs, name, bases, attrs):
    172         # The MBIMControlMessage base class, which inherits from 'object',
    173         # is merely used to establish the class hierarchy and is never
    174         # constructed on it's own.
    175         if object in bases:
    176             return super(MBIMControlMessageMeta, mcs).__new__(
    177                     mcs, name, bases, attrs)
    178 
    179         # Append the current class fields, defaults to any base parent class
    180         # fields.
    181         fields = []
    182         defaults = {}
    183         for base_class in bases:
    184             if hasattr(base_class, '_CONSOLIDATED_FIELDS'):
    185                 fields = getattr(base_class, '_CONSOLIDATED_FIELDS')
    186             if hasattr(base_class, '_CONSOLIDATED_DEFAULTS'):
    187                 defaults = getattr(base_class, '_CONSOLIDATED_DEFAULTS').copy()
    188         if '_FIELDS' in attrs:
    189             fields = fields + map(list, attrs['_FIELDS'])
    190         if '_DEFAULTS' in attrs:
    191             defaults.update(attrs['_DEFAULTS'])
    192         attrs['_CONSOLIDATED_FIELDS'] = fields
    193         attrs['_CONSOLIDATED_DEFAULTS'] = defaults
    194 
    195         if not fields:
    196             mbim_errors.log_and_raise(
    197                     mbim_errors.MBIMComplianceControlMessageError,
    198                     '%s message must have some fields defined' % name)
    199 
    200         attrs['__new__'] = message_class_new
    201         _, field_names, _ = zip(*fields)
    202         message_class = namedtuple(name, field_names)
    203         # Prepend the class created via namedtuple to |bases| in order to
    204         # correctly resolve the __new__ method while preserving the class
    205         # hierarchy.
    206         cls = super(MBIMControlMessageMeta, mcs).__new__(
    207                 mcs, name, (message_class,) + bases, attrs)
    208         return cls
    209 
    210 
    211 class MBIMControlMessage(object):
    212     """
    213     MBIMControlMessage base class.
    214 
    215     This class should not be instantiated or used directly.
    216 
    217     """
    218     __metaclass__ = MBIMControlMessageMeta
    219 
    220     _NEXT_TRANSACTION_ID = 0X00000000
    221 
    222 
    223     @classmethod
    224     def _find_subclasses(cls):
    225         """
    226         Helper function to find all the derived payload classes of this
    227         class.
    228 
    229         """
    230         return [c for c in cls.__subclasses__()]
    231 
    232 
    233     @classmethod
    234     def get_fields(cls, get_all=False):
    235         """
    236         Helper function to find all the fields of this class.
    237 
    238         Returns either the total message fields or only the current
    239         substructure fields in the nested message.
    240 
    241         @param get_all: Whether to return the total struct fields or sub struct
    242                          fields.
    243         @returns Fields of the structure.
    244 
    245         """
    246         if get_all:
    247             return cls._CONSOLIDATED_FIELDS
    248         else:
    249             return cls._FIELDS
    250 
    251 
    252     @classmethod
    253     def get_defaults(cls, get_all=False):
    254         """
    255         Helper function to find all the default field values of this class.
    256 
    257         Returns either the total message default field name/value pairs or only
    258         the current substructure defaults in the nested message.
    259 
    260         @param get_all: Whether to return the total struct defaults or sub
    261                          struct defaults.
    262         @returns Defaults of the structure.
    263 
    264         """
    265         if get_all:
    266             return cls._CONSOLIDATED_DEFAULTS
    267         else:
    268             return cls._DEFAULTS
    269 
    270 
    271     @classmethod
    272     def _get_identifiers(cls):
    273         """
    274         Helper function to find all the identifier field name/value pairs of
    275         this class.
    276 
    277         @returns All the idenitifiers of this class.
    278 
    279         """
    280         return getattr(cls, '_IDENTIFIERS', None)
    281 
    282 
    283     @classmethod
    284     def _find_field_names_of_type(cls, find_type, get_all=False):
    285         """
    286         Helper function to find all the field names which matches the field_type
    287         specified.
    288 
    289         params find_type: One of the FIELD_TYPE_* enum values specified above.
    290         @returns Corresponding field names if found, else None.
    291         """
    292         fields = cls.get_fields(get_all=get_all)
    293         field_names = []
    294         for _, field_name, field_type in fields:
    295             if field_type == find_type:
    296                 field_names.append(field_name)
    297         return field_names
    298 
    299 
    300     @classmethod
    301     def get_secondary_fragment(cls):
    302         """
    303         Helper function to retrieve the associated secondary fragment class.
    304 
    305         @returns |_SECONDARY_FRAGMENT| attribute of the class
    306 
    307         """
    308         return getattr(cls, '_SECONDARY_FRAGMENT', None)
    309 
    310 
    311     @classmethod
    312     def get_field_names(cls, get_all=True):
    313         """
    314         Helper function to return the field names of the message.
    315 
    316         @returns The field names of the message structure.
    317 
    318         """
    319         _, field_names, _ = zip(*cls.get_fields(get_all=get_all))
    320         return field_names
    321 
    322 
    323     @classmethod
    324     def get_field_formats(cls, get_all=True):
    325         """
    326         Helper function to return the field formats of the message.
    327 
    328         @returns The format of fields of the message structure.
    329 
    330         """
    331         field_formats, _, _ = zip(*cls.get_fields(get_all=get_all))
    332         return field_formats
    333 
    334 
    335     @classmethod
    336     def get_field_format_string(cls, get_all=True):
    337         """
    338         Helper function to return the field format string of the message.
    339 
    340         @returns The format string of the message structure.
    341 
    342         """
    343         format_string = '<' + ''.join(cls.get_field_formats(get_all=get_all))
    344         return format_string
    345 
    346 
    347     @classmethod
    348     def get_struct_len(cls, get_all=False):
    349         """
    350         Returns the length of the structure representing the message.
    351 
    352         Returns the length of either the total message or only the current
    353         substructure in the nested message.
    354 
    355         @param get_all: Whether to return the total struct length or sub struct
    356                 length.
    357         @returns Length of the structure.
    358 
    359         """
    360         return struct.calcsize(cls.get_field_format_string(get_all=get_all))
    361 
    362 
    363     @classmethod
    364     def find_primary_parent_fragment(cls):
    365         """
    366         Traverses up the message tree to find the primary fragment class
    367         at the same tree level as the secondary frag class associated with this
    368         message class. This should only be called on primary fragment derived
    369         classes!
    370 
    371         @returns Primary frag class associated with the message.
    372 
    373         """
    374         secondary_frag_cls = cls.get_secondary_fragment()
    375         secondary_frag_parent_cls = secondary_frag_cls.__bases__[1]
    376         message_cls = cls
    377         message_parent_cls = message_cls.__bases__[1]
    378         while message_parent_cls != secondary_frag_parent_cls:
    379             message_cls = message_parent_cls
    380             message_parent_cls = message_cls.__bases__[1]
    381         return message_cls
    382 
    383 
    384     @classmethod
    385     def get_next_transaction_id(cls):
    386         """
    387         Returns incrementing transaction ids on successive calls.
    388 
    389         @returns The tracsaction id for control message delivery.
    390 
    391         """
    392         if MBIMControlMessage._NEXT_TRANSACTION_ID > (sys.maxint - 2):
    393             MBIMControlMessage._NEXT_TRANSACTION_ID = 0x00000000
    394         MBIMControlMessage._NEXT_TRANSACTION_ID += 1
    395         return MBIMControlMessage._NEXT_TRANSACTION_ID
    396 
    397 
    398     def _get_fields_of_type(self, field_type, get_all=False):
    399         """
    400         Helper function to find all the field name/value of the specified type
    401         in the given object.
    402 
    403         @returns Corresponding map of field name/value pairs extracted from the
    404                 object.
    405 
    406         """
    407         field_names = self.__class__._find_field_names_of_type(field_type,
    408                                                                get_all=get_all)
    409         return {f: getattr(self, f) for f in field_names}
    410 
    411 
    412     def _get_payload_id_fields(self):
    413         """
    414         Helper function to find all the payload id field name/value in the given
    415         object.
    416 
    417         @returns Corresponding field name/value pairs extracted from the object.
    418 
    419         """
    420         return self._get_fields_of_type(FIELD_TYPE_PAYLOAD_ID)
    421 
    422 
    423     def get_payload_len(self):
    424         """
    425         Helper function to find the payload len field value in the given
    426         object.
    427 
    428         @returns Corresponding field value extracted from the object.
    429 
    430         """
    431         payload_len_fields = self._get_fields_of_type(FIELD_TYPE_PAYLOAD_LEN)
    432         if ((not payload_len_fields) or (len(payload_len_fields) > 1)):
    433             mbim_errors.log_and_raise(
    434                     mbim_errors.MBIMComplianceControlMessageError,
    435                     "Erorr in finding payload len field in message: %s" %
    436                     self.__class__.__name__)
    437         return payload_len_fields.values()[0]
    438 
    439 
    440     def get_total_len(self):
    441         """
    442         Helper function to find the total len field value in the given
    443         object.
    444 
    445         @returns Corresponding field value extracted from the object.
    446 
    447         """
    448         total_len_fields = self._get_fields_of_type(FIELD_TYPE_TOTAL_LEN,
    449                                                     get_all=True)
    450         if ((not total_len_fields) or (len(total_len_fields) > 1)):
    451             mbim_errors.log_and_raise(
    452                     mbim_errors.MBIMComplianceControlMessageError,
    453                     "Erorr in finding total len field in message: %s" %
    454                     self.__class__.__name__)
    455         return total_len_fields.values()[0]
    456 
    457 
    458     def get_num_fragments(self):
    459         """
    460         Helper function to find the fragment num field value in the given
    461         object.
    462 
    463         @returns Corresponding field value extracted from the object.
    464 
    465         """
    466         num_fragment_fields = self._get_fields_of_type(FIELD_TYPE_NUM_FRAGMENTS)
    467         if ((not num_fragment_fields) or (len(num_fragment_fields) > 1)):
    468             mbim_errors.log_and_raise(
    469                     mbim_errors.MBIMComplianceControlMessageError,
    470                     "Erorr in finding num fragments field in message: %s" %
    471                     self.__class__.__name__)
    472         return num_fragment_fields.values()[0]
    473 
    474 
    475     def find_payload_class(self):
    476         """
    477         Helper function to find the derived class which has the default
    478         |payload_id| fields matching the current message contents.
    479 
    480         @returns Corresponding class if found, else None.
    481 
    482         """
    483         cls = self.__class__
    484         for payload_cls in cls._find_subclasses():
    485             message_ids = self._get_payload_id_fields()
    486             subclass_ids = payload_cls._get_identifiers()
    487             if message_ids == subclass_ids:
    488                 return payload_cls
    489         return None
    490 
    491 
    492     def calculate_total_len(self):
    493         """
    494         Helper function to calculate the total len of a given message
    495         object.
    496 
    497         @returns Total length of the message.
    498 
    499         """
    500         message_class = self.__class__
    501         total_len = message_class.get_struct_len(get_all=True)
    502         if self.payload_buffer:
    503             total_len += len(self.payload_buffer)
    504         return total_len
    505 
    506 
    507     def pack(self, format_string, field_names):
    508         """
    509         Packs a list of fields based on their formats.
    510 
    511         @param format_string: The concatenated formats for the fields given in
    512                 |field_names|.
    513         @param field_names: The name of the fields to be packed.
    514         @returns The packet in binary array form.
    515 
    516         """
    517         field_values = [getattr(self, name) for name in field_names]
    518         return array.array('B', struct.pack(format_string, *field_values))
    519 
    520 
    521     def print_all_fields(self):
    522         """Prints all the field name, value pair of this message."""
    523         logging.debug('Class Name: %s', self.__class__.__name__)
    524         for field_name in self.__class__.get_field_names(get_all=True):
    525             logging.debug('Field Name: %s, Field Value: %s',
    526                            field_name, str(getattr(self, field_name)))
    527         if self.payload_buffer:
    528             logging.debug('Payload: %s', str(getattr(self, 'payload_buffer')))
    529 
    530 
    531     def create_raw_data(self):
    532         """
    533         Creates the raw binary data corresponding to the message struct.
    534 
    535         @param payload_buffer: Variable sized paylaod buffer to attach at the
    536                 end of the msg.
    537         @returns Packed byte array of the message.
    538 
    539         """
    540         message = self
    541         message_class = message.__class__
    542         format_string = message_class.get_field_format_string()
    543         field_names = message_class.get_field_names()
    544         packet = message.pack(format_string, field_names)
    545         if self.payload_buffer:
    546             packet.extend(self.payload_buffer)
    547         return packet
    548 
    549 
    550     def copy(self, **fields_to_alter):
    551         """
    552         Replaces the message tuple with updated field values.
    553 
    554         @param fields_to_alter: Field name/value pairs to be changed.
    555         @returns Updated message with the field values updated.
    556 
    557         """
    558         message = self._replace(**fields_to_alter)
    559         # Copy the associated payload_buffer field to the new tuple.
    560         message.payload_buffer = self.payload_buffer
    561         return message
    562