1 /* 2 * Copyright (C) 2016 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <stdint.h> 18 #include <stdbool.h> 19 #include <string.h> 20 #include <nanohub/rsa.h> 21 22 23 static bool biModIterative(uint32_t *num, const uint32_t *denum, uint32_t *tmp, uint32_t *state1, uint32_t *state2, uint32_t step) 24 //num %= denum where num is RSA_LEN * 2 and denum is RSA_LEN and tmp is RSA_LEN + limb_sz 25 //will need to be called till it returns true (up to RSA_LEN * 2 + 2 times) 26 { 27 uint32_t bitsh = *state1, limbsh = *state2; 28 bool ret = false; 29 int64_t t; 30 int32_t i; 31 32 //first step is init 33 if (!step) { 34 //initially set it up left shifted as far as possible 35 memcpy(tmp + 1, denum, RSA_BYTES); 36 tmp[0] = 0; 37 bitsh = 32; 38 limbsh = RSA_LIMBS - 1; 39 goto out; 40 } 41 42 //second is shifting denum 43 if (step == 1) { 44 while (!(tmp[RSA_LIMBS] & 0x80000000)) { 45 for (i = RSA_LIMBS; i > 0; i--) { 46 tmp[i] <<= 1; 47 if (tmp[i - 1] & 0x80000000) 48 tmp[i]++; 49 } 50 //no need to adjust tmp[0] as it is still zero 51 bitsh++; 52 } 53 goto out; 54 } 55 56 //all future steps do the division 57 58 //check if we should subtract (uses less space than subtracting and unroling it later) 59 for (i = RSA_LIMBS; i >= 0; i--) { 60 if (num[limbsh + i] < tmp[i]) 61 goto dont_subtract; 62 if (num[limbsh + i] > tmp[i]) 63 break; 64 } 65 66 //subtract 67 t = 0; 68 for (i = 0; i <= RSA_LIMBS; i++) { 69 t += (uint64_t)num[limbsh + i]; 70 t -= (uint64_t)tmp[i]; 71 num[limbsh + i] = t; 72 t >>= 32; 73 } 74 75 //carry the subtraction's carry to the end 76 for (i = RSA_LIMBS + limbsh + 1; i < RSA_LIMBS * 2; i++) { 77 t += (uint64_t)num[i]; 78 num[i] = t; 79 t >>= 32; 80 } 81 82 dont_subtract: 83 //handle bitshifts/refills 84 if (!bitsh) { // tmp = denum << 32 85 if (!limbsh) { 86 ret = true; 87 goto out; 88 } 89 90 memcpy(tmp + 1, denum, RSA_BYTES); 91 tmp[0] = 0; 92 bitsh = 32; 93 limbsh--; 94 } 95 else { // tmp >>= 1 96 for (i = 0; i < RSA_LIMBS; i++) { 97 tmp[i] >>= 1; 98 if (tmp[i + 1] & 1) 99 tmp[i] += 0x80000000; 100 } 101 tmp[i] >>= 1; 102 bitsh--; 103 } 104 105 106 out: 107 *state1 = bitsh; 108 *state2 = limbsh; 109 return ret; 110 } 111 112 static void biMulIterative(uint32_t *ret, const uint32_t *a, const uint32_t *b, uint32_t step) //ret = a * b, call with step = [0..RSA_LIMBS) 113 { 114 uint32_t j, c; 115 uint64_t r; 116 117 //zero the result on first call 118 if (!step) 119 memset(ret, 0, RSA_BYTES * 2); 120 121 //produce a partial sum & add it in 122 c = 0; 123 for (j = 0; j < RSA_LIMBS; j++) { 124 r = (uint64_t)a[step] * b[j] + c + ret[step + j]; 125 ret[step + j] = r; 126 c = r >> 32; 127 } 128 129 //carry the carry to the end 130 for (j = step + RSA_LIMBS; j < RSA_LIMBS * 2; j++) { 131 r = (uint64_t)ret[j] + c; 132 ret[j] = r; 133 c = r >> 32; 134 } 135 } 136 137 /* 138 * Piecewise RSA: 139 * normal RSA public op with 65537 exponent does 34 operations. 17 muls and 17 mods, as follows: 140 * 16x {mul, mod} to calculate a ^ 65536 mod c 141 * 1x {mul, mod} to calculate a ^ 65537 mod c 142 * we break up each mul and mod itself into more steps. mul needs RSA_LIMBS steps, and mod needs up to RSA_LEN * 2 + 2 steps 143 * so if we allocate RSA_LEN * 3 step values to mod, each mul-mod pair will use <= RSA_LEN * 4 step values 144 * and the whole opetaion will need <= RSA_LEN * 4 * 34 step values, which fits into a uint32. cool. In fact 145 * some values will be skipped, but this makes life easier, really. Call this func with *stepP = 0, and keep calling till 146 * output stepP is zero. We'll call each of the RSA_LEN * 4 pieces a gigastep, and have 17 of them as seen above. Each 147 * will be logically separated into 4 megasteps. First will contain the MUL, last 3 the MOD and maybe the memcpy. 148 * In the first 16 gigasteps, the very last step of the gigastep will be used for the memcpy call. 149 * 150 * The initial non-iterative RSA logic looks as follows, shown here for clarity: 151 * 152 * memcpy(state->tmpB, a, RSA_BYTES); 153 * for (i = 0; i < 16; i++) { 154 * biMul(state->tmpA, state->tmpB, state->tmpB); 155 * biMod(state->tmpA, c, state->tmpB); 156 * memcpy(state->tmpB, state->tmpA, RSA_BYTES); 157 * } 158 * 159 * //calculate a ^ 65537 mod c into state->tmpA [ at this point this means do state->tmpA = (state->tmpB * a) % c ] 160 * biMul(state->tmpA, state->tmpB, a); 161 * biMod(state->tmpA, c, state->tmpB); 162 * 163 * //return result 164 * return state->tmpA; 165 * 166 */ 167 168 const uint32_t* rsaPubOpIterative(struct RsaState* state, const uint32_t *a, const uint32_t *c, uint32_t *state1, uint32_t *state2, uint32_t *stepP) 169 { 170 uint32_t step = *stepP, gigastep, gigastepBase, gigastepSubstep, megaSubstep; 171 172 //step 0: copy a -> tmpB 173 if (!step) { 174 memcpy(state->tmpB, a, RSA_BYTES); 175 step = 1; 176 } 177 else { //subsequent steps: do real work 178 179 180 gigastep = (step - 1) / (RSA_LEN * 4); 181 gigastepSubstep = (step - 1) % (RSA_LEN * 4); 182 gigastepBase = gigastep * (RSA_LEN * 4); 183 megaSubstep = gigastepSubstep / RSA_LEN; 184 185 if (!megaSubstep) { // first megastep of the gigastep - MUL 186 biMulIterative(state->tmpA, state->tmpB, gigastep == 16 ? a : state->tmpB, gigastepSubstep); 187 if (gigastepSubstep == RSA_LIMBS - 1) //MUL is done - do mod next 188 step = gigastepBase + RSA_LEN + 1; 189 else //More of MUL is left to do 190 step++; 191 } 192 else if (gigastepSubstep != RSA_LEN * 4 - 1){ // second part of gigastep - MOD 193 if (biModIterative(state->tmpA, c, state->tmpB, state1, state2, gigastepSubstep - RSA_LEN)) { //MOD is done 194 if (gigastep == 16) // we're done 195 step = 0; 196 else // last part of the gigastep is a copy 197 step = gigastepBase + RSA_LEN * 4 - 1 + 1; 198 } 199 else 200 step++; 201 } 202 else { //last part - memcpy 203 memcpy(state->tmpB, state->tmpA, RSA_BYTES); 204 step++; 205 } 206 } 207 208 *stepP = step; 209 return state->tmpA; 210 } 211 212 #if defined(RSA_SUPPORT_PRIV_OP_LOWRAM) || defined (RSA_SUPPORT_PRIV_OP_BIGRAM) 213 #include <stdio.h> 214 const uint32_t* rsaPubOp(struct RsaState* state, const uint32_t *a, const uint32_t *c) 215 { 216 const uint32_t *ret; 217 uint32_t state1 = 0, state2 = 0, step = 0, ns = 0; 218 219 do { 220 ret = rsaPubOpIterative(state, a, c, &state1, &state2, &step); 221 ns++; 222 } while(step); 223 224 fprintf(stderr, "steps: %u\n", ns); 225 226 return ret; 227 } 228 229 static void biMod(uint32_t *num, const uint32_t *denum, uint32_t *tmp) 230 { 231 uint32_t state1 = 0, state2 = 0, step; 232 233 for (step = 0; !biModIterative(num, denum, tmp, &state1, &state2, step); step++); 234 } 235 236 static void biMul(uint32_t *ret, const uint32_t *a, const uint32_t *b) 237 { 238 uint32_t step; 239 240 for (step = 0; step < RSA_LIMBS; step++) 241 biMulIterative(ret, a, b, step); 242 } 243 244 const uint32_t* rsaPrivOp(struct RsaState* state, const uint32_t *a, const uint32_t *b, const uint32_t *c) 245 { 246 uint32_t i; 247 248 memcpy(state->tmpC, a, RSA_BYTES); //tC will hold our powers of a 249 250 memset(state->tmpA, 0, RSA_BYTES * 2); //tA will hold result 251 state->tmpA[0] = 1; 252 253 for (i = 0; i < RSA_LEN; i++) { 254 //if the bit is set, multiply the current power of A into result 255 if (b[i / 32] & (1 << (i % 32))) { 256 memcpy(state->tmpB, state->tmpA, RSA_BYTES); 257 biMul(state->tmpA, state->tmpB, state->tmpC); 258 biMod(state->tmpA, c, state->tmpB); 259 } 260 261 //calculate the next power of a and modulus it 262 #if defined(RSA_SUPPORT_PRIV_OP_LOWRAM) 263 memcpy(state->tmpB, state->tmpA, RSA_BYTES); //save tA 264 biMul(state->tmpA, state->tmpC, state->tmpC); 265 biMod(state->tmpA, c, state->tmpC); 266 memcpy(state->tmpC, state->tmpA, RSA_BYTES); 267 memcpy(state->tmpA, state->tmpB, RSA_BYTES); //restore tA 268 #elif defined (RSA_SUPPORT_PRIV_OP_BIGRAM) 269 memcpy(state->tmpB, state->tmpC, RSA_BYTES); 270 biMul(state->tmpC, state->tmpB, state->tmpB); 271 biMod(state->tmpC, c, state->tmpB); 272 #endif 273 } 274 275 return state->tmpA; 276 } 277 #endif 278 279 280 281 282 283 284 285 286