Home | History | Annotate | Download | only in ec
      1 package org.bouncycastle.math.ec;
      2 
      3 import org.bouncycastle.util.Arrays;
      4 
      5 import java.math.BigInteger;
      6 
      7 class IntArray
      8 {
      9     // TODO make m fixed for the IntArray, and hence compute T once and for all
     10 
     11     private int[] m_ints;
     12 
     13     public IntArray(int intLen)
     14     {
     15         m_ints = new int[intLen];
     16     }
     17 
     18     public IntArray(int[] ints)
     19     {
     20         m_ints = ints;
     21     }
     22 
     23     public IntArray(BigInteger bigInt)
     24     {
     25         this(bigInt, 0);
     26     }
     27 
     28     public IntArray(BigInteger bigInt, int minIntLen)
     29     {
     30         if (bigInt.signum() == -1)
     31         {
     32             throw new IllegalArgumentException("Only positive Integers allowed");
     33         }
     34         if (bigInt.equals(ECConstants.ZERO))
     35         {
     36             m_ints = new int[] { 0 };
     37             return;
     38         }
     39 
     40         byte[] barr = bigInt.toByteArray();
     41         int barrLen = barr.length;
     42         int barrStart = 0;
     43         if (barr[0] == 0)
     44         {
     45             // First byte is 0 to enforce highest (=sign) bit is zero.
     46             // In this case ignore barr[0].
     47             barrLen--;
     48             barrStart = 1;
     49         }
     50         int intLen = (barrLen + 3) / 4;
     51         if (intLen < minIntLen)
     52         {
     53             m_ints = new int[minIntLen];
     54         }
     55         else
     56         {
     57             m_ints = new int[intLen];
     58         }
     59 
     60         int iarrJ = intLen - 1;
     61         int rem = barrLen % 4 + barrStart;
     62         int temp = 0;
     63         int barrI = barrStart;
     64         if (barrStart < rem)
     65         {
     66             for (; barrI < rem; barrI++)
     67             {
     68                 temp <<= 8;
     69                 int barrBarrI = barr[barrI];
     70                 if (barrBarrI < 0)
     71                 {
     72                     barrBarrI += 256;
     73                 }
     74                 temp |= barrBarrI;
     75             }
     76             m_ints[iarrJ--] = temp;
     77         }
     78 
     79         for (; iarrJ >= 0; iarrJ--)
     80         {
     81             temp = 0;
     82             for (int i = 0; i < 4; i++)
     83             {
     84                 temp <<= 8;
     85                 int barrBarrI = barr[barrI++];
     86                 if (barrBarrI < 0)
     87                 {
     88                     barrBarrI += 256;
     89                 }
     90                 temp |= barrBarrI;
     91             }
     92             m_ints[iarrJ] = temp;
     93         }
     94     }
     95 
     96     public boolean isZero()
     97     {
     98         return m_ints.length == 0
     99             || (m_ints[0] == 0 && getUsedLength() == 0);
    100     }
    101 
    102     public int getUsedLength()
    103     {
    104         int highestIntPos = m_ints.length;
    105 
    106         if (highestIntPos < 1)
    107         {
    108             return 0;
    109         }
    110 
    111         // Check if first element will act as sentinel
    112         if (m_ints[0] != 0)
    113         {
    114             while (m_ints[--highestIntPos] == 0)
    115             {
    116             }
    117             return highestIntPos + 1;
    118         }
    119 
    120         do
    121         {
    122             if (m_ints[--highestIntPos] != 0)
    123             {
    124                 return highestIntPos + 1;
    125             }
    126         }
    127         while (highestIntPos > 0);
    128 
    129         return 0;
    130     }
    131 
    132     public int bitLength()
    133     {
    134         // JDK 1.5: see Integer.numberOfLeadingZeros()
    135         int intLen = getUsedLength();
    136         if (intLen == 0)
    137         {
    138             return 0;
    139         }
    140 
    141         int last = intLen - 1;
    142         int highest = m_ints[last];
    143         int bits = (last << 5) + 1;
    144 
    145         // A couple of binary search steps
    146         if ((highest & 0xffff0000) != 0)
    147         {
    148             if ((highest & 0xff000000) != 0)
    149             {
    150                 bits += 24;
    151                 highest >>>= 24;
    152             }
    153             else
    154             {
    155                 bits += 16;
    156                 highest >>>= 16;
    157             }
    158         }
    159         else if (highest > 0x000000ff)
    160         {
    161             bits += 8;
    162             highest >>>= 8;
    163         }
    164 
    165         while (highest != 1)
    166         {
    167             ++bits;
    168             highest >>>= 1;
    169         }
    170 
    171         return bits;
    172     }
    173 
    174     private int[] resizedInts(int newLen)
    175     {
    176         int[] newInts = new int[newLen];
    177         int oldLen = m_ints.length;
    178         int copyLen = oldLen < newLen ? oldLen : newLen;
    179         System.arraycopy(m_ints, 0, newInts, 0, copyLen);
    180         return newInts;
    181     }
    182 
    183     public BigInteger toBigInteger()
    184     {
    185         int usedLen = getUsedLength();
    186         if (usedLen == 0)
    187         {
    188             return ECConstants.ZERO;
    189         }
    190 
    191         int highestInt = m_ints[usedLen - 1];
    192         byte[] temp = new byte[4];
    193         int barrI = 0;
    194         boolean trailingZeroBytesDone = false;
    195         for (int j = 3; j >= 0; j--)
    196         {
    197             byte thisByte = (byte) (highestInt >>> (8 * j));
    198             if (trailingZeroBytesDone || (thisByte != 0))
    199             {
    200                 trailingZeroBytesDone = true;
    201                 temp[barrI++] = thisByte;
    202             }
    203         }
    204 
    205         int barrLen = 4 * (usedLen - 1) + barrI;
    206         byte[] barr = new byte[barrLen];
    207         for (int j = 0; j < barrI; j++)
    208         {
    209             barr[j] = temp[j];
    210         }
    211         // Highest value int is done now
    212 
    213         for (int iarrJ = usedLen - 2; iarrJ >= 0; iarrJ--)
    214         {
    215             for (int j = 3; j >= 0; j--)
    216             {
    217                 barr[barrI++] = (byte) (m_ints[iarrJ] >>> (8 * j));
    218             }
    219         }
    220         return new BigInteger(1, barr);
    221     }
    222 
    223     public void shiftLeft()
    224     {
    225         int usedLen = getUsedLength();
    226         if (usedLen == 0)
    227         {
    228             return;
    229         }
    230         if (m_ints[usedLen - 1] < 0)
    231         {
    232             // highest bit of highest used byte is set, so shifting left will
    233             // make the IntArray one byte longer
    234             usedLen++;
    235             if (usedLen > m_ints.length)
    236             {
    237                 // make the m_ints one byte longer, because we need one more
    238                 // byte which is not available in m_ints
    239                 m_ints = resizedInts(m_ints.length + 1);
    240             }
    241         }
    242 
    243         boolean carry = false;
    244         for (int i = 0; i < usedLen; i++)
    245         {
    246             // nextCarry is true if highest bit is set
    247             boolean nextCarry = m_ints[i] < 0;
    248             m_ints[i] <<= 1;
    249             if (carry)
    250             {
    251                 // set lowest bit
    252                 m_ints[i] |= 1;
    253             }
    254             carry = nextCarry;
    255         }
    256     }
    257 
    258     public IntArray shiftLeft(int n)
    259     {
    260         int usedLen = getUsedLength();
    261         if (usedLen == 0)
    262         {
    263             return this;
    264         }
    265 
    266         if (n == 0)
    267         {
    268             return this;
    269         }
    270 
    271         if (n > 31)
    272         {
    273             throw new IllegalArgumentException("shiftLeft() for max 31 bits "
    274                 + ", " + n + "bit shift is not possible");
    275         }
    276 
    277         int[] newInts = new int[usedLen + 1];
    278 
    279         int nm32 = 32 - n;
    280         newInts[0] = m_ints[0] << n;
    281         for (int i = 1; i < usedLen; i++)
    282         {
    283             newInts[i] = (m_ints[i] << n) | (m_ints[i - 1] >>> nm32);
    284         }
    285         newInts[usedLen] = m_ints[usedLen - 1] >>> nm32;
    286 
    287         return new IntArray(newInts);
    288     }
    289 
    290     public void addShifted(IntArray other, int shift)
    291     {
    292         int usedLenOther = other.getUsedLength();
    293         int newMinUsedLen = usedLenOther + shift;
    294         if (newMinUsedLen > m_ints.length)
    295         {
    296             m_ints = resizedInts(newMinUsedLen);
    297             //System.out.println("Resize required");
    298         }
    299 
    300         for (int i = 0; i < usedLenOther; i++)
    301         {
    302             m_ints[i + shift] ^= other.m_ints[i];
    303         }
    304     }
    305 
    306     public int getLength()
    307     {
    308         return m_ints.length;
    309     }
    310 
    311     public boolean testBit(int n)
    312     {
    313         // theInt = n / 32
    314         int theInt = n >> 5;
    315         // theBit = n % 32
    316         int theBit = n & 0x1F;
    317         int tester = 1 << theBit;
    318         return ((m_ints[theInt] & tester) != 0);
    319     }
    320 
    321     public void flipBit(int n)
    322     {
    323         // theInt = n / 32
    324         int theInt = n >> 5;
    325         // theBit = n % 32
    326         int theBit = n & 0x1F;
    327         int flipper = 1 << theBit;
    328         m_ints[theInt] ^= flipper;
    329     }
    330 
    331     public void setBit(int n)
    332     {
    333         // theInt = n / 32
    334         int theInt = n >> 5;
    335         // theBit = n % 32
    336         int theBit = n & 0x1F;
    337         int setter = 1 << theBit;
    338         m_ints[theInt] |= setter;
    339     }
    340 
    341     public IntArray multiply(IntArray other, int m)
    342     {
    343         // Lenght of c is 2m bits rounded up to the next int (32 bit)
    344         int t = (m + 31) >> 5;
    345         if (m_ints.length < t)
    346         {
    347             m_ints = resizedInts(t);
    348         }
    349 
    350         IntArray b = new IntArray(other.resizedInts(other.getLength() + 1));
    351         IntArray c = new IntArray((m + m + 31) >> 5);
    352         // IntArray c = new IntArray(t + t);
    353         int testBit = 1;
    354         for (int k = 0; k < 32; k++)
    355         {
    356             for (int j = 0; j < t; j++)
    357             {
    358                 if ((m_ints[j] & testBit) != 0)
    359                 {
    360                     // The kth bit of m_ints[j] is set
    361                     c.addShifted(b, j);
    362                 }
    363             }
    364             testBit <<= 1;
    365             b.shiftLeft();
    366         }
    367         return c;
    368     }
    369 
    370     // public IntArray multiplyLeftToRight(IntArray other, int m) {
    371     // // Lenght of c is 2m bits rounded up to the next int (32 bit)
    372     // int t = (m + 31) / 32;
    373     // if (m_ints.length < t) {
    374     // m_ints = resizedInts(t);
    375     // }
    376     //
    377     // IntArray b = new IntArray(other.resizedInts(other.getLength() + 1));
    378     // IntArray c = new IntArray((m + m + 31) / 32);
    379     // // IntArray c = new IntArray(t + t);
    380     // int testBit = 1 << 31;
    381     // for (int k = 31; k >= 0; k--) {
    382     // for (int j = 0; j < t; j++) {
    383     // if ((m_ints[j] & testBit) != 0) {
    384     // // The kth bit of m_ints[j] is set
    385     // c.addShifted(b, j);
    386     // }
    387     // }
    388     // testBit >>>= 1;
    389     // if (k > 0) {
    390     // c.shiftLeft();
    391     // }
    392     // }
    393     // return c;
    394     // }
    395 
    396     // TODO note, redPol.length must be 3 for TPB and 5 for PPB
    397     public void reduce(int m, int[] redPol)
    398     {
    399         for (int i = m + m - 2; i >= m; i--)
    400         {
    401             if (testBit(i))
    402             {
    403                 int bit = i - m;
    404                 flipBit(bit);
    405                 flipBit(i);
    406                 int l = redPol.length;
    407                 while (--l >= 0)
    408                 {
    409                     flipBit(redPol[l] + bit);
    410                 }
    411             }
    412         }
    413         m_ints = resizedInts((m + 31) >> 5);
    414     }
    415 
    416     public IntArray square(int m)
    417     {
    418         // TODO make the table static final
    419         final int[] table = { 0x0, 0x1, 0x4, 0x5, 0x10, 0x11, 0x14, 0x15, 0x40,
    420             0x41, 0x44, 0x45, 0x50, 0x51, 0x54, 0x55 };
    421 
    422         int t = (m + 31) >> 5;
    423         if (m_ints.length < t)
    424         {
    425             m_ints = resizedInts(t);
    426         }
    427 
    428         IntArray c = new IntArray(t + t);
    429 
    430         // TODO twice the same code, put in separate private method
    431         for (int i = 0; i < t; i++)
    432         {
    433             int v0 = 0;
    434             for (int j = 0; j < 4; j++)
    435             {
    436                 v0 = v0 >>> 8;
    437                 int u = (m_ints[i] >>> (j * 4)) & 0xF;
    438                 int w = table[u] << 24;
    439                 v0 |= w;
    440             }
    441             c.m_ints[i + i] = v0;
    442 
    443             v0 = 0;
    444             int upper = m_ints[i] >>> 16;
    445             for (int j = 0; j < 4; j++)
    446             {
    447                 v0 = v0 >>> 8;
    448                 int u = (upper >>> (j * 4)) & 0xF;
    449                 int w = table[u] << 24;
    450                 v0 |= w;
    451             }
    452             c.m_ints[i + i + 1] = v0;
    453         }
    454         return c;
    455     }
    456 
    457     public boolean equals(Object o)
    458     {
    459         if (!(o instanceof IntArray))
    460         {
    461             return false;
    462         }
    463         IntArray other = (IntArray) o;
    464         int usedLen = getUsedLength();
    465         if (other.getUsedLength() != usedLen)
    466         {
    467             return false;
    468         }
    469         for (int i = 0; i < usedLen; i++)
    470         {
    471             if (m_ints[i] != other.m_ints[i])
    472             {
    473                 return false;
    474             }
    475         }
    476         return true;
    477     }
    478 
    479     public int hashCode()
    480     {
    481         int usedLen = getUsedLength();
    482         int hash = 1;
    483         for (int i = 0; i < usedLen; i++)
    484         {
    485             hash = hash * 31 + m_ints[i];
    486         }
    487         return hash;
    488     }
    489 
    490     public Object clone()
    491     {
    492         return new IntArray(Arrays.clone(m_ints));
    493     }
    494 
    495     public String toString()
    496     {
    497         int usedLen = getUsedLength();
    498         if (usedLen == 0)
    499         {
    500             return "0";
    501         }
    502 
    503         StringBuffer sb = new StringBuffer(Integer
    504             .toBinaryString(m_ints[usedLen - 1]));
    505         for (int iarrJ = usedLen - 2; iarrJ >= 0; iarrJ--)
    506         {
    507             String hexString = Integer.toBinaryString(m_ints[iarrJ]);
    508 
    509             // Add leading zeroes, except for highest significant int
    510             for (int i = hexString.length(); i < 8; i++)
    511             {
    512                 hexString = "0" + hexString;
    513             }
    514             sb.append(hexString);
    515         }
    516         return sb.toString();
    517     }
    518 }
    519