Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                           License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
     15 // Third party copyrights are property of their respective owners.
     16 //
     17 // Redistribution and use in source and binary forms, with or without modification,
     18 // are permitted provided that the following conditions are met:
     19 //
     20 //   * Redistribution's of source code must retain the above copyright notice,
     21 //     this list of conditions and the following disclaimer.
     22 //
     23 //   * Redistribution's in binary form must reproduce the above copyright notice,
     24 //     this list of conditions and the following disclaimer in the documentation
     25 //     and/or other materials provided with the distribution.
     26 //
     27 //   * The name of the copyright holders may not be used to endorse or promote products
     28 //     derived from this software without specific prior written permission.
     29 //
     30 // This software is provided by the copyright holders and contributors "as is" and
     31 // any express or implied warranties, including, but not limited to, the implied
     32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     33 // In no event shall the Intel Corporation or contributors be liable for any direct,
     34 // indirect, incidental, special, exemplary, or consequential damages
     35 // (including, but not limited to, procurement of substitute goods or services;
     36 // loss of use, data, or profits; or business interruption) however caused
     37 // and on any theory of liability, whether in contract, strict liability,
     38 // or tort (including negligence or otherwise) arising in any way out of
     39 // the use of this software, even if advised of the possibility of such damage.
     40 //
     41 //M*/
     42 
     43 #include "precomp.hpp"
     44 
     45 #include <stdarg.h>
     46 #include <ctype.h>
     47 
     48 /****************************************************************************************\
     49                                 COPYRIGHT NOTICE
     50                                 ----------------
     51 
     52   The code has been derived from libsvm library (version 2.6)
     53   (http://www.csie.ntu.edu.tw/~cjlin/libsvm).
     54 
     55   Here is the orignal copyright:
     56 ------------------------------------------------------------------------------------------
     57     Copyright (c) 2000-2003 Chih-Chung Chang and Chih-Jen Lin
     58     All rights reserved.
     59 
     60     Redistribution and use in source and binary forms, with or without
     61     modification, are permitted provided that the following conditions
     62     are met:
     63 
     64     1. Redistributions of source code must retain the above copyright
     65     notice, this list of conditions and the following disclaimer.
     66 
     67     2. Redistributions in binary form must reproduce the above copyright
     68     notice, this list of conditions and the following disclaimer in the
     69     documentation and/or other materials provided with the distribution.
     70 
     71     3. Neither name of copyright holders nor the names of its contributors
     72     may be used to endorse or promote products derived from this software
     73     without specific prior written permission.
     74 
     75 
     76     THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     77     ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     78     LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
     79     A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE REGENTS OR
     80     CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
     81     EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
     82     PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
     83     PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
     84     LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
     85     NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
     86     SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
     87 \****************************************************************************************/
     88 
     89 namespace cv { namespace ml {
     90 
     91 typedef float Qfloat;
     92 const int QFLOAT_TYPE = DataDepth<Qfloat>::value;
     93 
     94 // Param Grid
     95 static void checkParamGrid(const ParamGrid& pg)
     96 {
     97     if( pg.minVal > pg.maxVal )
     98         CV_Error( CV_StsBadArg, "Lower bound of the grid must be less then the upper one" );
     99     if( pg.minVal < DBL_EPSILON )
    100         CV_Error( CV_StsBadArg, "Lower bound of the grid must be positive" );
    101     if( pg.logStep < 1. + FLT_EPSILON )
    102         CV_Error( CV_StsBadArg, "Grid step must greater then 1" );
    103 }
    104 
    105 // SVM training parameters
    106 struct SvmParams
    107 {
    108     int         svmType;
    109     int         kernelType;
    110     double      gamma;
    111     double      coef0;
    112     double      degree;
    113     double      C;
    114     double      nu;
    115     double      p;
    116     Mat         classWeights;
    117     TermCriteria termCrit;
    118 
    119     SvmParams()
    120     {
    121         svmType = SVM::C_SVC;
    122         kernelType = SVM::RBF;
    123         degree = 0;
    124         gamma = 1;
    125         coef0 = 0;
    126         C = 1;
    127         nu = 0;
    128         p = 0;
    129         termCrit = TermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 1000, FLT_EPSILON );
    130     }
    131 
    132     SvmParams( int _svmType, int _kernelType,
    133             double _degree, double _gamma, double _coef0,
    134             double _Con, double _nu, double _p,
    135             const Mat& _classWeights, TermCriteria _termCrit )
    136     {
    137         svmType = _svmType;
    138         kernelType = _kernelType;
    139         degree = _degree;
    140         gamma = _gamma;
    141         coef0 = _coef0;
    142         C = _Con;
    143         nu = _nu;
    144         p = _p;
    145         classWeights = _classWeights;
    146         termCrit = _termCrit;
    147     }
    148 
    149 };
    150 
    151 /////////////////////////////////////// SVM kernel ///////////////////////////////////////
    152 class SVMKernelImpl : public SVM::Kernel
    153 {
    154 public:
    155     SVMKernelImpl( const SvmParams& _params = SvmParams() )
    156     {
    157         params = _params;
    158     }
    159 
    160     int getType() const
    161     {
    162         return params.kernelType;
    163     }
    164 
    165     void calc_non_rbf_base( int vcount, int var_count, const float* vecs,
    166                             const float* another, Qfloat* results,
    167                             double alpha, double beta )
    168     {
    169         int j, k;
    170         for( j = 0; j < vcount; j++ )
    171         {
    172             const float* sample = &vecs[j*var_count];
    173             double s = 0;
    174             for( k = 0; k <= var_count - 4; k += 4 )
    175                 s += sample[k]*another[k] + sample[k+1]*another[k+1] +
    176                 sample[k+2]*another[k+2] + sample[k+3]*another[k+3];
    177             for( ; k < var_count; k++ )
    178                 s += sample[k]*another[k];
    179             results[j] = (Qfloat)(s*alpha + beta);
    180         }
    181     }
    182 
    183     void calc_linear( int vcount, int var_count, const float* vecs,
    184                       const float* another, Qfloat* results )
    185     {
    186         calc_non_rbf_base( vcount, var_count, vecs, another, results, 1, 0 );
    187     }
    188 
    189     void calc_poly( int vcount, int var_count, const float* vecs,
    190                     const float* another, Qfloat* results )
    191     {
    192         Mat R( 1, vcount, QFLOAT_TYPE, results );
    193         calc_non_rbf_base( vcount, var_count, vecs, another, results, params.gamma, params.coef0 );
    194         if( vcount > 0 )
    195             pow( R, params.degree, R );
    196     }
    197 
    198     void calc_sigmoid( int vcount, int var_count, const float* vecs,
    199                        const float* another, Qfloat* results )
    200     {
    201         int j;
    202         calc_non_rbf_base( vcount, var_count, vecs, another, results,
    203                           -2*params.gamma, -2*params.coef0 );
    204         // TODO: speedup this
    205         for( j = 0; j < vcount; j++ )
    206         {
    207             Qfloat t = results[j];
    208             Qfloat e = std::exp(-std::abs(t));
    209             if( t > 0 )
    210                 results[j] = (Qfloat)((1. - e)/(1. + e));
    211             else
    212                 results[j] = (Qfloat)((e - 1.)/(e + 1.));
    213         }
    214     }
    215 
    216 
    217     void calc_rbf( int vcount, int var_count, const float* vecs,
    218                    const float* another, Qfloat* results )
    219     {
    220         double gamma = -params.gamma;
    221         int j, k;
    222 
    223         for( j = 0; j < vcount; j++ )
    224         {
    225             const float* sample = &vecs[j*var_count];
    226             double s = 0;
    227 
    228             for( k = 0; k <= var_count - 4; k += 4 )
    229             {
    230                 double t0 = sample[k] - another[k];
    231                 double t1 = sample[k+1] - another[k+1];
    232 
    233                 s += t0*t0 + t1*t1;
    234 
    235                 t0 = sample[k+2] - another[k+2];
    236                 t1 = sample[k+3] - another[k+3];
    237 
    238                 s += t0*t0 + t1*t1;
    239             }
    240 
    241             for( ; k < var_count; k++ )
    242             {
    243                 double t0 = sample[k] - another[k];
    244                 s += t0*t0;
    245             }
    246             results[j] = (Qfloat)(s*gamma);
    247         }
    248 
    249         if( vcount > 0 )
    250         {
    251             Mat R( 1, vcount, QFLOAT_TYPE, results );
    252             exp( R, R );
    253         }
    254     }
    255 
    256     /// Histogram intersection kernel
    257     void calc_intersec( int vcount, int var_count, const float* vecs,
    258                         const float* another, Qfloat* results )
    259     {
    260         int j, k;
    261         for( j = 0; j < vcount; j++ )
    262         {
    263             const float* sample = &vecs[j*var_count];
    264             double s = 0;
    265             for( k = 0; k <= var_count - 4; k += 4 )
    266                 s += std::min(sample[k],another[k]) + std::min(sample[k+1],another[k+1]) +
    267                 std::min(sample[k+2],another[k+2]) + std::min(sample[k+3],another[k+3]);
    268             for( ; k < var_count; k++ )
    269                 s += std::min(sample[k],another[k]);
    270             results[j] = (Qfloat)(s);
    271         }
    272     }
    273 
    274     /// Exponential chi2 kernel
    275     void calc_chi2( int vcount, int var_count, const float* vecs,
    276                     const float* another, Qfloat* results )
    277     {
    278         Mat R( 1, vcount, QFLOAT_TYPE, results );
    279         double gamma = -params.gamma;
    280         int j, k;
    281         for( j = 0; j < vcount; j++ )
    282         {
    283             const float* sample = &vecs[j*var_count];
    284             double chi2 = 0;
    285             for(k = 0 ; k < var_count; k++ )
    286             {
    287                 double d = sample[k]-another[k];
    288                 double devisor = sample[k]+another[k];
    289                 /// if devisor == 0, the Chi2 distance would be zero,
    290                 // but calculation would rise an error because of deviding by zero
    291                 if (devisor != 0)
    292                 {
    293                     chi2 += d*d/devisor;
    294                 }
    295             }
    296             results[j] = (Qfloat) (gamma*chi2);
    297         }
    298         if( vcount > 0 )
    299             exp( R, R );
    300     }
    301 
    302     void calc( int vcount, int var_count, const float* vecs,
    303                const float* another, Qfloat* results )
    304     {
    305         switch( params.kernelType )
    306         {
    307         case SVM::LINEAR:
    308             calc_linear(vcount, var_count, vecs, another, results);
    309             break;
    310         case SVM::RBF:
    311             calc_rbf(vcount, var_count, vecs, another, results);
    312             break;
    313         case SVM::POLY:
    314             calc_poly(vcount, var_count, vecs, another, results);
    315             break;
    316         case SVM::SIGMOID:
    317             calc_sigmoid(vcount, var_count, vecs, another, results);
    318             break;
    319         case SVM::CHI2:
    320             calc_chi2(vcount, var_count, vecs, another, results);
    321             break;
    322         case SVM::INTER:
    323             calc_intersec(vcount, var_count, vecs, another, results);
    324             break;
    325         default:
    326             CV_Error(CV_StsBadArg, "Unknown kernel type");
    327         }
    328         const Qfloat max_val = (Qfloat)(FLT_MAX*1e-3);
    329         for( int j = 0; j < vcount; j++ )
    330         {
    331             if( results[j] > max_val )
    332                 results[j] = max_val;
    333         }
    334     }
    335 
    336     SvmParams params;
    337 };
    338 
    339 
    340 
    341 /////////////////////////////////////////////////////////////////////////
    342 
    343 static void sortSamplesByClasses( const Mat& _samples, const Mat& _responses,
    344                            vector<int>& sidx_all, vector<int>& class_ranges )
    345 {
    346     int i, nsamples = _samples.rows;
    347     CV_Assert( _responses.isContinuous() && _responses.checkVector(1, CV_32S) == nsamples );
    348 
    349     setRangeVector(sidx_all, nsamples);
    350 
    351     const int* rptr = _responses.ptr<int>();
    352     std::sort(sidx_all.begin(), sidx_all.end(), cmp_lt_idx<int>(rptr));
    353     class_ranges.clear();
    354     class_ranges.push_back(0);
    355 
    356     for( i = 0; i < nsamples; i++ )
    357     {
    358         if( i == nsamples-1 || rptr[sidx_all[i]] != rptr[sidx_all[i+1]] )
    359             class_ranges.push_back(i+1);
    360     }
    361 }
    362 
    363 //////////////////////// SVM implementation //////////////////////////////
    364 
    365 ParamGrid SVM::getDefaultGrid( int param_id )
    366 {
    367     ParamGrid grid;
    368     if( param_id == SVM::C )
    369     {
    370         grid.minVal = 0.1;
    371         grid.maxVal = 500;
    372         grid.logStep = 5; // total iterations = 5
    373     }
    374     else if( param_id == SVM::GAMMA )
    375     {
    376         grid.minVal = 1e-5;
    377         grid.maxVal = 0.6;
    378         grid.logStep = 15; // total iterations = 4
    379     }
    380     else if( param_id == SVM::P )
    381     {
    382         grid.minVal = 0.01;
    383         grid.maxVal = 100;
    384         grid.logStep = 7; // total iterations = 4
    385     }
    386     else if( param_id == SVM::NU )
    387     {
    388         grid.minVal = 0.01;
    389         grid.maxVal = 0.2;
    390         grid.logStep = 3; // total iterations = 3
    391     }
    392     else if( param_id == SVM::COEF )
    393     {
    394         grid.minVal = 0.1;
    395         grid.maxVal = 300;
    396         grid.logStep = 14; // total iterations = 3
    397     }
    398     else if( param_id == SVM::DEGREE )
    399     {
    400         grid.minVal = 0.01;
    401         grid.maxVal = 4;
    402         grid.logStep = 7; // total iterations = 3
    403     }
    404     else
    405         cvError( CV_StsBadArg, "SVM::getDefaultGrid", "Invalid type of parameter "
    406                 "(use one of SVM::C, SVM::GAMMA et al.)", __FILE__, __LINE__ );
    407     return grid;
    408 }
    409 
    410 
    411 class SVMImpl : public SVM
    412 {
    413 public:
    414     struct DecisionFunc
    415     {
    416         DecisionFunc(double _rho, int _ofs) : rho(_rho), ofs(_ofs) {}
    417         DecisionFunc() : rho(0.), ofs(0) {}
    418         double rho;
    419         int ofs;
    420     };
    421 
    422     // Generalized SMO+SVMlight algorithm
    423     // Solves:
    424     //
    425     //  min [0.5(\alpha^T Q \alpha) + b^T \alpha]
    426     //
    427     //      y^T \alpha = \delta
    428     //      y_i = +1 or -1
    429     //      0 <= alpha_i <= Cp for y_i = 1
    430     //      0 <= alpha_i <= Cn for y_i = -1
    431     //
    432     // Given:
    433     //
    434     //  Q, b, y, Cp, Cn, and an initial feasible point \alpha
    435     //  l is the size of vectors and matrices
    436     //  eps is the stopping criterion
    437     //
    438     // solution will be put in \alpha, objective value will be put in obj
    439     //
    440     class Solver
    441     {
    442     public:
    443         enum { MIN_CACHE_SIZE = (40 << 20) /* 40Mb */, MAX_CACHE_SIZE = (500 << 20) /* 500Mb */ };
    444 
    445         typedef bool (Solver::*SelectWorkingSet)( int& i, int& j );
    446         typedef Qfloat* (Solver::*GetRow)( int i, Qfloat* row, Qfloat* dst, bool existed );
    447         typedef void (Solver::*CalcRho)( double& rho, double& r );
    448 
    449         struct KernelRow
    450         {
    451             KernelRow() { idx = -1; prev = next = 0; }
    452             KernelRow(int _idx, int _prev, int _next) : idx(_idx), prev(_prev), next(_next) {}
    453             int idx;
    454             int prev;
    455             int next;
    456         };
    457 
    458         struct SolutionInfo
    459         {
    460             SolutionInfo() { obj = rho = upper_bound_p = upper_bound_n = r = 0; }
    461             double obj;
    462             double rho;
    463             double upper_bound_p;
    464             double upper_bound_n;
    465             double r;   // for Solver_NU
    466         };
    467 
    468         void clear()
    469         {
    470             alpha_vec = 0;
    471             select_working_set_func = 0;
    472             calc_rho_func = 0;
    473             get_row_func = 0;
    474             lru_cache.clear();
    475         }
    476 
    477         Solver( const Mat& _samples, const vector<schar>& _y,
    478                 vector<double>& _alpha, const vector<double>& _b,
    479                 double _Cp, double _Cn,
    480                 const Ptr<SVM::Kernel>& _kernel, GetRow _get_row,
    481                 SelectWorkingSet _select_working_set, CalcRho _calc_rho,
    482                 TermCriteria _termCrit )
    483         {
    484             clear();
    485 
    486             samples = _samples;
    487             sample_count = samples.rows;
    488             var_count = samples.cols;
    489 
    490             y_vec = _y;
    491             alpha_vec = &_alpha;
    492             alpha_count = (int)alpha_vec->size();
    493             b_vec = _b;
    494             kernel = _kernel;
    495 
    496             C[0] = _Cn;
    497             C[1] = _Cp;
    498             eps = _termCrit.epsilon;
    499             max_iter = _termCrit.maxCount;
    500 
    501             G_vec.resize(alpha_count);
    502             alpha_status_vec.resize(alpha_count);
    503             buf[0].resize(sample_count*2);
    504             buf[1].resize(sample_count*2);
    505 
    506             select_working_set_func = _select_working_set;
    507             CV_Assert(select_working_set_func != 0);
    508 
    509             calc_rho_func = _calc_rho;
    510             CV_Assert(calc_rho_func != 0);
    511 
    512             get_row_func = _get_row;
    513             CV_Assert(get_row_func != 0);
    514 
    515             // assume that for large training sets ~25% of Q matrix is used
    516             int64 csize = (int64)sample_count*sample_count/4;
    517             csize = std::max(csize, (int64)(MIN_CACHE_SIZE/sizeof(Qfloat)) );
    518             csize = std::min(csize, (int64)(MAX_CACHE_SIZE/sizeof(Qfloat)) );
    519             max_cache_size = (int)((csize + sample_count-1)/sample_count);
    520             max_cache_size = std::min(std::max(max_cache_size, 1), sample_count);
    521             cache_size = 0;
    522 
    523             lru_cache.clear();
    524             lru_cache.resize(sample_count+1, KernelRow(-1, 0, 0));
    525             lru_first = lru_last = 0;
    526             lru_cache_data.create(max_cache_size, sample_count, QFLOAT_TYPE);
    527         }
    528 
    529         Qfloat* get_row_base( int i, bool* _existed )
    530         {
    531             int i1 = i < sample_count ? i : i - sample_count;
    532             KernelRow& kr = lru_cache[i1+1];
    533             if( _existed )
    534                 *_existed = kr.idx >= 0;
    535             if( kr.idx < 0 )
    536             {
    537                 if( cache_size < max_cache_size )
    538                 {
    539                     kr.idx = cache_size;
    540                     cache_size++;
    541                     if (!lru_last)
    542                         lru_last = i1+1;
    543                 }
    544                 else
    545                 {
    546                     KernelRow& last = lru_cache[lru_last];
    547                     kr.idx = last.idx;
    548                     last.idx = -1;
    549                     lru_cache[last.prev].next = 0;
    550                     lru_last = last.prev;
    551                     last.prev = 0;
    552                     last.next = 0;
    553                 }
    554                 kernel->calc( sample_count, var_count, samples.ptr<float>(),
    555                               samples.ptr<float>(i1), lru_cache_data.ptr<Qfloat>(kr.idx) );
    556             }
    557             else
    558             {
    559                 if( kr.next )
    560                     lru_cache[kr.next].prev = kr.prev;
    561                 else
    562                     lru_last = kr.prev;
    563                 if( kr.prev )
    564                     lru_cache[kr.prev].next = kr.next;
    565                 else
    566                     lru_first = kr.next;
    567             }
    568             if (lru_first)
    569                 lru_cache[lru_first].prev = i1+1;
    570             kr.next = lru_first;
    571             kr.prev = 0;
    572             lru_first = i1+1;
    573 
    574             return lru_cache_data.ptr<Qfloat>(kr.idx);
    575         }
    576 
    577         Qfloat* get_row_svc( int i, Qfloat* row, Qfloat*, bool existed )
    578         {
    579             if( !existed )
    580             {
    581                 const schar* _y = &y_vec[0];
    582                 int j, len = sample_count;
    583 
    584                 if( _y[i] > 0 )
    585                 {
    586                     for( j = 0; j < len; j++ )
    587                         row[j] = _y[j]*row[j];
    588                 }
    589                 else
    590                 {
    591                     for( j = 0; j < len; j++ )
    592                         row[j] = -_y[j]*row[j];
    593                 }
    594             }
    595             return row;
    596         }
    597 
    598         Qfloat* get_row_one_class( int, Qfloat* row, Qfloat*, bool )
    599         {
    600             return row;
    601         }
    602 
    603         Qfloat* get_row_svr( int i, Qfloat* row, Qfloat* dst, bool )
    604         {
    605             int j, len = sample_count;
    606             Qfloat* dst_pos = dst;
    607             Qfloat* dst_neg = dst + len;
    608             if( i >= len )
    609                 std::swap(dst_pos, dst_neg);
    610 
    611             for( j = 0; j < len; j++ )
    612             {
    613                 Qfloat t = row[j];
    614                 dst_pos[j] = t;
    615                 dst_neg[j] = -t;
    616             }
    617             return dst;
    618         }
    619 
    620         Qfloat* get_row( int i, float* dst )
    621         {
    622             bool existed = false;
    623             float* row = get_row_base( i, &existed );
    624             return (this->*get_row_func)( i, row, dst, existed );
    625         }
    626 
    627         #undef is_upper_bound
    628         #define is_upper_bound(i) (alpha_status[i] > 0)
    629 
    630         #undef is_lower_bound
    631         #define is_lower_bound(i) (alpha_status[i] < 0)
    632 
    633         #undef is_free
    634         #define is_free(i) (alpha_status[i] == 0)
    635 
    636         #undef get_C
    637         #define get_C(i) (C[y[i]>0])
    638 
    639         #undef update_alpha_status
    640         #define update_alpha_status(i) \
    641             alpha_status[i] = (schar)(alpha[i] >= get_C(i) ? 1 : alpha[i] <= 0 ? -1 : 0)
    642 
    643         #undef reconstruct_gradient
    644         #define reconstruct_gradient() /* empty for now */
    645 
    646         bool solve_generic( SolutionInfo& si )
    647         {
    648             const schar* y = &y_vec[0];
    649             double* alpha = &alpha_vec->at(0);
    650             schar* alpha_status = &alpha_status_vec[0];
    651             double* G = &G_vec[0];
    652             double* b = &b_vec[0];
    653 
    654             int iter = 0;
    655             int i, j, k;
    656 
    657             // 1. initialize gradient and alpha status
    658             for( i = 0; i < alpha_count; i++ )
    659             {
    660                 update_alpha_status(i);
    661                 G[i] = b[i];
    662                 if( fabs(G[i]) > 1e200 )
    663                     return false;
    664             }
    665 
    666             for( i = 0; i < alpha_count; i++ )
    667             {
    668                 if( !is_lower_bound(i) )
    669                 {
    670                     const Qfloat *Q_i = get_row( i, &buf[0][0] );
    671                     double alpha_i = alpha[i];
    672 
    673                     for( j = 0; j < alpha_count; j++ )
    674                         G[j] += alpha_i*Q_i[j];
    675                 }
    676             }
    677 
    678             // 2. optimization loop
    679             for(;;)
    680             {
    681                 const Qfloat *Q_i, *Q_j;
    682                 double C_i, C_j;
    683                 double old_alpha_i, old_alpha_j, alpha_i, alpha_j;
    684                 double delta_alpha_i, delta_alpha_j;
    685 
    686         #ifdef _DEBUG
    687                 for( i = 0; i < alpha_count; i++ )
    688                 {
    689                     if( fabs(G[i]) > 1e+300 )
    690                         return false;
    691 
    692                     if( fabs(alpha[i]) > 1e16 )
    693                         return false;
    694                 }
    695         #endif
    696 
    697                 if( (this->*select_working_set_func)( i, j ) != 0 || iter++ >= max_iter )
    698                     break;
    699 
    700                 Q_i = get_row( i, &buf[0][0] );
    701                 Q_j = get_row( j, &buf[1][0] );
    702 
    703                 C_i = get_C(i);
    704                 C_j = get_C(j);
    705 
    706                 alpha_i = old_alpha_i = alpha[i];
    707                 alpha_j = old_alpha_j = alpha[j];
    708 
    709                 if( y[i] != y[j] )
    710                 {
    711                     double denom = Q_i[i]+Q_j[j]+2*Q_i[j];
    712                     double delta = (-G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
    713                     double diff = alpha_i - alpha_j;
    714                     alpha_i += delta;
    715                     alpha_j += delta;
    716 
    717                     if( diff > 0 && alpha_j < 0 )
    718                     {
    719                         alpha_j = 0;
    720                         alpha_i = diff;
    721                     }
    722                     else if( diff <= 0 && alpha_i < 0 )
    723                     {
    724                         alpha_i = 0;
    725                         alpha_j = -diff;
    726                     }
    727 
    728                     if( diff > C_i - C_j && alpha_i > C_i )
    729                     {
    730                         alpha_i = C_i;
    731                         alpha_j = C_i - diff;
    732                     }
    733                     else if( diff <= C_i - C_j && alpha_j > C_j )
    734                     {
    735                         alpha_j = C_j;
    736                         alpha_i = C_j + diff;
    737                     }
    738                 }
    739                 else
    740                 {
    741                     double denom = Q_i[i]+Q_j[j]-2*Q_i[j];
    742                     double delta = (G[i]-G[j])/MAX(fabs(denom),FLT_EPSILON);
    743                     double sum = alpha_i + alpha_j;
    744                     alpha_i -= delta;
    745                     alpha_j += delta;
    746 
    747                     if( sum > C_i && alpha_i > C_i )
    748                     {
    749                         alpha_i = C_i;
    750                         alpha_j = sum - C_i;
    751                     }
    752                     else if( sum <= C_i && alpha_j < 0)
    753                     {
    754                         alpha_j = 0;
    755                         alpha_i = sum;
    756                     }
    757 
    758                     if( sum > C_j && alpha_j > C_j )
    759                     {
    760                         alpha_j = C_j;
    761                         alpha_i = sum - C_j;
    762                     }
    763                     else if( sum <= C_j && alpha_i < 0 )
    764                     {
    765                         alpha_i = 0;
    766                         alpha_j = sum;
    767                     }
    768                 }
    769 
    770                 // update alpha
    771                 alpha[i] = alpha_i;
    772                 alpha[j] = alpha_j;
    773                 update_alpha_status(i);
    774                 update_alpha_status(j);
    775 
    776                 // update G
    777                 delta_alpha_i = alpha_i - old_alpha_i;
    778                 delta_alpha_j = alpha_j - old_alpha_j;
    779 
    780                 for( k = 0; k < alpha_count; k++ )
    781                     G[k] += Q_i[k]*delta_alpha_i + Q_j[k]*delta_alpha_j;
    782             }
    783 
    784             // calculate rho
    785             (this->*calc_rho_func)( si.rho, si.r );
    786 
    787             // calculate objective value
    788             for( i = 0, si.obj = 0; i < alpha_count; i++ )
    789                 si.obj += alpha[i] * (G[i] + b[i]);
    790 
    791             si.obj *= 0.5;
    792 
    793             si.upper_bound_p = C[1];
    794             si.upper_bound_n = C[0];
    795 
    796             return true;
    797         }
    798 
    799         // return 1 if already optimal, return 0 otherwise
    800         bool select_working_set( int& out_i, int& out_j )
    801         {
    802             // return i,j which maximize -grad(f)^T d , under constraint
    803             // if alpha_i == C, d != +1
    804             // if alpha_i == 0, d != -1
    805             double Gmax1 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = +1 }
    806             int Gmax1_idx = -1;
    807 
    808             double Gmax2 = -DBL_MAX;        // max { -grad(f)_i * d | y_i*d = -1 }
    809             int Gmax2_idx = -1;
    810 
    811             const schar* y = &y_vec[0];
    812             const schar* alpha_status = &alpha_status_vec[0];
    813             const double* G = &G_vec[0];
    814 
    815             for( int i = 0; i < alpha_count; i++ )
    816             {
    817                 double t;
    818 
    819                 if( y[i] > 0 )    // y = +1
    820                 {
    821                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
    822                     {
    823                         Gmax1 = t;
    824                         Gmax1_idx = i;
    825                     }
    826                     if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
    827                     {
    828                         Gmax2 = t;
    829                         Gmax2_idx = i;
    830                     }
    831                 }
    832                 else        // y = -1
    833                 {
    834                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax2 )  // d = +1
    835                     {
    836                         Gmax2 = t;
    837                         Gmax2_idx = i;
    838                     }
    839                     if( !is_lower_bound(i) && (t = G[i]) > Gmax1 )  // d = -1
    840                     {
    841                         Gmax1 = t;
    842                         Gmax1_idx = i;
    843                     }
    844                 }
    845             }
    846 
    847             out_i = Gmax1_idx;
    848             out_j = Gmax2_idx;
    849 
    850             return Gmax1 + Gmax2 < eps;
    851         }
    852 
    853         void calc_rho( double& rho, double& r )
    854         {
    855             int nr_free = 0;
    856             double ub = DBL_MAX, lb = -DBL_MAX, sum_free = 0;
    857             const schar* y = &y_vec[0];
    858             const schar* alpha_status = &alpha_status_vec[0];
    859             const double* G = &G_vec[0];
    860 
    861             for( int i = 0; i < alpha_count; i++ )
    862             {
    863                 double yG = y[i]*G[i];
    864 
    865                 if( is_lower_bound(i) )
    866                 {
    867                     if( y[i] > 0 )
    868                         ub = MIN(ub,yG);
    869                     else
    870                         lb = MAX(lb,yG);
    871                 }
    872                 else if( is_upper_bound(i) )
    873                 {
    874                     if( y[i] < 0)
    875                         ub = MIN(ub,yG);
    876                     else
    877                         lb = MAX(lb,yG);
    878                 }
    879                 else
    880                 {
    881                     ++nr_free;
    882                     sum_free += yG;
    883                 }
    884             }
    885 
    886             rho = nr_free > 0 ? sum_free/nr_free : (ub + lb)*0.5;
    887             r = 0;
    888         }
    889 
    890         bool select_working_set_nu_svm( int& out_i, int& out_j )
    891         {
    892             // return i,j which maximize -grad(f)^T d , under constraint
    893             // if alpha_i == C, d != +1
    894             // if alpha_i == 0, d != -1
    895             double Gmax1 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = +1 }
    896             int Gmax1_idx = -1;
    897 
    898             double Gmax2 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = +1, d = -1 }
    899             int Gmax2_idx = -1;
    900 
    901             double Gmax3 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = +1 }
    902             int Gmax3_idx = -1;
    903 
    904             double Gmax4 = -DBL_MAX;    // max { -grad(f)_i * d | y_i = -1, d = -1 }
    905             int Gmax4_idx = -1;
    906 
    907             const schar* y = &y_vec[0];
    908             const schar* alpha_status = &alpha_status_vec[0];
    909             const double* G = &G_vec[0];
    910 
    911             for( int i = 0; i < alpha_count; i++ )
    912             {
    913                 double t;
    914 
    915                 if( y[i] > 0 )    // y == +1
    916                 {
    917                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax1 )  // d = +1
    918                     {
    919                         Gmax1 = t;
    920                         Gmax1_idx = i;
    921                     }
    922                     if( !is_lower_bound(i) && (t = G[i]) > Gmax2 )  // d = -1
    923                     {
    924                         Gmax2 = t;
    925                         Gmax2_idx = i;
    926                     }
    927                 }
    928                 else        // y == -1
    929                 {
    930                     if( !is_upper_bound(i) && (t = -G[i]) > Gmax3 )  // d = +1
    931                     {
    932                         Gmax3 = t;
    933                         Gmax3_idx = i;
    934                     }
    935                     if( !is_lower_bound(i) && (t = G[i]) > Gmax4 )  // d = -1
    936                     {
    937                         Gmax4 = t;
    938                         Gmax4_idx = i;
    939                     }
    940                 }
    941             }
    942 
    943             if( MAX(Gmax1 + Gmax2, Gmax3 + Gmax4) < eps )
    944                 return 1;
    945 
    946             if( Gmax1 + Gmax2 > Gmax3 + Gmax4 )
    947             {
    948                 out_i = Gmax1_idx;
    949                 out_j = Gmax2_idx;
    950             }
    951             else
    952             {
    953                 out_i = Gmax3_idx;
    954                 out_j = Gmax4_idx;
    955             }
    956             return 0;
    957         }
    958 
    959         void calc_rho_nu_svm( double& rho, double& r )
    960         {
    961             int nr_free1 = 0, nr_free2 = 0;
    962             double ub1 = DBL_MAX, ub2 = DBL_MAX;
    963             double lb1 = -DBL_MAX, lb2 = -DBL_MAX;
    964             double sum_free1 = 0, sum_free2 = 0;
    965 
    966             const schar* y = &y_vec[0];
    967             const schar* alpha_status = &alpha_status_vec[0];
    968             const double* G = &G_vec[0];
    969 
    970             for( int i = 0; i < alpha_count; i++ )
    971             {
    972                 double G_i = G[i];
    973                 if( y[i] > 0 )
    974                 {
    975                     if( is_lower_bound(i) )
    976                         ub1 = MIN( ub1, G_i );
    977                     else if( is_upper_bound(i) )
    978                         lb1 = MAX( lb1, G_i );
    979                     else
    980                     {
    981                         ++nr_free1;
    982                         sum_free1 += G_i;
    983                     }
    984                 }
    985                 else
    986                 {
    987                     if( is_lower_bound(i) )
    988                         ub2 = MIN( ub2, G_i );
    989                     else if( is_upper_bound(i) )
    990                         lb2 = MAX( lb2, G_i );
    991                     else
    992                     {
    993                         ++nr_free2;
    994                         sum_free2 += G_i;
    995                     }
    996                 }
    997             }
    998 
    999             double r1 = nr_free1 > 0 ? sum_free1/nr_free1 : (ub1 + lb1)*0.5;
   1000             double r2 = nr_free2 > 0 ? sum_free2/nr_free2 : (ub2 + lb2)*0.5;
   1001 
   1002             rho = (r1 - r2)*0.5;
   1003             r = (r1 + r2)*0.5;
   1004         }
   1005 
   1006         /*
   1007         ///////////////////////// construct and solve various formulations ///////////////////////
   1008         */
   1009         static bool solve_c_svc( const Mat& _samples, const vector<schar>& _y,
   1010                                  double _Cp, double _Cn, const Ptr<SVM::Kernel>& _kernel,
   1011                                  vector<double>& _alpha, SolutionInfo& _si, TermCriteria termCrit )
   1012         {
   1013             int sample_count = _samples.rows;
   1014 
   1015             _alpha.assign(sample_count, 0.);
   1016             vector<double> _b(sample_count, -1.);
   1017 
   1018             Solver solver( _samples, _y, _alpha, _b, _Cp, _Cn, _kernel,
   1019                            &Solver::get_row_svc,
   1020                            &Solver::select_working_set,
   1021                            &Solver::calc_rho,
   1022                            termCrit );
   1023 
   1024             if( !solver.solve_generic( _si ))
   1025                 return false;
   1026 
   1027             for( int i = 0; i < sample_count; i++ )
   1028                 _alpha[i] *= _y[i];
   1029 
   1030             return true;
   1031         }
   1032 
   1033 
   1034         static bool solve_nu_svc( const Mat& _samples, const vector<schar>& _y,
   1035                                   double nu, const Ptr<SVM::Kernel>& _kernel,
   1036                                   vector<double>& _alpha, SolutionInfo& _si,
   1037                                   TermCriteria termCrit )
   1038         {
   1039             int sample_count = _samples.rows;
   1040 
   1041             _alpha.resize(sample_count);
   1042             vector<double> _b(sample_count, 0.);
   1043 
   1044             double sum_pos = nu * sample_count * 0.5;
   1045             double sum_neg = nu * sample_count * 0.5;
   1046 
   1047             for( int i = 0; i < sample_count; i++ )
   1048             {
   1049                 double a;
   1050                 if( _y[i] > 0 )
   1051                 {
   1052                     a = std::min(1.0, sum_pos);
   1053                     sum_pos -= a;
   1054                 }
   1055                 else
   1056                 {
   1057                     a = std::min(1.0, sum_neg);
   1058                     sum_neg -= a;
   1059                 }
   1060                 _alpha[i] = a;
   1061             }
   1062 
   1063             Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
   1064                            &Solver::get_row_svc,
   1065                            &Solver::select_working_set_nu_svm,
   1066                            &Solver::calc_rho_nu_svm,
   1067                            termCrit );
   1068 
   1069             if( !solver.solve_generic( _si ))
   1070                 return false;
   1071 
   1072             double inv_r = 1./_si.r;
   1073 
   1074             for( int i = 0; i < sample_count; i++ )
   1075                 _alpha[i] *= _y[i]*inv_r;
   1076 
   1077             _si.rho *= inv_r;
   1078             _si.obj *= (inv_r*inv_r);
   1079             _si.upper_bound_p = inv_r;
   1080             _si.upper_bound_n = inv_r;
   1081 
   1082             return true;
   1083         }
   1084 
   1085         static bool solve_one_class( const Mat& _samples, double nu,
   1086                                      const Ptr<SVM::Kernel>& _kernel,
   1087                                      vector<double>& _alpha, SolutionInfo& _si,
   1088                                      TermCriteria termCrit )
   1089         {
   1090             int sample_count = _samples.rows;
   1091             vector<schar> _y(sample_count, 1);
   1092             vector<double> _b(sample_count, 0.);
   1093 
   1094             int i, n = cvRound( nu*sample_count );
   1095 
   1096             _alpha.resize(sample_count);
   1097             for( i = 0; i < sample_count; i++ )
   1098                 _alpha[i] = i < n ? 1 : 0;
   1099 
   1100             if( n < sample_count )
   1101                 _alpha[n] = nu * sample_count - n;
   1102             else
   1103                 _alpha[n-1] = nu * sample_count - (n-1);
   1104 
   1105             Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
   1106                            &Solver::get_row_one_class,
   1107                            &Solver::select_working_set,
   1108                            &Solver::calc_rho,
   1109                            termCrit );
   1110 
   1111             return solver.solve_generic(_si);
   1112         }
   1113 
   1114         static bool solve_eps_svr( const Mat& _samples, const vector<float>& _yf,
   1115                                    double p, double C, const Ptr<SVM::Kernel>& _kernel,
   1116                                    vector<double>& _alpha, SolutionInfo& _si,
   1117                                    TermCriteria termCrit )
   1118         {
   1119             int sample_count = _samples.rows;
   1120             int alpha_count = sample_count*2;
   1121 
   1122             CV_Assert( (int)_yf.size() == sample_count );
   1123 
   1124             _alpha.assign(alpha_count, 0.);
   1125             vector<schar> _y(alpha_count);
   1126             vector<double> _b(alpha_count);
   1127 
   1128             for( int i = 0; i < sample_count; i++ )
   1129             {
   1130                 _b[i] = p - _yf[i];
   1131                 _y[i] = 1;
   1132 
   1133                 _b[i+sample_count] = p + _yf[i];
   1134                 _y[i+sample_count] = -1;
   1135             }
   1136 
   1137             Solver solver( _samples, _y, _alpha, _b, C, C, _kernel,
   1138                            &Solver::get_row_svr,
   1139                            &Solver::select_working_set,
   1140                            &Solver::calc_rho,
   1141                            termCrit );
   1142 
   1143             if( !solver.solve_generic( _si ))
   1144                 return false;
   1145 
   1146             for( int i = 0; i < sample_count; i++ )
   1147                 _alpha[i] -= _alpha[i+sample_count];
   1148 
   1149             return true;
   1150         }
   1151 
   1152 
   1153         static bool solve_nu_svr( const Mat& _samples, const vector<float>& _yf,
   1154                                   double nu, double C, const Ptr<SVM::Kernel>& _kernel,
   1155                                   vector<double>& _alpha, SolutionInfo& _si,
   1156                                   TermCriteria termCrit )
   1157         {
   1158             int sample_count = _samples.rows;
   1159             int alpha_count = sample_count*2;
   1160             double sum = C * nu * sample_count * 0.5;
   1161 
   1162             CV_Assert( (int)_yf.size() == sample_count );
   1163 
   1164             _alpha.resize(alpha_count);
   1165             vector<schar> _y(alpha_count);
   1166             vector<double> _b(alpha_count);
   1167 
   1168             for( int i = 0; i < sample_count; i++ )
   1169             {
   1170                 _alpha[i] = _alpha[i + sample_count] = std::min(sum, C);
   1171                 sum -= _alpha[i];
   1172 
   1173                 _b[i] = -_yf[i];
   1174                 _y[i] = 1;
   1175 
   1176                 _b[i + sample_count] = _yf[i];
   1177                 _y[i + sample_count] = -1;
   1178             }
   1179 
   1180             Solver solver( _samples, _y, _alpha, _b, 1., 1., _kernel,
   1181                            &Solver::get_row_svr,
   1182                            &Solver::select_working_set_nu_svm,
   1183                            &Solver::calc_rho_nu_svm,
   1184                            termCrit );
   1185 
   1186             if( !solver.solve_generic( _si ))
   1187                 return false;
   1188 
   1189             for( int i = 0; i < sample_count; i++ )
   1190                 _alpha[i] -= _alpha[i+sample_count];
   1191 
   1192             return true;
   1193         }
   1194 
   1195         int sample_count;
   1196         int var_count;
   1197         int cache_size;
   1198         int max_cache_size;
   1199         Mat samples;
   1200         SvmParams params;
   1201         vector<KernelRow> lru_cache;
   1202         int lru_first;
   1203         int lru_last;
   1204         Mat lru_cache_data;
   1205 
   1206         int alpha_count;
   1207 
   1208         vector<double> G_vec;
   1209         vector<double>* alpha_vec;
   1210         vector<schar> y_vec;
   1211         // -1 - lower bound, 0 - free, 1 - upper bound
   1212         vector<schar> alpha_status_vec;
   1213         vector<double> b_vec;
   1214 
   1215         vector<Qfloat> buf[2];
   1216         double eps;
   1217         int max_iter;
   1218         double C[2];  // C[0] == Cn, C[1] == Cp
   1219         Ptr<SVM::Kernel> kernel;
   1220 
   1221         SelectWorkingSet select_working_set_func;
   1222         CalcRho calc_rho_func;
   1223         GetRow get_row_func;
   1224     };
   1225 
   1226     //////////////////////////////////////////////////////////////////////////////////////////
   1227     SVMImpl()
   1228     {
   1229         clear();
   1230         checkParams();
   1231     }
   1232 
   1233     ~SVMImpl()
   1234     {
   1235         clear();
   1236     }
   1237 
   1238     void clear()
   1239     {
   1240         decision_func.clear();
   1241         df_alpha.clear();
   1242         df_index.clear();
   1243         sv.release();
   1244     }
   1245 
   1246     Mat getSupportVectors() const
   1247     {
   1248         return sv;
   1249     }
   1250 
   1251     CV_IMPL_PROPERTY(int, Type, params.svmType)
   1252     CV_IMPL_PROPERTY(double, Gamma, params.gamma)
   1253     CV_IMPL_PROPERTY(double, Coef0, params.coef0)
   1254     CV_IMPL_PROPERTY(double, Degree, params.degree)
   1255     CV_IMPL_PROPERTY(double, C, params.C)
   1256     CV_IMPL_PROPERTY(double, Nu, params.nu)
   1257     CV_IMPL_PROPERTY(double, P, params.p)
   1258     CV_IMPL_PROPERTY_S(cv::Mat, ClassWeights, params.classWeights)
   1259     CV_IMPL_PROPERTY_S(cv::TermCriteria, TermCriteria, params.termCrit)
   1260 
   1261     int getKernelType() const
   1262     {
   1263         return params.kernelType;
   1264     }
   1265 
   1266     void setKernel(int kernelType)
   1267     {
   1268         params.kernelType = kernelType;
   1269         if (kernelType != CUSTOM)
   1270             kernel = makePtr<SVMKernelImpl>(params);
   1271     }
   1272 
   1273     void setCustomKernel(const Ptr<Kernel> &_kernel)
   1274     {
   1275         params.kernelType = CUSTOM;
   1276         kernel = _kernel;
   1277     }
   1278 
   1279     void checkParams()
   1280     {
   1281         int kernelType = params.kernelType;
   1282         if (kernelType != CUSTOM)
   1283         {
   1284             if( kernelType != LINEAR && kernelType != POLY &&
   1285                 kernelType != SIGMOID && kernelType != RBF &&
   1286                 kernelType != INTER && kernelType != CHI2)
   1287                 CV_Error( CV_StsBadArg, "Unknown/unsupported kernel type" );
   1288 
   1289             if( kernelType == LINEAR )
   1290                 params.gamma = 1;
   1291             else if( params.gamma <= 0 )
   1292                 CV_Error( CV_StsOutOfRange, "gamma parameter of the kernel must be positive" );
   1293 
   1294             if( kernelType != SIGMOID && kernelType != POLY )
   1295                 params.coef0 = 0;
   1296             else if( params.coef0 < 0 )
   1297                 CV_Error( CV_StsOutOfRange, "The kernel parameter <coef0> must be positive or zero" );
   1298 
   1299             if( kernelType != POLY )
   1300                 params.degree = 0;
   1301             else if( params.degree <= 0 )
   1302                 CV_Error( CV_StsOutOfRange, "The kernel parameter <degree> must be positive" );
   1303 
   1304             kernel = makePtr<SVMKernelImpl>(params);
   1305         }
   1306         else
   1307         {
   1308             if (!kernel)
   1309                 CV_Error( CV_StsBadArg, "Custom kernel is not set" );
   1310         }
   1311 
   1312         int svmType = params.svmType;
   1313 
   1314         if( svmType != C_SVC && svmType != NU_SVC &&
   1315             svmType != ONE_CLASS && svmType != EPS_SVR &&
   1316             svmType != NU_SVR )
   1317             CV_Error( CV_StsBadArg, "Unknown/unsupported SVM type" );
   1318 
   1319         if( svmType == ONE_CLASS || svmType == NU_SVC )
   1320             params.C = 0;
   1321         else if( params.C <= 0 )
   1322             CV_Error( CV_StsOutOfRange, "The parameter C must be positive" );
   1323 
   1324         if( svmType == C_SVC || svmType == EPS_SVR )
   1325             params.nu = 0;
   1326         else if( params.nu <= 0 || params.nu >= 1 )
   1327             CV_Error( CV_StsOutOfRange, "The parameter nu must be between 0 and 1" );
   1328 
   1329         if( svmType != EPS_SVR )
   1330             params.p = 0;
   1331         else if( params.p <= 0 )
   1332             CV_Error( CV_StsOutOfRange, "The parameter p must be positive" );
   1333 
   1334         if( svmType != C_SVC )
   1335             params.classWeights.release();
   1336 
   1337         if( !(params.termCrit.type & TermCriteria::EPS) )
   1338             params.termCrit.epsilon = DBL_EPSILON;
   1339         params.termCrit.epsilon = std::max(params.termCrit.epsilon, DBL_EPSILON);
   1340         if( !(params.termCrit.type & TermCriteria::COUNT) )
   1341             params.termCrit.maxCount = INT_MAX;
   1342         params.termCrit.maxCount = std::max(params.termCrit.maxCount, 1);
   1343     }
   1344 
   1345     void setParams( const SvmParams& _params)
   1346     {
   1347         params = _params;
   1348         checkParams();
   1349     }
   1350 
   1351     int getSVCount(int i) const
   1352     {
   1353         return (i < (int)(decision_func.size()-1) ? decision_func[i+1].ofs :
   1354                 (int)df_index.size()) - decision_func[i].ofs;
   1355     }
   1356 
   1357     bool do_train( const Mat& _samples, const Mat& _responses )
   1358     {
   1359         int svmType = params.svmType;
   1360         int i, j, k, sample_count = _samples.rows;
   1361         vector<double> _alpha;
   1362         Solver::SolutionInfo sinfo;
   1363 
   1364         CV_Assert( _samples.type() == CV_32F );
   1365         var_count = _samples.cols;
   1366 
   1367         if( svmType == ONE_CLASS || svmType == EPS_SVR || svmType == NU_SVR )
   1368         {
   1369             int sv_count = 0;
   1370             decision_func.clear();
   1371 
   1372             vector<float> _yf;
   1373             if( !_responses.empty() )
   1374                 _responses.convertTo(_yf, CV_32F);
   1375 
   1376             bool ok =
   1377             svmType == ONE_CLASS ? Solver::solve_one_class( _samples, params.nu, kernel, _alpha, sinfo, params.termCrit ) :
   1378             svmType == EPS_SVR ? Solver::solve_eps_svr( _samples, _yf, params.p, params.C, kernel, _alpha, sinfo, params.termCrit ) :
   1379             svmType == NU_SVR ? Solver::solve_nu_svr( _samples, _yf, params.nu, params.C, kernel, _alpha, sinfo, params.termCrit ) : false;
   1380 
   1381             if( !ok )
   1382                 return false;
   1383 
   1384             for( i = 0; i < sample_count; i++ )
   1385                 sv_count += fabs(_alpha[i]) > 0;
   1386 
   1387             CV_Assert(sv_count != 0);
   1388 
   1389             sv.create(sv_count, _samples.cols, CV_32F);
   1390             df_alpha.resize(sv_count);
   1391             df_index.resize(sv_count);
   1392 
   1393             for( i = k = 0; i < sample_count; i++ )
   1394             {
   1395                 if( std::abs(_alpha[i]) > 0 )
   1396                 {
   1397                     _samples.row(i).copyTo(sv.row(k));
   1398                     df_alpha[k] = _alpha[i];
   1399                     df_index[k] = k;
   1400                     k++;
   1401                 }
   1402             }
   1403 
   1404             decision_func.push_back(DecisionFunc(sinfo.rho, 0));
   1405         }
   1406         else
   1407         {
   1408             int class_count = (int)class_labels.total();
   1409             vector<int> svidx, sidx, sidx_all, sv_tab(sample_count, 0);
   1410             Mat temp_samples, class_weights;
   1411             vector<int> class_ranges;
   1412             vector<schar> temp_y;
   1413             double nu = params.nu;
   1414             CV_Assert( svmType == C_SVC || svmType == NU_SVC );
   1415 
   1416             if( svmType == C_SVC && !params.classWeights.empty() )
   1417             {
   1418                 const Mat cw = params.classWeights;
   1419 
   1420                 if( (cw.cols != 1 && cw.rows != 1) ||
   1421                     (int)cw.total() != class_count ||
   1422                     (cw.type() != CV_32F && cw.type() != CV_64F) )
   1423                     CV_Error( CV_StsBadArg, "params.class_weights must be 1d floating-point vector "
   1424                         "containing as many elements as the number of classes" );
   1425 
   1426                 cw.convertTo(class_weights, CV_64F, params.C);
   1427                 //normalize(cw, class_weights, params.C, 0, NORM_L1, CV_64F);
   1428             }
   1429 
   1430             decision_func.clear();
   1431             df_alpha.clear();
   1432             df_index.clear();
   1433 
   1434             sortSamplesByClasses( _samples, _responses, sidx_all, class_ranges );
   1435 
   1436             //check that while cross-validation there were the samples from all the classes
   1437             if( class_ranges[class_count] <= 0 )
   1438                 CV_Error( CV_StsBadArg, "While cross-validation one or more of the classes have "
   1439                 "been fell out of the sample. Try to enlarge <Params::k_fold>" );
   1440 
   1441             if( svmType == NU_SVC )
   1442             {
   1443                 // check if nu is feasible
   1444                 for( i = 0; i < class_count; i++ )
   1445                 {
   1446                     int ci = class_ranges[i+1] - class_ranges[i];
   1447                     for( j = i+1; j< class_count; j++ )
   1448                     {
   1449                         int cj = class_ranges[j+1] - class_ranges[j];
   1450                         if( nu*(ci + cj)*0.5 > std::min( ci, cj ) )
   1451                             // TODO: add some diagnostic
   1452                             return false;
   1453                     }
   1454                 }
   1455             }
   1456 
   1457             size_t samplesize = _samples.cols*_samples.elemSize();
   1458 
   1459             // train n*(n-1)/2 classifiers
   1460             for( i = 0; i < class_count; i++ )
   1461             {
   1462                 for( j = i+1; j < class_count; j++ )
   1463                 {
   1464                     int si = class_ranges[i], ci = class_ranges[i+1] - si;
   1465                     int sj = class_ranges[j], cj = class_ranges[j+1] - sj;
   1466                     double Cp = params.C, Cn = Cp;
   1467 
   1468                     temp_samples.create(ci + cj, _samples.cols, _samples.type());
   1469                     sidx.resize(ci + cj);
   1470                     temp_y.resize(ci + cj);
   1471 
   1472                     // form input for the binary classification problem
   1473                     for( k = 0; k < ci+cj; k++ )
   1474                     {
   1475                         int idx = k < ci ? si+k : sj+k-ci;
   1476                         memcpy(temp_samples.ptr(k), _samples.ptr(sidx_all[idx]), samplesize);
   1477                         sidx[k] = sidx_all[idx];
   1478                         temp_y[k] = k < ci ? 1 : -1;
   1479                     }
   1480 
   1481                     if( !class_weights.empty() )
   1482                     {
   1483                         Cp = class_weights.at<double>(i);
   1484                         Cn = class_weights.at<double>(j);
   1485                     }
   1486 
   1487                     DecisionFunc df;
   1488                     bool ok = params.svmType == C_SVC ?
   1489                                 Solver::solve_c_svc( temp_samples, temp_y, Cp, Cn,
   1490                                                      kernel, _alpha, sinfo, params.termCrit ) :
   1491                               params.svmType == NU_SVC ?
   1492                                 Solver::solve_nu_svc( temp_samples, temp_y, params.nu,
   1493                                                       kernel, _alpha, sinfo, params.termCrit ) :
   1494                               false;
   1495                     if( !ok )
   1496                         return false;
   1497                     df.rho = sinfo.rho;
   1498                     df.ofs = (int)df_index.size();
   1499                     decision_func.push_back(df);
   1500 
   1501                     for( k = 0; k < ci + cj; k++ )
   1502                     {
   1503                         if( std::abs(_alpha[k]) > 0 )
   1504                         {
   1505                             int idx = k < ci ? si+k : sj+k-ci;
   1506                             sv_tab[sidx_all[idx]] = 1;
   1507                             df_index.push_back(sidx_all[idx]);
   1508                             df_alpha.push_back(_alpha[k]);
   1509                         }
   1510                     }
   1511                 }
   1512             }
   1513 
   1514             // allocate support vectors and initialize sv_tab
   1515             for( i = 0, k = 0; i < sample_count; i++ )
   1516             {
   1517                 if( sv_tab[i] )
   1518                     sv_tab[i] = ++k;
   1519             }
   1520 
   1521             int sv_total = k;
   1522             sv.create(sv_total, _samples.cols, _samples.type());
   1523 
   1524             for( i = 0; i < sample_count; i++ )
   1525             {
   1526                 if( !sv_tab[i] )
   1527                     continue;
   1528                 memcpy(sv.ptr(sv_tab[i]-1), _samples.ptr(i), samplesize);
   1529             }
   1530 
   1531             // set sv pointers
   1532             int n = (int)df_index.size();
   1533             for( i = 0; i < n; i++ )
   1534             {
   1535                 CV_Assert( sv_tab[df_index[i]] > 0 );
   1536                 df_index[i] = sv_tab[df_index[i]] - 1;
   1537             }
   1538         }
   1539 
   1540         optimize_linear_svm();
   1541         return true;
   1542     }
   1543 
   1544     void optimize_linear_svm()
   1545     {
   1546         // we optimize only linear SVM: compress all the support vectors into one.
   1547         if( params.kernelType != LINEAR )
   1548             return;
   1549 
   1550         int i, df_count = (int)decision_func.size();
   1551 
   1552         for( i = 0; i < df_count; i++ )
   1553         {
   1554             if( getSVCount(i) != 1 )
   1555                 break;
   1556         }
   1557 
   1558         // if every decision functions uses a single support vector;
   1559         // it's already compressed. skip it then.
   1560         if( i == df_count )
   1561             return;
   1562 
   1563         AutoBuffer<double> vbuf(var_count);
   1564         double* v = vbuf;
   1565         Mat new_sv(df_count, var_count, CV_32F);
   1566 
   1567         vector<DecisionFunc> new_df;
   1568 
   1569         for( i = 0; i < df_count; i++ )
   1570         {
   1571             float* dst = new_sv.ptr<float>(i);
   1572             memset(v, 0, var_count*sizeof(v[0]));
   1573             int j, k, sv_count = getSVCount(i);
   1574             const DecisionFunc& df = decision_func[i];
   1575             const int* sv_index = &df_index[df.ofs];
   1576             const double* sv_alpha = &df_alpha[df.ofs];
   1577             for( j = 0; j < sv_count; j++ )
   1578             {
   1579                 const float* src = sv.ptr<float>(sv_index[j]);
   1580                 double a = sv_alpha[j];
   1581                 for( k = 0; k < var_count; k++ )
   1582                     v[k] += src[k]*a;
   1583             }
   1584             for( k = 0; k < var_count; k++ )
   1585                 dst[k] = (float)v[k];
   1586             new_df.push_back(DecisionFunc(df.rho, i));
   1587         }
   1588 
   1589         setRangeVector(df_index, df_count);
   1590         df_alpha.assign(df_count, 1.);
   1591         std::swap(sv, new_sv);
   1592         std::swap(decision_func, new_df);
   1593     }
   1594 
   1595     bool train( const Ptr<TrainData>& data, int )
   1596     {
   1597         clear();
   1598 
   1599         checkParams();
   1600 
   1601         int svmType = params.svmType;
   1602         Mat samples = data->getTrainSamples();
   1603         Mat responses;
   1604 
   1605         if( svmType == C_SVC || svmType == NU_SVC )
   1606         {
   1607             responses = data->getTrainNormCatResponses();
   1608             if( responses.empty() )
   1609                 CV_Error(CV_StsBadArg, "in the case of classification problem the responses must be categorical; "
   1610                                        "either specify varType when creating TrainData, or pass integer responses");
   1611             class_labels = data->getClassLabels();
   1612         }
   1613         else
   1614             responses = data->getTrainResponses();
   1615 
   1616         if( !do_train( samples, responses ))
   1617         {
   1618             clear();
   1619             return false;
   1620         }
   1621 
   1622         return true;
   1623     }
   1624 
   1625     bool trainAuto( const Ptr<TrainData>& data, int k_fold,
   1626                     ParamGrid C_grid, ParamGrid gamma_grid, ParamGrid p_grid,
   1627                     ParamGrid nu_grid, ParamGrid coef_grid, ParamGrid degree_grid,
   1628                     bool balanced )
   1629     {
   1630         checkParams();
   1631 
   1632         int svmType = params.svmType;
   1633         RNG rng((uint64)-1);
   1634 
   1635         if( svmType == ONE_CLASS )
   1636             // current implementation of "auto" svm does not support the 1-class case.
   1637             return train( data, 0 );
   1638 
   1639         clear();
   1640 
   1641         CV_Assert( k_fold >= 2 );
   1642 
   1643         // All the parameters except, possibly, <coef0> are positive.
   1644         // <coef0> is nonnegative
   1645         #define CHECK_GRID(grid, param) \
   1646         if( grid.logStep <= 1 ) \
   1647         { \
   1648             grid.minVal = grid.maxVal = params.param; \
   1649             grid.logStep = 10; \
   1650         } \
   1651         else \
   1652             checkParamGrid(grid)
   1653 
   1654         CHECK_GRID(C_grid, C);
   1655         CHECK_GRID(gamma_grid, gamma);
   1656         CHECK_GRID(p_grid, p);
   1657         CHECK_GRID(nu_grid, nu);
   1658         CHECK_GRID(coef_grid, coef0);
   1659         CHECK_GRID(degree_grid, degree);
   1660 
   1661         // these parameters are not used:
   1662         if( params.kernelType != POLY )
   1663             degree_grid.minVal = degree_grid.maxVal = params.degree;
   1664         if( params.kernelType == LINEAR )
   1665             gamma_grid.minVal = gamma_grid.maxVal = params.gamma;
   1666         if( params.kernelType != POLY && params.kernelType != SIGMOID )
   1667             coef_grid.minVal = coef_grid.maxVal = params.coef0;
   1668         if( svmType == NU_SVC || svmType == ONE_CLASS )
   1669             C_grid.minVal = C_grid.maxVal = params.C;
   1670         if( svmType == C_SVC || svmType == EPS_SVR )
   1671             nu_grid.minVal = nu_grid.maxVal = params.nu;
   1672         if( svmType != EPS_SVR )
   1673             p_grid.minVal = p_grid.maxVal = params.p;
   1674 
   1675         Mat samples = data->getTrainSamples();
   1676         Mat responses;
   1677         bool is_classification = false;
   1678         int class_count = (int)class_labels.total();
   1679 
   1680         if( svmType == C_SVC || svmType == NU_SVC )
   1681         {
   1682             responses = data->getTrainNormCatResponses();
   1683             class_labels = data->getClassLabels();
   1684             class_count = (int)class_labels.total();
   1685             is_classification = true;
   1686 
   1687             vector<int> temp_class_labels;
   1688             setRangeVector(temp_class_labels, class_count);
   1689 
   1690             // temporarily replace class labels with 0, 1, ..., NCLASSES-1
   1691             Mat(temp_class_labels).copyTo(class_labels);
   1692         }
   1693         else
   1694             responses = data->getTrainResponses();
   1695 
   1696         CV_Assert(samples.type() == CV_32F);
   1697 
   1698         int sample_count = samples.rows;
   1699         var_count = samples.cols;
   1700         size_t sample_size = var_count*samples.elemSize();
   1701 
   1702         vector<int> sidx;
   1703         setRangeVector(sidx, sample_count);
   1704 
   1705         int i, j, k;
   1706 
   1707         // randomly permute training samples
   1708         for( i = 0; i < sample_count; i++ )
   1709         {
   1710             int i1 = rng.uniform(0, sample_count);
   1711             int i2 = rng.uniform(0, sample_count);
   1712             std::swap(sidx[i1], sidx[i2]);
   1713         }
   1714 
   1715         if( is_classification && class_count == 2 && balanced )
   1716         {
   1717             // reshuffle the training set in such a way that
   1718             // instances of each class are divided more or less evenly
   1719             // between the k_fold parts.
   1720             vector<int> sidx0, sidx1;
   1721 
   1722             for( i = 0; i < sample_count; i++ )
   1723             {
   1724                 if( responses.at<int>(sidx[i]) == 0 )
   1725                     sidx0.push_back(sidx[i]);
   1726                 else
   1727                     sidx1.push_back(sidx[i]);
   1728             }
   1729 
   1730             int n0 = (int)sidx0.size(), n1 = (int)sidx1.size();
   1731             int a0 = 0, a1 = 0;
   1732             sidx.clear();
   1733             for( k = 0; k < k_fold; k++ )
   1734             {
   1735                 int b0 = ((k+1)*n0 + k_fold/2)/k_fold, b1 = ((k+1)*n1 + k_fold/2)/k_fold;
   1736                 int a = (int)sidx.size(), b = a + (b0 - a0) + (b1 - a1);
   1737                 for( i = a0; i < b0; i++ )
   1738                     sidx.push_back(sidx0[i]);
   1739                 for( i = a1; i < b1; i++ )
   1740                     sidx.push_back(sidx1[i]);
   1741                 for( i = 0; i < (b - a); i++ )
   1742                 {
   1743                     int i1 = rng.uniform(a, b);
   1744                     int i2 = rng.uniform(a, b);
   1745                     std::swap(sidx[i1], sidx[i2]);
   1746                 }
   1747                 a0 = b0; a1 = b1;
   1748             }
   1749         }
   1750 
   1751         int test_sample_count = (sample_count + k_fold/2)/k_fold;
   1752         int train_sample_count = sample_count - test_sample_count;
   1753 
   1754         SvmParams best_params = params;
   1755         double min_error = FLT_MAX;
   1756 
   1757         int rtype = responses.type();
   1758 
   1759         Mat temp_train_samples(train_sample_count, var_count, CV_32F);
   1760         Mat temp_test_samples(test_sample_count, var_count, CV_32F);
   1761         Mat temp_train_responses(train_sample_count, 1, rtype);
   1762         Mat temp_test_responses;
   1763 
   1764         // If grid.minVal == grid.maxVal, this will allow one and only one pass through the loop with params.var = grid.minVal.
   1765         #define FOR_IN_GRID(var, grid) \
   1766             for( params.var = grid.minVal; params.var == grid.minVal || params.var < grid.maxVal; params.var = (grid.minVal == grid.maxVal) ? grid.maxVal + 1 : params.var * grid.logStep )
   1767 
   1768         FOR_IN_GRID(C, C_grid)
   1769         FOR_IN_GRID(gamma, gamma_grid)
   1770         FOR_IN_GRID(p, p_grid)
   1771         FOR_IN_GRID(nu, nu_grid)
   1772         FOR_IN_GRID(coef0, coef_grid)
   1773         FOR_IN_GRID(degree, degree_grid)
   1774         {
   1775             // make sure we updated the kernel and other parameters
   1776             setParams(params);
   1777 
   1778             double error = 0;
   1779             for( k = 0; k < k_fold; k++ )
   1780             {
   1781                 int start = (k*sample_count + k_fold/2)/k_fold;
   1782                 for( i = 0; i < train_sample_count; i++ )
   1783                 {
   1784                     j = sidx[(i+start)%sample_count];
   1785                     memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
   1786                     if( is_classification )
   1787                         temp_train_responses.at<int>(i) = responses.at<int>(j);
   1788                     else if( !responses.empty() )
   1789                         temp_train_responses.at<float>(i) = responses.at<float>(j);
   1790                 }
   1791 
   1792                 // Train SVM on <train_size> samples
   1793                 if( !do_train( temp_train_samples, temp_train_responses ))
   1794                     continue;
   1795 
   1796                 for( i = 0; i < train_sample_count; i++ )
   1797                 {
   1798                     j = sidx[(i+start+train_sample_count) % sample_count];
   1799                     memcpy(temp_train_samples.ptr(i), samples.ptr(j), sample_size);
   1800                 }
   1801 
   1802                 predict(temp_test_samples, temp_test_responses, 0);
   1803                 for( i = 0; i < test_sample_count; i++ )
   1804                 {
   1805                     float val = temp_test_responses.at<float>(i);
   1806                     j = sidx[(i+start+train_sample_count) % sample_count];
   1807                     if( is_classification )
   1808                         error += (float)(val != responses.at<int>(j));
   1809                     else
   1810                     {
   1811                         val -= responses.at<float>(j);
   1812                         error += val*val;
   1813                     }
   1814                 }
   1815             }
   1816             if( min_error > error )
   1817             {
   1818                 min_error   = error;
   1819                 best_params = params;
   1820             }
   1821         }
   1822 
   1823         params = best_params;
   1824         return do_train( samples, responses );
   1825     }
   1826 
   1827     struct PredictBody : ParallelLoopBody
   1828     {
   1829         PredictBody( const SVMImpl* _svm, const Mat& _samples, Mat& _results, bool _returnDFVal )
   1830         {
   1831             svm = _svm;
   1832             results = &_results;
   1833             samples = &_samples;
   1834             returnDFVal = _returnDFVal;
   1835         }
   1836 
   1837         void operator()( const Range& range ) const
   1838         {
   1839             int svmType = svm->params.svmType;
   1840             int sv_total = svm->sv.rows;
   1841             int class_count = !svm->class_labels.empty() ? (int)svm->class_labels.total() : svmType == ONE_CLASS ? 1 : 0;
   1842 
   1843             AutoBuffer<float> _buffer(sv_total + (class_count+1)*2);
   1844             float* buffer = _buffer;
   1845 
   1846             int i, j, dfi, k, si;
   1847 
   1848             if( svmType == EPS_SVR || svmType == NU_SVR || svmType == ONE_CLASS )
   1849             {
   1850                 for( si = range.start; si < range.end; si++ )
   1851                 {
   1852                     const float* row_sample = samples->ptr<float>(si);
   1853                     svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(), row_sample, buffer );
   1854 
   1855                     const SVMImpl::DecisionFunc* df = &svm->decision_func[0];
   1856                     double sum = -df->rho;
   1857                     for( i = 0; i < sv_total; i++ )
   1858                         sum += buffer[i]*svm->df_alpha[i];
   1859                     float result = svm->params.svmType == ONE_CLASS && !returnDFVal ? (float)(sum > 0) : (float)sum;
   1860                     results->at<float>(si) = result;
   1861                 }
   1862             }
   1863             else if( svmType == C_SVC || svmType == NU_SVC )
   1864             {
   1865                 int* vote = (int*)(buffer + sv_total);
   1866 
   1867                 for( si = range.start; si < range.end; si++ )
   1868                 {
   1869                     svm->kernel->calc( sv_total, svm->var_count, svm->sv.ptr<float>(),
   1870                                        samples->ptr<float>(si), buffer );
   1871                     double sum = 0.;
   1872 
   1873                     memset( vote, 0, class_count*sizeof(vote[0]));
   1874 
   1875                     for( i = dfi = 0; i < class_count; i++ )
   1876                     {
   1877                         for( j = i+1; j < class_count; j++, dfi++ )
   1878                         {
   1879                             const DecisionFunc& df = svm->decision_func[dfi];
   1880                             sum = -df.rho;
   1881                             int sv_count = svm->getSVCount(dfi);
   1882                             const double* alpha = &svm->df_alpha[df.ofs];
   1883                             const int* sv_index = &svm->df_index[df.ofs];
   1884                             for( k = 0; k < sv_count; k++ )
   1885                                 sum += alpha[k]*buffer[sv_index[k]];
   1886 
   1887                             vote[sum > 0 ? i : j]++;
   1888                         }
   1889                     }
   1890 
   1891                     for( i = 1, k = 0; i < class_count; i++ )
   1892                     {
   1893                         if( vote[i] > vote[k] )
   1894                             k = i;
   1895                     }
   1896                     float result = returnDFVal && class_count == 2 ?
   1897                         (float)sum : (float)(svm->class_labels.at<int>(k));
   1898                     results->at<float>(si) = result;
   1899                 }
   1900             }
   1901             else
   1902                 CV_Error( CV_StsBadArg, "INTERNAL ERROR: Unknown SVM type, "
   1903                          "the SVM structure is probably corrupted" );
   1904         }
   1905 
   1906         const SVMImpl* svm;
   1907         const Mat* samples;
   1908         Mat* results;
   1909         bool returnDFVal;
   1910     };
   1911 
   1912     float predict( InputArray _samples, OutputArray _results, int flags ) const
   1913     {
   1914         float result = 0;
   1915         Mat samples = _samples.getMat(), results;
   1916         int nsamples = samples.rows;
   1917         bool returnDFVal = (flags & RAW_OUTPUT) != 0;
   1918 
   1919         CV_Assert( samples.cols == var_count && samples.type() == CV_32F );
   1920 
   1921         if( _results.needed() )
   1922         {
   1923             _results.create( nsamples, 1, samples.type() );
   1924             results = _results.getMat();
   1925         }
   1926         else
   1927         {
   1928             CV_Assert( nsamples == 1 );
   1929             results = Mat(1, 1, CV_32F, &result);
   1930         }
   1931 
   1932         PredictBody invoker(this, samples, results, returnDFVal);
   1933         if( nsamples < 10 )
   1934             invoker(Range(0, nsamples));
   1935         else
   1936             parallel_for_(Range(0, nsamples), invoker);
   1937         return result;
   1938     }
   1939 
   1940     double getDecisionFunction(int i, OutputArray _alpha, OutputArray _svidx ) const
   1941     {
   1942         CV_Assert( 0 <= i && i < (int)decision_func.size());
   1943         const DecisionFunc& df = decision_func[i];
   1944         int count = getSVCount(i);
   1945         Mat(1, count, CV_64F, (double*)&df_alpha[df.ofs]).copyTo(_alpha);
   1946         Mat(1, count, CV_32S, (int*)&df_index[df.ofs]).copyTo(_svidx);
   1947         return df.rho;
   1948     }
   1949 
   1950     void write_params( FileStorage& fs ) const
   1951     {
   1952         int svmType = params.svmType;
   1953         int kernelType = params.kernelType;
   1954 
   1955         String svm_type_str =
   1956             svmType == C_SVC ? "C_SVC" :
   1957             svmType == NU_SVC ? "NU_SVC" :
   1958             svmType == ONE_CLASS ? "ONE_CLASS" :
   1959             svmType == EPS_SVR ? "EPS_SVR" :
   1960             svmType == NU_SVR ? "NU_SVR" : format("Uknown_%d", svmType);
   1961         String kernel_type_str =
   1962             kernelType == LINEAR ? "LINEAR" :
   1963             kernelType == POLY ? "POLY" :
   1964             kernelType == RBF ? "RBF" :
   1965             kernelType == SIGMOID ? "SIGMOID" :
   1966             kernelType == CHI2 ? "CHI2" :
   1967             kernelType == INTER ? "INTER" : format("Unknown_%d", kernelType);
   1968 
   1969         fs << "svmType" << svm_type_str;
   1970 
   1971         // save kernel
   1972         fs << "kernel" << "{" << "type" << kernel_type_str;
   1973 
   1974         if( kernelType == POLY )
   1975             fs << "degree" << params.degree;
   1976 
   1977         if( kernelType != LINEAR )
   1978             fs << "gamma" << params.gamma;
   1979 
   1980         if( kernelType == POLY || kernelType == SIGMOID )
   1981             fs << "coef0" << params.coef0;
   1982 
   1983         fs << "}";
   1984 
   1985         if( svmType == C_SVC || svmType == EPS_SVR || svmType == NU_SVR )
   1986             fs << "C" << params.C;
   1987 
   1988         if( svmType == NU_SVC || svmType == ONE_CLASS || svmType == NU_SVR )
   1989             fs << "nu" << params.nu;
   1990 
   1991         if( svmType == EPS_SVR )
   1992             fs << "p" << params.p;
   1993 
   1994         fs << "term_criteria" << "{:";
   1995         if( params.termCrit.type & TermCriteria::EPS )
   1996             fs << "epsilon" << params.termCrit.epsilon;
   1997         if( params.termCrit.type & TermCriteria::COUNT )
   1998             fs << "iterations" << params.termCrit.maxCount;
   1999         fs << "}";
   2000     }
   2001 
   2002     bool isTrained() const
   2003     {
   2004         return !sv.empty();
   2005     }
   2006 
   2007     bool isClassifier() const
   2008     {
   2009         return params.svmType == C_SVC || params.svmType == NU_SVC || params.svmType == ONE_CLASS;
   2010     }
   2011 
   2012     int getVarCount() const
   2013     {
   2014         return var_count;
   2015     }
   2016 
   2017     String getDefaultName() const
   2018     {
   2019         return "opencv_ml_svm";
   2020     }
   2021 
   2022     void write( FileStorage& fs ) const
   2023     {
   2024         int class_count = !class_labels.empty() ? (int)class_labels.total() :
   2025                           params.svmType == ONE_CLASS ? 1 : 0;
   2026         if( !isTrained() )
   2027             CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
   2028 
   2029         write_params( fs );
   2030 
   2031         fs << "var_count" << var_count;
   2032 
   2033         if( class_count > 0 )
   2034         {
   2035             fs << "class_count" << class_count;
   2036 
   2037             if( !class_labels.empty() )
   2038                 fs << "class_labels" << class_labels;
   2039 
   2040             if( !params.classWeights.empty() )
   2041                 fs << "class_weights" << params.classWeights;
   2042         }
   2043 
   2044         // write the joint collection of support vectors
   2045         int i, sv_total = sv.rows;
   2046         fs << "sv_total" << sv_total;
   2047         fs << "support_vectors" << "[";
   2048         for( i = 0; i < sv_total; i++ )
   2049         {
   2050             fs << "[:";
   2051             fs.writeRaw("f", sv.ptr(i), sv.cols*sv.elemSize());
   2052             fs << "]";
   2053         }
   2054         fs << "]";
   2055 
   2056         // write decision functions
   2057         int df_count = (int)decision_func.size();
   2058 
   2059         fs << "decision_functions" << "[";
   2060         for( i = 0; i < df_count; i++ )
   2061         {
   2062             const DecisionFunc& df = decision_func[i];
   2063             int sv_count = getSVCount(i);
   2064             fs << "{" << "sv_count" << sv_count
   2065                << "rho" << df.rho
   2066                << "alpha" << "[:";
   2067             fs.writeRaw("d", (const uchar*)&df_alpha[df.ofs], sv_count*sizeof(df_alpha[0]));
   2068             fs << "]";
   2069             if( class_count > 2 )
   2070             {
   2071                 fs << "index" << "[:";
   2072                 fs.writeRaw("i", (const uchar*)&df_index[df.ofs], sv_count*sizeof(df_index[0]));
   2073                 fs << "]";
   2074             }
   2075             else
   2076                 CV_Assert( sv_count == sv_total );
   2077             fs << "}";
   2078         }
   2079         fs << "]";
   2080     }
   2081 
   2082     void read_params( const FileNode& fn )
   2083     {
   2084         SvmParams _params;
   2085 
   2086         // check for old naming
   2087         String svm_type_str = (String)(fn["svm_type"].empty() ? fn["svmType"] : fn["svm_type"]);
   2088         int svmType =
   2089             svm_type_str == "C_SVC" ? C_SVC :
   2090             svm_type_str == "NU_SVC" ? NU_SVC :
   2091             svm_type_str == "ONE_CLASS" ? ONE_CLASS :
   2092             svm_type_str == "EPS_SVR" ? EPS_SVR :
   2093             svm_type_str == "NU_SVR" ? NU_SVR : -1;
   2094 
   2095         if( svmType < 0 )
   2096             CV_Error( CV_StsParseError, "Missing of invalid SVM type" );
   2097 
   2098         FileNode kernel_node = fn["kernel"];
   2099         if( kernel_node.empty() )
   2100             CV_Error( CV_StsParseError, "SVM kernel tag is not found" );
   2101 
   2102         String kernel_type_str = (String)kernel_node["type"];
   2103         int kernelType =
   2104             kernel_type_str == "LINEAR" ? LINEAR :
   2105             kernel_type_str == "POLY" ? POLY :
   2106             kernel_type_str == "RBF" ? RBF :
   2107             kernel_type_str == "SIGMOID" ? SIGMOID :
   2108             kernel_type_str == "CHI2" ? CHI2 :
   2109             kernel_type_str == "INTER" ? INTER : CUSTOM;
   2110 
   2111         if( kernelType == CUSTOM )
   2112             CV_Error( CV_StsParseError, "Invalid SVM kernel type (or custom kernel)" );
   2113 
   2114         _params.svmType = svmType;
   2115         _params.kernelType = kernelType;
   2116         _params.degree = (double)kernel_node["degree"];
   2117         _params.gamma = (double)kernel_node["gamma"];
   2118         _params.coef0 = (double)kernel_node["coef0"];
   2119 
   2120         _params.C = (double)fn["C"];
   2121         _params.nu = (double)fn["nu"];
   2122         _params.p = (double)fn["p"];
   2123         _params.classWeights = Mat();
   2124 
   2125         FileNode tcnode = fn["term_criteria"];
   2126         if( !tcnode.empty() )
   2127         {
   2128             _params.termCrit.epsilon = (double)tcnode["epsilon"];
   2129             _params.termCrit.maxCount = (int)tcnode["iterations"];
   2130             _params.termCrit.type = (_params.termCrit.epsilon > 0 ? TermCriteria::EPS : 0) +
   2131                                    (_params.termCrit.maxCount > 0 ? TermCriteria::COUNT : 0);
   2132         }
   2133         else
   2134             _params.termCrit = TermCriteria( TermCriteria::EPS + TermCriteria::COUNT, 1000, FLT_EPSILON );
   2135 
   2136         setParams( _params );
   2137     }
   2138 
   2139     void read( const FileNode& fn )
   2140     {
   2141         clear();
   2142 
   2143         // read SVM parameters
   2144         read_params( fn );
   2145 
   2146         // and top-level data
   2147         int i, sv_total = (int)fn["sv_total"];
   2148         var_count = (int)fn["var_count"];
   2149         int class_count = (int)fn["class_count"];
   2150 
   2151         if( sv_total <= 0 || var_count <= 0 )
   2152             CV_Error( CV_StsParseError, "SVM model data is invalid, check sv_count, var_* and class_count tags" );
   2153 
   2154         FileNode m = fn["class_labels"];
   2155         if( !m.empty() )
   2156             m >> class_labels;
   2157         m = fn["class_weights"];
   2158         if( !m.empty() )
   2159             m >> params.classWeights;
   2160 
   2161         if( class_count > 1 && (class_labels.empty() || (int)class_labels.total() != class_count))
   2162             CV_Error( CV_StsParseError, "Array of class labels is missing or invalid" );
   2163 
   2164         // read support vectors
   2165         FileNode sv_node = fn["support_vectors"];
   2166 
   2167         CV_Assert((int)sv_node.size() == sv_total);
   2168         sv.create(sv_total, var_count, CV_32F);
   2169 
   2170         FileNodeIterator sv_it = sv_node.begin();
   2171         for( i = 0; i < sv_total; i++, ++sv_it )
   2172         {
   2173             (*sv_it).readRaw("f", sv.ptr(i), var_count*sv.elemSize());
   2174         }
   2175 
   2176         // read decision functions
   2177         int df_count = class_count > 1 ? class_count*(class_count-1)/2 : 1;
   2178         FileNode df_node = fn["decision_functions"];
   2179 
   2180         CV_Assert((int)df_node.size() == df_count);
   2181 
   2182         FileNodeIterator df_it = df_node.begin();
   2183         for( i = 0; i < df_count; i++, ++df_it )
   2184         {
   2185             FileNode dfi = *df_it;
   2186             DecisionFunc df;
   2187             int sv_count = (int)dfi["sv_count"];
   2188             int ofs = (int)df_index.size();
   2189             df.rho = (double)dfi["rho"];
   2190             df.ofs = ofs;
   2191             df_index.resize(ofs + sv_count);
   2192             df_alpha.resize(ofs + sv_count);
   2193             dfi["alpha"].readRaw("d", (uchar*)&df_alpha[ofs], sv_count*sizeof(df_alpha[0]));
   2194             if( class_count > 2 )
   2195                 dfi["index"].readRaw("i", (uchar*)&df_index[ofs], sv_count*sizeof(df_index[0]));
   2196             decision_func.push_back(df);
   2197         }
   2198         if( class_count <= 2 )
   2199             setRangeVector(df_index, sv_total);
   2200         if( (int)fn["optimize_linear"] != 0 )
   2201             optimize_linear_svm();
   2202     }
   2203 
   2204     SvmParams params;
   2205     Mat class_labels;
   2206     int var_count;
   2207     Mat sv;
   2208     vector<DecisionFunc> decision_func;
   2209     vector<double> df_alpha;
   2210     vector<int> df_index;
   2211 
   2212     Ptr<Kernel> kernel;
   2213 };
   2214 
   2215 
   2216 Ptr<SVM> SVM::create()
   2217 {
   2218     return makePtr<SVMImpl>();
   2219 }
   2220 
   2221 }
   2222 }
   2223 
   2224 /* End of file. */
   2225