Home | History | Annotate | Download | only in ec
      1 package org.bouncycastle.math.ec;
      2 
      3 import java.math.BigInteger;
      4 
      5 import org.bouncycastle.math.ec.endo.ECEndomorphism;
      6 import org.bouncycastle.math.ec.endo.GLVEndomorphism;
      7 import org.bouncycastle.math.field.FiniteField;
      8 import org.bouncycastle.math.field.PolynomialExtensionField;
      9 
     10 public class ECAlgorithms
     11 {
     12     public static boolean isF2mCurve(ECCurve c)
     13     {
     14         return isF2mField(c.getField());
     15     }
     16 
     17     public static boolean isF2mField(FiniteField field)
     18     {
     19         return field.getDimension() > 1 && field.getCharacteristic().equals(ECConstants.TWO)
     20             && field instanceof PolynomialExtensionField;
     21     }
     22 
     23     public static boolean isFpCurve(ECCurve c)
     24     {
     25         return isFpField(c.getField());
     26     }
     27 
     28     public static boolean isFpField(FiniteField field)
     29     {
     30         return field.getDimension() == 1;
     31     }
     32 
     33     public static ECPoint sumOfMultiplies(ECPoint[] ps, BigInteger[] ks)
     34     {
     35         if (ps == null || ks == null || ps.length != ks.length || ps.length < 1)
     36         {
     37             throw new IllegalArgumentException("point and scalar arrays should be non-null, and of equal, non-zero, length");
     38         }
     39 
     40         int count = ps.length;
     41         switch (count)
     42         {
     43         case 1:
     44             return ps[0].multiply(ks[0]);
     45         case 2:
     46             return sumOfTwoMultiplies(ps[0], ks[0], ps[1], ks[1]);
     47         default:
     48             break;
     49         }
     50 
     51         ECPoint p = ps[0];
     52         ECCurve c = p.getCurve();
     53 
     54         ECPoint[] imported = new ECPoint[count];
     55         imported[0] = p;
     56         for (int i = 1; i < count; ++i)
     57         {
     58             imported[i] = importPoint(c, ps[i]);
     59         }
     60 
     61         ECEndomorphism endomorphism = c.getEndomorphism();
     62         if (endomorphism instanceof GLVEndomorphism)
     63         {
     64             return validatePoint(implSumOfMultipliesGLV(imported, ks, (GLVEndomorphism)endomorphism));
     65         }
     66 
     67         return validatePoint(implSumOfMultiplies(imported, ks));
     68     }
     69 
     70     public static ECPoint sumOfTwoMultiplies(ECPoint P, BigInteger a,
     71         ECPoint Q, BigInteger b)
     72     {
     73         ECCurve cp = P.getCurve();
     74         Q = importPoint(cp, Q);
     75 
     76         // Point multiplication for Koblitz curves (using WTNAF) beats Shamir's trick
     77         if (cp instanceof ECCurve.AbstractF2m)
     78         {
     79             ECCurve.AbstractF2m f2mCurve = (ECCurve.AbstractF2m)cp;
     80             if (f2mCurve.isKoblitz())
     81             {
     82                 return validatePoint(P.multiply(a).add(Q.multiply(b)));
     83             }
     84         }
     85 
     86         ECEndomorphism endomorphism = cp.getEndomorphism();
     87         if (endomorphism instanceof GLVEndomorphism)
     88         {
     89             return validatePoint(
     90                 implSumOfMultipliesGLV(new ECPoint[]{ P, Q }, new BigInteger[]{ a, b }, (GLVEndomorphism)endomorphism));
     91         }
     92 
     93         return validatePoint(implShamirsTrickWNaf(P, a, Q, b));
     94     }
     95 
     96     /*
     97      * "Shamir's Trick", originally due to E. G. Straus
     98      * (Addition chains of vectors. American Mathematical Monthly,
     99      * 71(7):806-808, Aug./Sept. 1964)
    100      * <pre>
    101      * Input: The points P, Q, scalar k = (km?, ... , k1, k0)
    102      * and scalar l = (lm?, ... , l1, l0).
    103      * Output: R = k * P + l * Q.
    104      * 1: Z <- P + Q
    105      * 2: R <- O
    106      * 3: for i from m-1 down to 0 do
    107      * 4:        R <- R + R        {point doubling}
    108      * 5:        if (ki = 1) and (li = 0) then R <- R + P end if
    109      * 6:        if (ki = 0) and (li = 1) then R <- R + Q end if
    110      * 7:        if (ki = 1) and (li = 1) then R <- R + Z end if
    111      * 8: end for
    112      * 9: return R
    113      * </pre>
    114      */
    115     public static ECPoint shamirsTrick(ECPoint P, BigInteger k,
    116         ECPoint Q, BigInteger l)
    117     {
    118         ECCurve cp = P.getCurve();
    119         Q = importPoint(cp, Q);
    120 
    121         return validatePoint(implShamirsTrickJsf(P, k, Q, l));
    122     }
    123 
    124     public static ECPoint importPoint(ECCurve c, ECPoint p)
    125     {
    126         ECCurve cp = p.getCurve();
    127         if (!c.equals(cp))
    128         {
    129             throw new IllegalArgumentException("Point must be on the same curve");
    130         }
    131         return c.importPoint(p);
    132     }
    133 
    134     public static void montgomeryTrick(ECFieldElement[] zs, int off, int len)
    135     {
    136         montgomeryTrick(zs, off, len, null);
    137     }
    138 
    139     public static void montgomeryTrick(ECFieldElement[] zs, int off, int len, ECFieldElement scale)
    140     {
    141         /*
    142          * Uses the "Montgomery Trick" to invert many field elements, with only a single actual
    143          * field inversion. See e.g. the paper:
    144          * "Fast Multi-scalar Multiplication Methods on Elliptic Curves with Precomputation Strategy Using Montgomery Trick"
    145          * by Katsuyuki Okeya, Kouichi Sakurai.
    146          */
    147 
    148         ECFieldElement[] c = new ECFieldElement[len];
    149         c[0] = zs[off];
    150 
    151         int i = 0;
    152         while (++i < len)
    153         {
    154             c[i] = c[i - 1].multiply(zs[off + i]);
    155         }
    156 
    157         --i;
    158 
    159         if (scale != null)
    160         {
    161             c[i] = c[i].multiply(scale);
    162         }
    163 
    164         ECFieldElement u = c[i].invert();
    165 
    166         while (i > 0)
    167         {
    168             int j = off + i--;
    169             ECFieldElement tmp = zs[j];
    170             zs[j] = c[i].multiply(u);
    171             u = u.multiply(tmp);
    172         }
    173 
    174         zs[off] = u;
    175     }
    176 
    177     /**
    178      * Simple shift-and-add multiplication. Serves as reference implementation
    179      * to verify (possibly faster) implementations, and for very small scalars.
    180      *
    181      * @param p
    182      *            The point to multiply.
    183      * @param k
    184      *            The multiplier.
    185      * @return The result of the point multiplication <code>kP</code>.
    186      */
    187     public static ECPoint referenceMultiply(ECPoint p, BigInteger k)
    188     {
    189         BigInteger x = k.abs();
    190         ECPoint q = p.getCurve().getInfinity();
    191         int t = x.bitLength();
    192         if (t > 0)
    193         {
    194             if (x.testBit(0))
    195             {
    196                 q = p;
    197             }
    198             for (int i = 1; i < t; i++)
    199             {
    200                 p = p.twice();
    201                 if (x.testBit(i))
    202                 {
    203                     q = q.add(p);
    204                 }
    205             }
    206         }
    207         return k.signum() < 0 ? q.negate() : q;
    208     }
    209 
    210     public static ECPoint validatePoint(ECPoint p)
    211     {
    212         if (!p.isValid())
    213         {
    214             throw new IllegalArgumentException("Invalid point");
    215         }
    216 
    217         return p;
    218     }
    219 
    220     static ECPoint implShamirsTrickJsf(ECPoint P, BigInteger k,
    221         ECPoint Q, BigInteger l)
    222     {
    223         ECCurve curve = P.getCurve();
    224         ECPoint infinity = curve.getInfinity();
    225 
    226         // TODO conjugate co-Z addition (ZADDC) can return both of these
    227         ECPoint PaddQ = P.add(Q);
    228         ECPoint PsubQ = P.subtract(Q);
    229 
    230         ECPoint[] points = new ECPoint[]{ Q, PsubQ, P, PaddQ };
    231         curve.normalizeAll(points);
    232 
    233         ECPoint[] table = new ECPoint[] {
    234             points[3].negate(), points[2].negate(), points[1].negate(),
    235             points[0].negate(), infinity, points[0],
    236             points[1], points[2], points[3] };
    237 
    238         byte[] jsf = WNafUtil.generateJSF(k, l);
    239 
    240         ECPoint R = infinity;
    241 
    242         int i = jsf.length;
    243         while (--i >= 0)
    244         {
    245             int jsfi = jsf[i];
    246 
    247             // NOTE: The shifting ensures the sign is extended correctly
    248             int kDigit = ((jsfi << 24) >> 28), lDigit = ((jsfi << 28) >> 28);
    249 
    250             int index = 4 + (kDigit * 3) + lDigit;
    251             R = R.twicePlus(table[index]);
    252         }
    253 
    254         return R;
    255     }
    256 
    257     static ECPoint implShamirsTrickWNaf(ECPoint P, BigInteger k,
    258         ECPoint Q, BigInteger l)
    259     {
    260         boolean negK = k.signum() < 0, negL = l.signum() < 0;
    261 
    262         k = k.abs();
    263         l = l.abs();
    264 
    265         int widthP = Math.max(2, Math.min(16, WNafUtil.getWindowSize(k.bitLength())));
    266         int widthQ = Math.max(2, Math.min(16, WNafUtil.getWindowSize(l.bitLength())));
    267 
    268         WNafPreCompInfo infoP = WNafUtil.precompute(P, widthP, true);
    269         WNafPreCompInfo infoQ = WNafUtil.precompute(Q, widthQ, true);
    270 
    271         ECPoint[] preCompP = negK ? infoP.getPreCompNeg() : infoP.getPreComp();
    272         ECPoint[] preCompQ = negL ? infoQ.getPreCompNeg() : infoQ.getPreComp();
    273         ECPoint[] preCompNegP = negK ? infoP.getPreComp() : infoP.getPreCompNeg();
    274         ECPoint[] preCompNegQ = negL ? infoQ.getPreComp() : infoQ.getPreCompNeg();
    275 
    276         byte[] wnafP = WNafUtil.generateWindowNaf(widthP, k);
    277         byte[] wnafQ = WNafUtil.generateWindowNaf(widthQ, l);
    278 
    279         return implShamirsTrickWNaf(preCompP, preCompNegP, wnafP, preCompQ, preCompNegQ, wnafQ);
    280     }
    281 
    282     static ECPoint implShamirsTrickWNaf(ECPoint P, BigInteger k, ECPointMap pointMapQ, BigInteger l)
    283     {
    284         boolean negK = k.signum() < 0, negL = l.signum() < 0;
    285 
    286         k = k.abs();
    287         l = l.abs();
    288 
    289         int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(Math.max(k.bitLength(), l.bitLength()))));
    290 
    291         ECPoint Q = WNafUtil.mapPointWithPrecomp(P, width, true, pointMapQ);
    292         WNafPreCompInfo infoP = WNafUtil.getWNafPreCompInfo(P);
    293         WNafPreCompInfo infoQ = WNafUtil.getWNafPreCompInfo(Q);
    294 
    295         ECPoint[] preCompP = negK ? infoP.getPreCompNeg() : infoP.getPreComp();
    296         ECPoint[] preCompQ = negL ? infoQ.getPreCompNeg() : infoQ.getPreComp();
    297         ECPoint[] preCompNegP = negK ? infoP.getPreComp() : infoP.getPreCompNeg();
    298         ECPoint[] preCompNegQ = negL ? infoQ.getPreComp() : infoQ.getPreCompNeg();
    299 
    300         byte[] wnafP = WNafUtil.generateWindowNaf(width, k);
    301         byte[] wnafQ = WNafUtil.generateWindowNaf(width, l);
    302 
    303         return implShamirsTrickWNaf(preCompP, preCompNegP, wnafP, preCompQ, preCompNegQ, wnafQ);
    304     }
    305 
    306     private static ECPoint implShamirsTrickWNaf(ECPoint[] preCompP, ECPoint[] preCompNegP, byte[] wnafP,
    307         ECPoint[] preCompQ, ECPoint[] preCompNegQ, byte[] wnafQ)
    308     {
    309         int len = Math.max(wnafP.length, wnafQ.length);
    310 
    311         ECCurve curve = preCompP[0].getCurve();
    312         ECPoint infinity = curve.getInfinity();
    313 
    314         ECPoint R = infinity;
    315         int zeroes = 0;
    316 
    317         for (int i = len - 1; i >= 0; --i)
    318         {
    319             int wiP = i < wnafP.length ? wnafP[i] : 0;
    320             int wiQ = i < wnafQ.length ? wnafQ[i] : 0;
    321 
    322             if ((wiP | wiQ) == 0)
    323             {
    324                 ++zeroes;
    325                 continue;
    326             }
    327 
    328             ECPoint r = infinity;
    329             if (wiP != 0)
    330             {
    331                 int nP = Math.abs(wiP);
    332                 ECPoint[] tableP = wiP < 0 ? preCompNegP : preCompP;
    333                 r = r.add(tableP[nP >>> 1]);
    334             }
    335             if (wiQ != 0)
    336             {
    337                 int nQ = Math.abs(wiQ);
    338                 ECPoint[] tableQ = wiQ < 0 ? preCompNegQ : preCompQ;
    339                 r = r.add(tableQ[nQ >>> 1]);
    340             }
    341 
    342             if (zeroes > 0)
    343             {
    344                 R = R.timesPow2(zeroes);
    345                 zeroes = 0;
    346             }
    347 
    348             R = R.twicePlus(r);
    349         }
    350 
    351         if (zeroes > 0)
    352         {
    353             R = R.timesPow2(zeroes);
    354         }
    355 
    356         return R;
    357     }
    358 
    359     static ECPoint implSumOfMultiplies(ECPoint[] ps, BigInteger[] ks)
    360     {
    361         int count = ps.length;
    362         boolean[] negs = new boolean[count];
    363         WNafPreCompInfo[] infos = new WNafPreCompInfo[count];
    364         byte[][] wnafs = new byte[count][];
    365 
    366         for (int i = 0; i < count; ++i)
    367         {
    368             BigInteger ki = ks[i]; negs[i] = ki.signum() < 0; ki = ki.abs();
    369 
    370             int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(ki.bitLength())));
    371             infos[i] = WNafUtil.precompute(ps[i], width, true);
    372             wnafs[i] = WNafUtil.generateWindowNaf(width, ki);
    373         }
    374 
    375         return implSumOfMultiplies(negs, infos, wnafs);
    376     }
    377 
    378     static ECPoint implSumOfMultipliesGLV(ECPoint[] ps, BigInteger[] ks, GLVEndomorphism glvEndomorphism)
    379     {
    380         BigInteger n = ps[0].getCurve().getOrder();
    381 
    382         int len = ps.length;
    383 
    384         BigInteger[] abs = new BigInteger[len << 1];
    385         for (int i = 0, j = 0; i < len; ++i)
    386         {
    387             BigInteger[] ab = glvEndomorphism.decomposeScalar(ks[i].mod(n));
    388             abs[j++] = ab[0];
    389             abs[j++] = ab[1];
    390         }
    391 
    392         ECPointMap pointMap = glvEndomorphism.getPointMap();
    393         if (glvEndomorphism.hasEfficientPointMap())
    394         {
    395             return ECAlgorithms.implSumOfMultiplies(ps, pointMap, abs);
    396         }
    397 
    398         ECPoint[] pqs = new ECPoint[len << 1];
    399         for (int i = 0, j = 0; i < len; ++i)
    400         {
    401             ECPoint p = ps[i], q = pointMap.map(p);
    402             pqs[j++] = p;
    403             pqs[j++] = q;
    404         }
    405 
    406         return ECAlgorithms.implSumOfMultiplies(pqs, abs);
    407 
    408     }
    409 
    410     static ECPoint implSumOfMultiplies(ECPoint[] ps, ECPointMap pointMap, BigInteger[] ks)
    411     {
    412         int halfCount = ps.length, fullCount = halfCount << 1;
    413 
    414         boolean[] negs = new boolean[fullCount];
    415         WNafPreCompInfo[] infos = new WNafPreCompInfo[fullCount];
    416         byte[][] wnafs = new byte[fullCount][];
    417 
    418         for (int i = 0; i < halfCount; ++i)
    419         {
    420             int j0 = i << 1, j1 = j0 + 1;
    421 
    422             BigInteger kj0 = ks[j0]; negs[j0] = kj0.signum() < 0; kj0 = kj0.abs();
    423             BigInteger kj1 = ks[j1]; negs[j1] = kj1.signum() < 0; kj1 = kj1.abs();
    424 
    425             int width = Math.max(2, Math.min(16, WNafUtil.getWindowSize(Math.max(kj0.bitLength(), kj1.bitLength()))));
    426 
    427             ECPoint P = ps[i], Q = WNafUtil.mapPointWithPrecomp(P, width, true, pointMap);
    428             infos[j0] = WNafUtil.getWNafPreCompInfo(P);
    429             infos[j1] = WNafUtil.getWNafPreCompInfo(Q);
    430             wnafs[j0] = WNafUtil.generateWindowNaf(width, kj0);
    431             wnafs[j1] = WNafUtil.generateWindowNaf(width, kj1);
    432         }
    433 
    434         return implSumOfMultiplies(negs, infos, wnafs);
    435     }
    436 
    437     private static ECPoint implSumOfMultiplies(boolean[] negs, WNafPreCompInfo[] infos, byte[][] wnafs)
    438     {
    439         int len = 0, count = wnafs.length;
    440         for (int i = 0; i < count; ++i)
    441         {
    442             len = Math.max(len, wnafs[i].length);
    443         }
    444 
    445         ECCurve curve = infos[0].getPreComp()[0].getCurve();
    446         ECPoint infinity = curve.getInfinity();
    447 
    448         ECPoint R = infinity;
    449         int zeroes = 0;
    450 
    451         for (int i = len - 1; i >= 0; --i)
    452         {
    453             ECPoint r = infinity;
    454 
    455             for (int j = 0; j < count; ++j)
    456             {
    457                 byte[] wnaf = wnafs[j];
    458                 int wi = i < wnaf.length ? wnaf[i] : 0;
    459                 if (wi != 0)
    460                 {
    461                     int n = Math.abs(wi);
    462                     WNafPreCompInfo info = infos[j];
    463                     ECPoint[] table = (wi < 0 == negs[j]) ? info.getPreComp() : info.getPreCompNeg();
    464                     r = r.add(table[n >>> 1]);
    465                 }
    466             }
    467 
    468             if (r == infinity)
    469             {
    470                 ++zeroes;
    471                 continue;
    472             }
    473 
    474             if (zeroes > 0)
    475             {
    476                 R = R.timesPow2(zeroes);
    477                 zeroes = 0;
    478             }
    479 
    480             R = R.twicePlus(r);
    481         }
    482 
    483         if (zeroes > 0)
    484         {
    485             R = R.timesPow2(zeroes);
    486         }
    487 
    488         return R;
    489     }
    490 }
    491