Home | History | Annotate | Download | only in ABITest
      1 """Flexible enumeration of C types."""
      2 
      3 from Enumeration import *
      4 
      5 # TODO:
      6 
      7 #  - struct improvements (flexible arrays, packed &
      8 #    unpacked, alignment)
      9 #  - objective-c qualified id
     10 #  - anonymous / transparent unions
     11 #  - VLAs
     12 #  - block types
     13 #  - K&R functions
     14 #  - pass arguments of different types (test extension, transparent union)
     15 #  - varargs
     16 
     17 ###
     18 # Actual type types
     19 
     20 class Type:
     21     def isBitField(self):
     22         return False
     23 
     24     def isPaddingBitField(self):
     25         return False
     26 
     27     def getTypeName(self, printer):
     28         name = 'T%d' % len(printer.types)
     29         typedef = self.getTypedefDef(name, printer)
     30         printer.addDeclaration(typedef)
     31         return name
     32 
     33 class BuiltinType(Type):
     34     def __init__(self, name, size, bitFieldSize=None):
     35         self.name = name
     36         self.size = size
     37         self.bitFieldSize = bitFieldSize
     38 
     39     def isBitField(self):
     40         return self.bitFieldSize is not None
     41 
     42     def isPaddingBitField(self):
     43         return self.bitFieldSize is 0
     44 
     45     def getBitFieldSize(self):
     46         assert self.isBitField()
     47         return self.bitFieldSize
     48 
     49     def getTypeName(self, printer):
     50         return self.name
     51 
     52     def sizeof(self):
     53         return self.size
     54 
     55     def __str__(self):
     56         return self.name
     57 
     58 class EnumType(Type):
     59     def __init__(self, index, enumerators):
     60         self.index = index
     61         self.enumerators = enumerators
     62 
     63     def getEnumerators(self):
     64         result = ''
     65         for i, init in enumerate(self.enumerators):
     66             if i > 0:
     67                 result = result + ', '
     68             result = result + 'enum%dval%d' % (self.index, i)
     69             if init:
     70                 result = result + ' = %s' % (init)
     71 
     72         return result
     73 
     74     def __str__(self):
     75         return 'enum { %s }' % (self.getEnumerators())
     76 
     77     def getTypedefDef(self, name, printer):
     78         return 'typedef enum %s { %s } %s;'%(name, self.getEnumerators(), name)
     79 
     80 class RecordType(Type):
     81     def __init__(self, index, isUnion, fields):
     82         self.index = index
     83         self.isUnion = isUnion
     84         self.fields = fields
     85         self.name = None
     86 
     87     def __str__(self):
     88         def getField(t):
     89             if t.isBitField():
     90                 return "%s : %d;" % (t, t.getBitFieldSize())
     91             else:
     92                 return "%s;" % t
     93 
     94         return '%s { %s }'%(('struct','union')[self.isUnion],
     95                             ' '.join(map(getField, self.fields)))
     96 
     97     def getTypedefDef(self, name, printer):
     98         def getField((i, t)):
     99             if t.isBitField():
    100                 if t.isPaddingBitField():
    101                     return '%s : 0;'%(printer.getTypeName(t),)
    102                 else:
    103                     return '%s field%d : %d;'%(printer.getTypeName(t),i,
    104                                                t.getBitFieldSize())
    105             else:
    106                 return '%s field%d;'%(printer.getTypeName(t),i)
    107         fields = map(getField, enumerate(self.fields))
    108         # Name the struct for more readable LLVM IR.
    109         return 'typedef %s %s { %s } %s;'%(('struct','union')[self.isUnion],
    110                                            name, ' '.join(fields), name)
    111                                            
    112 class ArrayType(Type):
    113     def __init__(self, index, isVector, elementType, size):
    114         if isVector:
    115             # Note that for vectors, this is the size in bytes.
    116             assert size > 0
    117         else:
    118             assert size is None or size >= 0
    119         self.index = index
    120         self.isVector = isVector
    121         self.elementType = elementType
    122         self.size = size
    123         if isVector:
    124             eltSize = self.elementType.sizeof()
    125             assert not (self.size % eltSize)
    126             self.numElements = self.size // eltSize
    127         else:
    128             self.numElements = self.size
    129 
    130     def __str__(self):
    131         if self.isVector:
    132             return 'vector (%s)[%d]'%(self.elementType,self.size)
    133         elif self.size is not None:
    134             return '(%s)[%d]'%(self.elementType,self.size)
    135         else:
    136             return '(%s)[]'%(self.elementType,)
    137 
    138     def getTypedefDef(self, name, printer):
    139         elementName = printer.getTypeName(self.elementType)
    140         if self.isVector:
    141             return 'typedef %s %s __attribute__ ((vector_size (%d)));'%(elementName,
    142                                                                         name,
    143                                                                         self.size)
    144         else:
    145             if self.size is None:
    146                 sizeStr = ''
    147             else:
    148                 sizeStr = str(self.size)
    149             return 'typedef %s %s[%s];'%(elementName, name, sizeStr)
    150 
    151 class ComplexType(Type):
    152     def __init__(self, index, elementType):
    153         self.index = index
    154         self.elementType = elementType
    155 
    156     def __str__(self):
    157         return '_Complex (%s)'%(self.elementType)
    158 
    159     def getTypedefDef(self, name, printer):
    160         return 'typedef _Complex %s %s;'%(printer.getTypeName(self.elementType), name)
    161 
    162 class FunctionType(Type):
    163     def __init__(self, index, returnType, argTypes):
    164         self.index = index
    165         self.returnType = returnType
    166         self.argTypes = argTypes
    167 
    168     def __str__(self):
    169         if self.returnType is None:
    170             rt = 'void'
    171         else:
    172             rt = str(self.returnType)
    173         if not self.argTypes:
    174             at = 'void'
    175         else:
    176             at = ', '.join(map(str, self.argTypes))
    177         return '%s (*)(%s)'%(rt, at)
    178 
    179     def getTypedefDef(self, name, printer):
    180         if self.returnType is None:
    181             rt = 'void'
    182         else:
    183             rt = str(self.returnType)
    184         if not self.argTypes:
    185             at = 'void'
    186         else:
    187             at = ', '.join(map(str, self.argTypes))
    188         return 'typedef %s (*%s)(%s);'%(rt, name, at)
    189 
    190 ###
    191 # Type enumerators
    192 
    193 class TypeGenerator(object):
    194     def __init__(self):
    195         self.cache = {}
    196 
    197     def setCardinality(self):
    198         abstract
    199 
    200     def get(self, N):
    201         T = self.cache.get(N)
    202         if T is None:
    203             assert 0 <= N < self.cardinality
    204             T = self.cache[N] = self.generateType(N)
    205         return T
    206 
    207     def generateType(self, N):
    208         abstract
    209 
    210 class FixedTypeGenerator(TypeGenerator):
    211     def __init__(self, types):
    212         TypeGenerator.__init__(self)
    213         self.types = types
    214         self.setCardinality()
    215 
    216     def setCardinality(self):
    217         self.cardinality = len(self.types)
    218 
    219     def generateType(self, N):
    220         return self.types[N]
    221 
    222 # Factorial
    223 def fact(n):
    224     result = 1
    225     while n > 0:
    226         result = result * n
    227         n = n - 1
    228     return result
    229 
    230 # Compute the number of combinations (n choose k)
    231 def num_combinations(n, k): 
    232     return fact(n) / (fact(k) * fact(n - k))
    233 
    234 # Enumerate the combinations choosing k elements from the list of values
    235 def combinations(values, k):
    236     # From ActiveState Recipe 190465: Generator for permutations,
    237     # combinations, selections of a sequence
    238     if k==0: yield []
    239     else:
    240         for i in xrange(len(values)-k+1):
    241             for cc in combinations(values[i+1:],k-1):
    242                 yield [values[i]]+cc
    243 
    244 class EnumTypeGenerator(TypeGenerator):
    245     def __init__(self, values, minEnumerators, maxEnumerators):
    246         TypeGenerator.__init__(self)
    247         self.values = values
    248         self.minEnumerators = minEnumerators
    249         self.maxEnumerators = maxEnumerators
    250         self.setCardinality()
    251 
    252     def setCardinality(self):
    253         self.cardinality = 0
    254         for num in range(self.minEnumerators, self.maxEnumerators + 1):
    255             self.cardinality += num_combinations(len(self.values), num)
    256 
    257     def generateType(self, n):
    258         # Figure out the number of enumerators in this type
    259         numEnumerators = self.minEnumerators
    260         valuesCovered = 0
    261         while numEnumerators < self.maxEnumerators:
    262             comb = num_combinations(len(self.values), numEnumerators)
    263             if valuesCovered + comb > n:
    264                 break
    265             numEnumerators = numEnumerators + 1
    266             valuesCovered += comb
    267 
    268         # Find the requested combination of enumerators and build a
    269         # type from it.
    270         i = 0
    271         for enumerators in combinations(self.values, numEnumerators):
    272             if i == n - valuesCovered:
    273                 return EnumType(n, enumerators)
    274                 
    275             i = i + 1
    276 
    277         assert False
    278 
    279 class ComplexTypeGenerator(TypeGenerator):
    280     def __init__(self, typeGen):
    281         TypeGenerator.__init__(self)
    282         self.typeGen = typeGen
    283         self.setCardinality()
    284     
    285     def setCardinality(self):
    286         self.cardinality = self.typeGen.cardinality
    287 
    288     def generateType(self, N):
    289         return ComplexType(N, self.typeGen.get(N))
    290 
    291 class VectorTypeGenerator(TypeGenerator):
    292     def __init__(self, typeGen, sizes):
    293         TypeGenerator.__init__(self)
    294         self.typeGen = typeGen
    295         self.sizes = tuple(map(int,sizes))
    296         self.setCardinality()
    297 
    298     def setCardinality(self):
    299         self.cardinality = len(self.sizes)*self.typeGen.cardinality
    300 
    301     def generateType(self, N):
    302         S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
    303         return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
    304 
    305 class FixedArrayTypeGenerator(TypeGenerator):
    306     def __init__(self, typeGen, sizes):
    307         TypeGenerator.__init__(self)
    308         self.typeGen = typeGen
    309         self.sizes = tuple(size)
    310         self.setCardinality()
    311 
    312     def setCardinality(self):
    313         self.cardinality = len(self.sizes)*self.typeGen.cardinality
    314 
    315     def generateType(self, N):
    316         S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
    317         return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
    318 
    319 class ArrayTypeGenerator(TypeGenerator):
    320     def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
    321         TypeGenerator.__init__(self)
    322         self.typeGen = typeGen
    323         self.useIncomplete = useIncomplete
    324         self.useZero = useZero
    325         self.maxSize = int(maxSize)
    326         self.W = useIncomplete + useZero + self.maxSize
    327         self.setCardinality()
    328 
    329     def setCardinality(self):
    330         self.cardinality = self.W * self.typeGen.cardinality
    331 
    332     def generateType(self, N):
    333         S,T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
    334         if self.useIncomplete:
    335             if S==0:
    336                 size = None
    337                 S = None
    338             else:
    339                 S = S - 1
    340         if S is not None:
    341             if self.useZero:
    342                 size = S
    343             else:
    344                 size = S + 1        
    345         return ArrayType(N, False, self.typeGen.get(T), size)
    346 
    347 class RecordTypeGenerator(TypeGenerator):
    348     def __init__(self, typeGen, useUnion, maxSize):
    349         TypeGenerator.__init__(self)
    350         self.typeGen = typeGen
    351         self.useUnion = bool(useUnion)
    352         self.maxSize = int(maxSize)
    353         self.setCardinality()
    354 
    355     def setCardinality(self):
    356         M = 1 + self.useUnion
    357         if self.maxSize is aleph0:
    358             S =  aleph0 * self.typeGen.cardinality
    359         else:
    360             S = 0
    361             for i in range(self.maxSize+1):
    362                 S += M * (self.typeGen.cardinality ** i)
    363         self.cardinality = S
    364 
    365     def generateType(self, N):
    366         isUnion,I = False,N
    367         if self.useUnion:
    368             isUnion,I = (I&1),I>>1
    369         fields = map(self.typeGen.get,getNthTuple(I,self.maxSize,self.typeGen.cardinality))
    370         return RecordType(N, isUnion, fields)
    371 
    372 class FunctionTypeGenerator(TypeGenerator):
    373     def __init__(self, typeGen, useReturn, maxSize):
    374         TypeGenerator.__init__(self)
    375         self.typeGen = typeGen
    376         self.useReturn = useReturn
    377         self.maxSize = maxSize
    378         self.setCardinality()
    379     
    380     def setCardinality(self):
    381         if self.maxSize is aleph0:
    382             S = aleph0 * self.typeGen.cardinality()
    383         elif self.useReturn:
    384             S = 0
    385             for i in range(1,self.maxSize+1+1):
    386                 S += self.typeGen.cardinality ** i
    387         else:
    388             S = 0
    389             for i in range(self.maxSize+1):
    390                 S += self.typeGen.cardinality ** i
    391         self.cardinality = S
    392     
    393     def generateType(self, N):
    394         if self.useReturn:
    395             # Skip the empty tuple
    396             argIndices = getNthTuple(N+1, self.maxSize+1, self.typeGen.cardinality)
    397             retIndex,argIndices = argIndices[0],argIndices[1:]
    398             retTy = self.typeGen.get(retIndex)
    399         else:
    400             retTy = None
    401             argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
    402         args = map(self.typeGen.get, argIndices)
    403         return FunctionType(N, retTy, args)
    404 
    405 class AnyTypeGenerator(TypeGenerator):
    406     def __init__(self):
    407         TypeGenerator.__init__(self)
    408         self.generators = []
    409         self.bounds = []
    410         self.setCardinality()
    411         self._cardinality = None
    412         
    413     def getCardinality(self):
    414         if self._cardinality is None:
    415             return aleph0
    416         else:
    417             return self._cardinality
    418     def setCardinality(self):
    419         self.bounds = [g.cardinality for g in self.generators]
    420         self._cardinality = sum(self.bounds)
    421     cardinality = property(getCardinality, None)
    422 
    423     def addGenerator(self, g):
    424         self.generators.append(g)
    425         for i in range(100):
    426             prev = self._cardinality
    427             self._cardinality = None
    428             for g in self.generators:
    429                 g.setCardinality()
    430             self.setCardinality()
    431             if (self._cardinality is aleph0) or prev==self._cardinality:
    432                 break
    433         else:
    434             raise RuntimeError,"Infinite loop in setting cardinality"
    435 
    436     def generateType(self, N):
    437         index,M = getNthPairVariableBounds(N, self.bounds)
    438         return self.generators[index].get(M)
    439 
    440 def test():
    441     fbtg = FixedTypeGenerator([BuiltinType('char', 4),
    442                                BuiltinType('char', 4, 0),
    443                                BuiltinType('int',  4, 5)])
    444 
    445     fields1 = AnyTypeGenerator()
    446     fields1.addGenerator( fbtg )
    447 
    448     fields0 = AnyTypeGenerator()
    449     fields0.addGenerator( fbtg )
    450 #    fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
    451 
    452     btg = FixedTypeGenerator([BuiltinType('char', 4),
    453                               BuiltinType('int',  4)])
    454     etg = EnumTypeGenerator([None, '-1', '1', '1u'], 0, 3)
    455 
    456     atg = AnyTypeGenerator()
    457     atg.addGenerator( btg )
    458     atg.addGenerator( RecordTypeGenerator(fields0, False, 4) )
    459     atg.addGenerator( etg )
    460     print 'Cardinality:',atg.cardinality
    461     for i in range(100):
    462         if i == atg.cardinality:
    463             try:
    464                 atg.get(i)
    465                 raise RuntimeError,"Cardinality was wrong"
    466             except AssertionError:
    467                 break
    468         print '%4d: %s'%(i, atg.get(i))
    469 
    470 if __name__ == '__main__':
    471     test()
    472