Home | History | Annotate | Download | only in FFT
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Mark Borgerding mark a borgerding net
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 namespace Eigen {
     11 
     12 namespace internal {
     13 
     14   // This FFT implementation was derived from kissfft http:sourceforge.net/projects/kissfft
     15   // Copyright 2003-2009 Mark Borgerding
     16 
     17 template <typename _Scalar>
     18 struct kiss_cpx_fft
     19 {
     20   typedef _Scalar Scalar;
     21   typedef std::complex<Scalar> Complex;
     22   std::vector<Complex> m_twiddles;
     23   std::vector<int> m_stageRadix;
     24   std::vector<int> m_stageRemainder;
     25   std::vector<Complex> m_scratchBuf;
     26   bool m_inverse;
     27 
     28   inline
     29     void make_twiddles(int nfft,bool inverse)
     30     {
     31       m_inverse = inverse;
     32       m_twiddles.resize(nfft);
     33       Scalar phinc =  (inverse?2:-2)* acos( (Scalar) -1)  / nfft;
     34       for (int i=0;i<nfft;++i)
     35         m_twiddles[i] = exp( Complex(0,i*phinc) );
     36     }
     37 
     38   void factorize(int nfft)
     39   {
     40     //start factoring out 4's, then 2's, then 3,5,7,9,...
     41     int n= nfft;
     42     int p=4;
     43     do {
     44       while (n % p) {
     45         switch (p) {
     46           case 4: p = 2; break;
     47           case 2: p = 3; break;
     48           default: p += 2; break;
     49         }
     50         if (p*p>n)
     51           p=n;// impossible to have a factor > sqrt(n)
     52       }
     53       n /= p;
     54       m_stageRadix.push_back(p);
     55       m_stageRemainder.push_back(n);
     56       if ( p > 5 )
     57         m_scratchBuf.resize(p); // scratchbuf will be needed in bfly_generic
     58     }while(n>1);
     59   }
     60 
     61   template <typename _Src>
     62     inline
     63     void work( int stage,Complex * xout, const _Src * xin, size_t fstride,size_t in_stride)
     64     {
     65       int p = m_stageRadix[stage];
     66       int m = m_stageRemainder[stage];
     67       Complex * Fout_beg = xout;
     68       Complex * Fout_end = xout + p*m;
     69 
     70       if (m>1) {
     71         do{
     72           // recursive call:
     73           // DFT of size m*p performed by doing
     74           // p instances of smaller DFTs of size m,
     75           // each one takes a decimated version of the input
     76           work(stage+1, xout , xin, fstride*p,in_stride);
     77           xin += fstride*in_stride;
     78         }while( (xout += m) != Fout_end );
     79       }else{
     80         do{
     81           *xout = *xin;
     82           xin += fstride*in_stride;
     83         }while(++xout != Fout_end );
     84       }
     85       xout=Fout_beg;
     86 
     87       // recombine the p smaller DFTs
     88       switch (p) {
     89         case 2: bfly2(xout,fstride,m); break;
     90         case 3: bfly3(xout,fstride,m); break;
     91         case 4: bfly4(xout,fstride,m); break;
     92         case 5: bfly5(xout,fstride,m); break;
     93         default: bfly_generic(xout,fstride,m,p); break;
     94       }
     95     }
     96 
     97   inline
     98     void bfly2( Complex * Fout, const size_t fstride, int m)
     99     {
    100       for (int k=0;k<m;++k) {
    101         Complex t = Fout[m+k] * m_twiddles[k*fstride];
    102         Fout[m+k] = Fout[k] - t;
    103         Fout[k] += t;
    104       }
    105     }
    106 
    107   inline
    108     void bfly4( Complex * Fout, const size_t fstride, const size_t m)
    109     {
    110       Complex scratch[6];
    111       int negative_if_inverse = m_inverse * -2 +1;
    112       for (size_t k=0;k<m;++k) {
    113         scratch[0] = Fout[k+m] * m_twiddles[k*fstride];
    114         scratch[1] = Fout[k+2*m] * m_twiddles[k*fstride*2];
    115         scratch[2] = Fout[k+3*m] * m_twiddles[k*fstride*3];
    116         scratch[5] = Fout[k] - scratch[1];
    117 
    118         Fout[k] += scratch[1];
    119         scratch[3] = scratch[0] + scratch[2];
    120         scratch[4] = scratch[0] - scratch[2];
    121         scratch[4] = Complex( scratch[4].imag()*negative_if_inverse , -scratch[4].real()* negative_if_inverse );
    122 
    123         Fout[k+2*m]  = Fout[k] - scratch[3];
    124         Fout[k] += scratch[3];
    125         Fout[k+m] = scratch[5] + scratch[4];
    126         Fout[k+3*m] = scratch[5] - scratch[4];
    127       }
    128     }
    129 
    130   inline
    131     void bfly3( Complex * Fout, const size_t fstride, const size_t m)
    132     {
    133       size_t k=m;
    134       const size_t m2 = 2*m;
    135       Complex *tw1,*tw2;
    136       Complex scratch[5];
    137       Complex epi3;
    138       epi3 = m_twiddles[fstride*m];
    139 
    140       tw1=tw2=&m_twiddles[0];
    141 
    142       do{
    143         scratch[1]=Fout[m] * *tw1;
    144         scratch[2]=Fout[m2] * *tw2;
    145 
    146         scratch[3]=scratch[1]+scratch[2];
    147         scratch[0]=scratch[1]-scratch[2];
    148         tw1 += fstride;
    149         tw2 += fstride*2;
    150         Fout[m] = Complex( Fout->real() - Scalar(.5)*scratch[3].real() , Fout->imag() - Scalar(.5)*scratch[3].imag() );
    151         scratch[0] *= epi3.imag();
    152         *Fout += scratch[3];
    153         Fout[m2] = Complex(  Fout[m].real() + scratch[0].imag() , Fout[m].imag() - scratch[0].real() );
    154         Fout[m] += Complex( -scratch[0].imag(),scratch[0].real() );
    155         ++Fout;
    156       }while(--k);
    157     }
    158 
    159   inline
    160     void bfly5( Complex * Fout, const size_t fstride, const size_t m)
    161     {
    162       Complex *Fout0,*Fout1,*Fout2,*Fout3,*Fout4;
    163       size_t u;
    164       Complex scratch[13];
    165       Complex * twiddles = &m_twiddles[0];
    166       Complex *tw;
    167       Complex ya,yb;
    168       ya = twiddles[fstride*m];
    169       yb = twiddles[fstride*2*m];
    170 
    171       Fout0=Fout;
    172       Fout1=Fout0+m;
    173       Fout2=Fout0+2*m;
    174       Fout3=Fout0+3*m;
    175       Fout4=Fout0+4*m;
    176 
    177       tw=twiddles;
    178       for ( u=0; u<m; ++u ) {
    179         scratch[0] = *Fout0;
    180 
    181         scratch[1]  = *Fout1 * tw[u*fstride];
    182         scratch[2]  = *Fout2 * tw[2*u*fstride];
    183         scratch[3]  = *Fout3 * tw[3*u*fstride];
    184         scratch[4]  = *Fout4 * tw[4*u*fstride];
    185 
    186         scratch[7] = scratch[1] + scratch[4];
    187         scratch[10] = scratch[1] - scratch[4];
    188         scratch[8] = scratch[2] + scratch[3];
    189         scratch[9] = scratch[2] - scratch[3];
    190 
    191         *Fout0 +=  scratch[7];
    192         *Fout0 +=  scratch[8];
    193 
    194         scratch[5] = scratch[0] + Complex(
    195             (scratch[7].real()*ya.real() ) + (scratch[8].real() *yb.real() ),
    196             (scratch[7].imag()*ya.real()) + (scratch[8].imag()*yb.real())
    197             );
    198 
    199         scratch[6] = Complex(
    200             (scratch[10].imag()*ya.imag()) + (scratch[9].imag()*yb.imag()),
    201             -(scratch[10].real()*ya.imag()) - (scratch[9].real()*yb.imag())
    202             );
    203 
    204         *Fout1 = scratch[5] - scratch[6];
    205         *Fout4 = scratch[5] + scratch[6];
    206 
    207         scratch[11] = scratch[0] +
    208           Complex(
    209               (scratch[7].real()*yb.real()) + (scratch[8].real()*ya.real()),
    210               (scratch[7].imag()*yb.real()) + (scratch[8].imag()*ya.real())
    211               );
    212 
    213         scratch[12] = Complex(
    214             -(scratch[10].imag()*yb.imag()) + (scratch[9].imag()*ya.imag()),
    215             (scratch[10].real()*yb.imag()) - (scratch[9].real()*ya.imag())
    216             );
    217 
    218         *Fout2=scratch[11]+scratch[12];
    219         *Fout3=scratch[11]-scratch[12];
    220 
    221         ++Fout0;++Fout1;++Fout2;++Fout3;++Fout4;
    222       }
    223     }
    224 
    225   /* perform the butterfly for one stage of a mixed radix FFT */
    226   inline
    227     void bfly_generic(
    228         Complex * Fout,
    229         const size_t fstride,
    230         int m,
    231         int p
    232         )
    233     {
    234       int u,k,q1,q;
    235       Complex * twiddles = &m_twiddles[0];
    236       Complex t;
    237       int Norig = static_cast<int>(m_twiddles.size());
    238       Complex * scratchbuf = &m_scratchBuf[0];
    239 
    240       for ( u=0; u<m; ++u ) {
    241         k=u;
    242         for ( q1=0 ; q1<p ; ++q1 ) {
    243           scratchbuf[q1] = Fout[ k  ];
    244           k += m;
    245         }
    246 
    247         k=u;
    248         for ( q1=0 ; q1<p ; ++q1 ) {
    249           int twidx=0;
    250           Fout[ k ] = scratchbuf[0];
    251           for (q=1;q<p;++q ) {
    252             twidx += static_cast<int>(fstride) * k;
    253             if (twidx>=Norig) twidx-=Norig;
    254             t=scratchbuf[q] * twiddles[twidx];
    255             Fout[ k ] += t;
    256           }
    257           k += m;
    258         }
    259       }
    260     }
    261 };
    262 
    263 template <typename _Scalar>
    264 struct kissfft_impl
    265 {
    266   typedef _Scalar Scalar;
    267   typedef std::complex<Scalar> Complex;
    268 
    269   void clear()
    270   {
    271     m_plans.clear();
    272     m_realTwiddles.clear();
    273   }
    274 
    275   inline
    276     void fwd( Complex * dst,const Complex *src,int nfft)
    277     {
    278       get_plan(nfft,false).work(0, dst, src, 1,1);
    279     }
    280 
    281   inline
    282     void fwd2( Complex * dst,const Complex *src,int n0,int n1)
    283     {
    284         EIGEN_UNUSED_VARIABLE(dst);
    285         EIGEN_UNUSED_VARIABLE(src);
    286         EIGEN_UNUSED_VARIABLE(n0);
    287         EIGEN_UNUSED_VARIABLE(n1);
    288     }
    289 
    290   inline
    291     void inv2( Complex * dst,const Complex *src,int n0,int n1)
    292     {
    293         EIGEN_UNUSED_VARIABLE(dst);
    294         EIGEN_UNUSED_VARIABLE(src);
    295         EIGEN_UNUSED_VARIABLE(n0);
    296         EIGEN_UNUSED_VARIABLE(n1);
    297     }
    298 
    299   // real-to-complex forward FFT
    300   // perform two FFTs of src even and src odd
    301   // then twiddle to recombine them into the half-spectrum format
    302   // then fill in the conjugate symmetric half
    303   inline
    304     void fwd( Complex * dst,const Scalar * src,int nfft)
    305     {
    306       if ( nfft&3  ) {
    307         // use generic mode for odd
    308         m_tmpBuf1.resize(nfft);
    309         get_plan(nfft,false).work(0, &m_tmpBuf1[0], src, 1,1);
    310         std::copy(m_tmpBuf1.begin(),m_tmpBuf1.begin()+(nfft>>1)+1,dst );
    311       }else{
    312         int ncfft = nfft>>1;
    313         int ncfft2 = nfft>>2;
    314         Complex * rtw = real_twiddles(ncfft2);
    315 
    316         // use optimized mode for even real
    317         fwd( dst, reinterpret_cast<const Complex*> (src), ncfft);
    318         Complex dc = dst[0].real() +  dst[0].imag();
    319         Complex nyquist = dst[0].real() -  dst[0].imag();
    320         int k;
    321         for ( k=1;k <= ncfft2 ; ++k ) {
    322           Complex fpk = dst[k];
    323           Complex fpnk = conj(dst[ncfft-k]);
    324           Complex f1k = fpk + fpnk;
    325           Complex f2k = fpk - fpnk;
    326           Complex tw= f2k * rtw[k-1];
    327           dst[k] =  (f1k + tw) * Scalar(.5);
    328           dst[ncfft-k] =  conj(f1k -tw)*Scalar(.5);
    329         }
    330         dst[0] = dc;
    331         dst[ncfft] = nyquist;
    332       }
    333     }
    334 
    335   // inverse complex-to-complex
    336   inline
    337     void inv(Complex * dst,const Complex  *src,int nfft)
    338     {
    339       get_plan(nfft,true).work(0, dst, src, 1,1);
    340     }
    341 
    342   // half-complex to scalar
    343   inline
    344     void inv( Scalar * dst,const Complex * src,int nfft)
    345     {
    346       if (nfft&3) {
    347         m_tmpBuf1.resize(nfft);
    348         m_tmpBuf2.resize(nfft);
    349         std::copy(src,src+(nfft>>1)+1,m_tmpBuf1.begin() );
    350         for (int k=1;k<(nfft>>1)+1;++k)
    351           m_tmpBuf1[nfft-k] = conj(m_tmpBuf1[k]);
    352         inv(&m_tmpBuf2[0],&m_tmpBuf1[0],nfft);
    353         for (int k=0;k<nfft;++k)
    354           dst[k] = m_tmpBuf2[k].real();
    355       }else{
    356         // optimized version for multiple of 4
    357         int ncfft = nfft>>1;
    358         int ncfft2 = nfft>>2;
    359         Complex * rtw = real_twiddles(ncfft2);
    360         m_tmpBuf1.resize(ncfft);
    361         m_tmpBuf1[0] = Complex( src[0].real() + src[ncfft].real(), src[0].real() - src[ncfft].real() );
    362         for (int k = 1; k <= ncfft / 2; ++k) {
    363           Complex fk = src[k];
    364           Complex fnkc = conj(src[ncfft-k]);
    365           Complex fek = fk + fnkc;
    366           Complex tmp = fk - fnkc;
    367           Complex fok = tmp * conj(rtw[k-1]);
    368           m_tmpBuf1[k] = fek + fok;
    369           m_tmpBuf1[ncfft-k] = conj(fek - fok);
    370         }
    371         get_plan(ncfft,true).work(0, reinterpret_cast<Complex*>(dst), &m_tmpBuf1[0], 1,1);
    372       }
    373     }
    374 
    375   protected:
    376   typedef kiss_cpx_fft<Scalar> PlanData;
    377   typedef std::map<int,PlanData> PlanMap;
    378 
    379   PlanMap m_plans;
    380   std::map<int, std::vector<Complex> > m_realTwiddles;
    381   std::vector<Complex> m_tmpBuf1;
    382   std::vector<Complex> m_tmpBuf2;
    383 
    384   inline
    385     int PlanKey(int nfft, bool isinverse) const { return (nfft<<1) | int(isinverse); }
    386 
    387   inline
    388     PlanData & get_plan(int nfft, bool inverse)
    389     {
    390       // TODO look for PlanKey(nfft, ! inverse) and conjugate the twiddles
    391       PlanData & pd = m_plans[ PlanKey(nfft,inverse) ];
    392       if ( pd.m_twiddles.size() == 0 ) {
    393         pd.make_twiddles(nfft,inverse);
    394         pd.factorize(nfft);
    395       }
    396       return pd;
    397     }
    398 
    399   inline
    400     Complex * real_twiddles(int ncfft2)
    401     {
    402       std::vector<Complex> & twidref = m_realTwiddles[ncfft2];// creates new if not there
    403       if ( (int)twidref.size() != ncfft2 ) {
    404         twidref.resize(ncfft2);
    405         int ncfft= ncfft2<<1;
    406         Scalar pi =  acos( Scalar(-1) );
    407         for (int k=1;k<=ncfft2;++k)
    408           twidref[k-1] = exp( Complex(0,-pi * (Scalar(k) / ncfft + Scalar(.5)) ) );
    409       }
    410       return &twidref[0];
    411     }
    412 };
    413 
    414 } // end namespace internal
    415 
    416 } // end namespace Eigen
    417 
    418 /* vim: set filetype=cpp et sw=2 ts=2 ai: */
    419