Home | History | Annotate | Download | only in nanohub
      1 /*
      2  * Copyright (C) 2016 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <stdint.h>
     18 #include <stdbool.h>
     19 #include <string.h>
     20 #include <nanohub/rsa.h>
     21 
     22 
     23 static bool biModIterative(uint32_t *num, const uint32_t *denum, uint32_t *tmp, uint32_t *state1, uint32_t *state2, uint32_t step)
     24 //num %= denum where num is RSA_LEN * 2 and denum is RSA_LEN and tmp is RSA_LEN + limb_sz
     25 //will need to be called till it returns true (up to RSA_LEN * 2 + 2 times)
     26 {
     27     uint32_t bitsh = *state1, limbsh = *state2;
     28     bool ret = false;
     29     int64_t t;
     30     int32_t i;
     31 
     32     //first step is init
     33     if (!step) {
     34         //initially set it up left shifted as far as possible
     35         memcpy(tmp + 1, denum, RSA_BYTES);
     36         tmp[0] = 0;
     37         bitsh = 32;
     38         limbsh = RSA_LIMBS - 1;
     39         goto out;
     40     }
     41 
     42     //second is shifting denum
     43     if (step == 1) {
     44         while (!(tmp[RSA_LIMBS] & 0x80000000)) {
     45             for (i = RSA_LIMBS; i > 0; i--) {
     46                 tmp[i] <<= 1;
     47                 if (tmp[i - 1] & 0x80000000)
     48                     tmp[i]++;
     49             }
     50             //no need to adjust tmp[0] as it is still zero
     51             bitsh++;
     52         }
     53         goto out;
     54     }
     55 
     56     //all future steps do the division
     57 
     58     //check if we should subtract (uses less space than subtracting and unroling it later)
     59     for (i = RSA_LIMBS; i >= 0; i--) {
     60         if (num[limbsh + i] < tmp[i])
     61             goto dont_subtract;
     62         if (num[limbsh + i] > tmp[i])
     63             break;
     64     }
     65 
     66     //subtract
     67     t = 0;
     68     for (i = 0; i <= RSA_LIMBS; i++) {
     69         t += (uint64_t)num[limbsh + i];
     70         t -= (uint64_t)tmp[i];
     71         num[limbsh + i] = t;
     72         t >>= 32;
     73     }
     74 
     75     //carry the subtraction's carry to the end
     76     for (i = RSA_LIMBS + limbsh + 1; i < RSA_LIMBS * 2; i++) {
     77         t += (uint64_t)num[i];
     78         num[i] = t;
     79         t >>= 32;
     80     }
     81 
     82 dont_subtract:
     83     //handle bitshifts/refills
     84     if (!bitsh) {                          // tmp = denum << 32
     85         if (!limbsh) {
     86             ret = true;
     87             goto out;
     88         }
     89 
     90         memcpy(tmp + 1, denum, RSA_BYTES);
     91         tmp[0] = 0;
     92         bitsh = 32;
     93         limbsh--;
     94     }
     95     else {                                 // tmp >>= 1
     96         for (i = 0; i < RSA_LIMBS; i++) {
     97             tmp[i] >>= 1;
     98             if (tmp[i + 1] & 1)
     99                 tmp[i] += 0x80000000;
    100         }
    101         tmp[i] >>= 1;
    102         bitsh--;
    103     }
    104 
    105 
    106 out:
    107     *state1 = bitsh;
    108     *state2 = limbsh;
    109     return ret;
    110 }
    111 
    112 static void biMulIterative(uint32_t *ret, const uint32_t *a, const uint32_t *b, uint32_t step) //ret = a * b, call with step = [0..RSA_LIMBS)
    113 {
    114     uint32_t j, c;
    115     uint64_t r;
    116 
    117     //zero the result on first call
    118     if (!step)
    119         memset(ret, 0, RSA_BYTES * 2);
    120 
    121     //produce a partial sum & add it in
    122     c = 0;
    123     for (j = 0; j < RSA_LIMBS; j++) {
    124         r = (uint64_t)a[step] * b[j] + c + ret[step + j];
    125         ret[step + j] = r;
    126         c = r >> 32;
    127     }
    128 
    129     //carry the carry to the end
    130     for (j = step + RSA_LIMBS; j < RSA_LIMBS * 2; j++) {
    131         r = (uint64_t)ret[j] + c;
    132         ret[j] = r;
    133         c = r >> 32;
    134     }
    135 }
    136 
    137 /*
    138  * Piecewise RSA:
    139  * normal RSA public op with 65537 exponent does 34 operations. 17 muls and 17 mods, as follows:
    140  * 16x {mul, mod} to calculate a ^ 65536 mod c
    141  * 1x {mul, mod} to calculate a ^ 65537 mod c
    142  * we break up each mul and mod itself into more steps. mul needs RSA_LIMBS steps, and mod needs up to RSA_LEN * 2 + 2 steps
    143  * so if we allocate RSA_LEN * 3 step values to mod, each mul-mod pair will use <= RSA_LEN * 4 step values
    144  * and the whole opetaion will need <= RSA_LEN * 4 * 34 step values, which fits into a uint32. cool. In fact
    145  * some values will be skipped, but this makes life easier, really. Call this func with *stepP = 0, and keep calling till
    146  * output stepP is zero. We'll call each of the RSA_LEN * 4 pieces a gigastep, and have 17 of them as seen above. Each
    147  * will be logically separated into 4 megasteps. First will contain the MUL, last 3 the MOD and maybe the memcpy.
    148  * In the first 16 gigasteps, the very last step of the gigastep will be used for the memcpy call.
    149  *
    150  * The initial non-iterative RSA logic looks as follows, shown here for clarity:
    151  *
    152  *   memcpy(state->tmpB, a, RSA_BYTES);
    153  *   for (i = 0; i < 16; i++) {
    154  *       biMul(state->tmpA, state->tmpB, state->tmpB);
    155  *       biMod(state->tmpA, c, state->tmpB);
    156  *       memcpy(state->tmpB, state->tmpA, RSA_BYTES);
    157  *   }
    158  *
    159  *   //calculate a ^ 65537 mod c into state->tmpA [ at this point this means do state->tmpA = (state->tmpB * a) % c ]
    160  *   biMul(state->tmpA, state->tmpB, a);
    161  *   biMod(state->tmpA, c, state->tmpB);
    162  *
    163  *   //return result
    164  *   return state->tmpA;
    165  *
    166  */
    167 
    168 const uint32_t* rsaPubOpIterative(struct RsaState* state, const uint32_t *a, const uint32_t *c, uint32_t *state1, uint32_t *state2, uint32_t *stepP)
    169 {
    170     uint32_t step = *stepP, gigastep, gigastepBase, gigastepSubstep, megaSubstep;
    171 
    172     //step 0: copy a -> tmpB
    173     if (!step) {
    174         memcpy(state->tmpB, a, RSA_BYTES);
    175         step = 1;
    176     }
    177     else { //subsequent steps: do real work
    178 
    179 
    180         gigastep = (step - 1) / (RSA_LEN * 4);
    181         gigastepSubstep = (step - 1) % (RSA_LEN * 4);
    182         gigastepBase = gigastep * (RSA_LEN * 4);
    183         megaSubstep = gigastepSubstep / RSA_LEN;
    184 
    185         if (!megaSubstep) { // first megastep of the gigastep - MUL
    186             biMulIterative(state->tmpA, state->tmpB, gigastep == 16 ? a : state->tmpB, gigastepSubstep);
    187             if (gigastepSubstep == RSA_LIMBS - 1) //MUL is done - do mod next
    188                 step = gigastepBase + RSA_LEN + 1;
    189             else                                  //More of MUL is left to do
    190                 step++;
    191         }
    192         else if (gigastepSubstep != RSA_LEN * 4 - 1){   // second part of gigastep - MOD
    193             if (biModIterative(state->tmpA, c, state->tmpB, state1, state2, gigastepSubstep - RSA_LEN)) { //MOD is done
    194                 if (gigastep == 16) // we're done
    195                     step = 0;
    196                 else              // last part of the gigastep is a copy
    197                     step = gigastepBase + RSA_LEN * 4 - 1 + 1;
    198             }
    199             else
    200                 step++;
    201         }
    202         else {   //last part - memcpy
    203             memcpy(state->tmpB, state->tmpA, RSA_BYTES);
    204             step++;
    205         }
    206     }
    207 
    208     *stepP = step;
    209     return state->tmpA;
    210 }
    211 
    212 #if defined(RSA_SUPPORT_PRIV_OP_LOWRAM) || defined (RSA_SUPPORT_PRIV_OP_BIGRAM)
    213 #include <stdio.h>
    214 const uint32_t* rsaPubOp(struct RsaState* state, const uint32_t *a, const uint32_t *c)
    215 {
    216     const uint32_t *ret;
    217     uint32_t state1 = 0, state2 = 0, step = 0, ns = 0;
    218 
    219     do {
    220         ret = rsaPubOpIterative(state, a, c, &state1, &state2, &step);
    221         ns++;
    222     } while(step);
    223 
    224 fprintf(stderr, "steps: %u\n", ns);
    225 
    226     return ret;
    227 }
    228 
    229 static void biMod(uint32_t *num, const uint32_t *denum, uint32_t *tmp)
    230 {
    231     uint32_t state1 = 0, state2 = 0, step;
    232 
    233     for (step = 0; !biModIterative(num, denum, tmp, &state1, &state2, step); step++);
    234 }
    235 
    236 static void biMul(uint32_t *ret, const uint32_t *a, const uint32_t *b)
    237 {
    238     uint32_t step;
    239 
    240     for (step = 0; step < RSA_LIMBS; step++)
    241         biMulIterative(ret, a, b, step);
    242 }
    243 
    244 const uint32_t* rsaPrivOp(struct RsaState* state, const uint32_t *a, const uint32_t *b, const uint32_t *c)
    245 {
    246     uint32_t i;
    247 
    248     memcpy(state->tmpC, a, RSA_BYTES);  //tC will hold our powers of a
    249 
    250     memset(state->tmpA, 0, RSA_BYTES * 2); //tA will hold result
    251     state->tmpA[0] = 1;
    252 
    253     for (i = 0; i < RSA_LEN; i++) {
    254         //if the bit is set, multiply the current power of A into result
    255         if (b[i / 32] & (1 << (i % 32))) {
    256             memcpy(state->tmpB, state->tmpA, RSA_BYTES);
    257             biMul(state->tmpA, state->tmpB, state->tmpC);
    258             biMod(state->tmpA, c, state->tmpB);
    259         }
    260 
    261         //calculate the next power of a and modulus it
    262 #if defined(RSA_SUPPORT_PRIV_OP_LOWRAM)
    263         memcpy(state->tmpB, state->tmpA, RSA_BYTES); //save tA
    264         biMul(state->tmpA, state->tmpC, state->tmpC);
    265         biMod(state->tmpA, c, state->tmpC);
    266         memcpy(state->tmpC, state->tmpA, RSA_BYTES);
    267         memcpy(state->tmpA, state->tmpB, RSA_BYTES); //restore tA
    268 #elif defined (RSA_SUPPORT_PRIV_OP_BIGRAM)
    269         memcpy(state->tmpB, state->tmpC, RSA_BYTES);
    270         biMul(state->tmpC, state->tmpB, state->tmpB);
    271         biMod(state->tmpC, c, state->tmpB);
    272 #endif
    273     }
    274 
    275     return state->tmpA;
    276 }
    277 #endif
    278 
    279 
    280 
    281 
    282 
    283 
    284 
    285 
    286