1 /* 2 * Copyright (c) 2008-2016 Stefan Krah. All rights reserved. 3 * 4 * Redistribution and use in source and binary forms, with or without 5 * modification, are permitted provided that the following conditions 6 * are met: 7 * 8 * 1. Redistributions of source code must retain the above copyright 9 * notice, this list of conditions and the following disclaimer. 10 * 11 * 2. Redistributions in binary form must reproduce the above copyright 12 * notice, this list of conditions and the following disclaimer in the 13 * documentation and/or other materials provided with the distribution. 14 * 15 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND 16 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 18 * ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE 19 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 20 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS 21 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) 22 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT 23 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY 24 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF 25 * SUCH DAMAGE. 26 */ 27 28 29 #include "mpdecimal.h" 30 #include <assert.h> 31 #include "numbertheory.h" 32 #include "sixstep.h" 33 #include "transpose.h" 34 #include "umodarith.h" 35 #include "fourstep.h" 36 37 38 /* Bignum: Cache efficient Matrix Fourier Transform for arrays of the 39 form 3 * 2**n (See literature/matrix-transform.txt). */ 40 41 42 #ifndef PPRO 43 static inline void 44 std_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3, 45 mpd_uint_t w3table[3], mpd_uint_t umod) 46 { 47 mpd_uint_t r1, r2; 48 mpd_uint_t w; 49 mpd_uint_t s, tmp; 50 51 52 /* k = 0 -> w = 1 */ 53 s = *x1; 54 s = addmod(s, *x2, umod); 55 s = addmod(s, *x3, umod); 56 57 r1 = s; 58 59 /* k = 1 */ 60 s = *x1; 61 62 w = w3table[1]; 63 tmp = MULMOD(*x2, w); 64 s = addmod(s, tmp, umod); 65 66 w = w3table[2]; 67 tmp = MULMOD(*x3, w); 68 s = addmod(s, tmp, umod); 69 70 r2 = s; 71 72 /* k = 2 */ 73 s = *x1; 74 75 w = w3table[2]; 76 tmp = MULMOD(*x2, w); 77 s = addmod(s, tmp, umod); 78 79 w = w3table[1]; 80 tmp = MULMOD(*x3, w); 81 s = addmod(s, tmp, umod); 82 83 *x3 = s; 84 *x2 = r2; 85 *x1 = r1; 86 } 87 #else /* PPRO */ 88 static inline void 89 ppro_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3, mpd_uint_t w3table[3], 90 mpd_uint_t umod, double *dmod, uint32_t dinvmod[3]) 91 { 92 mpd_uint_t r1, r2; 93 mpd_uint_t w; 94 mpd_uint_t s, tmp; 95 96 97 /* k = 0 -> w = 1 */ 98 s = *x1; 99 s = addmod(s, *x2, umod); 100 s = addmod(s, *x3, umod); 101 102 r1 = s; 103 104 /* k = 1 */ 105 s = *x1; 106 107 w = w3table[1]; 108 tmp = ppro_mulmod(*x2, w, dmod, dinvmod); 109 s = addmod(s, tmp, umod); 110 111 w = w3table[2]; 112 tmp = ppro_mulmod(*x3, w, dmod, dinvmod); 113 s = addmod(s, tmp, umod); 114 115 r2 = s; 116 117 /* k = 2 */ 118 s = *x1; 119 120 w = w3table[2]; 121 tmp = ppro_mulmod(*x2, w, dmod, dinvmod); 122 s = addmod(s, tmp, umod); 123 124 w = w3table[1]; 125 tmp = ppro_mulmod(*x3, w, dmod, dinvmod); 126 s = addmod(s, tmp, umod); 127 128 *x3 = s; 129 *x2 = r2; 130 *x1 = r1; 131 } 132 #endif 133 134 135 /* forward transform, sign = -1; transform length = 3 * 2**n */ 136 int 137 four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum) 138 { 139 mpd_size_t R = 3; /* number of rows */ 140 mpd_size_t C = n / 3; /* number of columns */ 141 mpd_uint_t w3table[3]; 142 mpd_uint_t kernel, w0, w1, wstep; 143 mpd_uint_t *s, *p0, *p1, *p2; 144 mpd_uint_t umod; 145 #ifdef PPRO 146 double dmod; 147 uint32_t dinvmod[3]; 148 #endif 149 mpd_size_t i, k; 150 151 152 assert(n >= 48); 153 assert(n <= 3*MPD_MAXTRANSFORM_2N); 154 155 156 /* Length R transform on the columns. */ 157 SETMODULUS(modnum); 158 _mpd_init_w3table(w3table, -1, modnum); 159 for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) { 160 161 SIZE3_NTT(p0, p1, p2, w3table); 162 } 163 164 /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */ 165 kernel = _mpd_getkernel(n, -1, modnum); 166 for (i = 1; i < R; i++) { 167 w0 = 1; /* r**(i*0): initial value for k=0 */ 168 w1 = POWMOD(kernel, i); /* r**(i*1): initial value for k=1 */ 169 wstep = MULMOD(w1, w1); /* r**(2*i) */ 170 for (k = 0; k < C-1; k += 2) { 171 mpd_uint_t x0 = a[i*C+k]; 172 mpd_uint_t x1 = a[i*C+k+1]; 173 MULMOD2(&x0, w0, &x1, w1); 174 MULMOD2C(&w0, &w1, wstep); /* r**(i*(k+2)) = r**(i*k) * r**(2*i) */ 175 a[i*C+k] = x0; 176 a[i*C+k+1] = x1; 177 } 178 } 179 180 /* Length C transform on the rows. */ 181 for (s = a; s < a+n; s += C) { 182 if (!six_step_fnt(s, C, modnum)) { 183 return 0; 184 } 185 } 186 187 #if 0 188 /* An unordered transform is sufficient for convolution. */ 189 /* Transpose the matrix. */ 190 transpose_3xpow2(a, R, C); 191 #endif 192 193 return 1; 194 } 195 196 /* backward transform, sign = 1; transform length = 3 * 2**n */ 197 int 198 inv_four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum) 199 { 200 mpd_size_t R = 3; /* number of rows */ 201 mpd_size_t C = n / 3; /* number of columns */ 202 mpd_uint_t w3table[3]; 203 mpd_uint_t kernel, w0, w1, wstep; 204 mpd_uint_t *s, *p0, *p1, *p2; 205 mpd_uint_t umod; 206 #ifdef PPRO 207 double dmod; 208 uint32_t dinvmod[3]; 209 #endif 210 mpd_size_t i, k; 211 212 213 assert(n >= 48); 214 assert(n <= 3*MPD_MAXTRANSFORM_2N); 215 216 217 #if 0 218 /* An unordered transform is sufficient for convolution. */ 219 /* Transpose the matrix, producing an R*C matrix. */ 220 transpose_3xpow2(a, C, R); 221 #endif 222 223 /* Length C transform on the rows. */ 224 for (s = a; s < a+n; s += C) { 225 if (!inv_six_step_fnt(s, C, modnum)) { 226 return 0; 227 } 228 } 229 230 /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */ 231 SETMODULUS(modnum); 232 kernel = _mpd_getkernel(n, 1, modnum); 233 for (i = 1; i < R; i++) { 234 w0 = 1; 235 w1 = POWMOD(kernel, i); 236 wstep = MULMOD(w1, w1); 237 for (k = 0; k < C; k += 2) { 238 mpd_uint_t x0 = a[i*C+k]; 239 mpd_uint_t x1 = a[i*C+k+1]; 240 MULMOD2(&x0, w0, &x1, w1); 241 MULMOD2C(&w0, &w1, wstep); 242 a[i*C+k] = x0; 243 a[i*C+k+1] = x1; 244 } 245 } 246 247 /* Length R transform on the columns. */ 248 _mpd_init_w3table(w3table, 1, modnum); 249 for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) { 250 251 SIZE3_NTT(p0, p1, p2, w3table); 252 } 253 254 return 1; 255 } 256 257 258