Home | History | Annotate | Download | only in kiss_fft
      1 #ifndef KISSFFT_CLASS_HH
      2 #include <complex>
      3 #include <vector>
      4 
      5 namespace kissfft_utils {
      6 
      7 template <typename T_scalar>
      8 struct traits
      9 {
     10     typedef T_scalar scalar_type;
     11     typedef std::complex<scalar_type> cpx_type;
     12     void fill_twiddles( std::complex<T_scalar> * dst ,int nfft,bool inverse)
     13     {
     14         T_scalar phinc =  (inverse?2:-2)* acos( (T_scalar) -1)  / nfft;
     15         for (int i=0;i<nfft;++i)
     16             dst[i] = exp( std::complex<T_scalar>(0,i*phinc) );
     17     }
     18 
     19     void prepare(
     20             std::vector< std::complex<T_scalar> > & dst,
     21             int nfft,bool inverse,
     22             std::vector<int> & stageRadix,
     23             std::vector<int> & stageRemainder )
     24     {
     25         _twiddles.resize(nfft);
     26         fill_twiddles( &_twiddles[0],nfft,inverse);
     27         dst = _twiddles;
     28 
     29         //factorize
     30         //start factoring out 4's, then 2's, then 3,5,7,9,...
     31         int n= nfft;
     32         int p=4;
     33         do {
     34             while (n % p) {
     35                 switch (p) {
     36                     case 4: p = 2; break;
     37                     case 2: p = 3; break;
     38                     default: p += 2; break;
     39                 }
     40                 if (p*p>n)
     41                     p=n;// no more factors
     42             }
     43             n /= p;
     44             stageRadix.push_back(p);
     45             stageRemainder.push_back(n);
     46         }while(n>1);
     47     }
     48     std::vector<cpx_type> _twiddles;
     49 
     50 
     51     const cpx_type twiddle(int i) { return _twiddles[i]; }
     52 };
     53 
     54 }
     55 
     56 template <typename T_Scalar,
     57          typename T_traits=kissfft_utils::traits<T_Scalar>
     58          >
     59 class kissfft
     60 {
     61     public:
     62         typedef T_traits traits_type;
     63         typedef typename traits_type::scalar_type scalar_type;
     64         typedef typename traits_type::cpx_type cpx_type;
     65 
     66         kissfft(int nfft,bool inverse,const traits_type & traits=traits_type() )
     67             :_nfft(nfft),_inverse(inverse),_traits(traits)
     68         {
     69             _traits.prepare(_twiddles, _nfft,_inverse ,_stageRadix, _stageRemainder);
     70         }
     71 
     72         void transform(const cpx_type * src , cpx_type * dst)
     73         {
     74             kf_work(0, dst, src, 1,1);
     75         }
     76 
     77     private:
     78         void kf_work( int stage,cpx_type * Fout, const cpx_type * f, size_t fstride,size_t in_stride)
     79         {
     80             int p = _stageRadix[stage];
     81             int m = _stageRemainder[stage];
     82             cpx_type * Fout_beg = Fout;
     83             cpx_type * Fout_end = Fout + p*m;
     84 
     85             if (m==1) {
     86                 do{
     87                     *Fout = *f;
     88                     f += fstride*in_stride;
     89                 }while(++Fout != Fout_end );
     90             }else{
     91                 do{
     92                     // recursive call:
     93                     // DFT of size m*p performed by doing
     94                     // p instances of smaller DFTs of size m,
     95                     // each one takes a decimated version of the input
     96                     kf_work(stage+1, Fout , f, fstride*p,in_stride);
     97                     f += fstride*in_stride;
     98                 }while( (Fout += m) != Fout_end );
     99             }
    100 
    101             Fout=Fout_beg;
    102 
    103             // recombine the p smaller DFTs
    104             switch (p) {
    105                 case 2: kf_bfly2(Fout,fstride,m); break;
    106                 case 3: kf_bfly3(Fout,fstride,m); break;
    107                 case 4: kf_bfly4(Fout,fstride,m); break;
    108                 case 5: kf_bfly5(Fout,fstride,m); break;
    109                 default: kf_bfly_generic(Fout,fstride,m,p); break;
    110             }
    111         }
    112 
    113         // these were #define macros in the original kiss_fft
    114         void C_ADD( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a+b;}
    115         void C_MUL( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a*b;}
    116         void C_SUB( cpx_type & c,const cpx_type & a,const cpx_type & b) { c=a-b;}
    117         void C_ADDTO( cpx_type & c,const cpx_type & a) { c+=a;}
    118         void C_FIXDIV( cpx_type & ,int ) {} // NO-OP for float types
    119         scalar_type S_MUL( const scalar_type & a,const scalar_type & b) { return a*b;}
    120         scalar_type HALF_OF( const scalar_type & a) { return a*.5;}
    121         void C_MULBYSCALAR(cpx_type & c,const scalar_type & a) {c*=a;}
    122 
    123         void kf_bfly2( cpx_type * Fout, const size_t fstride, int m)
    124         {
    125             for (int k=0;k<m;++k) {
    126                 cpx_type t = Fout[m+k] * _traits.twiddle(k*fstride);
    127                 Fout[m+k] = Fout[k] - t;
    128                 Fout[k] += t;
    129             }
    130         }
    131 
    132         void kf_bfly4( cpx_type * Fout, const size_t fstride, const size_t m)
    133         {
    134             cpx_type scratch[7];
    135             int negative_if_inverse = _inverse * -2 +1;
    136             for (size_t k=0;k<m;++k) {
    137                 scratch[0] = Fout[k+m] * _traits.twiddle(k*fstride);
    138                 scratch[1] = Fout[k+2*m] * _traits.twiddle(k*fstride*2);
    139                 scratch[2] = Fout[k+3*m] * _traits.twiddle(k*fstride*3);
    140                 scratch[5] = Fout[k] - scratch[1];
    141 
    142                 Fout[k] += scratch[1];
    143                 scratch[3] = scratch[0] + scratch[2];
    144                 scratch[4] = scratch[0] - scratch[2];
    145                 scratch[4] = cpx_type( scratch[4].imag()*negative_if_inverse , -scratch[4].real()* negative_if_inverse );
    146 
    147                 Fout[k+2*m]  = Fout[k] - scratch[3];
    148                 Fout[k] += scratch[3];
    149                 Fout[k+m] = scratch[5] + scratch[4];
    150                 Fout[k+3*m] = scratch[5] - scratch[4];
    151             }
    152         }
    153 
    154         void kf_bfly3( cpx_type * Fout, const size_t fstride, const size_t m)
    155         {
    156             size_t k=m;
    157             const size_t m2 = 2*m;
    158             cpx_type *tw1,*tw2;
    159             cpx_type scratch[5];
    160             cpx_type epi3;
    161             epi3 = _twiddles[fstride*m];
    162 
    163             tw1=tw2=&_twiddles[0];
    164 
    165             do{
    166                 C_FIXDIV(*Fout,3); C_FIXDIV(Fout[m],3); C_FIXDIV(Fout[m2],3);
    167 
    168                 C_MUL(scratch[1],Fout[m] , *tw1);
    169                 C_MUL(scratch[2],Fout[m2] , *tw2);
    170 
    171                 C_ADD(scratch[3],scratch[1],scratch[2]);
    172                 C_SUB(scratch[0],scratch[1],scratch[2]);
    173                 tw1 += fstride;
    174                 tw2 += fstride*2;
    175 
    176                 Fout[m] = cpx_type( Fout->real() - HALF_OF(scratch[3].real() ) , Fout->imag() - HALF_OF(scratch[3].imag() ) );
    177 
    178                 C_MULBYSCALAR( scratch[0] , epi3.imag() );
    179 
    180                 C_ADDTO(*Fout,scratch[3]);
    181 
    182                 Fout[m2] = cpx_type(  Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() );
    183 
    184                 C_ADDTO( Fout[m] , cpx_type( -scratch[0].imag(),scratch[0].real() ) );
    185                 ++Fout;
    186             }while(--k);
    187         }
    188 
    189         void kf_bfly5( cpx_type * Fout, const size_t fstride, const size_t m)
    190         {
    191             cpx_type *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
    192             size_t u;
    193             cpx_type scratch[13];
    194             cpx_type * twiddles = &_twiddles[0];
    195             cpx_type *tw;
    196             cpx_type ya,yb;
    197             ya = twiddles[fstride*m];
    198             yb = twiddles[fstride*2*m];
    199 
    200             Fout0=Fout;
    201             Fout1=Fout0+m;
    202             Fout2=Fout0+2*m;
    203             Fout3=Fout0+3*m;
    204             Fout4=Fout0+4*m;
    205 
    206             tw=twiddles;
    207             for ( u=0; u<m; ++u ) {
    208                 C_FIXDIV( *Fout0,5); C_FIXDIV( *Fout1,5); C_FIXDIV( *Fout2,5); C_FIXDIV( *Fout3,5); C_FIXDIV( *Fout4,5);
    209                 scratch[0] = *Fout0;
    210 
    211                 C_MUL(scratch[1] ,*Fout1, tw[u*fstride]);
    212                 C_MUL(scratch[2] ,*Fout2, tw[2*u*fstride]);
    213                 C_MUL(scratch[3] ,*Fout3, tw[3*u*fstride]);
    214                 C_MUL(scratch[4] ,*Fout4, tw[4*u*fstride]);
    215 
    216                 C_ADD( scratch[7],scratch[1],scratch[4]);
    217                 C_SUB( scratch[10],scratch[1],scratch[4]);
    218                 C_ADD( scratch[8],scratch[2],scratch[3]);
    219                 C_SUB( scratch[9],scratch[2],scratch[3]);
    220 
    221                 C_ADDTO( *Fout0, scratch[7]);
    222                 C_ADDTO( *Fout0, scratch[8]);
    223 
    224                 scratch[5] = scratch[0] + cpx_type(
    225                         S_MUL(scratch[7].real(),ya.real() ) + S_MUL(scratch[8].real() ,yb.real() ),
    226                         S_MUL(scratch[7].imag(),ya.real()) + S_MUL(scratch[8].imag(),yb.real())
    227                         );
    228 
    229                 scratch[6] =  cpx_type(
    230                         S_MUL(scratch[10].imag(),ya.imag()) + S_MUL(scratch[9].imag(),yb.imag()),
    231                         -S_MUL(scratch[10].real(),ya.imag()) - S_MUL(scratch[9].real(),yb.imag())
    232                         );
    233 
    234                 C_SUB(*Fout1,scratch[5],scratch[6]);
    235                 C_ADD(*Fout4,scratch[5],scratch[6]);
    236 
    237                 scratch[11] = scratch[0] +
    238                     cpx_type(
    239                             S_MUL(scratch[7].real(),yb.real()) + S_MUL(scratch[8].real(),ya.real()),
    240                             S_MUL(scratch[7].imag(),yb.real()) + S_MUL(scratch[8].imag(),ya.real())
    241                             );
    242 
    243                 scratch[12] = cpx_type(
    244                         -S_MUL(scratch[10].imag(),yb.imag()) + S_MUL(scratch[9].imag(),ya.imag()),
    245                         S_MUL(scratch[10].real(),yb.imag()) - S_MUL(scratch[9].real(),ya.imag())
    246                         );
    247 
    248                 C_ADD(*Fout2,scratch[11],scratch[12]);
    249                 C_SUB(*Fout3,scratch[11],scratch[12]);
    250 
    251                 ++Fout0;++Fout1;++Fout2;++Fout3;++Fout4;
    252             }
    253         }
    254 
    255         /* perform the butterfly for one stage of a mixed radix FFT */
    256         void kf_bfly_generic(
    257                 cpx_type * Fout,
    258                 const size_t fstride,
    259                 int m,
    260                 int p
    261                 )
    262         {
    263             int u,k,q1,q;
    264             cpx_type * twiddles = &_twiddles[0];
    265             cpx_type t;
    266             int Norig = _nfft;
    267             cpx_type scratchbuf[p];
    268 
    269             for ( u=0; u<m; ++u ) {
    270                 k=u;
    271                 for ( q1=0 ; q1<p ; ++q1 ) {
    272                     scratchbuf[q1] = Fout[ k  ];
    273                     C_FIXDIV(scratchbuf[q1],p);
    274                     k += m;
    275                 }
    276 
    277                 k=u;
    278                 for ( q1=0 ; q1<p ; ++q1 ) {
    279                     int twidx=0;
    280                     Fout[ k ] = scratchbuf[0];
    281                     for (q=1;q<p;++q ) {
    282                         twidx += fstride * k;
    283                         if (twidx>=Norig) twidx-=Norig;
    284                         C_MUL(t,scratchbuf[q] , twiddles[twidx] );
    285                         C_ADDTO( Fout[ k ] ,t);
    286                     }
    287                     k += m;
    288                 }
    289             }
    290         }
    291 
    292         int _nfft;
    293         bool _inverse;
    294         std::vector<cpx_type> _twiddles;
    295         std::vector<int> _stageRadix;
    296         std::vector<int> _stageRemainder;
    297         traits_type _traits;
    298 };
    299 #endif
    300