Home | History | Annotate | Download | only in ABITest
      1 """Flexible enumeration of C types."""
      2 
      3 from Enumeration import *
      4 
      5 # TODO:
      6 
      7 #  - struct improvements (flexible arrays, packed &
      8 #    unpacked, alignment)
      9 #  - objective-c qualified id
     10 #  - anonymous / transparent unions
     11 #  - VLAs
     12 #  - block types
     13 #  - K&R functions
     14 #  - pass arguments of different types (test extension, transparent union)
     15 #  - varargs
     16 
     17 ###
     18 # Actual type types
     19 
     20 class Type:
     21     def isBitField(self):
     22         return False
     23 
     24     def isPaddingBitField(self):
     25         return False
     26 
     27     def getTypeName(self, printer):
     28         name = 'T%d' % len(printer.types)
     29         typedef = self.getTypedefDef(name, printer)
     30         printer.addDeclaration(typedef)
     31         return name
     32 
     33 class BuiltinType(Type):
     34     def __init__(self, name, size, bitFieldSize=None):
     35         self.name = name
     36         self.size = size
     37         self.bitFieldSize = bitFieldSize
     38 
     39     def isBitField(self):
     40         return self.bitFieldSize is not None
     41 
     42     def isPaddingBitField(self):
     43         return self.bitFieldSize is 0
     44 
     45     def getBitFieldSize(self):
     46         assert self.isBitField()
     47         return self.bitFieldSize
     48 
     49     def getTypeName(self, printer):
     50         return self.name
     51 
     52     def sizeof(self):
     53         return self.size
     54 
     55     def __str__(self):
     56         return self.name
     57 
     58 class EnumType(Type):
     59     unique_id = 0
     60 
     61     def __init__(self, index, enumerators):
     62         self.index = index
     63         self.enumerators = enumerators
     64         self.unique_id = self.__class__.unique_id
     65         self.__class__.unique_id += 1
     66 
     67     def getEnumerators(self):
     68         result = ''
     69         for i, init in enumerate(self.enumerators):
     70             if i > 0:
     71                 result = result + ', '
     72             result = result + 'enum%dval%d_%d' % (self.index, i, self.unique_id)
     73             if init:
     74                 result = result + ' = %s' % (init)
     75 
     76         return result
     77 
     78     def __str__(self):
     79         return 'enum { %s }' % (self.getEnumerators())
     80 
     81     def getTypedefDef(self, name, printer):
     82         return 'typedef enum %s { %s } %s;'%(name, self.getEnumerators(), name)
     83 
     84 class RecordType(Type):
     85     def __init__(self, index, isUnion, fields):
     86         self.index = index
     87         self.isUnion = isUnion
     88         self.fields = fields
     89         self.name = None
     90 
     91     def __str__(self):
     92         def getField(t):
     93             if t.isBitField():
     94                 return "%s : %d;" % (t, t.getBitFieldSize())
     95             else:
     96                 return "%s;" % t
     97 
     98         return '%s { %s }'%(('struct','union')[self.isUnion],
     99                             ' '.join(map(getField, self.fields)))
    100 
    101     def getTypedefDef(self, name, printer):
    102         def getField((i, t)):
    103             if t.isBitField():
    104                 if t.isPaddingBitField():
    105                     return '%s : 0;'%(printer.getTypeName(t),)
    106                 else:
    107                     return '%s field%d : %d;'%(printer.getTypeName(t),i,
    108                                                t.getBitFieldSize())
    109             else:
    110                 return '%s field%d;'%(printer.getTypeName(t),i)
    111         fields = map(getField, enumerate(self.fields))
    112         # Name the struct for more readable LLVM IR.
    113         return 'typedef %s %s { %s } %s;'%(('struct','union')[self.isUnion],
    114                                            name, ' '.join(fields), name)
    115                                            
    116 class ArrayType(Type):
    117     def __init__(self, index, isVector, elementType, size):
    118         if isVector:
    119             # Note that for vectors, this is the size in bytes.
    120             assert size > 0
    121         else:
    122             assert size is None or size >= 0
    123         self.index = index
    124         self.isVector = isVector
    125         self.elementType = elementType
    126         self.size = size
    127         if isVector:
    128             eltSize = self.elementType.sizeof()
    129             assert not (self.size % eltSize)
    130             self.numElements = self.size // eltSize
    131         else:
    132             self.numElements = self.size
    133 
    134     def __str__(self):
    135         if self.isVector:
    136             return 'vector (%s)[%d]'%(self.elementType,self.size)
    137         elif self.size is not None:
    138             return '(%s)[%d]'%(self.elementType,self.size)
    139         else:
    140             return '(%s)[]'%(self.elementType,)
    141 
    142     def getTypedefDef(self, name, printer):
    143         elementName = printer.getTypeName(self.elementType)
    144         if self.isVector:
    145             return 'typedef %s %s __attribute__ ((vector_size (%d)));'%(elementName,
    146                                                                         name,
    147                                                                         self.size)
    148         else:
    149             if self.size is None:
    150                 sizeStr = ''
    151             else:
    152                 sizeStr = str(self.size)
    153             return 'typedef %s %s[%s];'%(elementName, name, sizeStr)
    154 
    155 class ComplexType(Type):
    156     def __init__(self, index, elementType):
    157         self.index = index
    158         self.elementType = elementType
    159 
    160     def __str__(self):
    161         return '_Complex (%s)'%(self.elementType)
    162 
    163     def getTypedefDef(self, name, printer):
    164         return 'typedef _Complex %s %s;'%(printer.getTypeName(self.elementType), name)
    165 
    166 class FunctionType(Type):
    167     def __init__(self, index, returnType, argTypes):
    168         self.index = index
    169         self.returnType = returnType
    170         self.argTypes = argTypes
    171 
    172     def __str__(self):
    173         if self.returnType is None:
    174             rt = 'void'
    175         else:
    176             rt = str(self.returnType)
    177         if not self.argTypes:
    178             at = 'void'
    179         else:
    180             at = ', '.join(map(str, self.argTypes))
    181         return '%s (*)(%s)'%(rt, at)
    182 
    183     def getTypedefDef(self, name, printer):
    184         if self.returnType is None:
    185             rt = 'void'
    186         else:
    187             rt = str(self.returnType)
    188         if not self.argTypes:
    189             at = 'void'
    190         else:
    191             at = ', '.join(map(str, self.argTypes))
    192         return 'typedef %s (*%s)(%s);'%(rt, name, at)
    193 
    194 ###
    195 # Type enumerators
    196 
    197 class TypeGenerator(object):
    198     def __init__(self):
    199         self.cache = {}
    200 
    201     def setCardinality(self):
    202         abstract
    203 
    204     def get(self, N):
    205         T = self.cache.get(N)
    206         if T is None:
    207             assert 0 <= N < self.cardinality
    208             T = self.cache[N] = self.generateType(N)
    209         return T
    210 
    211     def generateType(self, N):
    212         abstract
    213 
    214 class FixedTypeGenerator(TypeGenerator):
    215     def __init__(self, types):
    216         TypeGenerator.__init__(self)
    217         self.types = types
    218         self.setCardinality()
    219 
    220     def setCardinality(self):
    221         self.cardinality = len(self.types)
    222 
    223     def generateType(self, N):
    224         return self.types[N]
    225 
    226 # Factorial
    227 def fact(n):
    228     result = 1
    229     while n > 0:
    230         result = result * n
    231         n = n - 1
    232     return result
    233 
    234 # Compute the number of combinations (n choose k)
    235 def num_combinations(n, k): 
    236     return fact(n) / (fact(k) * fact(n - k))
    237 
    238 # Enumerate the combinations choosing k elements from the list of values
    239 def combinations(values, k):
    240     # From ActiveState Recipe 190465: Generator for permutations,
    241     # combinations, selections of a sequence
    242     if k==0: yield []
    243     else:
    244         for i in xrange(len(values)-k+1):
    245             for cc in combinations(values[i+1:],k-1):
    246                 yield [values[i]]+cc
    247 
    248 class EnumTypeGenerator(TypeGenerator):
    249     def __init__(self, values, minEnumerators, maxEnumerators):
    250         TypeGenerator.__init__(self)
    251         self.values = values
    252         self.minEnumerators = minEnumerators
    253         self.maxEnumerators = maxEnumerators
    254         self.setCardinality()
    255 
    256     def setCardinality(self):
    257         self.cardinality = 0
    258         for num in range(self.minEnumerators, self.maxEnumerators + 1):
    259             self.cardinality += num_combinations(len(self.values), num)
    260 
    261     def generateType(self, n):
    262         # Figure out the number of enumerators in this type
    263         numEnumerators = self.minEnumerators
    264         valuesCovered = 0
    265         while numEnumerators < self.maxEnumerators:
    266             comb = num_combinations(len(self.values), numEnumerators)
    267             if valuesCovered + comb > n:
    268                 break
    269             numEnumerators = numEnumerators + 1
    270             valuesCovered += comb
    271 
    272         # Find the requested combination of enumerators and build a
    273         # type from it.
    274         i = 0
    275         for enumerators in combinations(self.values, numEnumerators):
    276             if i == n - valuesCovered:
    277                 return EnumType(n, enumerators)
    278                 
    279             i = i + 1
    280 
    281         assert False
    282 
    283 class ComplexTypeGenerator(TypeGenerator):
    284     def __init__(self, typeGen):
    285         TypeGenerator.__init__(self)
    286         self.typeGen = typeGen
    287         self.setCardinality()
    288     
    289     def setCardinality(self):
    290         self.cardinality = self.typeGen.cardinality
    291 
    292     def generateType(self, N):
    293         return ComplexType(N, self.typeGen.get(N))
    294 
    295 class VectorTypeGenerator(TypeGenerator):
    296     def __init__(self, typeGen, sizes):
    297         TypeGenerator.__init__(self)
    298         self.typeGen = typeGen
    299         self.sizes = tuple(map(int,sizes))
    300         self.setCardinality()
    301 
    302     def setCardinality(self):
    303         self.cardinality = len(self.sizes)*self.typeGen.cardinality
    304 
    305     def generateType(self, N):
    306         S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
    307         return ArrayType(N, True, self.typeGen.get(T), self.sizes[S])
    308 
    309 class FixedArrayTypeGenerator(TypeGenerator):
    310     def __init__(self, typeGen, sizes):
    311         TypeGenerator.__init__(self)
    312         self.typeGen = typeGen
    313         self.sizes = tuple(size)
    314         self.setCardinality()
    315 
    316     def setCardinality(self):
    317         self.cardinality = len(self.sizes)*self.typeGen.cardinality
    318 
    319     def generateType(self, N):
    320         S,T = getNthPairBounded(N, len(self.sizes), self.typeGen.cardinality)
    321         return ArrayType(N, false, self.typeGen.get(T), self.sizes[S])
    322 
    323 class ArrayTypeGenerator(TypeGenerator):
    324     def __init__(self, typeGen, maxSize, useIncomplete=False, useZero=False):
    325         TypeGenerator.__init__(self)
    326         self.typeGen = typeGen
    327         self.useIncomplete = useIncomplete
    328         self.useZero = useZero
    329         self.maxSize = int(maxSize)
    330         self.W = useIncomplete + useZero + self.maxSize
    331         self.setCardinality()
    332 
    333     def setCardinality(self):
    334         self.cardinality = self.W * self.typeGen.cardinality
    335 
    336     def generateType(self, N):
    337         S,T = getNthPairBounded(N, self.W, self.typeGen.cardinality)
    338         if self.useIncomplete:
    339             if S==0:
    340                 size = None
    341                 S = None
    342             else:
    343                 S = S - 1
    344         if S is not None:
    345             if self.useZero:
    346                 size = S
    347             else:
    348                 size = S + 1        
    349         return ArrayType(N, False, self.typeGen.get(T), size)
    350 
    351 class RecordTypeGenerator(TypeGenerator):
    352     def __init__(self, typeGen, useUnion, maxSize):
    353         TypeGenerator.__init__(self)
    354         self.typeGen = typeGen
    355         self.useUnion = bool(useUnion)
    356         self.maxSize = int(maxSize)
    357         self.setCardinality()
    358 
    359     def setCardinality(self):
    360         M = 1 + self.useUnion
    361         if self.maxSize is aleph0:
    362             S =  aleph0 * self.typeGen.cardinality
    363         else:
    364             S = 0
    365             for i in range(self.maxSize+1):
    366                 S += M * (self.typeGen.cardinality ** i)
    367         self.cardinality = S
    368 
    369     def generateType(self, N):
    370         isUnion,I = False,N
    371         if self.useUnion:
    372             isUnion,I = (I&1),I>>1
    373         fields = map(self.typeGen.get,getNthTuple(I,self.maxSize,self.typeGen.cardinality))
    374         return RecordType(N, isUnion, fields)
    375 
    376 class FunctionTypeGenerator(TypeGenerator):
    377     def __init__(self, typeGen, useReturn, maxSize):
    378         TypeGenerator.__init__(self)
    379         self.typeGen = typeGen
    380         self.useReturn = useReturn
    381         self.maxSize = maxSize
    382         self.setCardinality()
    383     
    384     def setCardinality(self):
    385         if self.maxSize is aleph0:
    386             S = aleph0 * self.typeGen.cardinality()
    387         elif self.useReturn:
    388             S = 0
    389             for i in range(1,self.maxSize+1+1):
    390                 S += self.typeGen.cardinality ** i
    391         else:
    392             S = 0
    393             for i in range(self.maxSize+1):
    394                 S += self.typeGen.cardinality ** i
    395         self.cardinality = S
    396     
    397     def generateType(self, N):
    398         if self.useReturn:
    399             # Skip the empty tuple
    400             argIndices = getNthTuple(N+1, self.maxSize+1, self.typeGen.cardinality)
    401             retIndex,argIndices = argIndices[0],argIndices[1:]
    402             retTy = self.typeGen.get(retIndex)
    403         else:
    404             retTy = None
    405             argIndices = getNthTuple(N, self.maxSize, self.typeGen.cardinality)
    406         args = map(self.typeGen.get, argIndices)
    407         return FunctionType(N, retTy, args)
    408 
    409 class AnyTypeGenerator(TypeGenerator):
    410     def __init__(self):
    411         TypeGenerator.__init__(self)
    412         self.generators = []
    413         self.bounds = []
    414         self.setCardinality()
    415         self._cardinality = None
    416         
    417     def getCardinality(self):
    418         if self._cardinality is None:
    419             return aleph0
    420         else:
    421             return self._cardinality
    422     def setCardinality(self):
    423         self.bounds = [g.cardinality for g in self.generators]
    424         self._cardinality = sum(self.bounds)
    425     cardinality = property(getCardinality, None)
    426 
    427     def addGenerator(self, g):
    428         self.generators.append(g)
    429         for i in range(100):
    430             prev = self._cardinality
    431             self._cardinality = None
    432             for g in self.generators:
    433                 g.setCardinality()
    434             self.setCardinality()
    435             if (self._cardinality is aleph0) or prev==self._cardinality:
    436                 break
    437         else:
    438             raise RuntimeError,"Infinite loop in setting cardinality"
    439 
    440     def generateType(self, N):
    441         index,M = getNthPairVariableBounds(N, self.bounds)
    442         return self.generators[index].get(M)
    443 
    444 def test():
    445     fbtg = FixedTypeGenerator([BuiltinType('char', 4),
    446                                BuiltinType('char', 4, 0),
    447                                BuiltinType('int',  4, 5)])
    448 
    449     fields1 = AnyTypeGenerator()
    450     fields1.addGenerator( fbtg )
    451 
    452     fields0 = AnyTypeGenerator()
    453     fields0.addGenerator( fbtg )
    454 #    fields0.addGenerator( RecordTypeGenerator(fields1, False, 4) )
    455 
    456     btg = FixedTypeGenerator([BuiltinType('char', 4),
    457                               BuiltinType('int',  4)])
    458     etg = EnumTypeGenerator([None, '-1', '1', '1u'], 0, 3)
    459 
    460     atg = AnyTypeGenerator()
    461     atg.addGenerator( btg )
    462     atg.addGenerator( RecordTypeGenerator(fields0, False, 4) )
    463     atg.addGenerator( etg )
    464     print 'Cardinality:',atg.cardinality
    465     for i in range(100):
    466         if i == atg.cardinality:
    467             try:
    468                 atg.get(i)
    469                 raise RuntimeError,"Cardinality was wrong"
    470             except AssertionError:
    471                 break
    472         print '%4d: %s'%(i, atg.get(i))
    473 
    474 if __name__ == '__main__':
    475     test()
    476