Home | History | Annotate | Download | only in flann
      1 /***********************************************************************
      2  * Software License Agreement (BSD License)
      3  *
      4  * Copyright 2008-2009  Marius Muja (mariusm (at) cs.ubc.ca). All rights reserved.
      5  * Copyright 2008-2009  David G. Lowe (lowe (at) cs.ubc.ca). All rights reserved.
      6  *
      7  * THE BSD LICENSE
      8  *
      9  * Redistribution and use in source and binary forms, with or without
     10  * modification, are permitted provided that the following conditions
     11  * are met:
     12  *
     13  * 1. Redistributions of source code must retain the above copyright
     14  *    notice, this list of conditions and the following disclaimer.
     15  * 2. Redistributions in binary form must reproduce the above copyright
     16  *    notice, this list of conditions and the following disclaimer in the
     17  *    documentation and/or other materials provided with the distribution.
     18  *
     19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
     20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
     21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
     22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
     23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
     24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
     25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
     26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
     27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
     28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     29  *************************************************************************/
     30 
     31 #ifndef OPENCV_FLANN_INDEX_TESTING_H_
     32 #define OPENCV_FLANN_INDEX_TESTING_H_
     33 
     34 #include <cstring>
     35 #include <cassert>
     36 #include <cmath>
     37 
     38 #include "matrix.h"
     39 #include "nn_index.h"
     40 #include "result_set.h"
     41 #include "logger.h"
     42 #include "timer.h"
     43 
     44 
     45 namespace cvflann
     46 {
     47 
     48 inline int countCorrectMatches(int* neighbors, int* groundTruth, int n)
     49 {
     50     int count = 0;
     51     for (int i=0; i<n; ++i) {
     52         for (int k=0; k<n; ++k) {
     53             if (neighbors[i]==groundTruth[k]) {
     54                 count++;
     55                 break;
     56             }
     57         }
     58     }
     59     return count;
     60 }
     61 
     62 
     63 template <typename Distance>
     64 typename Distance::ResultType computeDistanceRaport(const Matrix<typename Distance::ElementType>& inputData, typename Distance::ElementType* target,
     65                                                     int* neighbors, int* groundTruth, int veclen, int n, const Distance& distance)
     66 {
     67     typedef typename Distance::ResultType DistanceType;
     68 
     69     DistanceType ret = 0;
     70     for (int i=0; i<n; ++i) {
     71         DistanceType den = distance(inputData[groundTruth[i]], target, veclen);
     72         DistanceType num = distance(inputData[neighbors[i]], target, veclen);
     73 
     74         if ((den==0)&&(num==0)) {
     75             ret += 1;
     76         }
     77         else {
     78             ret += num/den;
     79         }
     80     }
     81 
     82     return ret;
     83 }
     84 
     85 template <typename Distance>
     86 float search_with_ground_truth(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
     87                                const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches, int nn, int checks,
     88                                float& time, typename Distance::ResultType& dist, const Distance& distance, int skipMatches)
     89 {
     90     typedef typename Distance::ResultType DistanceType;
     91 
     92     if (matches.cols<size_t(nn)) {
     93         Logger::info("matches.cols=%d, nn=%d\n",matches.cols,nn);
     94 
     95         throw FLANNException("Ground truth is not computed for as many neighbors as requested");
     96     }
     97 
     98     KNNResultSet<DistanceType> resultSet(nn+skipMatches);
     99     SearchParams searchParams(checks);
    100 
    101     std::vector<int> indices(nn+skipMatches);
    102     std::vector<DistanceType> dists(nn+skipMatches);
    103     int* neighbors = &indices[skipMatches];
    104 
    105     int correct = 0;
    106     DistanceType distR = 0;
    107     StartStopTimer t;
    108     int repeats = 0;
    109     while (t.value<0.2) {
    110         repeats++;
    111         t.start();
    112         correct = 0;
    113         distR = 0;
    114         for (size_t i = 0; i < testData.rows; i++) {
    115             resultSet.init(&indices[0], &dists[0]);
    116             index.findNeighbors(resultSet, testData[i], searchParams);
    117 
    118             correct += countCorrectMatches(neighbors,matches[i], nn);
    119             distR += computeDistanceRaport<Distance>(inputData, testData[i], neighbors, matches[i], (int)testData.cols, nn, distance);
    120         }
    121         t.stop();
    122     }
    123     time = float(t.value/repeats);
    124 
    125     float precicion = (float)correct/(nn*testData.rows);
    126 
    127     dist = distR/(testData.rows*nn);
    128 
    129     Logger::info("%8d %10.4g %10.5g %10.5g %10.5g\n",
    130                  checks, precicion, time, 1000.0 * time / testData.rows, dist);
    131 
    132     return precicion;
    133 }
    134 
    135 
    136 template <typename Distance>
    137 float test_index_checks(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
    138                         const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches,
    139                         int checks, float& precision, const Distance& distance, int nn = 1, int skipMatches = 0)
    140 {
    141     typedef typename Distance::ResultType DistanceType;
    142 
    143     Logger::info("  Nodes  Precision(%)   Time(s)   Time/vec(ms)  Mean dist\n");
    144     Logger::info("---------------------------------------------------------\n");
    145 
    146     float time = 0;
    147     DistanceType dist = 0;
    148     precision = search_with_ground_truth(index, inputData, testData, matches, nn, checks, time, dist, distance, skipMatches);
    149 
    150     return time;
    151 }
    152 
    153 template <typename Distance>
    154 float test_index_precision(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
    155                            const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches,
    156                            float precision, int& checks, const Distance& distance, int nn = 1, int skipMatches = 0)
    157 {
    158     typedef typename Distance::ResultType DistanceType;
    159     const float SEARCH_EPS = 0.001f;
    160 
    161     Logger::info("  Nodes  Precision(%)   Time(s)   Time/vec(ms)  Mean dist\n");
    162     Logger::info("---------------------------------------------------------\n");
    163 
    164     int c2 = 1;
    165     float p2;
    166     int c1 = 1;
    167     //float p1;
    168     float time;
    169     DistanceType dist;
    170 
    171     p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, distance, skipMatches);
    172 
    173     if (p2>precision) {
    174         Logger::info("Got as close as I can\n");
    175         checks = c2;
    176         return time;
    177     }
    178 
    179     while (p2<precision) {
    180         c1 = c2;
    181         //p1 = p2;
    182         c2 *=2;
    183         p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, distance, skipMatches);
    184     }
    185 
    186     int cx;
    187     float realPrecision;
    188     if (fabs(p2-precision)>SEARCH_EPS) {
    189         Logger::info("Start linear estimation\n");
    190         // after we got to values in the vecinity of the desired precision
    191         // use linear approximation get a better estimation
    192 
    193         cx = (c1+c2)/2;
    194         realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, distance, skipMatches);
    195         while (fabs(realPrecision-precision)>SEARCH_EPS) {
    196 
    197             if (realPrecision<precision) {
    198                 c1 = cx;
    199             }
    200             else {
    201                 c2 = cx;
    202             }
    203             cx = (c1+c2)/2;
    204             if (cx==c1) {
    205                 Logger::info("Got as close as I can\n");
    206                 break;
    207             }
    208             realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, distance, skipMatches);
    209         }
    210 
    211         c2 = cx;
    212         p2 = realPrecision;
    213 
    214     }
    215     else {
    216         Logger::info("No need for linear estimation\n");
    217         cx = c2;
    218         realPrecision = p2;
    219     }
    220 
    221     checks = cx;
    222     return time;
    223 }
    224 
    225 
    226 template <typename Distance>
    227 void test_index_precisions(NNIndex<Distance>& index, const Matrix<typename Distance::ElementType>& inputData,
    228                            const Matrix<typename Distance::ElementType>& testData, const Matrix<int>& matches,
    229                            float* precisions, int precisions_length, const Distance& distance, int nn = 1, int skipMatches = 0, float maxTime = 0)
    230 {
    231     typedef typename Distance::ResultType DistanceType;
    232 
    233     const float SEARCH_EPS = 0.001;
    234 
    235     // make sure precisions array is sorted
    236     std::sort(precisions, precisions+precisions_length);
    237 
    238     int pindex = 0;
    239     float precision = precisions[pindex];
    240 
    241     Logger::info("  Nodes  Precision(%)   Time(s)   Time/vec(ms)  Mean dist\n");
    242     Logger::info("---------------------------------------------------------\n");
    243 
    244     int c2 = 1;
    245     float p2;
    246 
    247     int c1 = 1;
    248     float p1;
    249 
    250     float time;
    251     DistanceType dist;
    252 
    253     p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, distance, skipMatches);
    254 
    255     // if precision for 1 run down the tree is already
    256     // better then some of the requested precisions, then
    257     // skip those
    258     while (precisions[pindex]<p2 && pindex<precisions_length) {
    259         pindex++;
    260     }
    261 
    262     if (pindex==precisions_length) {
    263         Logger::info("Got as close as I can\n");
    264         return;
    265     }
    266 
    267     for (int i=pindex; i<precisions_length; ++i) {
    268 
    269         precision = precisions[i];
    270         while (p2<precision) {
    271             c1 = c2;
    272             p1 = p2;
    273             c2 *=2;
    274             p2 = search_with_ground_truth(index, inputData, testData, matches, nn, c2, time, dist, distance, skipMatches);
    275             if ((maxTime> 0)&&(time > maxTime)&&(p2<precision)) return;
    276         }
    277 
    278         int cx;
    279         float realPrecision;
    280         if (fabs(p2-precision)>SEARCH_EPS) {
    281             Logger::info("Start linear estimation\n");
    282             // after we got to values in the vecinity of the desired precision
    283             // use linear approximation get a better estimation
    284 
    285             cx = (c1+c2)/2;
    286             realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, distance, skipMatches);
    287             while (fabs(realPrecision-precision)>SEARCH_EPS) {
    288 
    289                 if (realPrecision<precision) {
    290                     c1 = cx;
    291                 }
    292                 else {
    293                     c2 = cx;
    294                 }
    295                 cx = (c1+c2)/2;
    296                 if (cx==c1) {
    297                     Logger::info("Got as close as I can\n");
    298                     break;
    299                 }
    300                 realPrecision = search_with_ground_truth(index, inputData, testData, matches, nn, cx, time, dist, distance, skipMatches);
    301             }
    302 
    303             c2 = cx;
    304             p2 = realPrecision;
    305 
    306         }
    307         else {
    308             Logger::info("No need for linear estimation\n");
    309             cx = c2;
    310             realPrecision = p2;
    311         }
    312 
    313     }
    314 }
    315 
    316 }
    317 
    318 #endif //OPENCV_FLANN_INDEX_TESTING_H_
    319