Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "_ml.h"
     42 
     43 /****************************************************************************************\
     44 *                          K-Nearest Neighbors Classifier                                *
     45 \****************************************************************************************/
     46 
     47 // k Nearest Neighbors
     48 CvKNearest::CvKNearest()
     49 {
     50     samples = 0;
     51     clear();
     52 }
     53 
     54 
     55 CvKNearest::~CvKNearest()
     56 {
     57     clear();
     58 }
     59 
     60 
     61 CvKNearest::CvKNearest( const CvMat* _train_data, const CvMat* _responses,
     62                         const CvMat* _sample_idx, bool _is_regression, int _max_k )
     63 {
     64     samples = 0;
     65     train( _train_data, _responses, _sample_idx, _is_regression, _max_k, false );
     66 }
     67 
     68 
     69 void CvKNearest::clear()
     70 {
     71     while( samples )
     72     {
     73         CvVectors* next_samples = samples->next;
     74         cvFree( &samples->data.fl );
     75         cvFree( &samples );
     76         samples = next_samples;
     77     }
     78     var_count = 0;
     79     total = 0;
     80     max_k = 0;
     81 }
     82 
     83 
     84 int CvKNearest::get_max_k() const { return max_k; }
     85 
     86 int CvKNearest::get_var_count() const { return var_count; }
     87 
     88 bool CvKNearest::is_regression() const { return regression; }
     89 
     90 int CvKNearest::get_sample_count() const { return total; }
     91 
     92 bool CvKNearest::train( const CvMat* _train_data, const CvMat* _responses,
     93                         const CvMat* _sample_idx, bool _is_regression,
     94                         int _max_k, bool _update_base )
     95 {
     96     bool ok = false;
     97     CvMat* responses = 0;
     98 
     99     CV_FUNCNAME( "CvKNearest::train" );
    100 
    101     __BEGIN__;
    102 
    103     CvVectors* _samples;
    104     float** _data;
    105     int _count, _dims, _dims_all, _rsize;
    106 
    107     if( !_update_base )
    108         clear();
    109 
    110     // Prepare training data and related parameters.
    111     // Treat categorical responses as ordered - to prevent class label compression and
    112     // to enable entering new classes in the updates
    113     CV_CALL( cvPrepareTrainData( "CvKNearest::train", _train_data, CV_ROW_SAMPLE,
    114         _responses, CV_VAR_ORDERED, 0, _sample_idx, true, (const float***)&_data,
    115         &_count, &_dims, &_dims_all, &responses, 0, 0 ));
    116 
    117     if( _update_base && _dims != var_count )
    118         CV_ERROR( CV_StsBadArg, "The newly added data have different dimensionality" );
    119 
    120     if( !_update_base )
    121     {
    122         if( _max_k < 1 )
    123             CV_ERROR( CV_StsOutOfRange, "max_k must be a positive number" );
    124 
    125         regression = _is_regression;
    126         var_count = _dims;
    127         max_k = _max_k;
    128     }
    129 
    130     _rsize = _count*sizeof(float);
    131     CV_CALL( _samples = (CvVectors*)cvAlloc( sizeof(*_samples) + _rsize ));
    132     _samples->next = samples;
    133     _samples->type = CV_32F;
    134     _samples->data.fl = _data;
    135     _samples->count = _count;
    136     total += _count;
    137 
    138     samples = _samples;
    139     memcpy( _samples + 1, responses->data.fl, _rsize );
    140 
    141     ok = true;
    142 
    143     __END__;
    144 
    145     return ok;
    146 }
    147 
    148 
    149 
    150 void CvKNearest::find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
    151                     float* neighbor_responses, const float** neighbors, float* dist ) const
    152 {
    153     int i, j, count = end - start, k1 = 0, k2 = 0, d = var_count;
    154     CvVectors* s = samples;
    155 
    156     for( ; s != 0; s = s->next )
    157     {
    158         int n = s->count;
    159         for( j = 0; j < n; j++ )
    160         {
    161             for( i = 0; i < count; i++ )
    162             {
    163                 double sum = 0;
    164                 Cv32suf si;
    165                 const float* v = s->data.fl[j];
    166                 const float* u = (float*)(_samples->data.ptr + _samples->step*(start + i));
    167                 Cv32suf* dd = (Cv32suf*)(dist + i*k);
    168                 float* nr;
    169                 const float** nn;
    170                 int t, ii, ii1;
    171 
    172                 for( t = 0; t <= d - 4; t += 4 )
    173                 {
    174                     double t0 = u[t] - v[t], t1 = u[t+1] - v[t+1];
    175                     double t2 = u[t+2] - v[t+2], t3 = u[t+3] - v[t+3];
    176                     sum += t0*t0 + t1*t1 + t2*t2 + t3*t3;
    177                 }
    178 
    179                 for( ; t < d; t++ )
    180                 {
    181                     double t0 = u[t] - v[t];
    182                     sum += t0*t0;
    183                 }
    184 
    185                 si.f = (float)sum;
    186                 for( ii = k1-1; ii >= 0; ii-- )
    187                     if( si.i > dd[ii].i )
    188                         break;
    189                 if( ii >= k-1 )
    190                     continue;
    191 
    192                 nr = neighbor_responses + i*k;
    193                 nn = neighbors ? neighbors + (start + i)*k : 0;
    194                 for( ii1 = k2 - 1; ii1 > ii; ii1-- )
    195                 {
    196                     dd[ii1+1].i = dd[ii1].i;
    197                     nr[ii1+1] = nr[ii1];
    198                     if( nn ) nn[ii1+1] = nn[ii1];
    199                 }
    200                 dd[ii+1].i = si.i;
    201                 nr[ii+1] = ((float*)(s + 1))[j];
    202                 if( nn )
    203                     nn[ii+1] = v;
    204             }
    205             k1 = MIN( k1+1, k );
    206             k2 = MIN( k1, k-1 );
    207         }
    208     }
    209 }
    210 
    211 
    212 float CvKNearest::write_results( int k, int k1, int start, int end,
    213     const float* neighbor_responses, const float* dist,
    214     CvMat* _results, CvMat* _neighbor_responses,
    215     CvMat* _dist, Cv32suf* sort_buf ) const
    216 {
    217     float result = 0.f;
    218     int i, j, j1, count = end - start;
    219     double inv_scale = 1./k1;
    220     int rstep = _results && !CV_IS_MAT_CONT(_results->type) ? _results->step/sizeof(result) : 1;
    221 
    222     for( i = 0; i < count; i++ )
    223     {
    224         const Cv32suf* nr = (const Cv32suf*)(neighbor_responses + i*k);
    225         float* dst;
    226         float r;
    227         if( _results || start+i == 0 )
    228         {
    229             if( regression )
    230             {
    231                 double s = 0;
    232                 for( j = 0; j < k1; j++ )
    233                     s += nr[j].f;
    234                 r = (float)(s*inv_scale);
    235             }
    236             else
    237             {
    238                 int prev_start = 0, best_count = 0, cur_count;
    239                 Cv32suf best_val;
    240 
    241                 for( j = 0; j < k1; j++ )
    242                     sort_buf[j].i = nr[j].i;
    243 
    244                 for( j = k1-1; j > 0; j-- )
    245                 {
    246                     bool swap_fl = false;
    247                     for( j1 = 0; j1 < j; j1++ )
    248                         if( sort_buf[j1].i > sort_buf[j1+1].i )
    249                         {
    250                             int t;
    251                             CV_SWAP( sort_buf[j1].i, sort_buf[j1+1].i, t );
    252                             swap_fl = true;
    253                         }
    254                     if( !swap_fl )
    255                         break;
    256                 }
    257 
    258                 best_val.i = 0;
    259                 for( j = 1; j <= k1; j++ )
    260                     if( j == k1 || sort_buf[j].i != sort_buf[j-1].i )
    261                     {
    262                         cur_count = j - prev_start;
    263                         if( best_count < cur_count )
    264                         {
    265                             best_count = cur_count;
    266                             best_val.i = sort_buf[j-1].i;
    267                         }
    268                         prev_start = j;
    269                     }
    270                 r = best_val.f;
    271             }
    272 
    273             if( start+i == 0 )
    274                 result = r;
    275 
    276             if( _results )
    277                 _results->data.fl[(start + i)*rstep] = r;
    278         }
    279 
    280         if( _neighbor_responses )
    281         {
    282             dst = (float*)(_neighbor_responses->data.ptr +
    283                 (start + i)*_neighbor_responses->step);
    284             for( j = 0; j < k1; j++ )
    285                 dst[j] = nr[j].f;
    286             for( ; j < k; j++ )
    287                 dst[j] = 0.f;
    288         }
    289 
    290         if( _dist )
    291         {
    292             dst = (float*)(_dist->data.ptr + (start + i)*_dist->step);
    293             for( j = 0; j < k1; j++ )
    294                 dst[j] = dist[j + i*k];
    295             for( ; j < k; j++ )
    296                 dst[j] = 0.f;
    297         }
    298     }
    299 
    300     return result;
    301 }
    302 
    303 
    304 
    305 float CvKNearest::find_nearest( const CvMat* _samples, int k, CvMat* _results,
    306     const float** _neighbors, CvMat* _neighbor_responses, CvMat* _dist ) const
    307 {
    308     float result = 0.f;
    309     bool local_alloc = false;
    310     float* buf = 0;
    311     const int max_blk_count = 128, max_buf_sz = 1 << 12;
    312 
    313     CV_FUNCNAME( "CvKNearest::find_nearest" );
    314 
    315     __BEGIN__;
    316 
    317     int i, count, count_scale, blk_count0, blk_count = 0, buf_sz, k1;
    318 
    319     if( !samples )
    320         CV_ERROR( CV_StsError, "The search tree must be constructed first using train method" );
    321 
    322     if( !CV_IS_MAT(_samples) ||
    323         CV_MAT_TYPE(_samples->type) != CV_32FC1 ||
    324         _samples->cols != var_count )
    325         CV_ERROR( CV_StsBadArg, "Input samples must be floating-point matrix (<num_samples>x<var_count>)" );
    326 
    327     if( _results && (!CV_IS_MAT(_results) ||
    328         _results->cols != 1 && _results->rows != 1 ||
    329         _results->cols + _results->rows - 1 != _samples->rows) )
    330         CV_ERROR( CV_StsBadArg,
    331         "The results must be 1d vector containing as much elements as the number of samples" );
    332 
    333     if( _results && CV_MAT_TYPE(_results->type) != CV_32FC1 &&
    334         (CV_MAT_TYPE(_results->type) != CV_32SC1 || regression))
    335         CV_ERROR( CV_StsUnsupportedFormat,
    336         "The results must be floating-point or integer (in case of classification) vector" );
    337 
    338     if( k < 1 || k > max_k )
    339         CV_ERROR( CV_StsOutOfRange, "k must be within 1..max_k range" );
    340 
    341     if( _neighbor_responses )
    342     {
    343         if( !CV_IS_MAT(_neighbor_responses) || CV_MAT_TYPE(_neighbor_responses->type) != CV_32FC1 ||
    344             _neighbor_responses->rows != _samples->rows || _neighbor_responses->cols != k )
    345             CV_ERROR( CV_StsBadArg,
    346             "The neighbor responses (if present) must be floating-point matrix of <num_samples> x <k> size" );
    347     }
    348 
    349     if( _dist )
    350     {
    351         if( !CV_IS_MAT(_dist) || CV_MAT_TYPE(_dist->type) != CV_32FC1 ||
    352             _dist->rows != _samples->rows || _dist->cols != k )
    353             CV_ERROR( CV_StsBadArg,
    354             "The distances from the neighbors (if present) must be floating-point matrix of <num_samples> x <k> size" );
    355     }
    356 
    357     count = _samples->rows;
    358     count_scale = k*2*sizeof(float);
    359     blk_count0 = MIN( count, max_blk_count );
    360     buf_sz = MIN( blk_count0 * count_scale, max_buf_sz );
    361     blk_count0 = MAX( buf_sz/count_scale, 1 );
    362     blk_count0 += blk_count0 % 2;
    363     blk_count0 = MIN( blk_count0, count );
    364     buf_sz = blk_count0 * count_scale + k*sizeof(float);
    365     k1 = get_sample_count();
    366     k1 = MIN( k1, k );
    367 
    368     if( buf_sz <= CV_MAX_LOCAL_SIZE )
    369     {
    370         buf = (float*)cvStackAlloc( buf_sz );
    371         local_alloc = true;
    372     }
    373     else
    374         CV_CALL( buf = (float*)cvAlloc( buf_sz ));
    375 
    376     for( i = 0; i < count; i += blk_count )
    377     {
    378         blk_count = MIN( count - i, blk_count0 );
    379         float* neighbor_responses = buf;
    380         float* dist = buf + blk_count*k;
    381         Cv32suf* sort_buf = (Cv32suf*)(dist + blk_count*k);
    382 
    383         find_neighbors_direct( _samples, k, i, i + blk_count,
    384                     neighbor_responses, _neighbors, dist );
    385 
    386         float r = write_results( k, k1, i, i + blk_count, neighbor_responses, dist,
    387                                  _results, _neighbor_responses, _dist, sort_buf );
    388         if( i == 0 )
    389             result = r;
    390     }
    391 
    392     __END__;
    393 
    394     if( !local_alloc )
    395         cvFree( &buf );
    396 
    397     return result;
    398 }
    399 
    400 /* End of file */
    401 
    402