Home | History | Annotate | Download | only in ABITest
      1 """Utilities for enumeration of finite and countably infinite sets.
      2 """
      3 ###
      4 # Countable iteration
      5 
      6 # Simplifies some calculations
      7 class Aleph0(int):
      8     _singleton = None
      9     def __new__(type):
     10         if type._singleton is None:
     11             type._singleton = int.__new__(type)
     12         return type._singleton
     13     def __repr__(self): return '<aleph0>'
     14     def __str__(self): return 'inf'
     15     
     16     def __cmp__(self, b):
     17         return 1
     18 
     19     def __sub__(self, b):
     20         raise ValueError,"Cannot subtract aleph0"
     21     __rsub__ = __sub__
     22 
     23     def __add__(self, b): 
     24         return self
     25     __radd__ = __add__
     26 
     27     def __mul__(self, b): 
     28         if b == 0: return b            
     29         return self
     30     __rmul__ = __mul__
     31 
     32     def __floordiv__(self, b):
     33         if b == 0: raise ZeroDivisionError
     34         return self
     35     __rfloordiv__ = __floordiv__
     36     __truediv__ = __floordiv__
     37     __rtuediv__ = __floordiv__
     38     __div__ = __floordiv__
     39     __rdiv__ = __floordiv__
     40 
     41     def __pow__(self, b):
     42         if b == 0: return 1
     43         return self
     44 aleph0 = Aleph0()
     45 
     46 def base(line):
     47     return line*(line+1)//2
     48 
     49 def pairToN((x,y)):
     50     line,index = x+y,y
     51     return base(line)+index
     52 
     53 def getNthPairInfo(N):
     54     # Avoid various singularities
     55     if N==0:
     56         return (0,0)
     57 
     58     # Gallop to find bounds for line
     59     line = 1
     60     next = 2
     61     while base(next)<=N:
     62         line = next
     63         next = line << 1
     64     
     65     # Binary search for starting line
     66     lo = line
     67     hi = line<<1
     68     while lo + 1 != hi:
     69         #assert base(lo) <= N < base(hi)
     70         mid = (lo + hi)>>1
     71         if base(mid)<=N:
     72             lo = mid
     73         else:
     74             hi = mid
     75 
     76     line = lo
     77     return line, N - base(line)
     78 
     79 def getNthPair(N):
     80     line,index = getNthPairInfo(N)
     81     return (line - index, index)
     82 
     83 def getNthPairBounded(N,W=aleph0,H=aleph0,useDivmod=False):
     84     """getNthPairBounded(N, W, H) -> (x, y)
     85     
     86     Return the N-th pair such that 0 <= x < W and 0 <= y < H."""
     87 
     88     if W <= 0 or H <= 0:
     89         raise ValueError,"Invalid bounds"
     90     elif N >= W*H:
     91         raise ValueError,"Invalid input (out of bounds)"
     92 
     93     # Simple case...
     94     if W is aleph0 and H is aleph0:
     95         return getNthPair(N)
     96 
     97     # Otherwise simplify by assuming W < H
     98     if H < W:
     99         x,y = getNthPairBounded(N,H,W,useDivmod=useDivmod)
    100         return y,x
    101 
    102     if useDivmod:
    103         return N%W,N//W
    104     else:
    105         # Conceptually we want to slide a diagonal line across a
    106         # rectangle. This gives more interesting results for large
    107         # bounds than using divmod.
    108         
    109         # If in lower left, just return as usual
    110         cornerSize = base(W)
    111         if N < cornerSize:
    112             return getNthPair(N)
    113 
    114         # Otherwise if in upper right, subtract from corner
    115         if H is not aleph0:
    116             M = W*H - N - 1
    117             if M < cornerSize:
    118                 x,y = getNthPair(M)
    119                 return (W-1-x,H-1-y)
    120 
    121         # Otherwise, compile line and index from number of times we
    122         # wrap.
    123         N = N - cornerSize
    124         index,offset = N%W,N//W
    125         # p = (W-1, 1+offset) + (-1,1)*index
    126         return (W-1-index, 1+offset+index)
    127 def getNthPairBoundedChecked(N,W=aleph0,H=aleph0,useDivmod=False,GNP=getNthPairBounded):
    128     x,y = GNP(N,W,H,useDivmod)
    129     assert 0 <= x < W and 0 <= y < H
    130     return x,y
    131 
    132 def getNthNTuple(N, W, H=aleph0, useLeftToRight=False):
    133     """getNthNTuple(N, W, H) -> (x_0, x_1, ..., x_W)
    134 
    135     Return the N-th W-tuple, where for 0 <= x_i < H."""
    136 
    137     if useLeftToRight:
    138         elts = [None]*W
    139         for i in range(W):
    140             elts[i],N = getNthPairBounded(N, H)
    141         return tuple(elts)
    142     else:
    143         if W==0:
    144             return ()
    145         elif W==1:
    146             return (N,)
    147         elif W==2:
    148             return getNthPairBounded(N, H, H)
    149         else:
    150             LW,RW = W//2, W - (W//2)
    151             L,R = getNthPairBounded(N, H**LW, H**RW)
    152             return (getNthNTuple(L,LW,H=H,useLeftToRight=useLeftToRight) + 
    153                     getNthNTuple(R,RW,H=H,useLeftToRight=useLeftToRight))
    154 def getNthNTupleChecked(N, W, H=aleph0, useLeftToRight=False, GNT=getNthNTuple):
    155     t = GNT(N,W,H,useLeftToRight)
    156     assert len(t) == W
    157     for i in t:
    158         assert i < H
    159     return t
    160 
    161 def getNthTuple(N, maxSize=aleph0, maxElement=aleph0, useDivmod=False, useLeftToRight=False):
    162     """getNthTuple(N, maxSize, maxElement) -> x
    163 
    164     Return the N-th tuple where len(x) < maxSize and for y in x, 0 <=
    165     y < maxElement."""
    166 
    167     # All zero sized tuples are isomorphic, don't ya know.
    168     if N == 0:
    169         return ()
    170     N -= 1
    171     if maxElement is not aleph0:
    172         if maxSize is aleph0:
    173             raise NotImplementedError,'Max element size without max size unhandled'
    174         bounds = [maxElement**i for i in range(1, maxSize+1)]
    175         S,M = getNthPairVariableBounds(N, bounds)
    176     else:
    177         S,M = getNthPairBounded(N, maxSize, useDivmod=useDivmod)
    178     return getNthNTuple(M, S+1, maxElement, useLeftToRight=useLeftToRight)
    179 def getNthTupleChecked(N, maxSize=aleph0, maxElement=aleph0, 
    180                        useDivmod=False, useLeftToRight=False, GNT=getNthTuple):
    181     # FIXME: maxsize is inclusive
    182     t = GNT(N,maxSize,maxElement,useDivmod,useLeftToRight)
    183     assert len(t) <= maxSize
    184     for i in t:
    185         assert i < maxElement
    186     return t
    187 
    188 def getNthPairVariableBounds(N, bounds):
    189     """getNthPairVariableBounds(N, bounds) -> (x, y)
    190 
    191     Given a finite list of bounds (which may be finite or aleph0),
    192     return the N-th pair such that 0 <= x < len(bounds) and 0 <= y <
    193     bounds[x]."""
    194 
    195     if not bounds:
    196         raise ValueError,"Invalid bounds"
    197     if not (0 <= N < sum(bounds)):
    198         raise ValueError,"Invalid input (out of bounds)"
    199 
    200     level = 0
    201     active = range(len(bounds))
    202     active.sort(key=lambda i: bounds[i])
    203     prevLevel = 0
    204     for i,index in enumerate(active):
    205         level = bounds[index]
    206         W = len(active) - i
    207         if level is aleph0:
    208             H = aleph0
    209         else:
    210             H = level - prevLevel
    211         levelSize = W*H
    212         if N<levelSize: # Found the level
    213             idelta,delta = getNthPairBounded(N, W, H)
    214             return active[i+idelta],prevLevel+delta
    215         else:
    216             N -= levelSize
    217             prevLevel = level
    218     else:
    219         raise RuntimError,"Unexpected loop completion"
    220 
    221 def getNthPairVariableBoundsChecked(N, bounds, GNVP=getNthPairVariableBounds):
    222     x,y = GNVP(N,bounds)
    223     assert 0 <= x < len(bounds) and 0 <= y < bounds[x]
    224     return (x,y)
    225 
    226 ###
    227 
    228 def testPairs():
    229     W = 3
    230     H = 6
    231     a = [['  ' for x in range(10)] for y in range(10)]
    232     b = [['  ' for x in range(10)] for y in range(10)]
    233     for i in range(min(W*H,40)):
    234         x,y = getNthPairBounded(i,W,H)
    235         x2,y2 = getNthPairBounded(i,W,H,useDivmod=True)
    236         print i,(x,y),(x2,y2)
    237         a[y][x] = '%2d'%i
    238         b[y2][x2] = '%2d'%i
    239 
    240     print '-- a --'
    241     for ln in a[::-1]:
    242         if ''.join(ln).strip():
    243             print '  '.join(ln)
    244     print '-- b --'
    245     for ln in b[::-1]:
    246         if ''.join(ln).strip():
    247             print '  '.join(ln)
    248 
    249 def testPairsVB():
    250     bounds = [2,2,4,aleph0,5,aleph0]
    251     a = [['  ' for x in range(15)] for y in range(15)]
    252     b = [['  ' for x in range(15)] for y in range(15)]
    253     for i in range(min(sum(bounds),40)):
    254         x,y = getNthPairVariableBounds(i, bounds)
    255         print i,(x,y)
    256         a[y][x] = '%2d'%i
    257 
    258     print '-- a --'
    259     for ln in a[::-1]:
    260         if ''.join(ln).strip():
    261             print '  '.join(ln)
    262 
    263 ###
    264 
    265 # Toggle to use checked versions of enumeration routines.
    266 if False:
    267     getNthPairVariableBounds = getNthPairVariableBoundsChecked
    268     getNthPairBounded = getNthPairBoundedChecked
    269     getNthNTuple = getNthNTupleChecked
    270     getNthTuple = getNthTupleChecked
    271 
    272 if __name__ == '__main__':
    273     testPairs()
    274 
    275     testPairsVB()
    276 
    277