Home | History | Annotate | Download | only in libmpdec
      1 /*
      2  * Copyright (c) 2008-2016 Stefan Krah. All rights reserved.
      3  *
      4  * Redistribution and use in source and binary forms, with or without
      5  * modification, are permitted provided that the following conditions
      6  * are met:
      7  *
      8  * 1. Redistributions of source code must retain the above copyright
      9  *    notice, this list of conditions and the following disclaimer.
     10  *
     11  * 2. Redistributions in binary form must reproduce the above copyright
     12  *    notice, this list of conditions and the following disclaimer in the
     13  *    documentation and/or other materials provided with the distribution.
     14  *
     15  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS "AS IS" AND
     16  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
     17  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
     18  * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
     19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
     20  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
     21  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
     22  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     23  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
     24  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
     25  * SUCH DAMAGE.
     26  */
     27 
     28 
     29 #include "mpdecimal.h"
     30 #include <assert.h>
     31 #include "numbertheory.h"
     32 #include "sixstep.h"
     33 #include "transpose.h"
     34 #include "umodarith.h"
     35 #include "fourstep.h"
     36 
     37 
     38 /* Bignum: Cache efficient Matrix Fourier Transform for arrays of the
     39    form 3 * 2**n (See literature/matrix-transform.txt). */
     40 
     41 
     42 #ifndef PPRO
     43 static inline void
     44 std_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3,
     45               mpd_uint_t w3table[3], mpd_uint_t umod)
     46 {
     47     mpd_uint_t r1, r2;
     48     mpd_uint_t w;
     49     mpd_uint_t s, tmp;
     50 
     51 
     52     /* k = 0 -> w = 1 */
     53     s = *x1;
     54     s = addmod(s, *x2, umod);
     55     s = addmod(s, *x3, umod);
     56 
     57     r1 = s;
     58 
     59     /* k = 1 */
     60     s = *x1;
     61 
     62     w = w3table[1];
     63     tmp = MULMOD(*x2, w);
     64     s = addmod(s, tmp, umod);
     65 
     66     w = w3table[2];
     67     tmp = MULMOD(*x3, w);
     68     s = addmod(s, tmp, umod);
     69 
     70     r2 = s;
     71 
     72     /* k = 2 */
     73     s = *x1;
     74 
     75     w = w3table[2];
     76     tmp = MULMOD(*x2, w);
     77     s = addmod(s, tmp, umod);
     78 
     79     w = w3table[1];
     80     tmp = MULMOD(*x3, w);
     81     s = addmod(s, tmp, umod);
     82 
     83     *x3 = s;
     84     *x2 = r2;
     85     *x1 = r1;
     86 }
     87 #else /* PPRO */
     88 static inline void
     89 ppro_size3_ntt(mpd_uint_t *x1, mpd_uint_t *x2, mpd_uint_t *x3, mpd_uint_t w3table[3],
     90                mpd_uint_t umod, double *dmod, uint32_t dinvmod[3])
     91 {
     92     mpd_uint_t r1, r2;
     93     mpd_uint_t w;
     94     mpd_uint_t s, tmp;
     95 
     96 
     97     /* k = 0 -> w = 1 */
     98     s = *x1;
     99     s = addmod(s, *x2, umod);
    100     s = addmod(s, *x3, umod);
    101 
    102     r1 = s;
    103 
    104     /* k = 1 */
    105     s = *x1;
    106 
    107     w = w3table[1];
    108     tmp = ppro_mulmod(*x2, w, dmod, dinvmod);
    109     s = addmod(s, tmp, umod);
    110 
    111     w = w3table[2];
    112     tmp = ppro_mulmod(*x3, w, dmod, dinvmod);
    113     s = addmod(s, tmp, umod);
    114 
    115     r2 = s;
    116 
    117     /* k = 2 */
    118     s = *x1;
    119 
    120     w = w3table[2];
    121     tmp = ppro_mulmod(*x2, w, dmod, dinvmod);
    122     s = addmod(s, tmp, umod);
    123 
    124     w = w3table[1];
    125     tmp = ppro_mulmod(*x3, w, dmod, dinvmod);
    126     s = addmod(s, tmp, umod);
    127 
    128     *x3 = s;
    129     *x2 = r2;
    130     *x1 = r1;
    131 }
    132 #endif
    133 
    134 
    135 /* forward transform, sign = -1; transform length = 3 * 2**n */
    136 int
    137 four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum)
    138 {
    139     mpd_size_t R = 3; /* number of rows */
    140     mpd_size_t C = n / 3; /* number of columns */
    141     mpd_uint_t w3table[3];
    142     mpd_uint_t kernel, w0, w1, wstep;
    143     mpd_uint_t *s, *p0, *p1, *p2;
    144     mpd_uint_t umod;
    145 #ifdef PPRO
    146     double dmod;
    147     uint32_t dinvmod[3];
    148 #endif
    149     mpd_size_t i, k;
    150 
    151 
    152     assert(n >= 48);
    153     assert(n <= 3*MPD_MAXTRANSFORM_2N);
    154 
    155 
    156     /* Length R transform on the columns. */
    157     SETMODULUS(modnum);
    158     _mpd_init_w3table(w3table, -1, modnum);
    159     for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) {
    160 
    161         SIZE3_NTT(p0, p1, p2, w3table);
    162     }
    163 
    164     /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */
    165     kernel = _mpd_getkernel(n, -1, modnum);
    166     for (i = 1; i < R; i++) {
    167         w0 = 1;                  /* r**(i*0): initial value for k=0 */
    168         w1 = POWMOD(kernel, i);  /* r**(i*1): initial value for k=1 */
    169         wstep = MULMOD(w1, w1);  /* r**(2*i) */
    170         for (k = 0; k < C-1; k += 2) {
    171             mpd_uint_t x0 = a[i*C+k];
    172             mpd_uint_t x1 = a[i*C+k+1];
    173             MULMOD2(&x0, w0, &x1, w1);
    174             MULMOD2C(&w0, &w1, wstep);  /* r**(i*(k+2)) = r**(i*k) * r**(2*i) */
    175             a[i*C+k] = x0;
    176             a[i*C+k+1] = x1;
    177         }
    178     }
    179 
    180     /* Length C transform on the rows. */
    181     for (s = a; s < a+n; s += C) {
    182         if (!six_step_fnt(s, C, modnum)) {
    183             return 0;
    184         }
    185     }
    186 
    187 #if 0
    188     /* An unordered transform is sufficient for convolution. */
    189     /* Transpose the matrix. */
    190     transpose_3xpow2(a, R, C);
    191 #endif
    192 
    193     return 1;
    194 }
    195 
    196 /* backward transform, sign = 1; transform length = 3 * 2**n */
    197 int
    198 inv_four_step_fnt(mpd_uint_t *a, mpd_size_t n, int modnum)
    199 {
    200     mpd_size_t R = 3; /* number of rows */
    201     mpd_size_t C = n / 3; /* number of columns */
    202     mpd_uint_t w3table[3];
    203     mpd_uint_t kernel, w0, w1, wstep;
    204     mpd_uint_t *s, *p0, *p1, *p2;
    205     mpd_uint_t umod;
    206 #ifdef PPRO
    207     double dmod;
    208     uint32_t dinvmod[3];
    209 #endif
    210     mpd_size_t i, k;
    211 
    212 
    213     assert(n >= 48);
    214     assert(n <= 3*MPD_MAXTRANSFORM_2N);
    215 
    216 
    217 #if 0
    218     /* An unordered transform is sufficient for convolution. */
    219     /* Transpose the matrix, producing an R*C matrix. */
    220     transpose_3xpow2(a, C, R);
    221 #endif
    222 
    223     /* Length C transform on the rows. */
    224     for (s = a; s < a+n; s += C) {
    225         if (!inv_six_step_fnt(s, C, modnum)) {
    226             return 0;
    227         }
    228     }
    229 
    230     /* Multiply each matrix element (addressed by i*C+k) by r**(i*k). */
    231     SETMODULUS(modnum);
    232     kernel = _mpd_getkernel(n, 1, modnum);
    233     for (i = 1; i < R; i++) {
    234         w0 = 1;
    235         w1 = POWMOD(kernel, i);
    236         wstep = MULMOD(w1, w1);
    237         for (k = 0; k < C; k += 2) {
    238             mpd_uint_t x0 = a[i*C+k];
    239             mpd_uint_t x1 = a[i*C+k+1];
    240             MULMOD2(&x0, w0, &x1, w1);
    241             MULMOD2C(&w0, &w1, wstep);
    242             a[i*C+k] = x0;
    243             a[i*C+k+1] = x1;
    244         }
    245     }
    246 
    247     /* Length R transform on the columns. */
    248     _mpd_init_w3table(w3table, 1, modnum);
    249     for (p0=a, p1=p0+C, p2=p0+2*C; p0<a+C; p0++,p1++,p2++) {
    250 
    251         SIZE3_NTT(p0, p1, p2, w3table);
    252     }
    253 
    254     return 1;
    255 }
    256 
    257 
    258