Home | History | Annotate | Download | only in test
      1 # Copyright 2014 The Android Open Source Project
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 # http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 
     15 """A simple module for declaring C-like structures.
     16 
     17 Example usage:
     18 
     19 >>> # Declare a struct type by specifying name, field formats and field names.
     20 ... # Field formats are the same as those used in the struct module, except:
     21 ... # - S: Nested Struct.
     22 ... # - A: NULL-padded ASCII string. Like s, but printing ignores contiguous
     23 ... #      trailing NULL blocks at the end.
     24 ... import cstruct
     25 >>> NLMsgHdr = cstruct.Struct("NLMsgHdr", "=LHHLL", "length type flags seq pid")
     26 >>>
     27 >>>
     28 >>> # Create instances from tuples or raw bytes. Data past the end is ignored.
     29 ... n1 = NLMsgHdr((44, 32, 0x2, 0, 491))
     30 >>> print n1
     31 NLMsgHdr(length=44, type=32, flags=2, seq=0, pid=491)
     32 >>>
     33 >>> n2 = NLMsgHdr("\x2c\x00\x00\x00\x21\x00\x02\x00"
     34 ...               "\x00\x00\x00\x00\xfe\x01\x00\x00" + "junk at end")
     35 >>> print n2
     36 NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510)
     37 >>>
     38 >>> # Serialize to raw bytes.
     39 ... print n1.Pack().encode("hex")
     40 2c0000002000020000000000eb010000
     41 >>>
     42 >>> # Parse the beginning of a byte stream as a struct, and return the struct
     43 ... # and the remainder of the stream for further reading.
     44 ... data = ("\x2c\x00\x00\x00\x21\x00\x02\x00"
     45 ...         "\x00\x00\x00\x00\xfe\x01\x00\x00"
     46 ...         "more data")
     47 >>> cstruct.Read(data, NLMsgHdr)
     48 (NLMsgHdr(length=44, type=33, flags=2, seq=0, pid=510), 'more data')
     49 >>>
     50 >>> # Structs can contain one or more nested structs. The nested struct types
     51 ... # are specified in a list as an optional last argument. Nested structs may
     52 ... # contain nested structs.
     53 ... S = cstruct.Struct("S", "=BI", "byte1 int2")
     54 >>> N = cstruct.Struct("N", "!BSiS", "byte1 s2 int3 s2", [S, S])
     55 >>> NN = cstruct.Struct("NN", "SHS", "s1 word2 n3", [S, N])
     56 >>> nn = NN((S((1, 25000)), -29876, N((55, S((5, 6)), 1111, S((7, 8))))))
     57 >>> nn.n3.s2.int2 = 5
     58 >>>
     59 """
     60 
     61 import ctypes
     62 import string
     63 import struct
     64 
     65 
     66 def CalcSize(fmt):
     67   if "A" in fmt:
     68     fmt = fmt.replace("A", "s")
     69   return struct.calcsize(fmt)
     70 
     71 def CalcNumElements(fmt):
     72   prevlen = len(fmt)
     73   fmt = fmt.replace("S", "")
     74   numstructs = prevlen - len(fmt)
     75   size = CalcSize(fmt)
     76   elements = struct.unpack(fmt, "\x00" * size)
     77   return len(elements) + numstructs
     78 
     79 
     80 def Struct(name, fmt, fieldnames, substructs={}):
     81   """Function that returns struct classes."""
     82 
     83   class Meta(type):
     84 
     85     def __len__(cls):
     86       return cls._length
     87 
     88     def __init__(cls, unused_name, unused_bases, namespace):
     89       # Make the class object have the name that's passed in.
     90       type.__init__(cls, namespace["_name"], unused_bases, namespace)
     91 
     92   class CStruct(object):
     93     """Class representing a C-like structure."""
     94 
     95     __metaclass__ = Meta
     96 
     97     # Name of the struct.
     98     _name = name
     99     # List of field names.
    100     _fieldnames = fieldnames
    101     # Dict mapping field indices to nested struct classes.
    102     _nested = {}
    103     # List of string fields that are ASCII strings.
    104     _asciiz = set()
    105 
    106     if isinstance(_fieldnames, str):
    107       _fieldnames = _fieldnames.split(" ")
    108 
    109     # Parse fmt into _format, converting any S format characters to "XXs",
    110     # where XX is the length of the struct type's packed representation.
    111     _format = ""
    112     laststructindex = 0
    113     for i in xrange(len(fmt)):
    114       if fmt[i] == "S":
    115         # Nested struct. Record the index in our struct it should go into.
    116         index = CalcNumElements(fmt[:i])
    117         _nested[index] = substructs[laststructindex]
    118         laststructindex += 1
    119         _format += "%ds" % len(_nested[index])
    120       elif fmt[i] == "A":
    121         # Null-terminated ASCII string.
    122         index = CalcNumElements(fmt[:i])
    123         _asciiz.add(index)
    124         _format += "s"
    125       else:
    126          # Standard struct format character.
    127         _format += fmt[i]
    128 
    129     _length = CalcSize(_format)
    130 
    131     def _SetValues(self, values):
    132       super(CStruct, self).__setattr__("_values", list(values))
    133 
    134     def _Parse(self, data):
    135       data = data[:self._length]
    136       values = list(struct.unpack(self._format, data))
    137       for index, value in enumerate(values):
    138         if isinstance(value, str) and index in self._nested:
    139           values[index] = self._nested[index](value)
    140       self._SetValues(values)
    141 
    142     def __init__(self, values):
    143       # Initializing from a string.
    144       if isinstance(values, str):
    145         if len(values) < self._length:
    146           raise TypeError("%s requires string of length %d, got %d" %
    147                           (self._name, self._length, len(values)))
    148         self._Parse(values)
    149       else:
    150         # Initializing from a tuple.
    151         if len(values) != len(self._fieldnames):
    152           raise TypeError("%s has exactly %d fieldnames (%d given)" %
    153                           (self._name, len(self._fieldnames), len(values)))
    154         self._SetValues(values)
    155 
    156     def _FieldIndex(self, attr):
    157       try:
    158         return self._fieldnames.index(attr)
    159       except ValueError:
    160         raise AttributeError("'%s' has no attribute '%s'" %
    161                              (self._name, attr))
    162 
    163     def __getattr__(self, name):
    164       return self._values[self._FieldIndex(name)]
    165 
    166     def __setattr__(self, name, value):
    167       self._values[self._FieldIndex(name)] = value
    168 
    169     @classmethod
    170     def __len__(cls):
    171       return cls._length
    172 
    173     def __ne__(self, other):
    174       return not self.__eq__(other)
    175 
    176     def __eq__(self, other):
    177       return (isinstance(other, self.__class__) and
    178               self._name == other._name and
    179               self._fieldnames == other._fieldnames and
    180               self._values == other._values)
    181 
    182     @staticmethod
    183     def _MaybePackStruct(value):
    184       if hasattr(value, "__metaclass__"):# and value.__metaclass__ == Meta:
    185         return value.Pack()
    186       else:
    187         return value
    188 
    189     def Pack(self):
    190       values = [self._MaybePackStruct(v) for v in self._values]
    191       return struct.pack(self._format, *values)
    192 
    193     def __str__(self):
    194       def FieldDesc(index, name, value):
    195         if isinstance(value, str):
    196           if index in self._asciiz:
    197             value = value.rstrip("\x00")
    198           elif any(c not in string.printable for c in value):
    199             value = value.encode("hex")
    200         return "%s=%s" % (name, value)
    201 
    202       descriptions = [
    203           FieldDesc(i, n, v) for i, (n, v) in
    204           enumerate(zip(self._fieldnames, self._values))]
    205 
    206       return "%s(%s)" % (self._name, ", ".join(descriptions))
    207 
    208     def __repr__(self):
    209       return str(self)
    210 
    211     def CPointer(self):
    212       """Returns a C pointer to the serialized structure."""
    213       buf = ctypes.create_string_buffer(self.Pack())
    214       # Store the C buffer in the object so it doesn't get garbage collected.
    215       super(CStruct, self).__setattr__("_buffer", buf)
    216       return ctypes.addressof(self._buffer)
    217 
    218   return CStruct
    219 
    220 
    221 def Read(data, struct_type):
    222   length = len(struct_type)
    223   return struct_type(data), data[length:]
    224