Home | History | Annotate | Download | only in test
      1 // This file is part of Eigen, a lightweight C++ template library
      2 // for linear algebra.
      3 //
      4 // Copyright (C) 2009 Mark Borgerding mark a borgerding net
      5 //
      6 // This Source Code Form is subject to the terms of the Mozilla
      7 // Public License v. 2.0. If a copy of the MPL was not distributed
      8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
      9 
     10 #include "main.h"
     11 #include <unsupported/Eigen/FFT>
     12 
     13 template <typename T>
     14 std::complex<T> RandomCpx() { return std::complex<T>( (T)(rand()/(T)RAND_MAX - .5), (T)(rand()/(T)RAND_MAX - .5) ); }
     15 
     16 using namespace std;
     17 using namespace Eigen;
     18 
     19 
     20 template < typename T>
     21 complex<long double>  promote(complex<T> x) { return complex<long double>(x.real(),x.imag()); }
     22 
     23 complex<long double>  promote(float x) { return complex<long double>( x); }
     24 complex<long double>  promote(double x) { return complex<long double>( x); }
     25 complex<long double>  promote(long double x) { return complex<long double>( x); }
     26 
     27 
     28     template <typename VT1,typename VT2>
     29     long double fft_rmse( const VT1 & fftbuf,const VT2 & timebuf)
     30     {
     31         long double totalpower=0;
     32         long double difpower=0;
     33         long double pi = acos((long double)-1 );
     34         for (size_t k0=0;k0<(size_t)fftbuf.size();++k0) {
     35             complex<long double> acc = 0;
     36             long double phinc = -2.*k0* pi / timebuf.size();
     37             for (size_t k1=0;k1<(size_t)timebuf.size();++k1) {
     38                 acc +=  promote( timebuf[k1] ) * exp( complex<long double>(0,k1*phinc) );
     39             }
     40             totalpower += numext::abs2(acc);
     41             complex<long double> x = promote(fftbuf[k0]);
     42             complex<long double> dif = acc - x;
     43             difpower += numext::abs2(dif);
     44             //cerr << k0 << "\t" << acc << "\t" <<  x << "\t" << sqrt(numext::abs2(dif)) << endl;
     45         }
     46         cerr << "rmse:" << sqrt(difpower/totalpower) << endl;
     47         return sqrt(difpower/totalpower);
     48     }
     49 
     50     template <typename VT1,typename VT2>
     51     long double dif_rmse( const VT1 buf1,const VT2 buf2)
     52     {
     53         long double totalpower=0;
     54         long double difpower=0;
     55         size_t n = (min)( buf1.size(),buf2.size() );
     56         for (size_t k=0;k<n;++k) {
     57             totalpower += (numext::abs2( buf1[k] ) + numext::abs2(buf2[k]) )/2.;
     58             difpower += numext::abs2(buf1[k] - buf2[k]);
     59         }
     60         return sqrt(difpower/totalpower);
     61     }
     62 
     63 enum { StdVectorContainer, EigenVectorContainer };
     64 
     65 template<int Container, typename Scalar> struct VectorType;
     66 
     67 template<typename Scalar> struct VectorType<StdVectorContainer,Scalar>
     68 {
     69   typedef vector<Scalar> type;
     70 };
     71 
     72 template<typename Scalar> struct VectorType<EigenVectorContainer,Scalar>
     73 {
     74   typedef Matrix<Scalar,Dynamic,1> type;
     75 };
     76 
     77 template <int Container, typename T>
     78 void test_scalar_generic(int nfft)
     79 {
     80     typedef typename FFT<T>::Complex Complex;
     81     typedef typename FFT<T>::Scalar Scalar;
     82     typedef typename VectorType<Container,Scalar>::type ScalarVector;
     83     typedef typename VectorType<Container,Complex>::type ComplexVector;
     84 
     85     FFT<T> fft;
     86     ScalarVector tbuf(nfft);
     87     ComplexVector freqBuf;
     88     for (int k=0;k<nfft;++k)
     89         tbuf[k]= (T)( rand()/(double)RAND_MAX - .5);
     90 
     91     // make sure it DOESN'T give the right full spectrum answer
     92     // if we've asked for half-spectrum
     93     fft.SetFlag(fft.HalfSpectrum );
     94     fft.fwd( freqBuf,tbuf);
     95     VERIFY((size_t)freqBuf.size() == (size_t)( (nfft>>1)+1) );
     96     VERIFY( fft_rmse(freqBuf,tbuf) < test_precision<T>()  );// gross check
     97 
     98     fft.ClearFlag(fft.HalfSpectrum );
     99     fft.fwd( freqBuf,tbuf);
    100     VERIFY( (size_t)freqBuf.size() == (size_t)nfft);
    101     VERIFY( fft_rmse(freqBuf,tbuf) < test_precision<T>()  );// gross check
    102 
    103     if (nfft&1)
    104         return; // odd FFTs get the wrong size inverse FFT
    105 
    106     ScalarVector tbuf2;
    107     fft.inv( tbuf2 , freqBuf);
    108     VERIFY( dif_rmse(tbuf,tbuf2) < test_precision<T>()  );// gross check
    109 
    110 
    111     // verify that the Unscaled flag takes effect
    112     ScalarVector tbuf3;
    113     fft.SetFlag(fft.Unscaled);
    114 
    115     fft.inv( tbuf3 , freqBuf);
    116 
    117     for (int k=0;k<nfft;++k)
    118         tbuf3[k] *= T(1./nfft);
    119 
    120 
    121     //for (size_t i=0;i<(size_t) tbuf.size();++i)
    122     //    cout << "freqBuf=" << freqBuf[i] << " in2=" << tbuf3[i] << " -  in=" << tbuf[i] << " => " << (tbuf3[i] - tbuf[i] ) <<  endl;
    123 
    124     VERIFY( dif_rmse(tbuf,tbuf3) < test_precision<T>()  );// gross check
    125 
    126     // verify that ClearFlag works
    127     fft.ClearFlag(fft.Unscaled);
    128     fft.inv( tbuf2 , freqBuf);
    129     VERIFY( dif_rmse(tbuf,tbuf2) < test_precision<T>()  );// gross check
    130 }
    131 
    132 template <typename T>
    133 void test_scalar(int nfft)
    134 {
    135   test_scalar_generic<StdVectorContainer,T>(nfft);
    136   //test_scalar_generic<EigenVectorContainer,T>(nfft);
    137 }
    138 
    139 
    140 template <int Container, typename T>
    141 void test_complex_generic(int nfft)
    142 {
    143     typedef typename FFT<T>::Complex Complex;
    144     typedef typename VectorType<Container,Complex>::type ComplexVector;
    145 
    146     FFT<T> fft;
    147 
    148     ComplexVector inbuf(nfft);
    149     ComplexVector outbuf;
    150     ComplexVector buf3;
    151     for (int k=0;k<nfft;++k)
    152         inbuf[k]= Complex( (T)(rand()/(double)RAND_MAX - .5), (T)(rand()/(double)RAND_MAX - .5) );
    153     fft.fwd( outbuf , inbuf);
    154 
    155     VERIFY( fft_rmse(outbuf,inbuf) < test_precision<T>()  );// gross check
    156     fft.inv( buf3 , outbuf);
    157 
    158     VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>()  );// gross check
    159 
    160     // verify that the Unscaled flag takes effect
    161     ComplexVector buf4;
    162     fft.SetFlag(fft.Unscaled);
    163     fft.inv( buf4 , outbuf);
    164     for (int k=0;k<nfft;++k)
    165         buf4[k] *= T(1./nfft);
    166     VERIFY( dif_rmse(inbuf,buf4) < test_precision<T>()  );// gross check
    167 
    168     // verify that ClearFlag works
    169     fft.ClearFlag(fft.Unscaled);
    170     fft.inv( buf3 , outbuf);
    171     VERIFY( dif_rmse(inbuf,buf3) < test_precision<T>()  );// gross check
    172 }
    173 
    174 template <typename T>
    175 void test_complex(int nfft)
    176 {
    177   test_complex_generic<StdVectorContainer,T>(nfft);
    178   test_complex_generic<EigenVectorContainer,T>(nfft);
    179 }
    180 /*
    181 template <typename T,int nrows,int ncols>
    182 void test_complex2d()
    183 {
    184     typedef typename Eigen::FFT<T>::Complex Complex;
    185     FFT<T> fft;
    186     Eigen::Matrix<Complex,nrows,ncols> src,src2,dst,dst2;
    187 
    188     src = Eigen::Matrix<Complex,nrows,ncols>::Random();
    189     //src =  Eigen::Matrix<Complex,nrows,ncols>::Identity();
    190 
    191     for (int k=0;k<ncols;k++) {
    192         Eigen::Matrix<Complex,nrows,1> tmpOut;
    193         fft.fwd( tmpOut,src.col(k) );
    194         dst2.col(k) = tmpOut;
    195     }
    196 
    197     for (int k=0;k<nrows;k++) {
    198         Eigen::Matrix<Complex,1,ncols> tmpOut;
    199         fft.fwd( tmpOut,  dst2.row(k) );
    200         dst2.row(k) = tmpOut;
    201     }
    202 
    203     fft.fwd2(dst.data(),src.data(),ncols,nrows);
    204     fft.inv2(src2.data(),dst.data(),ncols,nrows);
    205     VERIFY( (src-src2).norm() < test_precision<T>() );
    206     VERIFY( (dst-dst2).norm() < test_precision<T>() );
    207 }
    208 */
    209 
    210 
    211 void test_return_by_value(int len)
    212 {
    213     VectorXf in;
    214     VectorXf in1;
    215     in.setRandom( len );
    216     VectorXcf out1,out2;
    217     FFT<float> fft;
    218 
    219     fft.SetFlag(fft.HalfSpectrum );
    220 
    221     fft.fwd(out1,in);
    222     out2 = fft.fwd(in);
    223     VERIFY( (out1-out2).norm() < test_precision<float>() );
    224     in1 = fft.inv(out1);
    225     VERIFY( (in1-in).norm() < test_precision<float>() );
    226 }
    227 
    228 void test_FFTW()
    229 {
    230   CALL_SUBTEST( test_return_by_value(32) );
    231   //CALL_SUBTEST( ( test_complex2d<float,4,8> () ) ); CALL_SUBTEST( ( test_complex2d<double,4,8> () ) );
    232   //CALL_SUBTEST( ( test_complex2d<long double,4,8> () ) );
    233   CALL_SUBTEST( test_complex<float>(32) ); CALL_SUBTEST( test_complex<double>(32) );
    234   CALL_SUBTEST( test_complex<float>(256) ); CALL_SUBTEST( test_complex<double>(256) );
    235   CALL_SUBTEST( test_complex<float>(3*8) ); CALL_SUBTEST( test_complex<double>(3*8) );
    236   CALL_SUBTEST( test_complex<float>(5*32) ); CALL_SUBTEST( test_complex<double>(5*32) );
    237   CALL_SUBTEST( test_complex<float>(2*3*4) ); CALL_SUBTEST( test_complex<double>(2*3*4) );
    238   CALL_SUBTEST( test_complex<float>(2*3*4*5) ); CALL_SUBTEST( test_complex<double>(2*3*4*5) );
    239   CALL_SUBTEST( test_complex<float>(2*3*4*5*7) ); CALL_SUBTEST( test_complex<double>(2*3*4*5*7) );
    240 
    241   CALL_SUBTEST( test_scalar<float>(32) ); CALL_SUBTEST( test_scalar<double>(32) );
    242   CALL_SUBTEST( test_scalar<float>(45) ); CALL_SUBTEST( test_scalar<double>(45) );
    243   CALL_SUBTEST( test_scalar<float>(50) ); CALL_SUBTEST( test_scalar<double>(50) );
    244   CALL_SUBTEST( test_scalar<float>(256) ); CALL_SUBTEST( test_scalar<double>(256) );
    245   CALL_SUBTEST( test_scalar<float>(2*3*4*5*7) ); CALL_SUBTEST( test_scalar<double>(2*3*4*5*7) );
    246 
    247   #ifdef EIGEN_HAS_FFTWL
    248   CALL_SUBTEST( test_complex<long double>(32) );
    249   CALL_SUBTEST( test_complex<long double>(256) );
    250   CALL_SUBTEST( test_complex<long double>(3*8) );
    251   CALL_SUBTEST( test_complex<long double>(5*32) );
    252   CALL_SUBTEST( test_complex<long double>(2*3*4) );
    253   CALL_SUBTEST( test_complex<long double>(2*3*4*5) );
    254   CALL_SUBTEST( test_complex<long double>(2*3*4*5*7) );
    255 
    256   CALL_SUBTEST( test_scalar<long double>(32) );
    257   CALL_SUBTEST( test_scalar<long double>(45) );
    258   CALL_SUBTEST( test_scalar<long double>(50) );
    259   CALL_SUBTEST( test_scalar<long double>(256) );
    260   CALL_SUBTEST( test_scalar<long double>(2*3*4*5*7) );
    261   #endif
    262 }
    263