Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                           License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
     15 // Third party copyrights are property of their respective owners.
     16 //
     17 // Redistribution and use in source and binary forms, with or without modification,
     18 // are permitted provided that the following conditions are met:
     19 //
     20 //   * Redistribution's of source code must retain the above copyright notice,
     21 //     this list of conditions and the following disclaimer.
     22 //
     23 //   * Redistribution's in binary form must reproduce the above copyright notice,
     24 //     this list of conditions and the following disclaimer in the documentation
     25 //     and/or other materials provided with the distribution.
     26 //
     27 //   * The name of the copyright holders may not be used to endorse or promote products
     28 //     derived from this software without specific prior written permission.
     29 //
     30 // This software is provided by the copyright holders and contributors "as is" and
     31 // any express or implied warranties, including, but not limited to, the implied
     32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     33 // In no event shall the Intel Corporation or contributors be liable for any direct,
     34 // indirect, incidental, special, exemplary, or consequential damages
     35 // (including, but not limited to, procurement of substitute goods or services;
     36 // loss of use, data, or profits; or business interruption) however caused
     37 // and on any theory of liability, whether in contract, strict liability,
     38 // or tort (including negligence or otherwise) arising in any way out of
     39 // the use of this software, even if advised of the possibility of such damage.
     40 //
     41 //M*/
     42 
     43 #include "precomp.hpp"
     44 #include "kdtree.hpp"
     45 
     46 /****************************************************************************************\
     47 *                              K-Nearest Neighbors Classifier                            *
     48 \****************************************************************************************/
     49 
     50 namespace cv {
     51 namespace ml {
     52 
     53 const String NAME_BRUTE_FORCE = "opencv_ml_knn";
     54 const String NAME_KDTREE = "opencv_ml_knn_kd";
     55 
     56 class Impl
     57 {
     58 public:
     59     Impl()
     60     {
     61         defaultK = 10;
     62         isclassifier = true;
     63         Emax = INT_MAX;
     64     }
     65 
     66     virtual ~Impl() {}
     67     virtual String getModelName() const = 0;
     68     virtual int getType() const = 0;
     69     virtual float findNearest( InputArray _samples, int k,
     70                                OutputArray _results,
     71                                OutputArray _neighborResponses,
     72                                OutputArray _dists ) const = 0;
     73 
     74     bool train( const Ptr<TrainData>& data, int flags )
     75     {
     76         Mat new_samples = data->getTrainSamples(ROW_SAMPLE);
     77         Mat new_responses;
     78         data->getTrainResponses().convertTo(new_responses, CV_32F);
     79         bool update = (flags & ml::KNearest::UPDATE_MODEL) != 0 && !samples.empty();
     80 
     81         CV_Assert( new_samples.type() == CV_32F );
     82 
     83         if( !update )
     84         {
     85             clear();
     86         }
     87         else
     88         {
     89             CV_Assert( new_samples.cols == samples.cols &&
     90                        new_responses.cols == responses.cols );
     91         }
     92 
     93         samples.push_back(new_samples);
     94         responses.push_back(new_responses);
     95 
     96         doTrain(samples);
     97 
     98         return true;
     99     }
    100 
    101     virtual void doTrain(InputArray points) { (void)points; }
    102 
    103     void clear()
    104     {
    105         samples.release();
    106         responses.release();
    107     }
    108 
    109     void read( const FileNode& fn )
    110     {
    111         clear();
    112         isclassifier = (int)fn["is_classifier"] != 0;
    113         defaultK = (int)fn["default_k"];
    114 
    115         fn["samples"] >> samples;
    116         fn["responses"] >> responses;
    117     }
    118 
    119     void write( FileStorage& fs ) const
    120     {
    121         fs << "is_classifier" << (int)isclassifier;
    122         fs << "default_k" << defaultK;
    123 
    124         fs << "samples" << samples;
    125         fs << "responses" << responses;
    126     }
    127 
    128 public:
    129     int defaultK;
    130     bool isclassifier;
    131     int Emax;
    132 
    133     Mat samples;
    134     Mat responses;
    135 };
    136 
    137 class BruteForceImpl : public Impl
    138 {
    139 public:
    140     String getModelName() const { return NAME_BRUTE_FORCE; }
    141     int getType() const { return ml::KNearest::BRUTE_FORCE; }
    142 
    143     void findNearestCore( const Mat& _samples, int k0, const Range& range,
    144                           Mat* results, Mat* neighbor_responses,
    145                           Mat* dists, float* presult ) const
    146     {
    147         int testidx, baseidx, i, j, d = samples.cols, nsamples = samples.rows;
    148         int testcount = range.end - range.start;
    149         int k = std::min(k0, nsamples);
    150 
    151         AutoBuffer<float> buf(testcount*k*2);
    152         float* dbuf = buf;
    153         float* rbuf = dbuf + testcount*k;
    154 
    155         const float* rptr = responses.ptr<float>();
    156 
    157         for( testidx = 0; testidx < testcount; testidx++ )
    158         {
    159             for( i = 0; i < k; i++ )
    160             {
    161                 dbuf[testidx*k + i] = FLT_MAX;
    162                 rbuf[testidx*k + i] = 0.f;
    163             }
    164         }
    165 
    166         for( baseidx = 0; baseidx < nsamples; baseidx++ )
    167         {
    168             for( testidx = 0; testidx < testcount; testidx++ )
    169             {
    170                 const float* v = samples.ptr<float>(baseidx);
    171                 const float* u = _samples.ptr<float>(testidx + range.start);
    172 
    173                 float s = 0;
    174                 for( i = 0; i <= d - 4; i += 4 )
    175                 {
    176                     float t0 = u[i] - v[i], t1 = u[i+1] - v[i+1];
    177                     float t2 = u[i+2] - v[i+2], t3 = u[i+3] - v[i+3];
    178                     s += t0*t0 + t1*t1 + t2*t2 + t3*t3;
    179                 }
    180 
    181                 for( ; i < d; i++ )
    182                 {
    183                     float t0 = u[i] - v[i];
    184                     s += t0*t0;
    185                 }
    186 
    187                 Cv32suf si;
    188                 si.f = (float)s;
    189                 Cv32suf* dd = (Cv32suf*)(&dbuf[testidx*k]);
    190                 float* nr = &rbuf[testidx*k];
    191 
    192                 for( i = k; i > 0; i-- )
    193                     if( si.i >= dd[i-1].i )
    194                         break;
    195                 if( i >= k )
    196                     continue;
    197 
    198                 for( j = k-2; j >= i; j-- )
    199                 {
    200                     dd[j+1].i = dd[j].i;
    201                     nr[j+1] = nr[j];
    202                 }
    203                 dd[i].i = si.i;
    204                 nr[i] = rptr[baseidx];
    205             }
    206         }
    207 
    208         float result = 0.f;
    209         float inv_scale = 1.f/k;
    210 
    211         for( testidx = 0; testidx < testcount; testidx++ )
    212         {
    213             if( neighbor_responses )
    214             {
    215                 float* nr = neighbor_responses->ptr<float>(testidx + range.start);
    216                 for( j = 0; j < k; j++ )
    217                     nr[j] = rbuf[testidx*k + j];
    218                 for( ; j < k0; j++ )
    219                     nr[j] = 0.f;
    220             }
    221 
    222             if( dists )
    223             {
    224                 float* dptr = dists->ptr<float>(testidx + range.start);
    225                 for( j = 0; j < k; j++ )
    226                     dptr[j] = dbuf[testidx*k + j];
    227                 for( ; j < k0; j++ )
    228                     dptr[j] = 0.f;
    229             }
    230 
    231             if( results || testidx+range.start == 0 )
    232             {
    233                 if( !isclassifier || k == 1 )
    234                 {
    235                     float s = 0.f;
    236                     for( j = 0; j < k; j++ )
    237                         s += rbuf[testidx*k + j];
    238                     result = (float)(s*inv_scale);
    239                 }
    240                 else
    241                 {
    242                     float* rp = rbuf + testidx*k;
    243                     for( j = k-1; j > 0; j-- )
    244                     {
    245                         bool swap_fl = false;
    246                         for( i = 0; i < j; i++ )
    247                         {
    248                             if( rp[i] > rp[i+1] )
    249                             {
    250                                 std::swap(rp[i], rp[i+1]);
    251                                 swap_fl = true;
    252                             }
    253                         }
    254                         if( !swap_fl )
    255                             break;
    256                     }
    257 
    258                     result = rp[0];
    259                     int prev_start = 0;
    260                     int best_count = 0;
    261                     for( j = 1; j <= k; j++ )
    262                     {
    263                         if( j == k || rp[j] != rp[j-1] )
    264                         {
    265                             int count = j - prev_start;
    266                             if( best_count < count )
    267                             {
    268                                 best_count = count;
    269                                 result = rp[j-1];
    270                             }
    271                             prev_start = j;
    272                         }
    273                     }
    274                 }
    275                 if( results )
    276                     results->at<float>(testidx + range.start) = result;
    277                 if( presult && testidx+range.start == 0 )
    278                     *presult = result;
    279             }
    280         }
    281     }
    282 
    283     struct findKNearestInvoker : public ParallelLoopBody
    284     {
    285         findKNearestInvoker(const BruteForceImpl* _p, int _k, const Mat& __samples,
    286                             Mat* __results, Mat* __neighbor_responses, Mat* __dists, float* _presult)
    287         {
    288             p = _p;
    289             k = _k;
    290             _samples = &__samples;
    291             _results = __results;
    292             _neighbor_responses = __neighbor_responses;
    293             _dists = __dists;
    294             presult = _presult;
    295         }
    296 
    297         void operator()( const Range& range ) const
    298         {
    299             int delta = std::min(range.end - range.start, 256);
    300             for( int start = range.start; start < range.end; start += delta )
    301             {
    302                 p->findNearestCore( *_samples, k, Range(start, std::min(start + delta, range.end)),
    303                                     _results, _neighbor_responses, _dists, presult );
    304             }
    305         }
    306 
    307         const BruteForceImpl* p;
    308         int k;
    309         const Mat* _samples;
    310         Mat* _results;
    311         Mat* _neighbor_responses;
    312         Mat* _dists;
    313         float* presult;
    314     };
    315 
    316     float findNearest( InputArray _samples, int k,
    317                        OutputArray _results,
    318                        OutputArray _neighborResponses,
    319                        OutputArray _dists ) const
    320     {
    321         float result = 0.f;
    322         CV_Assert( 0 < k );
    323 
    324         Mat test_samples = _samples.getMat();
    325         CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
    326         int testcount = test_samples.rows;
    327 
    328         if( testcount == 0 )
    329         {
    330             _results.release();
    331             _neighborResponses.release();
    332             _dists.release();
    333             return 0.f;
    334         }
    335 
    336         Mat res, nr, d, *pres = 0, *pnr = 0, *pd = 0;
    337         if( _results.needed() )
    338         {
    339             _results.create(testcount, 1, CV_32F);
    340             pres = &(res = _results.getMat());
    341         }
    342         if( _neighborResponses.needed() )
    343         {
    344             _neighborResponses.create(testcount, k, CV_32F);
    345             pnr = &(nr = _neighborResponses.getMat());
    346         }
    347         if( _dists.needed() )
    348         {
    349             _dists.create(testcount, k, CV_32F);
    350             pd = &(d = _dists.getMat());
    351         }
    352 
    353         findKNearestInvoker invoker(this, k, test_samples, pres, pnr, pd, &result);
    354         parallel_for_(Range(0, testcount), invoker);
    355         //invoker(Range(0, testcount));
    356         return result;
    357     }
    358 };
    359 
    360 
    361 class KDTreeImpl : public Impl
    362 {
    363 public:
    364     String getModelName() const { return NAME_KDTREE; }
    365     int getType() const { return ml::KNearest::KDTREE; }
    366 
    367     void doTrain(InputArray points)
    368     {
    369         tr.build(points);
    370     }
    371 
    372     float findNearest( InputArray _samples, int k,
    373                        OutputArray _results,
    374                        OutputArray _neighborResponses,
    375                        OutputArray _dists ) const
    376     {
    377         float result = 0.f;
    378         CV_Assert( 0 < k );
    379 
    380         Mat test_samples = _samples.getMat();
    381         CV_Assert( test_samples.type() == CV_32F && test_samples.cols == samples.cols );
    382         int testcount = test_samples.rows;
    383 
    384         if( testcount == 0 )
    385         {
    386             _results.release();
    387             _neighborResponses.release();
    388             _dists.release();
    389             return 0.f;
    390         }
    391 
    392         Mat res, nr, d;
    393         if( _results.needed() )
    394         {
    395             _results.create(testcount, 1, CV_32F);
    396             res = _results.getMat();
    397         }
    398         if( _neighborResponses.needed() )
    399         {
    400             _neighborResponses.create(testcount, k, CV_32F);
    401             nr = _neighborResponses.getMat();
    402         }
    403         if( _dists.needed() )
    404         {
    405             _dists.create(testcount, k, CV_32F);
    406             d = _dists.getMat();
    407         }
    408 
    409         for (int i=0; i<test_samples.rows; ++i)
    410         {
    411             Mat _res, _nr, _d;
    412             if (res.rows>i)
    413             {
    414                 _res = res.row(i);
    415             }
    416             if (nr.rows>i)
    417             {
    418                 _nr = nr.row(i);
    419             }
    420             if (d.rows>i)
    421             {
    422                 _d = d.row(i);
    423             }
    424             tr.findNearest(test_samples.row(i), k, Emax, _res, _nr, _d, noArray());
    425         }
    426 
    427         return result; // currently always 0
    428     }
    429 
    430     KDTree tr;
    431 };
    432 
    433 //================================================================
    434 
    435 class KNearestImpl : public KNearest
    436 {
    437     CV_IMPL_PROPERTY(int, DefaultK, impl->defaultK)
    438     CV_IMPL_PROPERTY(bool, IsClassifier, impl->isclassifier)
    439     CV_IMPL_PROPERTY(int, Emax, impl->Emax)
    440 
    441 public:
    442     int getAlgorithmType() const
    443     {
    444         return impl->getType();
    445     }
    446     void setAlgorithmType(int val)
    447     {
    448         if (val != BRUTE_FORCE && val != KDTREE)
    449             val = BRUTE_FORCE;
    450         initImpl(val);
    451     }
    452 
    453 public:
    454     KNearestImpl()
    455     {
    456         initImpl(BRUTE_FORCE);
    457     }
    458     ~KNearestImpl()
    459     {
    460     }
    461 
    462     bool isClassifier() const { return impl->isclassifier; }
    463     bool isTrained() const { return !impl->samples.empty(); }
    464 
    465     int getVarCount() const { return impl->samples.cols; }
    466 
    467     void write( FileStorage& fs ) const
    468     {
    469         impl->write(fs);
    470     }
    471 
    472     void read( const FileNode& fn )
    473     {
    474         int algorithmType = BRUTE_FORCE;
    475         if (fn.name() == NAME_KDTREE)
    476             algorithmType = KDTREE;
    477         initImpl(algorithmType);
    478         impl->read(fn);
    479     }
    480 
    481     float findNearest( InputArray samples, int k,
    482                        OutputArray results,
    483                        OutputArray neighborResponses=noArray(),
    484                        OutputArray dist=noArray() ) const
    485     {
    486         return impl->findNearest(samples, k, results, neighborResponses, dist);
    487     }
    488 
    489     float predict(InputArray inputs, OutputArray outputs, int) const
    490     {
    491         return impl->findNearest( inputs, impl->defaultK, outputs, noArray(), noArray() );
    492     }
    493 
    494     bool train( const Ptr<TrainData>& data, int flags )
    495     {
    496         return impl->train(data, flags);
    497     }
    498 
    499     String getDefaultName() const { return impl->getModelName(); }
    500 
    501 protected:
    502     void initImpl(int algorithmType)
    503     {
    504         if (algorithmType != KDTREE)
    505             impl = makePtr<BruteForceImpl>();
    506         else
    507             impl = makePtr<KDTreeImpl>();
    508     }
    509     Ptr<Impl> impl;
    510 };
    511 
    512 Ptr<KNearest> KNearest::create()
    513 {
    514     return makePtr<KNearestImpl>();
    515 }
    516 
    517 }
    518 }
    519 
    520 /* End of file */
    521