Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "_ml.h"
     42 
     43 /****************************************************************************************\
     44                                 COPYRIGHT NOTICE
     45                                 ----------------
     46 
     47   The code has been derived from libsvm library (version 2.6)
     48   (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
     49 
     50   Here is the orignal copyright:
     51 ------------------------------------------------------------------------------------------
     52     Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
     53     All rights reserved.
     54 
     55     Redistribution and use in source and binary forms, with or without
     56     modification, are permitted provided that the following conditions
     57     are met:
     58 
     59     1. Redistributions of source code must retain the above copyright
     60     notice, this list of conditions and the following disclaimer.
     61 
     62     2. Redistributions in binary form must reproduce the above copyright
     63     notice, this list of conditions and the following disclaimer in the
     64     documentation and/or other materials provided with the distribution.
     65 
     66     3. Neither name of copyright holders nor the names of its contributors
     67     may be used to endorse or promote products derived from this software
     68     without specific prior written permission.
     69 
     70 
     71     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     72     ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     73     LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     74     A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
     75     CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
     76     EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
     77     PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
     78     PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
     79     LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
     80     NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
     81     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     82 \****************************************************************************************/
     83 
     84 #define CV_SVM_MIN_CACHE_SIZE  (40 << 20)  /* 40Mb */
     85 
     86 #include <stdarg.h>
     87 #include <ctype.h>
     88 
     89 #if _MSC_VER >= 1200
     90 #pragma warning( disable: 4514 ) /* unreferenced inline functions */
     91 #endif
     92 
     93 #if 1
     94 typedef float Qfloat;
     95 #define QFLOAT_TYPE CV_32F
     96 #else
     97 typedef double Qfloat;
     98 #define QFLOAT_TYPE CV_64F
     99 #endif
    100 
    101 // Param Grid
    102 bool CvParamGrid::check() const
    103 {
    104     bool ok = false;
    105 
    106     CV_FUNCNAME( "CvParamGrid::check" );
    107     __BEGIN__;
    108 
    109     if( min_val > max_val )
    110         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
    111     if( min_val < DBL_EPSILON )
    112         CV_ERROR( CV_StsBadArg, "Lower bound of the grid must be positive" );
    113     if( step < 1. + FLT_EPSILON )
    114         CV_ERROR( CV_StsBadArg, "Grid step must greater then 1" );
    115 
    116     ok = true;
    117 
    118     __END__;
    119 
    120     return ok;
    121 }
    122 
    123 CvParamGrid CvSVM::get_default_grid( int param_id )
    124 {
    125     CvParamGrid grid;
    126     if( param_id == CvSVM::C )
    127     {
    128         grid.min_val = 0.1;
    129         grid.max_val = 500;
    130         grid.step = 5; // total iterations = 5
    131     }
    132     else if( param_id == CvSVM::GAMMA )
    133     {
    134         grid.min_val = 1e-5;
    135         grid.max_val = 0.6;
    136         grid.step = 15; // total iterations = 4
    137     }
    138     else if( param_id == CvSVM::P )
    139     {
    140         grid.min_val = 0.01;
    141         grid.max_val = 100;
    142         grid.step = 7; // total iterations = 4
    143     }
    144     else if( param_id == CvSVM::NU )
    145     {
    146         grid.min_val = 0.01;
    147         grid.max_val = 0.2;
    148         grid.step = 3; // total iterations = 3
    149     }
    150     else if( param_id == CvSVM::COEF )
    151     {
    152         grid.min_val = 0.1;
    153         grid.max_val = 300;
    154         grid.step = 14; // total iterations = 3
    155     }
    156     else if( param_id == CvSVM::DEGREE )
    157     {
    158         grid.min_val = 0.01;
    159         grid.max_val = 4;
    160         grid.step = 7; // total iterations = 3
    161     }
    162     else
    163         cvError( CV_StsBadArg, "CvSVM::get_default_grid", "Invalid type of parameter "
    164             "(use one of CvSVM::C, CvSVM::GAMMA et al.)", __FILE__, __LINE__ );
    165     return grid;
    166 }
    167 
    168 // SVM training parameters
    169 CvSVMParams::CvSVMParams() :
    170     svm_type(CvSVM::C_SVC), kernel_type(CvSVM::RBF), degree(0),
    171     gamma(1), coef0(0), C(1), nu(0), p(0), class_weights(0)
    172 {
    173     term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
    174 }
    175 
    176 
    177 CvSVMParams::CvSVMParams( int _svm_type, int _kernel_type,
    178     double _degree, double _gamma, double _coef0,
    179     double _Con, double _nu, double _p,
    180     CvMat* _class_weights, CvTermCriteria _term_crit ) :
    181     svm_type(_svm_type), kernel_type(_kernel_type),
    182     degree(_degree), gamma(_gamma), coef0(_coef0),
    183     C(_Con), nu(_nu), p(_p), class_weights(_class_weights), term_crit(_term_crit)
    184 {
    185 }
    186 
    187 
    188 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
    189 
    190 CvSVMKernel::CvSVMKernel()
    191 {
    192     clear();
    193 }
    194 
    195 
    196 void CvSVMKernel::clear()
    197 {
    198     params = 0;
    199     calc_func = 0;
    200 }
    201 
    202 
    203 CvSVMKernel::~CvSVMKernel()
    204 {
    205 }
    206 
    207 
    208 CvSVMKernel::CvSVMKernel( const CvSVMParams* _params, Calc _calc_func )
    209 {
    210     clear();
    211     create( _params, _calc_func );
    212 }
    213 
    214 
    215 bool CvSVMKernel::create( const CvSVMParams* _params, Calc _calc_func )
    216 {
    217     clear();
    218     params = _params;
    219     calc_func = _calc_func;
    220 
    221     if( !calc_func )
    222         calc_func = params->kernel_type == CvSVM::RBF ? &CvSVMKernel::calc_rbf :
    223                     params->kernel_type == CvSVM::POLY ? &CvSVMKernel::calc_poly :
    224                     params->kernel_type == CvSVM::SIGMOID ? &CvSVMKernel::calc_sigmoid :
    225                     &CvSVMKernel::calc_linear;
    226 
    227     return true;
    228 }
    229 
    230 
    231 void CvSVMKernel::calc_non_rbf_base( int vcount, int var_count, const float** vecs,
    232                                      const float* another, Qfloat* results,
    233                                      double alpha, double beta )
    234 {
    235     int j, k;
    236     for( j = 0; j < vcount; j++ )
    237     {
    238         const float* sample = vecs[j];
    239         double s = 0;
    240         for( k = 0; k <= var_count - 4; k += 4 )
    241             s += sample[k]*another[k] + sample[k+1]*another[k+1] +
    242                  sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
    243         for( ; k < var_count; k++ )
    244             s += sample[k]*another[k];
    245         results[j] = (Qfloat)(s*alpha + beta);
    246     }
    247 }
    248 
    249 
    250 void CvSVMKernel::calc_linear( int vcount, int var_count, const float** vecs,
    251                                const float* another, Qfloat* results )
    252 {
    253     calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
    254 }
    255 
    256 
    257 void CvSVMKernel::calc_poly( int vcount, int var_count, const float** vecs,
    258                              const float* another, Qfloat* results )
    259 {
    260     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
    261     calc_non_rbf_base( vcount, var_count, vecs, another, results, params->gamma, params->coef0 );
    262     cvPow( &R, &R, params->degree );
    263 }
    264 
    265 
    266 void CvSVMKernel::calc_sigmoid( int vcount, int var_count, const float** vecs,
    267                                 const float* another, Qfloat* results )
    268 {
    269     int j;
    270     calc_non_rbf_base( vcount, var_count, vecs, another, results,
    271                        -2*params->gamma, -2*params->coef0 );
    272     // TODO: speedup this
    273     for( j = 0; j < vcount; j++ )
    274     {
    275         Qfloat t = results[j];
    276         double e = exp(-fabs(t));
    277         if( t > 0 )
    278             results[j] = (Qfloat)((1. - e)/(1. + e));
    279         else
    280             results[j] = (Qfloat)((e - 1.)/(e + 1.));
    281     }
    282 }
    283 
    284 
    285 void CvSVMKernel::calc_rbf( int vcount, int var_count, const float** vecs,
    286                             const float* another, Qfloat* results )
    287 {
    288     CvMat R = cvMat( 1, vcount, QFLOAT_TYPE, results );
    289     double gamma = -params->gamma;
    290     int j, k;
    291 
    292     for( j = 0; j < vcount; j++ )
    293     {
    294         const float* sample = vecs[j];
    295         double s = 0;
    296 
    297         for( k = 0; k <= var_count - 4; k += 4 )
    298         {
    299             double t0 = sample[k] - another[k];
    300             double t1 = sample[k+1] - another[k+1];
    301 
    302             s += t0*t0 + t1*t1;
    303 
    304             t0 = sample[k+2] - another[k+2];
    305             t1 = sample[k+3] - another[k+3];
    306 
    307             s += t0*t0 + t1*t1;
    308         }
    309 
    310         for( ; k < var_count; k++ )
    311         {
    312             double t0 = sample[k] - another[k];
    313             s += t0*t0;
    314         }
    315         results[j] = (Qfloat)(s*gamma);
    316     }
    317 
    318     cvExp( &R, &R );
    319 }
    320 
    321 
    322 void CvSVMKernel::calc( int vcount, int var_count, const float** vecs,
    323                         const float* another, Qfloat* results )
    324 {
    325     const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
    326     int j;
    327     (this->*calc_func)( vcount, var_count, vecs, another, results );
    328     for( j = 0; j < vcount; j++ )
    329     {
    330         if( results[j] > max_val )
    331             results[j] = max_val;
    332     }
    333 }
    334 
    335 
    336 // Generalized SMO+SVMlight algorithm
    337 // Solves:
    338 //
    339 //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
    340 //
    341 //      y^T \alpha = \delta
    342 //      y_i = +1 or -1
    343 //      0 <= alpha_i <= Cp for y_i = 1
    344 //      0 <= alpha_i <= Cn for y_i = -1
    345 //
    346 // Given:
    347 //
    348 //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
    349 //  l is the size of vectors and matrices
    350 //  eps is the stopping criterion
    351 //
    352 // solution will be put in \alpha, objective value will be put in obj
    353 //
    354 
    355 void CvSVMSolver::clear()
    356 {
    357     G = 0;
    358     alpha = 0;
    359     y = 0;
    360     b = 0;
    361     buf[0] = buf[1] = 0;
    362     cvReleaseMemStorage( &storage );
    363     kernel = 0;
    364     select_working_set_func = 0;
    365     calc_rho_func = 0;
    366 
    367     rows = 0;
    368     samples = 0;
    369     get_row_func = 0;
    370 }
    371 
    372 
    373 CvSVMSolver::CvSVMSolver()
    374 {
    375     storage = 0;
    376     clear();
    377 }
    378 
    379 
    380 CvSVMSolver::~CvSVMSolver()
    381 {
    382     clear();
    383 }
    384 
    385 
    386 CvSVMSolver::CvSVMSolver( int _sample_count, int _var_count, const float** _samples, schar* _y,
    387                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
    388                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
    389                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
    390 {
    391     storage = 0;
    392     create( _sample_count, _var_count, _samples, _y, _alpha_count, _alpha, _Cp, _Cn,
    393             _storage, _kernel, _get_row, _select_working_set, _calc_rho );
    394 }
    395 
    396 
    397 bool CvSVMSolver::create( int _sample_count, int _var_count, const float** _samples, schar* _y,
    398                 int _alpha_count, double* _alpha, double _Cp, double _Cn,
    399                 CvMemStorage* _storage, CvSVMKernel* _kernel, GetRow _get_row,
    400                 SelectWorkingSet _select_working_set, CalcRho _calc_rho )
    401 {
    402     bool ok = false;
    403     int i, svm_type;
    404 
    405     CV_FUNCNAME( "CvSVMSolver::create" );
    406 
    407     __BEGIN__;
    408 
    409     int rows_hdr_size;
    410 
    411     clear();
    412 
    413     sample_count = _sample_count;
    414     var_count = _var_count;
    415     samples = _samples;
    416     y = _y;
    417     alpha_count = _alpha_count;
    418     alpha = _alpha;
    419     kernel = _kernel;
    420 
    421     C[0] = _Cn;
    422     C[1] = _Cp;
    423     eps = kernel->params->term_crit.epsilon;
    424     max_iter = kernel->params->term_crit.max_iter;
    425     storage = cvCreateChildMemStorage( _storage );
    426 
    427     b = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(b[0]));
    428     alpha_status = (schar*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha_status[0]));
    429     G = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(G[0]));
    430     for( i = 0; i < 2; i++ )
    431         buf[i] = (Qfloat*)cvMemStorageAlloc( storage, sample_count*2*sizeof(buf[i][0]) );
    432     svm_type = kernel->params->svm_type;
    433 
    434     select_working_set_func = _select_working_set;
    435     if( !select_working_set_func )
    436         select_working_set_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
    437         &CvSVMSolver::select_working_set_nu_svm : &CvSVMSolver::select_working_set;
    438 
    439     calc_rho_func = _calc_rho;
    440     if( !calc_rho_func )
    441         calc_rho_func = svm_type == CvSVM::NU_SVC || svm_type == CvSVM::NU_SVR ?
    442             &CvSVMSolver::calc_rho_nu_svm : &CvSVMSolver::calc_rho;
    443 
    444     get_row_func = _get_row;
    445     if( !get_row_func )
    446         get_row_func = params->svm_type == CvSVM::EPS_SVR ||
    447                        params->svm_type == CvSVM::NU_SVR ? &CvSVMSolver::get_row_svr :
    448                        params->svm_type == CvSVM::C_SVC ||
    449                        params->svm_type == CvSVM::NU_SVC ? &CvSVMSolver::get_row_svc :
    450                        &CvSVMSolver::get_row_one_class;
    451 
    452     cache_line_size = sample_count*sizeof(Qfloat);
    453     // cache size = max(num_of_samples^2*sizeof(Qfloat)*0.25, 64Kb)
    454     // (assuming that for large training sets ~25% of Q matrix is used)
    455     cache_size = MAX( cache_line_size*sample_count/4, CV_SVM_MIN_CACHE_SIZE );
    456 
    457     // the size of Q matrix row headers
    458     rows_hdr_size = sample_count*sizeof(rows[0]);
    459     if( rows_hdr_size > storage->block_size )
    460         CV_ERROR( CV_StsOutOfRange, "Too small storage block size" );
    461 
    462     lru_list.prev = lru_list.next = &lru_list;
    463     rows = (CvSVMKernelRow*)cvMemStorageAlloc( storage, rows_hdr_size );
    464     memset( rows, 0, rows_hdr_size );
    465 
    466     ok = true;
    467 
    468     __END__;
    469 
    470     return ok;
    471 }
    472 
    473 
    474 float* CvSVMSolver::get_row_base( int i, bool* _existed )
    475 {
    476     int i1 = i < sample_count ? i : i - sample_count;
    477     CvSVMKernelRow* row = rows + i1;
    478     bool existed = row->data != 0;
    479     Qfloat* data;
    480 
    481     if( existed || cache_size <= 0 )
    482     {
    483         CvSVMKernelRow* del_row = existed ? row : lru_list.prev;
    484         data = del_row->data;
    485         assert( data != 0 );
    486 
    487         // delete row from the LRU list
    488         del_row->data = 0;
    489         del_row->prev->next = del_row->next;
    490         del_row->next->prev = del_row->prev;
    491     }
    492     else
    493     {
    494         data = (Qfloat*)cvMemStorageAlloc( storage, cache_line_size );
    495         cache_size -= cache_line_size;
    496     }
    497 
    498     // insert row into the LRU list
    499     row->data = data;
    500     row->prev = &lru_list;
    501     row->next = lru_list.next;
    502     row->prev->next = row->next->prev = row;
    503 
    504     if( !existed )
    505     {
    506         kernel->calc( sample_count, var_count, samples, samples[i1], row->data );
    507     }
    508 
    509     if( _existed )
    510         *_existed = existed;
    511 
    512     return row->data;
    513 }
    514 
    515 
    516 float* CvSVMSolver::get_row_svc( int i, float* row, float*, bool existed )
    517 {
    518     if( !existed )
    519     {
    520         const schar* _y = y;
    521         int j, len = sample_count;
    522         assert( _y && i < sample_count );
    523 
    524         if( _y[i] > 0 )
    525         {
    526             for( j = 0; j < len; j++ )
    527                 row[j] = _y[j]*row[j];
    528         }
    529         else
    530         {
    531             for( j = 0; j < len; j++ )
    532                 row[j] = -_y[j]*row[j];
    533         }
    534     }
    535     return row;
    536 }
    537 
    538 
    539 float* CvSVMSolver::get_row_one_class( int, float* row, float*, bool )
    540 {
    541     return row;
    542 }
    543 
    544 
    545 float* CvSVMSolver::get_row_svr( int i, float* row, float* dst, bool )
    546 {
    547     int j, len = sample_count;
    548     Qfloat* dst_pos = dst;
    549     Qfloat* dst_neg = dst + len;
    550     if( i >= len )
    551     {
    552         Qfloat* temp;
    553         CV_SWAP( dst_pos, dst_neg, temp );
    554     }
    555 
    556     for( j = 0; j < len; j++ )
    557     {
    558         Qfloat t = row[j];
    559         dst_pos[j] = t;
    560         dst_neg[j] = -t;
    561     }
    562     return dst;
    563 }
    564 
    565 
    566 
    567 float* CvSVMSolver::get_row( int i, float* dst )
    568 {
    569     bool existed = false;
    570     float* row = get_row_base( i, &existed );
    571     return (this->*get_row_func)( i, row, dst, existed );
    572 }
    573 
    574 
    575 #undef is_upper_bound
    576 #define is_upper_bound(i) (alpha_status[i] > 0)
    577 
    578 #undef is_lower_bound
    579 #define is_lower_bound(i) (alpha_status[i] < 0)
    580 
    581 #undef is_free
    582 #define is_free(i) (alpha_status[i] == 0)
    583 
    584 #undef get_C
    585 #define get_C(i) (C[y[i]>0])
    586 
    587 #undef update_alpha_status
    588 #define update_alpha_status(i) \
    589     alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
    590 
    591 #undef reconstruct_gradient
    592 #define reconstruct_gradient() /* empty for now */
    593 
    594 
    595 bool CvSVMSolver::solve_generic( CvSVMSolutionInfo& si )
    596 {
    597     int iter = 0;
    598     int i, j, k;
    599 
    600     // 1. initialize gradient and alpha status
    601     for( i = 0; i < alpha_count; i++ )
    602     {
    603         update_alpha_status(i);
    604         G[i] = b[i];
    605         if( fabs(G[i]) > 1e200 )
    606             return false;
    607     }
    608 
    609     for( i = 0; i < alpha_count; i++ )
    610     {
    611         if( !is_lower_bound(i) )
    612         {
    613             const Qfloat *Q_i = get_row( i, buf[0] );
    614             double alpha_i = alpha[i];
    615 
    616             for( j = 0; j < alpha_count; j++ )
    617                 G[j] += alpha_i*Q_i[j];
    618         }
    619     }
    620 
    621     // 2. optimization loop
    622     for(;;)
    623     {
    624         const Qfloat *Q_i, *Q_j;
    625         double C_i, C_j;
    626         double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
    627         double delta_alpha_i, delta_alpha_j;
    628 
    629 #ifdef _DEBUG
    630         for( i = 0; i < alpha_count; i++ )
    631         {
    632             if( fabs(G[i]) > 1e+300 )
    633                 return false;
    634 
    635             if( fabs(alpha[i]) > 1e16 )
    636                 return false;
    637         }
    638 #endif
    639 
    640         if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
    641             break;
    642 
    643         Q_i = get_row( i, buf[0] );
    644         Q_j = get_row( j, buf[1] );
    645 
    646         C_i = get_C(i);
    647         C_j = get_C(j);
    648 
    649         alpha_i = old_alpha_i = alpha[i];
    650         alpha_j = old_alpha_j = alpha[j];
    651 
    652         if( y[i] != y[j] )
    653         {
    654             double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
    655             double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
    656             double diff = alpha_i - alpha_j;
    657             alpha_i += delta;
    658             alpha_j += delta;
    659 
    660             if( diff > 0 && alpha_j < 0 )
    661             {
    662                 alpha_j = 0;
    663                 alpha_i = diff;
    664             }
    665             else if( diff <= 0 && alpha_i < 0 )
    666             {
    667                 alpha_i = 0;
    668                 alpha_j = -diff;
    669             }
    670 
    671             if( diff > C_i - C_j && alpha_i > C_i )
    672             {
    673                 alpha_i = C_i;
    674                 alpha_j = C_i - diff;
    675             }
    676             else if( diff <= C_i - C_j && alpha_j > C_j )
    677             {
    678                 alpha_j = C_j;
    679                 alpha_i = C_j + diff;
    680             }
    681         }
    682         else
    683         {
    684             double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
    685             double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
    686             double sum = alpha_i + alpha_j;
    687             alpha_i -= delta;
    688             alpha_j += delta;
    689 
    690             if( sum > C_i && alpha_i > C_i )
    691             {
    692                 alpha_i = C_i;
    693                 alpha_j = sum - C_i;
    694             }
    695             else if( sum <= C_i && alpha_j < 0)
    696             {
    697                 alpha_j = 0;
    698                 alpha_i = sum;
    699             }
    700 
    701             if( sum > C_j && alpha_j > C_j )
    702             {
    703                 alpha_j = C_j;
    704                 alpha_i = sum - C_j;
    705             }
    706             else if( sum <= C_j && alpha_i < 0 )
    707             {
    708                 alpha_i = 0;
    709                 alpha_j = sum;
    710             }
    711         }
    712 
    713         // update alpha
    714         alpha[i] = alpha_i;
    715         alpha[j] = alpha_j;
    716         update_alpha_status(i);
    717         update_alpha_status(j);
    718 
    719         // update G
    720         delta_alpha_i = alpha_i - old_alpha_i;
    721         delta_alpha_j = alpha_j - old_alpha_j;
    722 
    723         for( k = 0; k < alpha_count; k++ )
    724             G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
    725     }
    726 
    727     // calculate rho
    728     (this->*calc_rho_func)( si.rho, si.r );
    729 
    730     // calculate objective value
    731     for( i = 0, si.obj = 0; i < alpha_count; i++ )
    732         si.obj += alpha[i] * (G[i] + b[i]);
    733 
    734     si.obj *= 0.5;
    735 
    736     si.upper_bound_p = C[1];
    737     si.upper_bound_n = C[0];
    738 
    739     return true;
    740 }
    741 
    742 
    743 // return 1 if already optimal, return 0 otherwise
    744 bool
    745 CvSVMSolver::select_working_set( int& out_i, int& out_j )
    746 {
    747     // return i,j which maximize -grad(f)^T d , under constraint
    748     // if alpha_i == C, d != +1
    749     // if alpha_i == 0, d != -1
    750     double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
    751     int Gmax1_idx = -1;
    752 
    753     double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
    754     int Gmax2_idx = -1;
    755 
    756     int i;
    757 
    758     for( i = 0; i < alpha_count; i++ )
    759     {
    760         double t;
    761 
    762         if( y[i] > 0 )    // y = +1
    763         {
    764             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
    765             {
    766                 Gmax1 = t;
    767                 Gmax1_idx = i;
    768             }
    769             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
    770             {
    771                 Gmax2 = t;
    772                 Gmax2_idx = i;
    773             }
    774         }
    775         else        // y = -1
    776         {
    777             if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
    778             {
    779                 Gmax2 = t;
    780                 Gmax2_idx = i;
    781             }
    782             if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
    783             {
    784                 Gmax1 = t;
    785                 Gmax1_idx = i;
    786             }
    787         }
    788     }
    789 
    790     out_i = Gmax1_idx;
    791     out_j = Gmax2_idx;
    792 
    793     return Gmax1 + Gmax2 < eps;
    794 }
    795 
    796 
    797 void
    798 CvSVMSolver::calc_rho( double& rho, double& r )
    799 {
    800     int i, nr_free = 0;
    801     double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
    802 
    803     for( i = 0; i < alpha_count; i++ )
    804     {
    805         double yG = y[i]*G[i];
    806 
    807         if( is_lower_bound(i) )
    808         {
    809             if( y[i] > 0 )
    810                 ub = MIN(ub,yG);
    811             else
    812                 lb = MAX(lb,yG);
    813         }
    814         else if( is_upper_bound(i) )
    815         {
    816             if( y[i] < 0)
    817                 ub = MIN(ub,yG);
    818             else
    819                 lb = MAX(lb,yG);
    820         }
    821         else
    822         {
    823             ++nr_free;
    824             sum_free += yG;
    825         }
    826     }
    827 
    828     rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
    829     r = 0;
    830 }
    831 
    832 
    833 bool
    834 CvSVMSolver::select_working_set_nu_svm( int& out_i, int& out_j )
    835 {
    836     // return i,j which maximize -grad(f)^T d , under constraint
    837     // if alpha_i == C, d != +1
    838     // if alpha_i == 0, d != -1
    839     double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
    840     int Gmax1_idx = -1;
    841 
    842     double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
    843     int Gmax2_idx = -1;
    844 
    845     double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
    846     int Gmax3_idx = -1;
    847 
    848     double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
    849     int Gmax4_idx = -1;
    850 
    851     int i;
    852 
    853     for( i = 0; i < alpha_count; i++ )
    854     {
    855         double t;
    856 
    857         if( y[i] > 0 )    // y == +1
    858         {
    859             if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
    860             {
    861                 Gmax1 = t;
    862                 Gmax1_idx = i;
    863             }
    864             if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
    865             {
    866                 Gmax2 = t;
    867                 Gmax2_idx = i;
    868             }
    869         }
    870         else        // y == -1
    871         {
    872             if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
    873             {
    874                 Gmax3 = t;
    875                 Gmax3_idx = i;
    876             }
    877             if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
    878             {
    879                 Gmax4 = t;
    880                 Gmax4_idx = i;
    881             }
    882         }
    883     }
    884 
    885     if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
    886         return 1;
    887 
    888     if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
    889     {
    890         out_i = Gmax1_idx;
    891         out_j = Gmax2_idx;
    892     }
    893     else
    894     {
    895         out_i = Gmax3_idx;
    896         out_j = Gmax4_idx;
    897     }
    898     return 0;
    899 }
    900 
    901 
    902 void
    903 CvSVMSolver::calc_rho_nu_svm( double& rho, double& r )
    904 {
    905     int nr_free1 = 0, nr_free2 = 0;
    906     double ub1 = DBL_MAX, ub2 = DBL_MAX;
    907     double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
    908     double sum_free1 = 0, sum_free2 = 0;
    909     double r1, r2;
    910 
    911     int i;
    912 
    913     for( i = 0; i < alpha_count; i++ )
    914     {
    915         double G_i = G[i];
    916         if( y[i] > 0 )
    917         {
    918             if( is_lower_bound(i) )
    919                 ub1 = MIN( ub1, G_i );
    920             else if( is_upper_bound(i) )
    921                 lb1 = MAX( lb1, G_i );
    922             else
    923             {
    924                 ++nr_free1;
    925                 sum_free1 += G_i;
    926             }
    927         }
    928         else
    929         {
    930             if( is_lower_bound(i) )
    931                 ub2 = MIN( ub2, G_i );
    932             else if( is_upper_bound(i) )
    933                 lb2 = MAX( lb2, G_i );
    934             else
    935             {
    936                 ++nr_free2;
    937                 sum_free2 += G_i;
    938             }
    939         }
    940     }
    941 
    942     r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
    943     r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
    944 
    945     rho = (r1 - r2)*0.5;
    946     r = (r1 + r2)*0.5;
    947 }
    948 
    949 
    950 /*
    951 ///////////////////////// construct and solve various formulations ///////////////////////
    952 */
    953 
    954 bool CvSVMSolver::solve_c_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
    955                                double _Cp, double _Cn, CvMemStorage* _storage,
    956                                CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
    957 {
    958     int i;
    959 
    960     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
    961                  _alpha, _Cp, _Cn, _storage, _kernel, &CvSVMSolver::get_row_svc,
    962                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
    963         return false;
    964 
    965     for( i = 0; i < sample_count; i++ )
    966     {
    967         alpha[i] = 0;
    968         b[i] = -1;
    969     }
    970 
    971     if( !solve_generic( _si ))
    972         return false;
    973 
    974     for( i = 0; i < sample_count; i++ )
    975         alpha[i] *= y[i];
    976 
    977     return true;
    978 }
    979 
    980 
    981 bool CvSVMSolver::solve_nu_svc( int _sample_count, int _var_count, const float** _samples, schar* _y,
    982                                 CvMemStorage* _storage, CvSVMKernel* _kernel,
    983                                 double* _alpha, CvSVMSolutionInfo& _si )
    984 {
    985     int i;
    986     double sum_pos, sum_neg, inv_r;
    987 
    988     if( !create( _sample_count, _var_count, _samples, _y, _sample_count,
    989                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svc,
    990                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
    991         return false;
    992 
    993     sum_pos = kernel->params->nu * sample_count * 0.5;
    994     sum_neg = kernel->params->nu * sample_count * 0.5;
    995 
    996     for( i = 0; i < sample_count; i++ )
    997     {
    998         if( y[i] > 0 )
    999         {
   1000             alpha[i] = MIN(1.0, sum_pos);
   1001             sum_pos -= alpha[i];
   1002         }
   1003         else
   1004         {
   1005             alpha[i] = MIN(1.0, sum_neg);
   1006             sum_neg -= alpha[i];
   1007         }
   1008         b[i] = 0;
   1009     }
   1010 
   1011     if( !solve_generic( _si ))
   1012         return false;
   1013 
   1014     inv_r = 1./_si.r;
   1015 
   1016     for( i = 0; i < sample_count; i++ )
   1017         alpha[i] *= y[i]*inv_r;
   1018 
   1019     _si.rho *= inv_r;
   1020     _si.obj *= (inv_r*inv_r);
   1021     _si.upper_bound_p = inv_r;
   1022     _si.upper_bound_n = inv_r;
   1023 
   1024     return true;
   1025 }
   1026 
   1027 
   1028 bool CvSVMSolver::solve_one_class( int _sample_count, int _var_count, const float** _samples,
   1029                                    CvMemStorage* _storage, CvSVMKernel* _kernel,
   1030                                    double* _alpha, CvSVMSolutionInfo& _si )
   1031 {
   1032     int i, n;
   1033     double nu = _kernel->params->nu;
   1034 
   1035     if( !create( _sample_count, _var_count, _samples, 0, _sample_count,
   1036                  _alpha, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_one_class,
   1037                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
   1038         return false;
   1039 
   1040     y = (schar*)cvMemStorageAlloc( storage, sample_count*sizeof(y[0]) );
   1041     n = cvRound( nu*sample_count );
   1042 
   1043     for( i = 0; i < sample_count; i++ )
   1044     {
   1045         y[i] = 1;
   1046         b[i] = 0;
   1047         alpha[i] = i < n ? 1 : 0;
   1048     }
   1049 
   1050     if( n < sample_count )
   1051         alpha[n] = nu * sample_count - n;
   1052     else
   1053         alpha[n-1] = nu * sample_count - (n-1);
   1054 
   1055     return solve_generic(_si);
   1056 }
   1057 
   1058 
   1059 bool CvSVMSolver::solve_eps_svr( int _sample_count, int _var_count, const float** _samples,
   1060                                  const float* _y, CvMemStorage* _storage,
   1061                                  CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
   1062 {
   1063     int i;
   1064     double p = _kernel->params->p, C = _kernel->params->C;
   1065 
   1066     if( !create( _sample_count, _var_count, _samples, 0,
   1067                  _sample_count*2, 0, C, C, _storage, _kernel, &CvSVMSolver::get_row_svr,
   1068                  &CvSVMSolver::select_working_set, &CvSVMSolver::calc_rho ))
   1069         return false;
   1070 
   1071     y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
   1072     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
   1073 
   1074     for( i = 0; i < sample_count; i++ )
   1075     {
   1076         alpha[i] = 0;
   1077         b[i] = p - _y[i];
   1078         y[i] = 1;
   1079 
   1080         alpha[i+sample_count] = 0;
   1081         b[i+sample_count] = p + _y[i];
   1082         y[i+sample_count] = -1;
   1083     }
   1084 
   1085     if( !solve_generic( _si ))
   1086         return false;
   1087 
   1088     for( i = 0; i < sample_count; i++ )
   1089         _alpha[i] = alpha[i] - alpha[i+sample_count];
   1090 
   1091     return true;
   1092 }
   1093 
   1094 
   1095 bool CvSVMSolver::solve_nu_svr( int _sample_count, int _var_count, const float** _samples,
   1096                                 const float* _y, CvMemStorage* _storage,
   1097                                 CvSVMKernel* _kernel, double* _alpha, CvSVMSolutionInfo& _si )
   1098 {
   1099     int i;
   1100     double C = _kernel->params->C, sum;
   1101 
   1102     if( !create( _sample_count, _var_count, _samples, 0,
   1103                  _sample_count*2, 0, 1., 1., _storage, _kernel, &CvSVMSolver::get_row_svr,
   1104                  &CvSVMSolver::select_working_set_nu_svm, &CvSVMSolver::calc_rho_nu_svm ))
   1105         return false;
   1106 
   1107     y = (schar*)cvMemStorageAlloc( storage, sample_count*2*sizeof(y[0]) );
   1108     alpha = (double*)cvMemStorageAlloc( storage, alpha_count*sizeof(alpha[0]) );
   1109     sum = C * _kernel->params->nu * sample_count * 0.5;
   1110 
   1111     for( i = 0; i < sample_count; i++ )
   1112     {
   1113         alpha[i] = alpha[i + sample_count] = MIN(sum, C);
   1114         sum -= alpha[i];
   1115 
   1116         b[i] = -_y[i];
   1117         y[i] = 1;
   1118 
   1119         b[i + sample_count] = _y[i];
   1120         y[i + sample_count] = -1;
   1121     }
   1122 
   1123     if( !solve_generic( _si ))
   1124         return false;
   1125 
   1126     for( i = 0; i < sample_count; i++ )
   1127         _alpha[i] = alpha[i] - alpha[i+sample_count];
   1128 
   1129     return true;
   1130 }
   1131 
   1132 
   1133 //////////////////////////////////////////////////////////////////////////////////////////
   1134 
   1135 CvSVM::CvSVM()
   1136 {
   1137     decision_func = 0;
   1138     class_labels = 0;
   1139     class_weights = 0;
   1140     storage = 0;
   1141     var_idx = 0;
   1142     kernel = 0;
   1143     solver = 0;
   1144     default_model_name = "my_svm";
   1145 
   1146     clear();
   1147 }
   1148 
   1149 
   1150 CvSVM::~CvSVM()
   1151 {
   1152     clear();
   1153 }
   1154 
   1155 
   1156 void CvSVM::clear()
   1157 {
   1158     cvFree( &decision_func );
   1159     cvReleaseMat( &class_labels );
   1160     cvReleaseMat( &class_weights );
   1161     cvReleaseMemStorage( &storage );
   1162     cvReleaseMat( &var_idx );
   1163     delete kernel;
   1164     delete solver;
   1165     kernel = 0;
   1166     solver = 0;
   1167     var_all = 0;
   1168     sv = 0;
   1169     sv_total = 0;
   1170 }
   1171 
   1172 
   1173 CvSVM::CvSVM( const CvMat* _train_data, const CvMat* _responses,
   1174     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
   1175 {
   1176     decision_func = 0;
   1177     class_labels = 0;
   1178     class_weights = 0;
   1179     storage = 0;
   1180     var_idx = 0;
   1181     kernel = 0;
   1182     solver = 0;
   1183     default_model_name = "my_svm";
   1184 
   1185     train( _train_data, _responses, _var_idx, _sample_idx, _params );
   1186 }
   1187 
   1188 
   1189 int CvSVM::get_support_vector_count() const
   1190 {
   1191     return sv_total;
   1192 }
   1193 
   1194 
   1195 const float* CvSVM::get_support_vector(int i) const
   1196 {
   1197     return sv && (unsigned)i < (unsigned)sv_total ? sv[i] : 0;
   1198 }
   1199 
   1200 
   1201 bool CvSVM::set_params( const CvSVMParams& _params )
   1202 {
   1203     bool ok = false;
   1204 
   1205     CV_FUNCNAME( "CvSVM::set_params" );
   1206 
   1207     __BEGIN__;
   1208 
   1209     int kernel_type, svm_type;
   1210 
   1211     params = _params;
   1212 
   1213     kernel_type = params.kernel_type;
   1214     svm_type = params.svm_type;
   1215 
   1216     if( kernel_type != LINEAR && kernel_type != POLY &&
   1217         kernel_type != SIGMOID && kernel_type != RBF )
   1218         CV_ERROR( CV_StsBadArg, "Unknown/unsupported kernel type" );
   1219 
   1220     if( kernel_type == LINEAR )
   1221         params.gamma = 1;
   1222     else if( params.gamma <= 0 )
   1223         CV_ERROR( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
   1224 
   1225     if( kernel_type != SIGMOID && kernel_type != POLY )
   1226         params.coef0 = 0;
   1227     else if( params.coef0 < 0 )
   1228         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
   1229 
   1230     if( kernel_type != POLY )
   1231         params.degree = 0;
   1232     else if( params.degree <= 0 )
   1233         CV_ERROR( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
   1234 
   1235     if( svm_type != C_SVC && svm_type != NU_SVC &&
   1236         svm_type != ONE_CLASS && svm_type != EPS_SVR &&
   1237         svm_type != NU_SVR )
   1238         CV_ERROR( CV_StsBadArg, "Unknown/unsupported SVM type" );
   1239 
   1240     if( svm_type == ONE_CLASS || svm_type == NU_SVC )
   1241         params.C = 0;
   1242     else if( params.C <= 0 )
   1243         CV_ERROR( CV_StsOutOfRange, "The parameter C must be positive" );
   1244 
   1245     if( svm_type == C_SVC || svm_type == EPS_SVR )
   1246         params.nu = 0;
   1247     else if( params.nu <= 0 || params.nu >= 1 )
   1248         CV_ERROR( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
   1249 
   1250     if( svm_type != EPS_SVR )
   1251         params.p = 0;
   1252     else if( params.p <= 0 )
   1253         CV_ERROR( CV_StsOutOfRange, "The parameter p must be positive" );
   1254 
   1255     if( svm_type != C_SVC )
   1256         params.class_weights = 0;
   1257 
   1258     params.term_crit = cvCheckTermCriteria( params.term_crit, DBL_EPSILON, INT_MAX );
   1259     params.term_crit.epsilon = MAX( params.term_crit.epsilon, DBL_EPSILON );
   1260     ok = true;
   1261 
   1262     __END__;
   1263 
   1264     return ok;
   1265 }
   1266 
   1267 
   1268 
   1269 void CvSVM::create_kernel()
   1270 {
   1271     kernel = new CvSVMKernel(&params,0);
   1272 }
   1273 
   1274 
   1275 void CvSVM::create_solver( )
   1276 {
   1277     solver = new CvSVMSolver;
   1278 }
   1279 
   1280 
   1281 // switching function
   1282 bool CvSVM::train1( int sample_count, int var_count, const float** samples,
   1283                     const void* _responses, double Cp, double Cn,
   1284                     CvMemStorage* _storage, double* alpha, double& rho )
   1285 {
   1286     bool ok = false;
   1287 
   1288     //CV_FUNCNAME( "CvSVM::train1" );
   1289 
   1290     __BEGIN__;
   1291 
   1292     CvSVMSolutionInfo si;
   1293     int svm_type = params.svm_type;
   1294 
   1295     si.rho = 0;
   1296 
   1297     ok = svm_type == C_SVC ? solver->solve_c_svc( sample_count, var_count, samples, (schar*)_responses,
   1298                                                   Cp, Cn, _storage, kernel, alpha, si ) :
   1299          svm_type == NU_SVC ? solver->solve_nu_svc( sample_count, var_count, samples, (schar*)_responses,
   1300                                                     _storage, kernel, alpha, si ) :
   1301          svm_type == ONE_CLASS ? solver->solve_one_class( sample_count, var_count, samples,
   1302                                                           _storage, kernel, alpha, si ) :
   1303          svm_type == EPS_SVR ? solver->solve_eps_svr( sample_count, var_count, samples, (float*)_responses,
   1304                                                       _storage, kernel, alpha, si ) :
   1305          svm_type == NU_SVR ? solver->solve_nu_svr( sample_count, var_count, samples, (float*)_responses,
   1306                                                     _storage, kernel, alpha, si ) : false;
   1307 
   1308     rho = si.rho;
   1309 
   1310     __END__;
   1311 
   1312     return ok;
   1313 }
   1314 
   1315 
   1316 bool CvSVM::do_train( int svm_type, int sample_count, int var_count, const float** samples,
   1317                     const CvMat* responses, CvMemStorage* temp_storage, double* alpha )
   1318 {
   1319     bool ok = false;
   1320 
   1321     CV_FUNCNAME( "CvSVM::do_train" );
   1322 
   1323     __BEGIN__;
   1324 
   1325     CvSVMDecisionFunc* df = 0;
   1326     const int sample_size = var_count*sizeof(samples[0][0]);
   1327     int i, j, k;
   1328 
   1329     if( svm_type == ONE_CLASS || svm_type == EPS_SVR || svm_type == NU_SVR )
   1330     {
   1331         int sv_count = 0;
   1332 
   1333         CV_CALL( decision_func = df =
   1334             (CvSVMDecisionFunc*)cvAlloc( sizeof(df[0]) ));
   1335 
   1336         df->rho = 0;
   1337         if( !train1( sample_count, var_count, samples, svm_type == ONE_CLASS ? 0 :
   1338             responses->data.i, 0, 0, temp_storage, alpha, df->rho ))
   1339             EXIT;
   1340 
   1341         for( i = 0; i < sample_count; i++ )
   1342             sv_count += fabs(alpha[i]) > 0;
   1343 
   1344         sv_total = df->sv_count = sv_count;
   1345         CV_CALL( df->alpha = (double*)cvMemStorageAlloc( storage, sv_count*sizeof(df->alpha[0])) );
   1346         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_count*sizeof(sv[0])));
   1347 
   1348         for( i = k = 0; i < sample_count; i++ )
   1349         {
   1350             if( fabs(alpha[i]) > 0 )
   1351             {
   1352                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
   1353                 memcpy( sv[k], samples[i], sample_size );
   1354                 df->alpha[k++] = alpha[i];
   1355             }
   1356         }
   1357     }
   1358     else
   1359     {
   1360         int class_count = class_labels->cols;
   1361         int* sv_tab = 0;
   1362         const float** temp_samples = 0;
   1363         int* class_ranges = 0;
   1364         schar* temp_y = 0;
   1365         assert( svm_type == CvSVM::C_SVC || svm_type == CvSVM::NU_SVC );
   1366 
   1367         if( svm_type == CvSVM::C_SVC && params.class_weights )
   1368         {
   1369             const CvMat* cw = params.class_weights;
   1370 
   1371             if( !CV_IS_MAT(cw) || cw->cols != 1 && cw->rows != 1 ||
   1372                 cw->rows + cw->cols - 1 != class_count ||
   1373                 CV_MAT_TYPE(cw->type) != CV_32FC1 && CV_MAT_TYPE(cw->type) != CV_64FC1 )
   1374                 CV_ERROR( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
   1375                     "containing as many elements as the number of classes" );
   1376 
   1377             CV_CALL( class_weights = cvCreateMat( cw->rows, cw->cols, CV_64F ));
   1378             CV_CALL( cvConvert( cw, class_weights ));
   1379             CV_CALL( cvScale( class_weights, class_weights, params.C ));
   1380         }
   1381 
   1382         CV_CALL( decision_func = df = (CvSVMDecisionFunc*)cvAlloc(
   1383             (class_count*(class_count-1)/2)*sizeof(df[0])));
   1384 
   1385         CV_CALL( sv_tab = (int*)cvMemStorageAlloc( temp_storage, sample_count*sizeof(sv_tab[0]) ));
   1386         memset( sv_tab, 0, sample_count*sizeof(sv_tab[0]) );
   1387         CV_CALL( class_ranges = (int*)cvMemStorageAlloc( temp_storage,
   1388                             (class_count + 1)*sizeof(class_ranges[0])));
   1389         CV_CALL( temp_samples = (const float**)cvMemStorageAlloc( temp_storage,
   1390                             sample_count*sizeof(temp_samples[0])));
   1391         CV_CALL( temp_y = (schar*)cvMemStorageAlloc( temp_storage, sample_count));
   1392 
   1393         class_ranges[class_count] = 0;
   1394         cvSortSamplesByClasses( samples, responses, class_ranges, 0 );
   1395         //check that while cross-validation there were the samples from all the classes
   1396         if( class_ranges[class_count] <= 0 )
   1397             CV_ERROR( CV_StsBadArg, "While cross-validation one or more of the classes have "
   1398             "been fell out of the sample. Try to enlarge <CvSVMParams::k_fold>" );
   1399 
   1400         if( svm_type == NU_SVC )
   1401         {
   1402             // check if nu is feasible
   1403             for(i = 0; i < class_count; i++ )
   1404             {
   1405                 int ci = class_ranges[i+1] - class_ranges[i];
   1406                 for( j = i+1; j< class_count; j++ )
   1407                 {
   1408                     int cj = class_ranges[j+1] - class_ranges[j];
   1409                     if( params.nu*(ci + cj)*0.5 > MIN( ci, cj ) )
   1410                     {
   1411                         // !!!TODO!!! add some diagnostic
   1412                         EXIT; // exit immediately; will release the model and return NULL pointer
   1413                     }
   1414                 }
   1415             }
   1416         }
   1417 
   1418         // train n*(n-1)/2 classifiers
   1419         for( i = 0; i < class_count; i++ )
   1420         {
   1421             for( j = i+1; j < class_count; j++, df++ )
   1422             {
   1423                 int si = class_ranges[i], ci = class_ranges[i+1] - si;
   1424                 int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
   1425                 double Cp = params.C, Cn = Cp;
   1426                 int k1 = 0, sv_count = 0;
   1427 
   1428                 for( k = 0; k < ci; k++ )
   1429                 {
   1430                     temp_samples[k] = samples[si + k];
   1431                     temp_y[k] = 1;
   1432                 }
   1433 
   1434                 for( k = 0; k < cj; k++ )
   1435                 {
   1436                     temp_samples[ci + k] = samples[sj + k];
   1437                     temp_y[ci + k] = -1;
   1438                 }
   1439 
   1440                 if( class_weights )
   1441                 {
   1442                     Cp = class_weights->data.db[i];
   1443                     Cn = class_weights->data.db[j];
   1444                 }
   1445 
   1446                 if( !train1( ci + cj, var_count, temp_samples, temp_y,
   1447                              Cp, Cn, temp_storage, alpha, df->rho ))
   1448                     EXIT;
   1449 
   1450                 for( k = 0; k < ci + cj; k++ )
   1451                     sv_count += fabs(alpha[k]) > 0;
   1452 
   1453                 df->sv_count = sv_count;
   1454 
   1455                 CV_CALL( df->alpha = (double*)cvMemStorageAlloc( temp_storage,
   1456                                                 sv_count*sizeof(df->alpha[0])));
   1457                 CV_CALL( df->sv_index = (int*)cvMemStorageAlloc( temp_storage,
   1458                                                 sv_count*sizeof(df->sv_index[0])));
   1459 
   1460                 for( k = 0; k < ci; k++ )
   1461                 {
   1462                     if( fabs(alpha[k]) > 0 )
   1463                     {
   1464                         sv_tab[si + k] = 1;
   1465                         df->sv_index[k1] = si + k;
   1466                         df->alpha[k1++] = alpha[k];
   1467                     }
   1468                 }
   1469 
   1470                 for( k = 0; k < cj; k++ )
   1471                 {
   1472                     if( fabs(alpha[ci + k]) > 0 )
   1473                     {
   1474                         sv_tab[sj + k] = 1;
   1475                         df->sv_index[k1] = sj + k;
   1476                         df->alpha[k1++] = alpha[ci + k];
   1477                     }
   1478                 }
   1479             }
   1480         }
   1481 
   1482         // allocate support vectors and initialize sv_tab
   1483         for( i = 0, k = 0; i < sample_count; i++ )
   1484         {
   1485             if( sv_tab[i] )
   1486                 sv_tab[i] = ++k;
   1487         }
   1488 
   1489         sv_total = k;
   1490         CV_CALL( sv = (float**)cvMemStorageAlloc( storage, sv_total*sizeof(sv[0])));
   1491 
   1492         for( i = 0, k = 0; i < sample_count; i++ )
   1493         {
   1494             if( sv_tab[i] )
   1495             {
   1496                 CV_CALL( sv[k] = (float*)cvMemStorageAlloc( storage, sample_size ));
   1497                 memcpy( sv[k], samples[i], sample_size );
   1498                 k++;
   1499             }
   1500         }
   1501 
   1502         df = (CvSVMDecisionFunc*)decision_func;
   1503 
   1504         // set sv pointers
   1505         for( i = 0; i < class_count; i++ )
   1506         {
   1507             for( j = i+1; j < class_count; j++, df++ )
   1508             {
   1509                 for( k = 0; k < df->sv_count; k++ )
   1510                 {
   1511                     df->sv_index[k] = sv_tab[df->sv_index[k]]-1;
   1512                     assert( (unsigned)df->sv_index[k] < (unsigned)sv_total );
   1513                 }
   1514             }
   1515         }
   1516     }
   1517 
   1518     ok = true;
   1519 
   1520     __END__;
   1521 
   1522     return ok;
   1523 }
   1524 
   1525 bool CvSVM::train( const CvMat* _train_data, const CvMat* _responses,
   1526     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params )
   1527 {
   1528     bool ok = false;
   1529     CvMat* responses = 0;
   1530     CvMemStorage* temp_storage = 0;
   1531     const float** samples = 0;
   1532 
   1533     CV_FUNCNAME( "CvSVM::train" );
   1534 
   1535     __BEGIN__;
   1536 
   1537     int svm_type, sample_count, var_count, sample_size;
   1538     int block_size = 1 << 16;
   1539     double* alpha;
   1540 
   1541     clear();
   1542     CV_CALL( set_params( _params ));
   1543 
   1544     svm_type = _params.svm_type;
   1545 
   1546     /* Prepare training data and related parameters */
   1547     CV_CALL( cvPrepareTrainData( "CvSVM::train", _train_data, CV_ROW_SAMPLE,
   1548                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
   1549                                  svm_type == CvSVM::C_SVC ||
   1550                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
   1551                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
   1552                                  false, &samples, &sample_count, &var_count, &var_all,
   1553                                  &responses, &class_labels, &var_idx ));
   1554 
   1555 
   1556     sample_size = var_count*sizeof(samples[0][0]);
   1557 
   1558     // make the storage block size large enough to fit all
   1559     // the temporary vectors and output support vectors.
   1560     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
   1561     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
   1562     block_size = MAX( block_size, sample_size*2 + 1024 );
   1563 
   1564     CV_CALL( storage = cvCreateMemStorage(block_size));
   1565     CV_CALL( temp_storage = cvCreateChildMemStorage(storage));
   1566     CV_CALL( alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
   1567 
   1568     create_kernel();
   1569     create_solver();
   1570 
   1571     if( !do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ))
   1572         EXIT;
   1573 
   1574     ok = true; // model has been trained succesfully
   1575 
   1576     __END__;
   1577 
   1578     delete solver;
   1579     solver = 0;
   1580     cvReleaseMemStorage( &temp_storage );
   1581     cvReleaseMat( &responses );
   1582     cvFree( &samples );
   1583 
   1584     if( cvGetErrStatus() < 0 || !ok )
   1585         clear();
   1586 
   1587     return ok;
   1588 }
   1589 
   1590 bool CvSVM::train_auto( const CvMat* _train_data, const CvMat* _responses,
   1591     const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, int k_fold,
   1592     CvParamGrid C_grid, CvParamGrid gamma_grid, CvParamGrid p_grid,
   1593     CvParamGrid nu_grid, CvParamGrid coef_grid, CvParamGrid degree_grid )
   1594 {
   1595     bool ok = false;
   1596     CvMat* responses = 0;
   1597     CvMat* responses_local = 0;
   1598     CvMemStorage* temp_storage = 0;
   1599     const float** samples = 0;
   1600     const float** samples_local = 0;
   1601 
   1602     CV_FUNCNAME( "CvSVM::train_auto" );
   1603     __BEGIN__;
   1604 
   1605     int svm_type, sample_count, var_count, sample_size;
   1606     int block_size = 1 << 16;
   1607     double* alpha;
   1608     int i, k;
   1609     CvRNG rng = cvRNG(-1);
   1610 
   1611     // all steps are logarithmic and must be > 1
   1612     double degree_step = 10, g_step = 10, coef_step = 10, C_step = 10, nu_step = 10, p_step = 10;
   1613     double gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
   1614     double best_degree = 0, best_gamma = 0, best_coef = 0, best_C = 0, best_nu = 0, best_p = 0;
   1615     float min_error = FLT_MAX, error;
   1616 
   1617     if( _params.svm_type == CvSVM::ONE_CLASS )
   1618     {
   1619         if(!train( _train_data, _responses, _var_idx, _sample_idx, _params ))
   1620             EXIT;
   1621         return true;
   1622     }
   1623 
   1624     clear();
   1625 
   1626     if( k_fold < 2 )
   1627         CV_ERROR( CV_StsBadArg, "Parameter <k_fold> must be > 1" );
   1628 
   1629     CV_CALL(set_params( _params ));
   1630     svm_type = _params.svm_type;
   1631 
   1632     // All the parameters except, possibly, <coef0> are positive.
   1633     // <coef0> is nonnegative
   1634     if( C_grid.step <= 1 )
   1635     {
   1636         C_grid.min_val = C_grid.max_val = params.C;
   1637         C_grid.step = 10;
   1638     }
   1639     else
   1640         CV_CALL(C_grid.check());
   1641 
   1642     if( gamma_grid.step <= 1 )
   1643     {
   1644         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
   1645         gamma_grid.step = 10;
   1646     }
   1647     else
   1648         CV_CALL(gamma_grid.check());
   1649 
   1650     if( p_grid.step <= 1 )
   1651     {
   1652         p_grid.min_val = p_grid.max_val = params.p;
   1653         p_grid.step = 10;
   1654     }
   1655     else
   1656         CV_CALL(p_grid.check());
   1657 
   1658     if( nu_grid.step <= 1 )
   1659     {
   1660         nu_grid.min_val = nu_grid.max_val = params.nu;
   1661         nu_grid.step = 10;
   1662     }
   1663     else
   1664         CV_CALL(nu_grid.check());
   1665 
   1666     if( coef_grid.step <= 1 )
   1667     {
   1668         coef_grid.min_val = coef_grid.max_val = params.coef0;
   1669         coef_grid.step = 10;
   1670     }
   1671     else
   1672         CV_CALL(coef_grid.check());
   1673 
   1674     if( degree_grid.step <= 1 )
   1675     {
   1676         degree_grid.min_val = degree_grid.max_val = params.degree;
   1677         degree_grid.step = 10;
   1678     }
   1679     else
   1680         CV_CALL(degree_grid.check());
   1681 
   1682     // these parameters are not used:
   1683     if( params.kernel_type != CvSVM::POLY )
   1684         degree_grid.min_val = degree_grid.max_val = params.degree;
   1685     if( params.kernel_type == CvSVM::LINEAR )
   1686         gamma_grid.min_val = gamma_grid.max_val = params.gamma;
   1687     if( params.kernel_type != CvSVM::POLY && params.kernel_type != CvSVM::SIGMOID )
   1688         coef_grid.min_val = coef_grid.max_val = params.coef0;
   1689     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
   1690         C_grid.min_val = C_grid.max_val = params.C;
   1691     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
   1692         nu_grid.min_val = nu_grid.max_val = params.nu;
   1693     if( svm_type != CvSVM::EPS_SVR )
   1694         p_grid.min_val = p_grid.max_val = params.p;
   1695 
   1696     CV_ASSERT( g_step > 1 && degree_step > 1 && coef_step > 1);
   1697     CV_ASSERT( p_step > 1 && C_step > 1 && nu_step > 1 );
   1698 
   1699     /* Prepare training data and related parameters */
   1700     CV_CALL(cvPrepareTrainData( "CvSVM::train_auto", _train_data, CV_ROW_SAMPLE,
   1701                                  svm_type != CvSVM::ONE_CLASS ? _responses : 0,
   1702                                  svm_type == CvSVM::C_SVC ||
   1703                                  svm_type == CvSVM::NU_SVC ? CV_VAR_CATEGORICAL :
   1704                                  CV_VAR_ORDERED, _var_idx, _sample_idx,
   1705                                  false, &samples, &sample_count, &var_count, &var_all,
   1706                                  &responses, &class_labels, &var_idx ));
   1707 
   1708     sample_size = var_count*sizeof(samples[0][0]);
   1709 
   1710     // make the storage block size large enough to fit all
   1711     // the temporary vectors and output support vectors.
   1712     block_size = MAX( block_size, sample_count*(int)sizeof(CvSVMKernelRow));
   1713     block_size = MAX( block_size, sample_count*2*(int)sizeof(double) + 1024 );
   1714     block_size = MAX( block_size, sample_size*2 + 1024 );
   1715 
   1716     CV_CALL(storage = cvCreateMemStorage(block_size));
   1717     CV_CALL(temp_storage = cvCreateChildMemStorage(storage));
   1718     CV_CALL(alpha = (double*)cvMemStorageAlloc(temp_storage, sample_count*sizeof(double)));
   1719 
   1720     create_kernel();
   1721     create_solver();
   1722 
   1723     {
   1724     const int testset_size = sample_count/k_fold;
   1725     const int trainset_size = sample_count - testset_size;
   1726     const int last_testset_size = sample_count - testset_size*(k_fold-1);
   1727     const int last_trainset_size = sample_count - last_testset_size;
   1728     const bool is_regression = (svm_type == EPS_SVR) || (svm_type == NU_SVR);
   1729 
   1730     size_t resp_elem_size = CV_ELEM_SIZE(responses->type);
   1731     size_t size = 2*last_trainset_size*sizeof(samples[0]);
   1732 
   1733     samples_local = (const float**) cvAlloc( size );
   1734     memset( samples_local, 0, size );
   1735 
   1736     responses_local = cvCreateMat( 1, trainset_size, CV_MAT_TYPE(responses->type) );
   1737     cvZero( responses_local );
   1738 
   1739     // randomly permute samples and responses
   1740     for( i = 0; i < sample_count; i++ )
   1741     {
   1742         int i1 = cvRandInt( &rng ) % sample_count;
   1743         int i2 = cvRandInt( &rng ) % sample_count;
   1744         const float* temp;
   1745         float t;
   1746         int y;
   1747 
   1748         CV_SWAP( samples[i1], samples[i2], temp );
   1749         if( is_regression )
   1750             CV_SWAP( responses->data.fl[i1], responses->data.fl[i2], t );
   1751         else
   1752             CV_SWAP( responses->data.i[i1], responses->data.i[i2], y );
   1753     }
   1754 
   1755     C = C_grid.min_val;
   1756     do
   1757     {
   1758       params.C = C;
   1759       gamma = gamma_grid.min_val;
   1760       do
   1761       {
   1762         params.gamma = gamma;
   1763         p = p_grid.min_val;
   1764         do
   1765         {
   1766           params.p = p;
   1767           nu = nu_grid.min_val;
   1768           do
   1769           {
   1770             params.nu = nu;
   1771             coef = coef_grid.min_val;
   1772             do
   1773             {
   1774               params.coef0 = coef;
   1775               degree = degree_grid.min_val;
   1776               do
   1777               {
   1778                 params.degree = degree;
   1779 
   1780                 float** test_samples_ptr = (float**)samples;
   1781                 uchar* true_resp = responses->data.ptr;
   1782                 int test_size = testset_size;
   1783                 int train_size = trainset_size;
   1784 
   1785                 error = 0;
   1786                 for( k = 0; k < k_fold; k++ )
   1787                 {
   1788                     memcpy( samples_local, samples, sizeof(samples[0])*test_size*k );
   1789                     memcpy( samples_local + test_size*k, test_samples_ptr + test_size,
   1790                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
   1791 
   1792                     memcpy( responses_local->data.ptr, responses->data.ptr, resp_elem_size*test_size*k );
   1793                     memcpy( responses_local->data.ptr + resp_elem_size*test_size*k,
   1794                         true_resp + resp_elem_size*test_size,
   1795                         sizeof(samples[0])*(sample_count - testset_size*(k+1)) );
   1796 
   1797                     if( k == k_fold - 1 )
   1798                     {
   1799                         test_size = last_testset_size;
   1800                         train_size = last_trainset_size;
   1801                         responses_local->cols = last_trainset_size;
   1802                     }
   1803 
   1804                     // Train SVM on <train_size> samples
   1805                     if( !do_train( svm_type, train_size, var_count,
   1806                         (const float**)samples_local, responses_local, temp_storage, alpha ) )
   1807                         EXIT;
   1808 
   1809                     // Compute test set error on <test_size> samples
   1810                     CvMat s = cvMat( 1, var_count, CV_32FC1 );
   1811                     for( i = 0; i < test_size; i++, true_resp += resp_elem_size, test_samples_ptr++ )
   1812                     {
   1813                         float resp;
   1814                         s.data.fl = *test_samples_ptr;
   1815                         resp = predict( &s );
   1816                         error += is_regression ? powf( resp - *(float*)true_resp, 2 )
   1817                             : ((int)resp != *(int*)true_resp);
   1818                     }
   1819                 }
   1820                 if( min_error > error )
   1821                 {
   1822                     min_error   = error;
   1823                     best_degree = degree;
   1824                     best_gamma  = gamma;
   1825                     best_coef   = coef;
   1826                     best_C      = C;
   1827                     best_nu     = nu;
   1828                     best_p      = p;
   1829                 }
   1830                 degree *= degree_grid.step;
   1831               }
   1832               while( degree < degree_grid.max_val );
   1833               coef *= coef_grid.step;
   1834             }
   1835             while( coef < coef_grid.max_val );
   1836             nu *= nu_grid.step;
   1837           }
   1838           while( nu < nu_grid.max_val );
   1839           p *= p_grid.step;
   1840         }
   1841         while( p < p_grid.max_val );
   1842         gamma *= gamma_grid.step;
   1843       }
   1844       while( gamma < gamma_grid.max_val );
   1845       C *= C_grid.step;
   1846     }
   1847     while( C < C_grid.max_val );
   1848     }
   1849 
   1850     min_error /= (float) sample_count;
   1851 
   1852     params.C      = best_C;
   1853     params.nu     = best_nu;
   1854     params.p      = best_p;
   1855     params.gamma  = best_gamma;
   1856     params.degree = best_degree;
   1857     params.coef0  = best_coef;
   1858 
   1859     CV_CALL(ok = do_train( svm_type, sample_count, var_count, samples, responses, temp_storage, alpha ));
   1860 
   1861     __END__;
   1862 
   1863     delete solver;
   1864     solver = 0;
   1865     cvReleaseMemStorage( &temp_storage );
   1866     cvReleaseMat( &responses );
   1867     cvReleaseMat( &responses_local );
   1868     cvFree( &samples );
   1869     cvFree( &samples_local );
   1870 
   1871     if( cvGetErrStatus() < 0 || !ok )
   1872         clear();
   1873 
   1874     return ok;
   1875 }
   1876 
   1877 float CvSVM::predict( const CvMat* sample ) const
   1878 {
   1879     bool local_alloc = 0;
   1880     float result = 0;
   1881     float* row_sample = 0;
   1882     Qfloat* buffer = 0;
   1883 
   1884     CV_FUNCNAME( "CvSVM::predict" );
   1885 
   1886     __BEGIN__;
   1887 
   1888     int class_count;
   1889     int var_count, buf_sz;
   1890 
   1891     if( !kernel )
   1892         CV_ERROR( CV_StsBadArg, "The SVM should be trained first" );
   1893 
   1894     class_count = class_labels ? class_labels->cols :
   1895                   params.svm_type == ONE_CLASS ? 1 : 0;
   1896 
   1897     CV_CALL( cvPreparePredictData( sample, var_all, var_idx,
   1898                                    class_count, 0, &row_sample ));
   1899 
   1900     var_count = get_var_count();
   1901 
   1902     buf_sz = sv_total*sizeof(buffer[0]) + (class_count+1)*sizeof(int);
   1903     if( buf_sz <= CV_MAX_LOCAL_SIZE )
   1904     {
   1905         CV_CALL( buffer = (Qfloat*)cvStackAlloc( buf_sz ));
   1906         local_alloc = 1;
   1907     }
   1908     else
   1909         CV_CALL( buffer = (Qfloat*)cvAlloc( buf_sz ));
   1910 
   1911     if( params.svm_type == EPS_SVR ||
   1912         params.svm_type == NU_SVR ||
   1913         params.svm_type == ONE_CLASS )
   1914     {
   1915         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
   1916         int i, sv_count = df->sv_count;
   1917         double sum = -df->rho;
   1918 
   1919         kernel->calc( sv_count, var_count, (const float**)sv, row_sample, buffer );
   1920         for( i = 0; i < sv_count; i++ )
   1921             sum += buffer[i]*df->alpha[i];
   1922 
   1923         result = params.svm_type == ONE_CLASS ? (float)(sum > 0) : (float)sum;
   1924     }
   1925     else if( params.svm_type == C_SVC ||
   1926              params.svm_type == NU_SVC )
   1927     {
   1928         CvSVMDecisionFunc* df = (CvSVMDecisionFunc*)decision_func;
   1929         int* vote = (int*)(buffer + sv_total);
   1930         int i, j, k;
   1931 
   1932         memset( vote, 0, class_count*sizeof(vote[0]));
   1933         kernel->calc( sv_total, var_count, (const float**)sv, row_sample, buffer );
   1934 
   1935         for( i = 0; i < class_count; i++ )
   1936         {
   1937             for( j = i+1; j < class_count; j++, df++ )
   1938             {
   1939                 double sum = -df->rho;
   1940                 int sv_count = df->sv_count;
   1941                 for( k = 0; k < sv_count; k++ )
   1942                     sum += df->alpha[k]*buffer[df->sv_index[k]];
   1943 
   1944                 vote[sum > 0 ? i : j]++;
   1945             }
   1946         }
   1947 
   1948         for( i = 1, k = 0; i < class_count; i++ )
   1949         {
   1950             if( vote[i] > vote[k] )
   1951                 k = i;
   1952         }
   1953 
   1954         result = (float)(class_labels->data.i[k]);
   1955     }
   1956     else
   1957         CV_ERROR( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
   1958                                 "the SVM structure is probably corrupted" );
   1959 
   1960     __END__;
   1961 
   1962     if( sample && (!CV_IS_MAT(sample) || sample->data.fl != row_sample) )
   1963         cvFree( &row_sample );
   1964 
   1965     if( !local_alloc )
   1966         cvFree( &buffer );
   1967 
   1968     return result;
   1969 }
   1970 
   1971 
   1972 void CvSVM::write_params( CvFileStorage* fs )
   1973 {
   1974     //CV_FUNCNAME( "CvSVM::write_params" );
   1975 
   1976     __BEGIN__;
   1977 
   1978     int svm_type = params.svm_type;
   1979     int kernel_type = params.kernel_type;
   1980 
   1981     const char* svm_type_str =
   1982         svm_type == CvSVM::C_SVC ? "C_SVC" :
   1983         svm_type == CvSVM::NU_SVC ? "NU_SVC" :
   1984         svm_type == CvSVM::ONE_CLASS ? "ONE_CLASS" :
   1985         svm_type == CvSVM::EPS_SVR ? "EPS_SVR" :
   1986         svm_type == CvSVM::NU_SVR ? "NU_SVR" : 0;
   1987     const char* kernel_type_str =
   1988         kernel_type == CvSVM::LINEAR ? "LINEAR" :
   1989         kernel_type == CvSVM::POLY ? "POLY" :
   1990         kernel_type == CvSVM::RBF ? "RBF" :
   1991         kernel_type == CvSVM::SIGMOID ? "SIGMOID" : 0;
   1992 
   1993     if( svm_type_str )
   1994         cvWriteString( fs, "svm_type", svm_type_str );
   1995     else
   1996         cvWriteInt( fs, "svm_type", svm_type );
   1997 
   1998     // save kernel
   1999     cvStartWriteStruct( fs, "kernel", CV_NODE_MAP + CV_NODE_FLOW );
   2000 
   2001     if( kernel_type_str )
   2002         cvWriteString( fs, "type", kernel_type_str );
   2003     else
   2004         cvWriteInt( fs, "type", kernel_type );
   2005 
   2006     if( kernel_type == CvSVM::POLY || !kernel_type_str )
   2007         cvWriteReal( fs, "degree", params.degree );
   2008 
   2009     if( kernel_type != CvSVM::LINEAR || !kernel_type_str )
   2010         cvWriteReal( fs, "gamma", params.gamma );
   2011 
   2012     if( kernel_type == CvSVM::POLY || kernel_type == CvSVM::SIGMOID || !kernel_type_str )
   2013         cvWriteReal( fs, "coef0", params.coef0 );
   2014 
   2015     cvEndWriteStruct(fs);
   2016 
   2017     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR ||
   2018         svm_type == CvSVM::NU_SVR || !svm_type_str )
   2019         cvWriteReal( fs, "C", params.C );
   2020 
   2021     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS ||
   2022         svm_type == CvSVM::NU_SVR || !svm_type_str )
   2023         cvWriteReal( fs, "nu", params.nu );
   2024 
   2025     if( svm_type == CvSVM::EPS_SVR || !svm_type_str )
   2026         cvWriteReal( fs, "p", params.p );
   2027 
   2028     cvStartWriteStruct( fs, "term_criteria", CV_NODE_MAP + CV_NODE_FLOW );
   2029     if( params.term_crit.type & CV_TERMCRIT_EPS )
   2030         cvWriteReal( fs, "epsilon", params.term_crit.epsilon );
   2031     if( params.term_crit.type & CV_TERMCRIT_ITER )
   2032         cvWriteInt( fs, "iterations", params.term_crit.max_iter );
   2033     cvEndWriteStruct( fs );
   2034 
   2035     __END__;
   2036 }
   2037 
   2038 
   2039 void CvSVM::write( CvFileStorage* fs, const char* name )
   2040 {
   2041     CV_FUNCNAME( "CvSVM::write" );
   2042 
   2043     __BEGIN__;
   2044 
   2045     int i, var_count = get_var_count(), df_count, class_count;
   2046     const CvSVMDecisionFunc* df = decision_func;
   2047 
   2048     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_SVM );
   2049 
   2050     write_params( fs );
   2051 
   2052     cvWriteInt( fs, "var_all", var_all );
   2053     cvWriteInt( fs, "var_count", var_count );
   2054 
   2055     class_count = class_labels ? class_labels->cols :
   2056                   params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
   2057 
   2058     if( class_count )
   2059     {
   2060         cvWriteInt( fs, "class_count", class_count );
   2061 
   2062         if( class_labels )
   2063             cvWrite( fs, "class_labels", class_labels );
   2064 
   2065         if( class_weights )
   2066             cvWrite( fs, "class_weights", class_weights );
   2067     }
   2068 
   2069     if( var_idx )
   2070         cvWrite( fs, "var_idx", var_idx );
   2071 
   2072     // write the joint collection of support vectors
   2073     cvWriteInt( fs, "sv_total", sv_total );
   2074     cvStartWriteStruct( fs, "support_vectors", CV_NODE_SEQ );
   2075     for( i = 0; i < sv_total; i++ )
   2076     {
   2077         cvStartWriteStruct( fs, 0, CV_NODE_SEQ + CV_NODE_FLOW );
   2078         cvWriteRawData( fs, sv[i], var_count, "f" );
   2079         cvEndWriteStruct( fs );
   2080     }
   2081 
   2082     cvEndWriteStruct( fs );
   2083 
   2084     // write decision functions
   2085     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
   2086     df = decision_func;
   2087 
   2088     cvStartWriteStruct( fs, "decision_functions", CV_NODE_SEQ );
   2089     for( i = 0; i < df_count; i++ )
   2090     {
   2091         int sv_count = df[i].sv_count;
   2092         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
   2093         cvWriteInt( fs, "sv_count", sv_count );
   2094         cvWriteReal( fs, "rho", df[i].rho );
   2095         cvStartWriteStruct( fs, "alpha", CV_NODE_SEQ+CV_NODE_FLOW );
   2096         cvWriteRawData( fs, df[i].alpha, df[i].sv_count, "d" );
   2097         cvEndWriteStruct( fs );
   2098         if( class_count > 1 )
   2099         {
   2100             cvStartWriteStruct( fs, "index", CV_NODE_SEQ+CV_NODE_FLOW );
   2101             cvWriteRawData( fs, df[i].sv_index, df[i].sv_count, "i" );
   2102             cvEndWriteStruct( fs );
   2103         }
   2104         else
   2105             CV_ASSERT( sv_count == sv_total );
   2106         cvEndWriteStruct( fs );
   2107     }
   2108     cvEndWriteStruct( fs );
   2109     cvEndWriteStruct( fs );
   2110 
   2111     __END__;
   2112 }
   2113 
   2114 
   2115 void CvSVM::read_params( CvFileStorage* fs, CvFileNode* svm_node )
   2116 {
   2117     CV_FUNCNAME( "CvSVM::read_params" );
   2118 
   2119     __BEGIN__;
   2120 
   2121     int svm_type, kernel_type;
   2122     CvSVMParams _params;
   2123 
   2124     CvFileNode* tmp_node = cvGetFileNodeByName( fs, svm_node, "svm_type" );
   2125     CvFileNode* kernel_node;
   2126     if( !tmp_node )
   2127         CV_ERROR( CV_StsBadArg, "svm_type tag is not found" );
   2128 
   2129     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
   2130         svm_type = cvReadInt( tmp_node, -1 );
   2131     else
   2132     {
   2133         const char* svm_type_str = cvReadString( tmp_node, "" );
   2134         svm_type =
   2135             strcmp( svm_type_str, "C_SVC" ) == 0 ? CvSVM::C_SVC :
   2136             strcmp( svm_type_str, "NU_SVC" ) == 0 ? CvSVM::NU_SVC :
   2137             strcmp( svm_type_str, "ONE_CLASS" ) == 0 ? CvSVM::ONE_CLASS :
   2138             strcmp( svm_type_str, "EPS_SVR" ) == 0 ? CvSVM::EPS_SVR :
   2139             strcmp( svm_type_str, "NU_SVR" ) == 0 ? CvSVM::NU_SVR : -1;
   2140 
   2141         if( svm_type < 0 )
   2142             CV_ERROR( CV_StsParseError, "Missing of invalid SVM type" );
   2143     }
   2144 
   2145     kernel_node = cvGetFileNodeByName( fs, svm_node, "kernel" );
   2146     if( !kernel_node )
   2147         CV_ERROR( CV_StsParseError, "SVM kernel tag is not found" );
   2148 
   2149     tmp_node = cvGetFileNodeByName( fs, kernel_node, "type" );
   2150     if( !tmp_node )
   2151         CV_ERROR( CV_StsParseError, "SVM kernel type tag is not found" );
   2152 
   2153     if( CV_NODE_TYPE(tmp_node->tag) == CV_NODE_INT )
   2154         kernel_type = cvReadInt( tmp_node, -1 );
   2155     else
   2156     {
   2157         const char* kernel_type_str = cvReadString( tmp_node, "" );
   2158         kernel_type =
   2159             strcmp( kernel_type_str, "LINEAR" ) == 0 ? CvSVM::LINEAR :
   2160             strcmp( kernel_type_str, "POLY" ) == 0 ? CvSVM::POLY :
   2161             strcmp( kernel_type_str, "RBF" ) == 0 ? CvSVM::RBF :
   2162             strcmp( kernel_type_str, "SIGMOID" ) == 0 ? CvSVM::SIGMOID : -1;
   2163 
   2164         if( kernel_type < 0 )
   2165             CV_ERROR( CV_StsParseError, "Missing of invalid SVM kernel type" );
   2166     }
   2167 
   2168     _params.svm_type = svm_type;
   2169     _params.kernel_type = kernel_type;
   2170     _params.degree = cvReadRealByName( fs, kernel_node, "degree", 0 );
   2171     _params.gamma = cvReadRealByName( fs, kernel_node, "gamma", 0 );
   2172     _params.coef0 = cvReadRealByName( fs, kernel_node, "coef0", 0 );
   2173 
   2174     _params.C = cvReadRealByName( fs, svm_node, "C", 0 );
   2175     _params.nu = cvReadRealByName( fs, svm_node, "nu", 0 );
   2176     _params.p = cvReadRealByName( fs, svm_node, "p", 0 );
   2177     _params.class_weights = 0;
   2178 
   2179     tmp_node = cvGetFileNodeByName( fs, svm_node, "term_criteria" );
   2180     if( tmp_node )
   2181     {
   2182         _params.term_crit.epsilon = cvReadRealByName( fs, tmp_node, "epsilon", -1. );
   2183         _params.term_crit.max_iter = cvReadIntByName( fs, tmp_node, "iterations", -1 );
   2184         _params.term_crit.type = (_params.term_crit.epsilon >= 0 ? CV_TERMCRIT_EPS : 0) +
   2185                                (_params.term_crit.max_iter >= 0 ? CV_TERMCRIT_ITER : 0);
   2186     }
   2187     else
   2188         _params.term_crit = cvTermCriteria( CV_TERMCRIT_EPS + CV_TERMCRIT_ITER, 1000, FLT_EPSILON );
   2189 
   2190     set_params( _params );
   2191 
   2192     __END__;
   2193 }
   2194 
   2195 
   2196 void CvSVM::read( CvFileStorage* fs, CvFileNode* svm_node )
   2197 {
   2198     const double not_found_dbl = DBL_MAX;
   2199 
   2200     CV_FUNCNAME( "CvSVM::read" );
   2201 
   2202     __BEGIN__;
   2203 
   2204     int i, var_count, df_count, class_count;
   2205     int block_size = 1 << 16, sv_size;
   2206     CvFileNode *sv_node, *df_node;
   2207     CvSVMDecisionFunc* df;
   2208     CvSeqReader reader;
   2209 
   2210     if( !svm_node )
   2211         CV_ERROR( CV_StsParseError, "The requested element is not found" );
   2212 
   2213     clear();
   2214 
   2215     // read SVM parameters
   2216     read_params( fs, svm_node );
   2217 
   2218     // and top-level data
   2219     sv_total = cvReadIntByName( fs, svm_node, "sv_total", -1 );
   2220     var_all = cvReadIntByName( fs, svm_node, "var_all", -1 );
   2221     var_count = cvReadIntByName( fs, svm_node, "var_count", var_all );
   2222     class_count = cvReadIntByName( fs, svm_node, "class_count", 0 );
   2223 
   2224     if( sv_total <= 0 || var_all <= 0 || var_count <= 0 || var_count > var_all || class_count < 0 )
   2225         CV_ERROR( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
   2226 
   2227     CV_CALL( class_labels = (CvMat*)cvReadByName( fs, svm_node, "class_labels" ));
   2228     CV_CALL( class_weights = (CvMat*)cvReadByName( fs, svm_node, "class_weights" ));
   2229     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, svm_node, "comp_idx" ));
   2230 
   2231     if( class_count > 1 && (!class_labels ||
   2232         !CV_IS_MAT(class_labels) || class_labels->cols != class_count))
   2233         CV_ERROR( CV_StsParseError, "Array of class labels is missing or invalid" );
   2234 
   2235     if( var_count < var_all && (!var_idx || !CV_IS_MAT(var_idx) || var_idx->cols != var_count) )
   2236         CV_ERROR( CV_StsParseError, "var_idx array is missing or invalid" );
   2237 
   2238     // read support vectors
   2239     sv_node = cvGetFileNodeByName( fs, svm_node, "support_vectors" );
   2240     if( !sv_node || !CV_NODE_IS_SEQ(sv_node->tag))
   2241         CV_ERROR( CV_StsParseError, "Missing or invalid sequence of support vectors" );
   2242 
   2243     block_size = MAX( block_size, sv_total*(int)sizeof(CvSVMKernelRow));
   2244     block_size = MAX( block_size, sv_total*2*(int)sizeof(double));
   2245     block_size = MAX( block_size, var_all*(int)sizeof(double));
   2246     CV_CALL( storage = cvCreateMemStorage( block_size ));
   2247     CV_CALL( sv = (float**)cvMemStorageAlloc( storage,
   2248                                 sv_total*sizeof(sv[0]) ));
   2249 
   2250     CV_CALL( cvStartReadSeq( sv_node->data.seq, &reader, 0 ));
   2251     sv_size = var_count*sizeof(sv[0][0]);
   2252 
   2253     for( i = 0; i < sv_total; i++ )
   2254     {
   2255         CvFileNode* sv_elem = (CvFileNode*)reader.ptr;
   2256         CV_ASSERT( var_count == 1 || (CV_NODE_IS_SEQ(sv_elem->tag) &&
   2257                    sv_elem->data.seq->total == var_count) );
   2258 
   2259         CV_CALL( sv[i] = (float*)cvMemStorageAlloc( storage, sv_size ));
   2260         CV_CALL( cvReadRawData( fs, sv_elem, sv[i], "f" ));
   2261         CV_NEXT_SEQ_ELEM( sv_node->data.seq->elem_size, reader );
   2262     }
   2263 
   2264     // read decision functions
   2265     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
   2266     df_node = cvGetFileNodeByName( fs, svm_node, "decision_functions" );
   2267     if( !df_node || !CV_NODE_IS_SEQ(df_node->tag) ||
   2268         df_node->data.seq->total != df_count )
   2269         CV_ERROR( CV_StsParseError, "decision_functions is missing or is not a collection "
   2270                   "or has a wrong number of elements" );
   2271 
   2272     CV_CALL( df = decision_func = (CvSVMDecisionFunc*)cvAlloc( df_count*sizeof(df[0]) ));
   2273     cvStartReadSeq( df_node->data.seq, &reader, 0 );
   2274 
   2275     for( i = 0; i < df_count; i++ )
   2276     {
   2277         CvFileNode* df_elem = (CvFileNode*)reader.ptr;
   2278         CvFileNode* alpha_node = cvGetFileNodeByName( fs, df_elem, "alpha" );
   2279 
   2280         int sv_count = cvReadIntByName( fs, df_elem, "sv_count", -1 );
   2281         if( sv_count <= 0 )
   2282             CV_ERROR( CV_StsParseError, "sv_count is missing or non-positive" );
   2283         df[i].sv_count = sv_count;
   2284 
   2285         df[i].rho = cvReadRealByName( fs, df_elem, "rho", not_found_dbl );
   2286         if( fabs(df[i].rho - not_found_dbl) < DBL_EPSILON )
   2287             CV_ERROR( CV_StsParseError, "rho is missing" );
   2288 
   2289         if( !alpha_node )
   2290             CV_ERROR( CV_StsParseError, "alpha is missing in the decision function" );
   2291 
   2292         CV_CALL( df[i].alpha = (double*)cvMemStorageAlloc( storage,
   2293                                         sv_count*sizeof(df[i].alpha[0])));
   2294         CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(alpha_node->tag) &&
   2295                    alpha_node->data.seq->total == sv_count );
   2296         CV_CALL( cvReadRawData( fs, alpha_node, df[i].alpha, "d" ));
   2297 
   2298         if( class_count > 1 )
   2299         {
   2300             CvFileNode* index_node = cvGetFileNodeByName( fs, df_elem, "index" );
   2301             if( !index_node )
   2302                 CV_ERROR( CV_StsParseError, "index is missing in the decision function" );
   2303             CV_CALL( df[i].sv_index = (int*)cvMemStorageAlloc( storage,
   2304                                             sv_count*sizeof(df[i].sv_index[0])));
   2305             CV_ASSERT( sv_count == 1 || CV_NODE_IS_SEQ(index_node->tag) &&
   2306                    index_node->data.seq->total == sv_count );
   2307             CV_CALL( cvReadRawData( fs, index_node, df[i].sv_index, "i" ));
   2308         }
   2309         else
   2310             df[i].sv_index = 0;
   2311 
   2312         CV_NEXT_SEQ_ELEM( df_node->data.seq->elem_size, reader );
   2313     }
   2314 
   2315     create_kernel();
   2316 
   2317     __END__;
   2318 }
   2319 
   2320 #if 0
   2321 
   2322 static void*
   2323 icvCloneSVM( const void* _src )
   2324 {
   2325     CvSVMModel* dst = 0;
   2326 
   2327     CV_FUNCNAME( "icvCloneSVM" );
   2328 
   2329     __BEGIN__;
   2330 
   2331     const CvSVMModel* src = (const CvSVMModel*)_src;
   2332     int var_count, class_count;
   2333     int i, sv_total, df_count;
   2334     int sv_size;
   2335 
   2336     if( !CV_IS_SVM(src) )
   2337         CV_ERROR( !src ? CV_StsNullPtr : CV_StsBadArg, "Input pointer is NULL or invalid" );
   2338 
   2339     // 0. create initial CvSVMModel structure
   2340     CV_CALL( dst = icvCreateSVM() );
   2341     dst->params = src->params;
   2342     dst->params.weight_labels = 0;
   2343     dst->params.weights = 0;
   2344 
   2345     dst->var_all = src->var_all;
   2346     if( src->class_labels )
   2347         dst->class_labels = cvCloneMat( src->class_labels );
   2348     if( src->class_weights )
   2349         dst->class_weights = cvCloneMat( src->class_weights );
   2350     if( src->comp_idx )
   2351         dst->comp_idx = cvCloneMat( src->comp_idx );
   2352 
   2353     var_count = src->comp_idx ? src->comp_idx->cols : src->var_all;
   2354     class_count = src->class_labels ? src->class_labels->cols :
   2355                   src->params.svm_type == CvSVM::ONE_CLASS ? 1 : 0;
   2356     sv_total = dst->sv_total = src->sv_total;
   2357     CV_CALL( dst->storage = cvCreateMemStorage( src->storage->block_size ));
   2358     CV_CALL( dst->sv = (float**)cvMemStorageAlloc( dst->storage,
   2359                                     sv_total*sizeof(dst->sv[0]) ));
   2360 
   2361     sv_size = var_count*sizeof(dst->sv[0][0]);
   2362 
   2363     for( i = 0; i < sv_total; i++ )
   2364     {
   2365         CV_CALL( dst->sv[i] = (float*)cvMemStorageAlloc( dst->storage, sv_size ));
   2366         memcpy( dst->sv[i], src->sv[i], sv_size );
   2367     }
   2368 
   2369     df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
   2370 
   2371     CV_CALL( dst->decision_func = cvAlloc( df_count*sizeof(CvSVMDecisionFunc) ));
   2372 
   2373     for( i = 0; i < df_count; i++ )
   2374     {
   2375         const CvSVMDecisionFunc *sdf =
   2376             (const CvSVMDecisionFunc*)src->decision_func+i;
   2377         CvSVMDecisionFunc *ddf =
   2378             (CvSVMDecisionFunc*)dst->decision_func+i;
   2379         int sv_count = sdf->sv_count;
   2380         ddf->sv_count = sv_count;
   2381         ddf->rho = sdf->rho;
   2382         CV_CALL( ddf->alpha = (double*)cvMemStorageAlloc( dst->storage,
   2383                                         sv_count*sizeof(ddf->alpha[0])));
   2384         memcpy( ddf->alpha, sdf->alpha, sv_count*sizeof(ddf->alpha[0]));
   2385 
   2386         if( class_count > 1 )
   2387         {
   2388             CV_CALL( ddf->sv_index = (int*)cvMemStorageAlloc( dst->storage,
   2389                                                 sv_count*sizeof(ddf->sv_index[0])));
   2390             memcpy( ddf->sv_index, sdf->sv_index, sv_count*sizeof(ddf->sv_index[0]));
   2391         }
   2392         else
   2393             ddf->sv_index = 0;
   2394     }
   2395 
   2396     __END__;
   2397 
   2398     if( cvGetErrStatus() < 0 && dst )
   2399         icvReleaseSVM( &dst );
   2400 
   2401     return dst;
   2402 }
   2403 
   2404 static int icvRegisterSVMType()
   2405 {
   2406     CvTypeInfo info;
   2407     memset( &info, 0, sizeof(info) );
   2408 
   2409     info.flags = 0;
   2410     info.header_size = sizeof( info );
   2411     info.is_instance = icvIsSVM;
   2412     info.release = (CvReleaseFunc)icvReleaseSVM;
   2413     info.read = icvReadSVM;
   2414     info.write = icvWriteSVM;
   2415     info.clone = icvCloneSVM;
   2416     info.type_name = CV_TYPE_NAME_ML_SVM;
   2417     cvRegisterType( &info );
   2418 
   2419     return 1;
   2420 }
   2421 
   2422 
   2423 static int svm = icvRegisterSVMType();
   2424 
   2425 /* The function trains SVM model with optimal parameters, obtained by using cross-validation.
   2426 The parameters to be estimated should be indicated by setting theirs values to FLT_MAX.
   2427 The optimal parameters are saved in <model_params> */
   2428 CV_IMPL CvStatModel*
   2429 cvTrainSVM_CrossValidation( const CvMat* train_data, int tflag,
   2430             const CvMat* responses,
   2431             CvStatModelParams* model_params,
   2432             const CvStatModelParams* cross_valid_params,
   2433             const CvMat* comp_idx,
   2434             const CvMat* sample_idx,
   2435             const CvParamGrid* degree_grid,
   2436             const CvParamGrid* gamma_grid,
   2437             const CvParamGrid* coef_grid,
   2438             const CvParamGrid* C_grid,
   2439             const CvParamGrid* nu_grid,
   2440             const CvParamGrid* p_grid )
   2441 {
   2442     CvStatModel* svm = 0;
   2443 
   2444     CV_FUNCNAME("cvTainSVMCrossValidation");
   2445     __BEGIN__;
   2446 
   2447     double degree_step = 7,
   2448 	       g_step      = 15,
   2449 		   coef_step   = 14,
   2450 		   C_step      = 20,
   2451 		   nu_step     = 5,
   2452 		   p_step      = 7; // all steps must be > 1
   2453     double degree_begin = 0.01, degree_end = 2;
   2454     double g_begin      = 1e-5, g_end      = 0.5;
   2455     double coef_begin   = 0.1,  coef_end   = 300;
   2456     double C_begin      = 0.1,  C_end      = 6000;
   2457     double nu_begin     = 0.01,  nu_end    = 0.4;
   2458     double p_begin      = 0.01, p_end      = 100;
   2459 
   2460     double rate = 0, gamma = 0, C = 0, degree = 0, coef = 0, p = 0, nu = 0;
   2461 
   2462 	double best_rate    = 0;
   2463     double best_degree  = degree_begin;
   2464     double best_gamma   = g_begin;
   2465     double best_coef    = coef_begin;
   2466 	double best_C       = C_begin;
   2467 	double best_nu      = nu_begin;
   2468     double best_p       = p_begin;
   2469 
   2470     CvSVMModelParams svm_params, *psvm_params;
   2471     CvCrossValidationParams* cv_params = (CvCrossValidationParams*)cross_valid_params;
   2472     int svm_type, kernel;
   2473     int is_regression;
   2474 
   2475     if( !model_params )
   2476         CV_ERROR( CV_StsBadArg, "" );
   2477     if( !cv_params )
   2478         CV_ERROR( CV_StsBadArg, "" );
   2479 
   2480     svm_params = *(CvSVMModelParams*)model_params;
   2481     psvm_params = (CvSVMModelParams*)model_params;
   2482     svm_type = svm_params.svm_type;
   2483     kernel = svm_params.kernel_type;
   2484 
   2485     svm_params.degree = svm_params.degree > 0 ? svm_params.degree : 1;
   2486     svm_params.gamma = svm_params.gamma > 0 ? svm_params.gamma : 1;
   2487     svm_params.coef0 = svm_params.coef0 > 0 ? svm_params.coef0 : 1e-6;
   2488     svm_params.C = svm_params.C > 0 ? svm_params.C : 1;
   2489     svm_params.nu = svm_params.nu > 0 ? svm_params.nu : 1;
   2490     svm_params.p = svm_params.p > 0 ? svm_params.p : 1;
   2491 
   2492     if( degree_grid )
   2493     {
   2494         if( !(degree_grid->max_val == 0 && degree_grid->min_val == 0 &&
   2495               degree_grid->step == 0) )
   2496         {
   2497             if( degree_grid->min_val > degree_grid->max_val )
   2498                 CV_ERROR( CV_StsBadArg,
   2499                 "low bound of grid should be less then the upper one");
   2500             if( degree_grid->step <= 1 )
   2501                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
   2502             degree_begin = degree_grid->min_val;
   2503             degree_end   = degree_grid->max_val;
   2504             degree_step  = degree_grid->step;
   2505         }
   2506     }
   2507     else
   2508         degree_begin = degree_end = svm_params.degree;
   2509 
   2510     if( gamma_grid )
   2511     {
   2512         if( !(gamma_grid->max_val == 0 && gamma_grid->min_val == 0 &&
   2513               gamma_grid->step == 0) )
   2514         {
   2515             if( gamma_grid->min_val > gamma_grid->max_val )
   2516                 CV_ERROR( CV_StsBadArg,
   2517                 "low bound of grid should be less then the upper one");
   2518             if( gamma_grid->step <= 1 )
   2519                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
   2520             g_begin = gamma_grid->min_val;
   2521             g_end   = gamma_grid->max_val;
   2522             g_step  = gamma_grid->step;
   2523         }
   2524     }
   2525     else
   2526         g_begin = g_end = svm_params.gamma;
   2527 
   2528     if( coef_grid )
   2529     {
   2530         if( !(coef_grid->max_val == 0 && coef_grid->min_val == 0 &&
   2531               coef_grid->step == 0) )
   2532         {
   2533             if( coef_grid->min_val > coef_grid->max_val )
   2534                 CV_ERROR( CV_StsBadArg,
   2535                 "low bound of grid should be less then the upper one");
   2536             if( coef_grid->step <= 1 )
   2537                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
   2538             coef_begin = coef_grid->min_val;
   2539             coef_end   = coef_grid->max_val;
   2540             coef_step  = coef_grid->step;
   2541         }
   2542     }
   2543     else
   2544         coef_begin = coef_end = svm_params.coef0;
   2545 
   2546     if( C_grid )
   2547     {
   2548         if( !(C_grid->max_val == 0 && C_grid->min_val == 0 && C_grid->step == 0))
   2549         {
   2550             if( C_grid->min_val > C_grid->max_val )
   2551                 CV_ERROR( CV_StsBadArg,
   2552                 "low bound of grid should be less then the upper one");
   2553             if( C_grid->step <= 1 )
   2554                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
   2555             C_begin = C_grid->min_val;
   2556             C_end   = C_grid->max_val;
   2557             C_step  = C_grid->step;
   2558         }
   2559     }
   2560     else
   2561         C_begin = C_end = svm_params.C;
   2562 
   2563     if( nu_grid )
   2564     {
   2565         if(!(nu_grid->max_val == 0 && nu_grid->min_val == 0 && nu_grid->step==0))
   2566         {
   2567             if( nu_grid->min_val > nu_grid->max_val )
   2568                 CV_ERROR( CV_StsBadArg,
   2569                 "low bound of grid should be less then the upper one");
   2570             if( nu_grid->step <= 1 )
   2571                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
   2572             nu_begin = nu_grid->min_val;
   2573             nu_end   = nu_grid->max_val;
   2574             nu_step  = nu_grid->step;
   2575         }
   2576     }
   2577     else
   2578         nu_begin = nu_end = svm_params.nu;
   2579 
   2580     if( p_grid )
   2581     {
   2582         if( !(p_grid->max_val == 0 && p_grid->min_val == 0 && p_grid->step == 0))
   2583         {
   2584             if( p_grid->min_val > p_grid->max_val )
   2585                 CV_ERROR( CV_StsBadArg,
   2586                 "low bound of grid should be less then the upper one");
   2587             if( p_grid->step <= 1 )
   2588                 CV_ERROR( CV_StsBadArg, "grid step should be greater 1" );
   2589             p_begin = p_grid->min_val;
   2590             p_end   = p_grid->max_val;
   2591             p_step  = p_grid->step;
   2592         }
   2593     }
   2594     else
   2595         p_begin = p_end = svm_params.p;
   2596 
   2597     // these parameters are not used:
   2598     if( kernel != CvSVM::POLY )
   2599         degree_begin = degree_end = svm_params.degree;
   2600 
   2601    if( kernel == CvSVM::LINEAR )
   2602         g_begin = g_end = svm_params.gamma;
   2603 
   2604     if( kernel != CvSVM::POLY && kernel != CvSVM::SIGMOID )
   2605         coef_begin = coef_end = svm_params.coef0;
   2606 
   2607     if( svm_type == CvSVM::NU_SVC || svm_type == CvSVM::ONE_CLASS )
   2608         C_begin = C_end = svm_params.C;
   2609 
   2610     if( svm_type == CvSVM::C_SVC || svm_type == CvSVM::EPS_SVR )
   2611         nu_begin = nu_end = svm_params.nu;
   2612 
   2613     if( svm_type != CvSVM::EPS_SVR )
   2614         p_begin = p_end = svm_params.p;
   2615 
   2616     is_regression = cv_params->is_regression;
   2617     best_rate = is_regression ? FLT_MAX : 0;
   2618 
   2619     assert( g_step > 1 && degree_step > 1 && coef_step > 1);
   2620     assert( p_step > 1 && C_step > 1 && nu_step > 1 );
   2621 
   2622     for( degree = degree_begin; degree <= degree_end; degree *= degree_step )
   2623     {
   2624       svm_params.degree = degree;
   2625       //printf("degree = %.3f\n", degree );
   2626       for( gamma= g_begin; gamma <= g_end; gamma *= g_step )
   2627       {
   2628         svm_params.gamma = gamma;
   2629         //printf("   gamma = %.3f\n", gamma );
   2630         for( coef = coef_begin; coef <= coef_end; coef *= coef_step )
   2631         {
   2632           svm_params.coef0 = coef;
   2633           //printf("      coef = %.3f\n", coef );
   2634           for( C = C_begin; C <= C_end; C *= C_step )
   2635           {
   2636             svm_params.C = C;
   2637             //printf("         C = %.3f\n", C );
   2638             for( nu = nu_begin; nu <= nu_end; nu *= nu_step )
   2639             {
   2640               svm_params.nu = nu;
   2641               //printf("            nu = %.3f\n", nu );
   2642               for( p = p_begin; p <= p_end; p *= p_step )
   2643               {
   2644                 int well;
   2645                 svm_params.p = p;
   2646                 //printf("               p = %.3f\n", p );
   2647 
   2648                 CV_CALL(rate = cvCrossValidation( train_data, tflag, responses, &cvTrainSVM,
   2649                     cross_valid_params, (CvStatModelParams*)&svm_params, comp_idx, sample_idx ));
   2650 
   2651                 well =  rate > best_rate && !is_regression || rate < best_rate && is_regression;
   2652                 if( well || (rate == best_rate && C < best_C) )
   2653                 {
   2654                     best_rate   = rate;
   2655                     best_degree = degree;
   2656                     best_gamma  = gamma;
   2657                     best_coef   = coef;
   2658                     best_C      = C;
   2659                     best_nu     = nu;
   2660                     best_p      = p;
   2661                 }
   2662                 //printf("                  rate = %.2f\n", rate );
   2663               }
   2664             }
   2665           }
   2666         }
   2667       }
   2668     }
   2669     //printf("The best:\nrate = %.2f%% degree = %f gamma = %f coef = %f c = %f nu = %f p = %f\n",
   2670       //  best_rate, best_degree, best_gamma, best_coef, best_C, best_nu, best_p );
   2671 
   2672     psvm_params->C      = best_C;
   2673     psvm_params->nu     = best_nu;
   2674     psvm_params->p      = best_p;
   2675     psvm_params->gamma  = best_gamma;
   2676     psvm_params->degree = best_degree;
   2677     psvm_params->coef0  = best_coef;
   2678 
   2679     CV_CALL(svm = cvTrainSVM( train_data, tflag, responses, model_params, comp_idx, sample_idx ));
   2680 
   2681     __END__;
   2682 
   2683     return svm;
   2684 }
   2685 
   2686 #endif
   2687 
   2688 /* End of file. */
   2689 
   2690