Home | History | Annotate | Download | only in FFT
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Mark Borgerding mark a borgerding net
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 namespace Eigen {
     11 
     12 namespace internal {
     13 
     14   // This FFT implementation was derived from kissfft http:sourceforge.net/projects/kissfft
     15   // Copyright 2003-2009 Mark Borgerding
     16 
     17 template <typename _Scalar>
     18 struct kiss_cpx_fft
     19 {
     20   typedef _Scalar Scalar;
     21   typedef std::complex<Scalar> Complex;
     22   std::vector<Complex> m_twiddles;
     23   std::vector<int> m_stageRadix;
     24   std::vector<int> m_stageRemainder;
     25   std::vector<Complex> m_scratchBuf;
     26   bool m_inverse;
     27 
     28   inline
     29     void make_twiddles(int nfft,bool inverse)
     30     {
     31       using std::acos;
     32       m_inverse = inverse;
     33       m_twiddles.resize(nfft);
     34       Scalar phinc =  (inverse?2:-2)* acos( (Scalar) -1)  / nfft;
     35       for (int i=0;i<nfft;++i)
     36         m_twiddles[i] = exp( Complex(0,i*phinc) );
     37     }
     38 
     39   void factorize(int nfft)
     40   {
     41     //start factoring out 4's, then 2's, then 3,5,7,9,...
     42     int n= nfft;
     43     int p=4;
     44     do {
     45       while (n % p) {
     46         switch (p) {
     47           case 4: p = 2; break;
     48           case 2: p = 3; break;
     49           default: p += 2; break;
     50         }
     51         if (p*p>n)
     52           p=n;// impossible to have a factor > sqrt(n)
     53       }
     54       n /= p;
     55       m_stageRadix.push_back(p);
     56       m_stageRemainder.push_back(n);
     57       if ( p > 5 )
     58         m_scratchBuf.resize(p); // scratchbuf will be needed in bfly_generic
     59     }while(n>1);
     60   }
     61 
     62   template <typename _Src>
     63     inline
     64     void work( int stage,Complex * xout, const _Src * xin, size_t fstride,size_t in_stride)
     65     {
     66       int p = m_stageRadix[stage];
     67       int m = m_stageRemainder[stage];
     68       Complex * Fout_beg = xout;
     69       Complex * Fout_end = xout + p*m;
     70 
     71       if (m>1) {
     72         do{
     73           // recursive call:
     74           // DFT of size m*p performed by doing
     75           // p instances of smaller DFTs of size m,
     76           // each one takes a decimated version of the input
     77           work(stage+1, xout , xin, fstride*p,in_stride);
     78           xin += fstride*in_stride;
     79         }while( (xout += m) != Fout_end );
     80       }else{
     81         do{
     82           *xout = *xin;
     83           xin += fstride*in_stride;
     84         }while(++xout != Fout_end );
     85       }
     86       xout=Fout_beg;
     87 
     88       // recombine the p smaller DFTs
     89       switch (p) {
     90         case 2: bfly2(xout,fstride,m); break;
     91         case 3: bfly3(xout,fstride,m); break;
     92         case 4: bfly4(xout,fstride,m); break;
     93         case 5: bfly5(xout,fstride,m); break;
     94         default: bfly_generic(xout,fstride,m,p); break;
     95       }
     96     }
     97 
     98   inline
     99     void bfly2( Complex * Fout, const size_t fstride, int m)
    100     {
    101       for (int k=0;k<m;++k) {
    102         Complex t = Fout[m+k] * m_twiddles[k*fstride];
    103         Fout[m+k] = Fout[k] - t;
    104         Fout[k] += t;
    105       }
    106     }
    107 
    108   inline
    109     void bfly4( Complex * Fout, const size_t fstride, const size_t m)
    110     {
    111       Complex scratch[6];
    112       int negative_if_inverse = m_inverse * -2 +1;
    113       for (size_t k=0;k<m;++k) {
    114         scratch[0] = Fout[k+m] * m_twiddles[k*fstride];
    115         scratch[1] = Fout[k+2*m] * m_twiddles[k*fstride*2];
    116         scratch[2] = Fout[k+3*m] * m_twiddles[k*fstride*3];
    117         scratch[5] = Fout[k] - scratch[1];
    118 
    119         Fout[k] += scratch[1];
    120         scratch[3] = scratch[0] + scratch[2];
    121         scratch[4] = scratch[0] - scratch[2];
    122         scratch[4] = Complex( scratch[4].imag()*negative_if_inverse , -scratch[4].real()* negative_if_inverse );
    123 
    124         Fout[k+2*m]  = Fout[k] - scratch[3];
    125         Fout[k] += scratch[3];
    126         Fout[k+m] = scratch[5] + scratch[4];
    127         Fout[k+3*m] = scratch[5] - scratch[4];
    128       }
    129     }
    130 
    131   inline
    132     void bfly3( Complex * Fout, const size_t fstride, const size_t m)
    133     {
    134       size_t k=m;
    135       const size_t m2 = 2*m;
    136       Complex *tw1,*tw2;
    137       Complex scratch[5];
    138       Complex epi3;
    139       epi3 = m_twiddles[fstride*m];
    140 
    141       tw1=tw2=&m_twiddles[0];
    142 
    143       do{
    144         scratch[1]=Fout[m] * *tw1;
    145         scratch[2]=Fout[m2] * *tw2;
    146 
    147         scratch[3]=scratch[1]+scratch[2];
    148         scratch[0]=scratch[1]-scratch[2];
    149         tw1 += fstride;
    150         tw2 += fstride*2;
    151         Fout[m] = Complex( Fout->real() - Scalar(.5)*scratch[3].real() , Fout->imag() - Scalar(.5)*scratch[3].imag() );
    152         scratch[0] *= epi3.imag();
    153         *Fout += scratch[3];
    154         Fout[m2] = Complex(  Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() );
    155         Fout[m] += Complex( -scratch[0].imag(),scratch[0].real() );
    156         ++Fout;
    157       }while(--k);
    158     }
    159 
    160   inline
    161     void bfly5( Complex * Fout, const size_t fstride, const size_t m)
    162     {
    163       Complex *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
    164       size_t u;
    165       Complex scratch[13];
    166       Complex * twiddles = &m_twiddles[0];
    167       Complex *tw;
    168       Complex ya,yb;
    169       ya = twiddles[fstride*m];
    170       yb = twiddles[fstride*2*m];
    171 
    172       Fout0=Fout;
    173       Fout1=Fout0+m;
    174       Fout2=Fout0+2*m;
    175       Fout3=Fout0+3*m;
    176       Fout4=Fout0+4*m;
    177 
    178       tw=twiddles;
    179       for ( u=0; u<m; ++u ) {
    180         scratch[0] = *Fout0;
    181 
    182         scratch[1]  = *Fout1 * tw[u*fstride];
    183         scratch[2]  = *Fout2 * tw[2*u*fstride];
    184         scratch[3]  = *Fout3 * tw[3*u*fstride];
    185         scratch[4]  = *Fout4 * tw[4*u*fstride];
    186 
    187         scratch[7] = scratch[1] + scratch[4];
    188         scratch[10] = scratch[1] - scratch[4];
    189         scratch[8] = scratch[2] + scratch[3];
    190         scratch[9] = scratch[2] - scratch[3];
    191 
    192         *Fout0 +=  scratch[7];
    193         *Fout0 +=  scratch[8];
    194 
    195         scratch[5] = scratch[0] + Complex(
    196             (scratch[7].real()*ya.real() ) + (scratch[8].real() *yb.real() ),
    197             (scratch[7].imag()*ya.real()) + (scratch[8].imag()*yb.real())
    198             );
    199 
    200         scratch[6] = Complex(
    201             (scratch[10].imag()*ya.imag()) + (scratch[9].imag()*yb.imag()),
    202             -(scratch[10].real()*ya.imag()) - (scratch[9].real()*yb.imag())
    203             );
    204 
    205         *Fout1 = scratch[5] - scratch[6];
    206         *Fout4 = scratch[5] + scratch[6];
    207 
    208         scratch[11] = scratch[0] +
    209           Complex(
    210               (scratch[7].real()*yb.real()) + (scratch[8].real()*ya.real()),
    211               (scratch[7].imag()*yb.real()) + (scratch[8].imag()*ya.real())
    212               );
    213 
    214         scratch[12] = Complex(
    215             -(scratch[10].imag()*yb.imag()) + (scratch[9].imag()*ya.imag()),
    216             (scratch[10].real()*yb.imag()) - (scratch[9].real()*ya.imag())
    217             );
    218 
    219         *Fout2=scratch[11]+scratch[12];
    220         *Fout3=scratch[11]-scratch[12];
    221 
    222         ++Fout0;++Fout1;++Fout2;++Fout3;++Fout4;
    223       }
    224     }
    225 
    226   /* perform the butterfly for one stage of a mixed radix FFT */
    227   inline
    228     void bfly_generic(
    229         Complex * Fout,
    230         const size_t fstride,
    231         int m,
    232         int p
    233         )
    234     {
    235       int u,k,q1,q;
    236       Complex * twiddles = &m_twiddles[0];
    237       Complex t;
    238       int Norig = static_cast<int>(m_twiddles.size());
    239       Complex * scratchbuf = &m_scratchBuf[0];
    240 
    241       for ( u=0; u<m; ++u ) {
    242         k=u;
    243         for ( q1=0 ; q1<p ; ++q1 ) {
    244           scratchbuf[q1] = Fout[ k  ];
    245           k += m;
    246         }
    247 
    248         k=u;
    249         for ( q1=0 ; q1<p ; ++q1 ) {
    250           int twidx=0;
    251           Fout[ k ] = scratchbuf[0];
    252           for (q=1;q<p;++q ) {
    253             twidx += static_cast<int>(fstride) * k;
    254             if (twidx>=Norig) twidx-=Norig;
    255             t=scratchbuf[q] * twiddles[twidx];
    256             Fout[ k ] += t;
    257           }
    258           k += m;
    259         }
    260       }
    261     }
    262 };
    263 
    264 template <typename _Scalar>
    265 struct kissfft_impl
    266 {
    267   typedef _Scalar Scalar;
    268   typedef std::complex<Scalar> Complex;
    269 
    270   void clear()
    271   {
    272     m_plans.clear();
    273     m_realTwiddles.clear();
    274   }
    275 
    276   inline
    277     void fwd( Complex * dst,const Complex *src,int nfft)
    278     {
    279       get_plan(nfft,false).work(0, dst, src, 1,1);
    280     }
    281 
    282   inline
    283     void fwd2( Complex * dst,const Complex *src,int n0,int n1)
    284     {
    285         EIGEN_UNUSED_VARIABLE(dst);
    286         EIGEN_UNUSED_VARIABLE(src);
    287         EIGEN_UNUSED_VARIABLE(n0);
    288         EIGEN_UNUSED_VARIABLE(n1);
    289     }
    290 
    291   inline
    292     void inv2( Complex * dst,const Complex *src,int n0,int n1)
    293     {
    294         EIGEN_UNUSED_VARIABLE(dst);
    295         EIGEN_UNUSED_VARIABLE(src);
    296         EIGEN_UNUSED_VARIABLE(n0);
    297         EIGEN_UNUSED_VARIABLE(n1);
    298     }
    299 
    300   // real-to-complex forward FFT
    301   // perform two FFTs of src even and src odd
    302   // then twiddle to recombine them into the half-spectrum format
    303   // then fill in the conjugate symmetric half
    304   inline
    305     void fwd( Complex * dst,const Scalar * src,int nfft)
    306     {
    307       if ( nfft&3  ) {
    308         // use generic mode for odd
    309         m_tmpBuf1.resize(nfft);
    310         get_plan(nfft,false).work(0, &m_tmpBuf1[0], src, 1,1);
    311         std::copy(m_tmpBuf1.begin(),m_tmpBuf1.begin()+(nfft>>1)+1,dst );
    312       }else{
    313         int ncfft = nfft>>1;
    314         int ncfft2 = nfft>>2;
    315         Complex * rtw = real_twiddles(ncfft2);
    316 
    317         // use optimized mode for even real
    318         fwd( dst, reinterpret_cast<const Complex*> (src), ncfft);
    319         Complex dc = dst[0].real() +  dst[0].imag();
    320         Complex nyquist = dst[0].real() -  dst[0].imag();
    321         int k;
    322         for ( k=1;k <= ncfft2 ; ++k ) {
    323           Complex fpk = dst[k];
    324           Complex fpnk = conj(dst[ncfft-k]);
    325           Complex f1k = fpk + fpnk;
    326           Complex f2k = fpk - fpnk;
    327           Complex tw= f2k * rtw[k-1];
    328           dst[k] =  (f1k + tw) * Scalar(.5);
    329           dst[ncfft-k] =  conj(f1k -tw)*Scalar(.5);
    330         }
    331         dst[0] = dc;
    332         dst[ncfft] = nyquist;
    333       }
    334     }
    335 
    336   // inverse complex-to-complex
    337   inline
    338     void inv(Complex * dst,const Complex  *src,int nfft)
    339     {
    340       get_plan(nfft,true).work(0, dst, src, 1,1);
    341     }
    342 
    343   // half-complex to scalar
    344   inline
    345     void inv( Scalar * dst,const Complex * src,int nfft)
    346     {
    347       if (nfft&3) {
    348         m_tmpBuf1.resize(nfft);
    349         m_tmpBuf2.resize(nfft);
    350         std::copy(src,src+(nfft>>1)+1,m_tmpBuf1.begin() );
    351         for (int k=1;k<(nfft>>1)+1;++k)
    352           m_tmpBuf1[nfft-k] = conj(m_tmpBuf1[k]);
    353         inv(&m_tmpBuf2[0],&m_tmpBuf1[0],nfft);
    354         for (int k=0;k<nfft;++k)
    355           dst[k] = m_tmpBuf2[k].real();
    356       }else{
    357         // optimized version for multiple of 4
    358         int ncfft = nfft>>1;
    359         int ncfft2 = nfft>>2;
    360         Complex * rtw = real_twiddles(ncfft2);
    361         m_tmpBuf1.resize(ncfft);
    362         m_tmpBuf1[0] = Complex( src[0].real() + src[ncfft].real(), src[0].real() - src[ncfft].real() );
    363         for (int k = 1; k <= ncfft / 2; ++k) {
    364           Complex fk = src[k];
    365           Complex fnkc = conj(src[ncfft-k]);
    366           Complex fek = fk + fnkc;
    367           Complex tmp = fk - fnkc;
    368           Complex fok = tmp * conj(rtw[k-1]);
    369           m_tmpBuf1[k] = fek + fok;
    370           m_tmpBuf1[ncfft-k] = conj(fek - fok);
    371         }
    372         get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_tmpBuf1[0], 1,1);
    373       }
    374     }
    375 
    376   protected:
    377   typedef kiss_cpx_fft<Scalar> PlanData;
    378   typedef std::map<int,PlanData> PlanMap;
    379 
    380   PlanMap m_plans;
    381   std::map<int, std::vector<Complex> > m_realTwiddles;
    382   std::vector<Complex> m_tmpBuf1;
    383   std::vector<Complex> m_tmpBuf2;
    384 
    385   inline
    386     int PlanKey(int nfft, bool isinverse) const { return (nfft<<1) | int(isinverse); }
    387 
    388   inline
    389     PlanData & get_plan(int nfft, bool inverse)
    390     {
    391       // TODO look for PlanKey(nfft, ! inverse) and conjugate the twiddles
    392       PlanData & pd = m_plans[ PlanKey(nfft,inverse) ];
    393       if ( pd.m_twiddles.size() == 0 ) {
    394         pd.make_twiddles(nfft,inverse);
    395         pd.factorize(nfft);
    396       }
    397       return pd;
    398     }
    399 
    400   inline
    401     Complex * real_twiddles(int ncfft2)
    402     {
    403       using std::acos;
    404       std::vector<Complex> & twidref = m_realTwiddles[ncfft2];// creates new if not there
    405       if ( (int)twidref.size() != ncfft2 ) {
    406         twidref.resize(ncfft2);
    407         int ncfft= ncfft2<<1;
    408         Scalar pi =  acos( Scalar(-1) );
    409         for (int k=1;k<=ncfft2;++k)
    410           twidref[k-1] = exp( Complex(0,-pi * (Scalar(k) / ncfft + Scalar(.5)) ) );
    411       }
    412       return &twidref[0];
    413     }
    414 };
    415 
    416 } // end namespace internal
    417 
    418 } // end namespace Eigen
    419 
    420 /* vim: set filetype=cpp et sw=2 ts=2 ai: */
    421