Home | History | Annotate | Download | only in perf
      1 #include "perf_precomp.hpp"
      2 
      3 using namespace std;
      4 using namespace cv;
      5 using namespace perf;
      6 using std::tr1::make_tuple;
      7 using std::tr1::get;
      8 
      9 CV_ENUM(NormType, NORM_L1, NORM_L2, NORM_L2SQR, NORM_HAMMING, NORM_HAMMING2)
     10 
     11 typedef std::tr1::tuple<NormType, MatType, bool> Norm_Destination_CrossCheck_t;
     12 typedef perf::TestBaseWithParam<Norm_Destination_CrossCheck_t> Norm_Destination_CrossCheck;
     13 
     14 typedef std::tr1::tuple<NormType, bool> Norm_CrossCheck_t;
     15 typedef perf::TestBaseWithParam<Norm_CrossCheck_t> Norm_CrossCheck;
     16 
     17 typedef std::tr1::tuple<MatType, bool> Source_CrossCheck_t;
     18 typedef perf::TestBaseWithParam<Source_CrossCheck_t> Source_CrossCheck;
     19 
     20 void generateData( Mat& query, Mat& train, const int sourceType );
     21 
     22 PERF_TEST_P(Norm_Destination_CrossCheck, batchDistance_8U,
     23             testing::Combine(testing::Values((int)NORM_L1, (int)NORM_L2SQR),
     24                              testing::Values(CV_32S, CV_32F),
     25                              testing::Bool()
     26                              )
     27             )
     28 {
     29     NormType normType = get<0>(GetParam());
     30     int destinationType = get<1>(GetParam());
     31     bool isCrossCheck = get<2>(GetParam());
     32     int knn = isCrossCheck ? 1 : 0;
     33 
     34     Mat queryDescriptors;
     35     Mat trainDescriptors;
     36     Mat dist;
     37     Mat ndix;
     38 
     39     generateData(queryDescriptors, trainDescriptors, CV_8U);
     40 
     41     TEST_CYCLE()
     42     {
     43         batchDistance(queryDescriptors, trainDescriptors, dist, destinationType, (isCrossCheck) ? ndix : noArray(),
     44                       normType, knn, Mat(), 0, isCrossCheck);
     45     }
     46 
     47     SANITY_CHECK(dist);
     48     if (isCrossCheck) SANITY_CHECK(ndix);
     49 }
     50 
     51 PERF_TEST_P(Norm_CrossCheck, batchDistance_Dest_32S,
     52             testing::Combine(testing::Values((int)NORM_HAMMING, (int)NORM_HAMMING2),
     53                              testing::Bool()
     54                              )
     55             )
     56 {
     57     NormType normType = get<0>(GetParam());
     58     bool isCrossCheck = get<1>(GetParam());
     59     int knn = isCrossCheck ? 1 : 0;
     60 
     61     Mat queryDescriptors;
     62     Mat trainDescriptors;
     63     Mat dist;
     64     Mat ndix;
     65 
     66     generateData(queryDescriptors, trainDescriptors, CV_8U);
     67 
     68     TEST_CYCLE()
     69     {
     70         batchDistance(queryDescriptors, trainDescriptors, dist, CV_32S, (isCrossCheck) ? ndix : noArray(),
     71                       normType, knn, Mat(), 0, isCrossCheck);
     72     }
     73 
     74     SANITY_CHECK(dist);
     75     if (isCrossCheck) SANITY_CHECK(ndix);
     76 }
     77 
     78 PERF_TEST_P(Source_CrossCheck, batchDistance_L2,
     79             testing::Combine(testing::Values(CV_8U, CV_32F),
     80                              testing::Bool()
     81                              )
     82             )
     83 {
     84     int sourceType = get<0>(GetParam());
     85     bool isCrossCheck = get<1>(GetParam());
     86     int knn = isCrossCheck ? 1 : 0;
     87 
     88     Mat queryDescriptors;
     89     Mat trainDescriptors;
     90     Mat dist;
     91     Mat ndix;
     92 
     93     generateData(queryDescriptors, trainDescriptors, sourceType);
     94 
     95     declare.time(50);
     96     TEST_CYCLE()
     97     {
     98         batchDistance(queryDescriptors, trainDescriptors, dist, CV_32F, (isCrossCheck) ? ndix : noArray(),
     99                       NORM_L2, knn, Mat(), 0, isCrossCheck);
    100     }
    101 
    102     SANITY_CHECK(dist);
    103     if (isCrossCheck) SANITY_CHECK(ndix);
    104 }
    105 
    106 PERF_TEST_P(Norm_CrossCheck, batchDistance_32F,
    107             testing::Combine(testing::Values((int)NORM_L1, (int)NORM_L2SQR),
    108                              testing::Bool()
    109                              )
    110             )
    111 {
    112     NormType normType = get<0>(GetParam());
    113     bool isCrossCheck = get<1>(GetParam());
    114     int knn = isCrossCheck ? 1 : 0;
    115 
    116     Mat queryDescriptors;
    117     Mat trainDescriptors;
    118     Mat dist;
    119     Mat ndix;
    120 
    121     generateData(queryDescriptors, trainDescriptors, CV_32F);
    122     declare.time(100);
    123 
    124     TEST_CYCLE()
    125     {
    126         batchDistance(queryDescriptors, trainDescriptors, dist, CV_32F, (isCrossCheck) ? ndix : noArray(),
    127                       normType, knn, Mat(), 0, isCrossCheck);
    128     }
    129 
    130     SANITY_CHECK(dist, 1e-4);
    131     if (isCrossCheck) SANITY_CHECK(ndix);
    132 }
    133 
    134 void generateData( Mat& query, Mat& train, const int sourceType )
    135 {
    136     const int dim = 500;
    137     const int queryDescCount = 300; // must be even number because we split train data in some cases in two
    138     const int countFactor = 4; // do not change it
    139     RNG& rng = theRNG();
    140 
    141     // Generate query descriptors randomly.
    142     // Descriptor vector elements are integer values.
    143     Mat buf( queryDescCount, dim, CV_32SC1 );
    144     rng.fill( buf, RNG::UNIFORM, Scalar::all(0), Scalar(3) );
    145     buf.convertTo( query, sourceType );
    146 
    147     // Generate train decriptors as follows:
    148     // copy each query descriptor to train set countFactor times
    149     // and perturb some one element of the copied descriptors in
    150     // in ascending order. General boundaries of the perturbation
    151     // are (0.f, 1.f).
    152     train.create( query.rows*countFactor, query.cols, sourceType );
    153     float step = (sourceType == CV_8U ? 256.f : 1.f) / countFactor;
    154     for( int qIdx = 0; qIdx < query.rows; qIdx++ )
    155     {
    156         Mat queryDescriptor = query.row(qIdx);
    157         for( int c = 0; c < countFactor; c++ )
    158         {
    159             int tIdx = qIdx * countFactor + c;
    160             Mat trainDescriptor = train.row(tIdx);
    161             queryDescriptor.copyTo( trainDescriptor );
    162             int elem = rng(dim);
    163             float diff = rng.uniform( step*c, step*(c+1) );
    164             trainDescriptor.col(elem) += diff;
    165         }
    166     }
    167 }
    168