Home | History | Annotate | Download | only in include
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #ifndef __ML_H__
     42 #define __ML_H__
     43 
     44 // disable deprecation warning which appears in VisualStudio 8.0
     45 #if _MSC_VER >= 1400
     46 #pragma warning( disable : 4996 )
     47 #endif
     48 
     49 #ifndef SKIP_INCLUDES
     50 
     51   #include "cxcore.h"
     52   #include <limits.h>
     53 
     54   #if defined WIN32 || defined WIN64
     55     #include <windows.h>
     56   #endif
     57 
     58 #else // SKIP_INCLUDES
     59 
     60   #if defined WIN32 || defined WIN64
     61     #define CV_CDECL __cdecl
     62     #define CV_STDCALL __stdcall
     63   #else
     64     #define CV_CDECL
     65     #define CV_STDCALL
     66   #endif
     67 
     68   #ifndef CV_EXTERN_C
     69     #ifdef __cplusplus
     70       #define CV_EXTERN_C extern "C"
     71       #define CV_DEFAULT(val) = val
     72     #else
     73       #define CV_EXTERN_C
     74       #define CV_DEFAULT(val)
     75     #endif
     76   #endif
     77 
     78   #ifndef CV_EXTERN_C_FUNCPTR
     79     #ifdef __cplusplus
     80       #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; }
     81     #else
     82       #define CV_EXTERN_C_FUNCPTR(x) typedef x
     83     #endif
     84   #endif
     85 
     86   #ifndef CV_INLINE
     87     #if defined __cplusplus
     88       #define CV_INLINE inline
     89     #elif (defined WIN32 || defined WIN64) && !defined __GNUC__
     90       #define CV_INLINE __inline
     91     #else
     92       #define CV_INLINE static
     93     #endif
     94   #endif /* CV_INLINE */
     95 
     96   #if (defined WIN32 || defined WIN64) && defined CVAPI_EXPORTS
     97     #define CV_EXPORTS __declspec(dllexport)
     98   #else
     99     #define CV_EXPORTS
    100   #endif
    101 
    102   #ifndef CVAPI
    103     #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL
    104   #endif
    105 
    106 #endif // SKIP_INCLUDES
    107 
    108 
    109 #ifdef __cplusplus
    110 
    111 // Apple defines a check() macro somewhere in the debug headers
    112 // that interferes with a method definiton in this header
    113 #undef check
    114 
    115 /****************************************************************************************\
    116 *                               Main struct definitions                                  *
    117 \****************************************************************************************/
    118 
    119 /* log(2*PI) */
    120 #define CV_LOG2PI (1.8378770664093454835606594728112)
    121 
    122 /* columns of <trainData> matrix are training samples */
    123 #define CV_COL_SAMPLE 0
    124 
    125 /* rows of <trainData> matrix are training samples */
    126 #define CV_ROW_SAMPLE 1
    127 
    128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
    129 
    130 struct CvVectors
    131 {
    132     int type;
    133     int dims, count;
    134     CvVectors* next;
    135     union
    136     {
    137         uchar** ptr;
    138         float** fl;
    139         double** db;
    140     } data;
    141 };
    142 
    143 #if 0
    144 /* A structure, representing the lattice range of statmodel parameters.
    145    It is used for optimizing statmodel parameters by cross-validation method.
    146    The lattice is logarithmic, so <step> must be greater then 1. */
    147 typedef struct CvParamLattice
    148 {
    149     double min_val;
    150     double max_val;
    151     double step;
    152 }
    153 CvParamLattice;
    154 
    155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
    156                                          double log_step )
    157 {
    158     CvParamLattice pl;
    159     pl.min_val = MIN( min_val, max_val );
    160     pl.max_val = MAX( min_val, max_val );
    161     pl.step = MAX( log_step, 1. );
    162     return pl;
    163 }
    164 
    165 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
    166 {
    167     CvParamLattice pl = {0,0,0};
    168     return pl;
    169 }
    170 #endif
    171 
    172 /* Variable type */
    173 #define CV_VAR_NUMERICAL    0
    174 #define CV_VAR_ORDERED      0
    175 #define CV_VAR_CATEGORICAL  1
    176 
    177 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
    178 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
    179 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
    180 #define CV_TYPE_NAME_ML_EM          "opencv-ml-em"
    181 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
    182 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
    183 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
    184 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
    185 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
    186 
    187 class CV_EXPORTS CvStatModel
    188 {
    189 public:
    190     CvStatModel();
    191     virtual ~CvStatModel();
    192 
    193     virtual void clear();
    194 
    195     virtual void save( const char* filename, const char* name=0 );
    196     virtual void load( const char* filename, const char* name=0 );
    197 
    198     virtual void write( CvFileStorage* storage, const char* name );
    199     virtual void read( CvFileStorage* storage, CvFileNode* node );
    200 
    201 protected:
    202     const char* default_model_name;
    203 };
    204 
    205 
    206 /****************************************************************************************\
    207 *                                 Normal Bayes Classifier                                *
    208 \****************************************************************************************/
    209 
    210 /* The structure, representing the grid range of statmodel parameters.
    211    It is used for optimizing statmodel accuracy by varying model parameters,
    212    the accuracy estimate being computed by cross-validation.
    213    The grid is logarithmic, so <step> must be greater then 1. */
    214 struct CV_EXPORTS CvParamGrid
    215 {
    216     // SVM params type
    217     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
    218 
    219     CvParamGrid()
    220     {
    221         min_val = max_val = step = 0;
    222     }
    223 
    224     CvParamGrid( double _min_val, double _max_val, double log_step )
    225     {
    226         min_val = _min_val;
    227         max_val = _max_val;
    228         step = log_step;
    229     }
    230     //CvParamGrid( int param_id );
    231     bool check() const;
    232 
    233     double min_val;
    234     double max_val;
    235     double step;
    236 };
    237 
    238 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel
    239 {
    240 public:
    241     CvNormalBayesClassifier();
    242     virtual ~CvNormalBayesClassifier();
    243 
    244     CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses,
    245         const CvMat* _var_idx=0, const CvMat* _sample_idx=0 );
    246 
    247     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
    248         const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false );
    249 
    250     virtual float predict( const CvMat* _samples, CvMat* results=0 ) const;
    251     virtual void clear();
    252 
    253     virtual void write( CvFileStorage* storage, const char* name );
    254     virtual void read( CvFileStorage* storage, CvFileNode* node );
    255 
    256 protected:
    257     int     var_count, var_all;
    258     CvMat*  var_idx;
    259     CvMat*  cls_labels;
    260     CvMat** count;
    261     CvMat** sum;
    262     CvMat** productsum;
    263     CvMat** avg;
    264     CvMat** inv_eigen_values;
    265     CvMat** cov_rotate_mats;
    266     CvMat*  c;
    267 };
    268 
    269 
    270 /****************************************************************************************\
    271 *                          K-Nearest Neighbour Classifier                                *
    272 \****************************************************************************************/
    273 
    274 // k Nearest Neighbors
    275 class CV_EXPORTS CvKNearest : public CvStatModel
    276 {
    277 public:
    278 
    279     CvKNearest();
    280     virtual ~CvKNearest();
    281 
    282     CvKNearest( const CvMat* _train_data, const CvMat* _responses,
    283                 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 );
    284 
    285     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
    286                         const CvMat* _sample_idx=0, bool is_regression=false,
    287                         int _max_k=32, bool _update_base=false );
    288 
    289     virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0,
    290         const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const;
    291 
    292     virtual void clear();
    293     int get_max_k() const;
    294     int get_var_count() const;
    295     int get_sample_count() const;
    296     bool is_regression() const;
    297 
    298 protected:
    299 
    300     virtual float write_results( int k, int k1, int start, int end,
    301         const float* neighbor_responses, const float* dist, CvMat* _results,
    302         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
    303 
    304     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
    305         float* neighbor_responses, const float** neighbors, float* dist ) const;
    306 
    307 
    308     int max_k, var_count;
    309     int total;
    310     bool regression;
    311     CvVectors* samples;
    312 };
    313 
    314 /****************************************************************************************\
    315 *                                   Support Vector Machines                              *
    316 \****************************************************************************************/
    317 
    318 // SVM training parameters
    319 struct CV_EXPORTS CvSVMParams
    320 {
    321     CvSVMParams();
    322     CvSVMParams( int _svm_type, int _kernel_type,
    323                  double _degree, double _gamma, double _coef0,
    324                  double _C, double _nu, double _p,
    325                  CvMat* _class_weights, CvTermCriteria _term_crit );
    326 
    327     int         svm_type;
    328     int         kernel_type;
    329     double      degree; // for poly
    330     double      gamma;  // for poly/rbf/sigmoid
    331     double      coef0;  // for poly/sigmoid
    332 
    333     double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
    334     double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
    335     double      p; // for CV_SVM_EPS_SVR
    336     CvMat*      class_weights; // for CV_SVM_C_SVC
    337     CvTermCriteria term_crit; // termination criteria
    338 };
    339 
    340 
    341 struct CV_EXPORTS CvSVMKernel
    342 {
    343     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
    344                                        const float* another, float* results );
    345     CvSVMKernel();
    346     CvSVMKernel( const CvSVMParams* _params, Calc _calc_func );
    347     virtual bool create( const CvSVMParams* _params, Calc _calc_func );
    348     virtual ~CvSVMKernel();
    349 
    350     virtual void clear();
    351     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
    352 
    353     const CvSVMParams* params;
    354     Calc calc_func;
    355 
    356     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
    357                                     const float* another, float* results,
    358                                     double alpha, double beta );
    359 
    360     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
    361                               const float* another, float* results );
    362     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
    363                            const float* another, float* results );
    364     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
    365                             const float* another, float* results );
    366     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
    367                                const float* another, float* results );
    368 };
    369 
    370 
    371 struct CvSVMKernelRow
    372 {
    373     CvSVMKernelRow* prev;
    374     CvSVMKernelRow* next;
    375     float* data;
    376 };
    377 
    378 
    379 struct CvSVMSolutionInfo
    380 {
    381     double obj;
    382     double rho;
    383     double upper_bound_p;
    384     double upper_bound_n;
    385     double r;   // for Solver_NU
    386 };
    387 
    388 class CV_EXPORTS CvSVMSolver
    389 {
    390 public:
    391     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
    392     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
    393     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
    394 
    395     CvSVMSolver();
    396 
    397     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
    398                  int alpha_count, double* alpha, double Cp, double Cn,
    399                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
    400                  SelectWorkingSet select_working_set, CalcRho calc_rho );
    401     virtual bool create( int count, int var_count, const float** samples, schar* y,
    402                  int alpha_count, double* alpha, double Cp, double Cn,
    403                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
    404                  SelectWorkingSet select_working_set, CalcRho calc_rho );
    405     virtual ~CvSVMSolver();
    406 
    407     virtual void clear();
    408     virtual bool solve_generic( CvSVMSolutionInfo& si );
    409 
    410     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
    411                               double Cp, double Cn, CvMemStorage* storage,
    412                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
    413     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
    414                                CvMemStorage* storage, CvSVMKernel* kernel,
    415                                double* alpha, CvSVMSolutionInfo& si );
    416     virtual bool solve_one_class( int count, int var_count, const float** samples,
    417                                   CvMemStorage* storage, CvSVMKernel* kernel,
    418                                   double* alpha, CvSVMSolutionInfo& si );
    419 
    420     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
    421                                 CvMemStorage* storage, CvSVMKernel* kernel,
    422                                 double* alpha, CvSVMSolutionInfo& si );
    423 
    424     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
    425                                CvMemStorage* storage, CvSVMKernel* kernel,
    426                                double* alpha, CvSVMSolutionInfo& si );
    427 
    428     virtual float* get_row_base( int i, bool* _existed );
    429     virtual float* get_row( int i, float* dst );
    430 
    431     int sample_count;
    432     int var_count;
    433     int cache_size;
    434     int cache_line_size;
    435     const float** samples;
    436     const CvSVMParams* params;
    437     CvMemStorage* storage;
    438     CvSVMKernelRow lru_list;
    439     CvSVMKernelRow* rows;
    440 
    441     int alpha_count;
    442 
    443     double* G;
    444     double* alpha;
    445 
    446     // -1 - lower bound, 0 - free, 1 - upper bound
    447     schar* alpha_status;
    448 
    449     schar* y;
    450     double* b;
    451     float* buf[2];
    452     double eps;
    453     int max_iter;
    454     double C[2];  // C[0] == Cn, C[1] == Cp
    455     CvSVMKernel* kernel;
    456 
    457     SelectWorkingSet select_working_set_func;
    458     CalcRho calc_rho_func;
    459     GetRow get_row_func;
    460 
    461     virtual bool select_working_set( int& i, int& j );
    462     virtual bool select_working_set_nu_svm( int& i, int& j );
    463     virtual void calc_rho( double& rho, double& r );
    464     virtual void calc_rho_nu_svm( double& rho, double& r );
    465 
    466     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
    467     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
    468     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
    469 };
    470 
    471 
    472 struct CvSVMDecisionFunc
    473 {
    474     double rho;
    475     int sv_count;
    476     double* alpha;
    477     int* sv_index;
    478 };
    479 
    480 
    481 // SVM model
    482 class CV_EXPORTS CvSVM : public CvStatModel
    483 {
    484 public:
    485     // SVM type
    486     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
    487 
    488     // SVM kernel type
    489     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 };
    490 
    491     // SVM params type
    492     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
    493 
    494     CvSVM();
    495     virtual ~CvSVM();
    496 
    497     CvSVM( const CvMat* _train_data, const CvMat* _responses,
    498            const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
    499            CvSVMParams _params=CvSVMParams() );
    500 
    501     virtual bool train( const CvMat* _train_data, const CvMat* _responses,
    502                         const CvMat* _var_idx=0, const CvMat* _sample_idx=0,
    503                         CvSVMParams _params=CvSVMParams() );
    504     virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses,
    505         const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params,
    506         int k_fold = 10,
    507         CvParamGrid C_grid      = get_default_grid(CvSVM::C),
    508         CvParamGrid gamma_grid  = get_default_grid(CvSVM::GAMMA),
    509         CvParamGrid p_grid      = get_default_grid(CvSVM::P),
    510         CvParamGrid nu_grid     = get_default_grid(CvSVM::NU),
    511         CvParamGrid coef_grid   = get_default_grid(CvSVM::COEF),
    512         CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) );
    513 
    514     virtual float predict( const CvMat* _sample ) const;
    515 
    516     virtual int get_support_vector_count() const;
    517     virtual const float* get_support_vector(int i) const;
    518     virtual CvSVMParams get_params() const { return params; };
    519     virtual void clear();
    520 
    521     static CvParamGrid get_default_grid( int param_id );
    522 
    523     virtual void write( CvFileStorage* storage, const char* name );
    524     virtual void read( CvFileStorage* storage, CvFileNode* node );
    525     int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
    526 
    527 protected:
    528 
    529     virtual bool set_params( const CvSVMParams& _params );
    530     virtual bool train1( int sample_count, int var_count, const float** samples,
    531                     const void* _responses, double Cp, double Cn,
    532                     CvMemStorage* _storage, double* alpha, double& rho );
    533     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
    534                     const CvMat* _responses, CvMemStorage* _storage, double* alpha );
    535     virtual void create_kernel();
    536     virtual void create_solver();
    537 
    538     virtual void write_params( CvFileStorage* fs );
    539     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
    540 
    541     CvSVMParams params;
    542     CvMat* class_labels;
    543     int var_all;
    544     float** sv;
    545     int sv_total;
    546     CvMat* var_idx;
    547     CvMat* class_weights;
    548     CvSVMDecisionFunc* decision_func;
    549     CvMemStorage* storage;
    550 
    551     CvSVMSolver* solver;
    552     CvSVMKernel* kernel;
    553 };
    554 
    555 /****************************************************************************************\
    556 *                              Expectation - Maximization                                *
    557 \****************************************************************************************/
    558 
    559 struct CV_EXPORTS CvEMParams
    560 {
    561     CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/),
    562         start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0)
    563     {
    564         term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON );
    565     }
    566 
    567     CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/,
    568                 int _start_step=0/*CvEM::START_AUTO_STEP*/,
    569                 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON),
    570                 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) :
    571                 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step),
    572                 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit)
    573     {}
    574 
    575     int nclusters;
    576     int cov_mat_type;
    577     int start_step;
    578     const CvMat* probs;
    579     const CvMat* weights;
    580     const CvMat* means;
    581     const CvMat** covs;
    582     CvTermCriteria term_crit;
    583 };
    584 
    585 
    586 class CV_EXPORTS CvEM : public CvStatModel
    587 {
    588 public:
    589     // Type of covariation matrices
    590     enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 };
    591 
    592     // The initial step
    593     enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 };
    594 
    595     CvEM();
    596     CvEM( const CvMat* samples, const CvMat* sample_idx=0,
    597           CvEMParams params=CvEMParams(), CvMat* labels=0 );
    598 
    599     virtual ~CvEM();
    600 
    601     virtual bool train( const CvMat* samples, const CvMat* sample_idx=0,
    602                         CvEMParams params=CvEMParams(), CvMat* labels=0 );
    603 
    604     virtual float predict( const CvMat* sample, CvMat* probs ) const;
    605     virtual void clear();
    606 
    607     int get_nclusters() const;
    608     const CvMat* get_means() const;
    609     const CvMat** get_covs() const;
    610     const CvMat* get_weights() const;
    611     const CvMat* get_probs() const;
    612 
    613     inline double get_log_likelihood () const { return log_likelihood; };
    614 
    615 protected:
    616 
    617     virtual void set_params( const CvEMParams& params,
    618                              const CvVectors& train_data );
    619     virtual void init_em( const CvVectors& train_data );
    620     virtual double run_em( const CvVectors& train_data );
    621     virtual void init_auto( const CvVectors& samples );
    622     virtual void kmeans( const CvVectors& train_data, int nclusters,
    623                          CvMat* labels, CvTermCriteria criteria,
    624                          const CvMat* means );
    625     CvEMParams params;
    626     double log_likelihood;
    627 
    628     CvMat* means;
    629     CvMat** covs;
    630     CvMat* weights;
    631     CvMat* probs;
    632 
    633     CvMat* log_weight_div_det;
    634     CvMat* inv_eigen_values;
    635     CvMat** cov_rotate_mats;
    636 };
    637 
    638 /****************************************************************************************\
    639 *                                      Decision Tree                                     *
    640 \****************************************************************************************/
    641 
    642 struct CvPair32s32f
    643 {
    644     int i;
    645     float val;
    646 };
    647 
    648 
    649 #define CV_DTREE_CAT_DIR(idx,subset) \
    650     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
    651 
    652 struct CvDTreeSplit
    653 {
    654     int var_idx;
    655     int inversed;
    656     float quality;
    657     CvDTreeSplit* next;
    658     union
    659     {
    660         int subset[2];
    661         struct
    662         {
    663             float c;
    664             int split_point;
    665         }
    666         ord;
    667     };
    668 };
    669 
    670 
    671 struct CvDTreeNode
    672 {
    673     int class_idx;
    674     int Tn;
    675     double value;
    676 
    677     CvDTreeNode* parent;
    678     CvDTreeNode* left;
    679     CvDTreeNode* right;
    680 
    681     CvDTreeSplit* split;
    682 
    683     int sample_count;
    684     int depth;
    685     int* num_valid;
    686     int offset;
    687     int buf_idx;
    688     double maxlr;
    689 
    690     // global pruning data
    691     int complexity;
    692     double alpha;
    693     double node_risk, tree_risk, tree_error;
    694 
    695     // cross-validation pruning data
    696     int* cv_Tn;
    697     double* cv_node_risk;
    698     double* cv_node_error;
    699 
    700     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
    701     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
    702 };
    703 
    704 
    705 struct CV_EXPORTS CvDTreeParams
    706 {
    707     int   max_categories;
    708     int   max_depth;
    709     int   min_sample_count;
    710     int   cv_folds;
    711     bool  use_surrogates;
    712     bool  use_1se_rule;
    713     bool  truncate_pruned_tree;
    714     float regression_accuracy;
    715     const float* priors;
    716 
    717     CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
    718         cv_folds(10), use_surrogates(true), use_1se_rule(true),
    719         truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
    720     {}
    721 
    722     CvDTreeParams( int _max_depth, int _min_sample_count,
    723                    float _regression_accuracy, bool _use_surrogates,
    724                    int _max_categories, int _cv_folds,
    725                    bool _use_1se_rule, bool _truncate_pruned_tree,
    726                    const float* _priors ) :
    727         max_categories(_max_categories), max_depth(_max_depth),
    728         min_sample_count(_min_sample_count), cv_folds (_cv_folds),
    729         use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
    730         truncate_pruned_tree(_truncate_pruned_tree),
    731         regression_accuracy(_regression_accuracy),
    732         priors(_priors)
    733     {}
    734 };
    735 
    736 
    737 struct CV_EXPORTS CvDTreeTrainData
    738 {
    739     CvDTreeTrainData();
    740     CvDTreeTrainData( const CvMat* _train_data, int _tflag,
    741                       const CvMat* _responses, const CvMat* _var_idx=0,
    742                       const CvMat* _sample_idx=0, const CvMat* _var_type=0,
    743                       const CvMat* _missing_mask=0,
    744                       const CvDTreeParams& _params=CvDTreeParams(),
    745                       bool _shared=false, bool _add_labels=false );
    746     virtual ~CvDTreeTrainData();
    747 
    748     virtual void set_data( const CvMat* _train_data, int _tflag,
    749                           const CvMat* _responses, const CvMat* _var_idx=0,
    750                           const CvMat* _sample_idx=0, const CvMat* _var_type=0,
    751                           const CvMat* _missing_mask=0,
    752                           const CvDTreeParams& _params=CvDTreeParams(),
    753                           bool _shared=false, bool _add_labels=false,
    754                           bool _update_data=false );
    755 
    756     virtual void get_vectors( const CvMat* _subsample_idx,
    757          float* values, uchar* missing, float* responses, bool get_class_idx=false );
    758 
    759     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
    760 
    761     virtual void write_params( CvFileStorage* fs );
    762     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
    763 
    764     // release all the data
    765     virtual void clear();
    766 
    767     int get_num_classes() const;
    768     int get_var_type(int vi) const;
    769     int get_work_var_count() const;
    770 
    771     virtual int* get_class_labels( CvDTreeNode* n );
    772     virtual float* get_ord_responses( CvDTreeNode* n );
    773     virtual int* get_labels( CvDTreeNode* n );
    774     virtual int* get_cat_var_data( CvDTreeNode* n, int vi );
    775     virtual CvPair32s32f* get_ord_var_data( CvDTreeNode* n, int vi );
    776     virtual int get_child_buf_idx( CvDTreeNode* n );
    777 
    778     ////////////////////////////////////
    779 
    780     virtual bool set_params( const CvDTreeParams& params );
    781     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
    782                                    int storage_idx, int offset );
    783 
    784     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
    785                 int split_point, int inversed, float quality );
    786     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
    787     virtual void free_node_data( CvDTreeNode* node );
    788     virtual void free_train_data();
    789     virtual void free_node( CvDTreeNode* node );
    790 
    791     int sample_count, var_all, var_count, max_c_count;
    792     int ord_var_count, cat_var_count;
    793     bool have_labels, have_priors;
    794     bool is_classifier;
    795 
    796     int buf_count, buf_size;
    797     bool shared;
    798 
    799     CvMat* cat_count;
    800     CvMat* cat_ofs;
    801     CvMat* cat_map;
    802 
    803     CvMat* counts;
    804     CvMat* buf;
    805     CvMat* direction;
    806     CvMat* split_buf;
    807 
    808     CvMat* var_idx;
    809     CvMat* var_type; // i-th element =
    810                      //   k<0  - ordered
    811                      //   k>=0 - categorical, see k-th element of cat_* arrays
    812     CvMat* priors;
    813     CvMat* priors_mult;
    814 
    815     CvDTreeParams params;
    816 
    817     CvMemStorage* tree_storage;
    818     CvMemStorage* temp_storage;
    819 
    820     CvDTreeNode* data_root;
    821 
    822     CvSet* node_heap;
    823     CvSet* split_heap;
    824     CvSet* cv_heap;
    825     CvSet* nv_heap;
    826 
    827     CvRNG rng;
    828 };
    829 
    830 
    831 class CV_EXPORTS CvDTree : public CvStatModel
    832 {
    833 public:
    834     CvDTree();
    835     virtual ~CvDTree();
    836 
    837     virtual bool train( const CvMat* _train_data, int _tflag,
    838                         const CvMat* _responses, const CvMat* _var_idx=0,
    839                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
    840                         const CvMat* _missing_mask=0,
    841                         CvDTreeParams params=CvDTreeParams() );
    842 
    843     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
    844 
    845     virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
    846                                   bool preprocessed_input=false ) const;
    847     virtual const CvMat* get_var_importance();
    848     virtual void clear();
    849 
    850     virtual void read( CvFileStorage* fs, CvFileNode* node );
    851     virtual void write( CvFileStorage* fs, const char* name );
    852 
    853     // special read & write methods for trees in the tree ensembles
    854     virtual void read( CvFileStorage* fs, CvFileNode* node,
    855                        CvDTreeTrainData* data );
    856     virtual void write( CvFileStorage* fs );
    857 
    858     const CvDTreeNode* get_root() const;
    859     int get_pruned_tree_idx() const;
    860     CvDTreeTrainData* get_data();
    861 
    862 protected:
    863 
    864     virtual bool do_train( const CvMat* _subsample_idx );
    865 
    866     virtual void try_split_node( CvDTreeNode* n );
    867     virtual void split_node_data( CvDTreeNode* n );
    868     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
    869     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
    870     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
    871     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
    872     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
    873     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
    874     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
    875     virtual double calc_node_dir( CvDTreeNode* node );
    876     virtual void complete_node_dir( CvDTreeNode* node );
    877     virtual void cluster_categories( const int* vectors, int vector_count,
    878         int var_count, int* sums, int k, int* cluster_labels );
    879 
    880     virtual void calc_node_value( CvDTreeNode* node );
    881 
    882     virtual void prune_cv();
    883     virtual double update_tree_rnc( int T, int fold );
    884     virtual int cut_tree( int T, int fold, double min_alpha );
    885     virtual void free_prune_data(bool cut_tree);
    886     virtual void free_tree();
    887 
    888     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node );
    889     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split );
    890     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
    891     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
    892     virtual void write_tree_nodes( CvFileStorage* fs );
    893     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
    894 
    895     CvDTreeNode* root;
    896 
    897     int pruned_tree_idx;
    898     CvMat* var_importance;
    899 
    900     CvDTreeTrainData* data;
    901 };
    902 
    903 
    904 /****************************************************************************************\
    905 *                                   Random Trees Classifier                              *
    906 \****************************************************************************************/
    907 
    908 class CvRTrees;
    909 
    910 class CV_EXPORTS CvForestTree: public CvDTree
    911 {
    912 public:
    913     CvForestTree();
    914     virtual ~CvForestTree();
    915 
    916     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest );
    917 
    918     virtual int get_var_count() const {return data ? data->var_count : 0;}
    919     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
    920 
    921     /* dummy methods to avoid warnings: BEGIN */
    922     virtual bool train( const CvMat* _train_data, int _tflag,
    923                         const CvMat* _responses, const CvMat* _var_idx=0,
    924                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
    925                         const CvMat* _missing_mask=0,
    926                         CvDTreeParams params=CvDTreeParams() );
    927 
    928     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
    929     virtual void read( CvFileStorage* fs, CvFileNode* node );
    930     virtual void read( CvFileStorage* fs, CvFileNode* node,
    931                        CvDTreeTrainData* data );
    932     /* dummy methods to avoid warnings: END */
    933 
    934 protected:
    935     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
    936     CvRTrees* forest;
    937 };
    938 
    939 
    940 struct CV_EXPORTS CvRTParams : public CvDTreeParams
    941 {
    942     //Parameters for the forest
    943     bool calc_var_importance; // true <=> RF processes variable importance
    944     int nactive_vars;
    945     CvTermCriteria term_crit;
    946 
    947     CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ),
    948         calc_var_importance(false), nactive_vars(0)
    949     {
    950         term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 );
    951     }
    952 
    953     CvRTParams( int _max_depth, int _min_sample_count,
    954                 float _regression_accuracy, bool _use_surrogates,
    955                 int _max_categories, const float* _priors, bool _calc_var_importance,
    956                 int _nactive_vars, int max_num_of_trees_in_the_forest,
    957                 float forest_accuracy, int termcrit_type ) :
    958         CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy,
    959                        _use_surrogates, _max_categories, 0,
    960                        false, false, _priors ),
    961         calc_var_importance(_calc_var_importance),
    962         nactive_vars(_nactive_vars)
    963     {
    964         term_crit = cvTermCriteria(termcrit_type,
    965             max_num_of_trees_in_the_forest, forest_accuracy);
    966     }
    967 };
    968 
    969 
    970 class CV_EXPORTS CvRTrees : public CvStatModel
    971 {
    972 public:
    973     CvRTrees();
    974     virtual ~CvRTrees();
    975     virtual bool train( const CvMat* _train_data, int _tflag,
    976                         const CvMat* _responses, const CvMat* _var_idx=0,
    977                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
    978                         const CvMat* _missing_mask=0,
    979                         CvRTParams params=CvRTParams() );
    980     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
    981     virtual void clear();
    982 
    983     virtual const CvMat* get_var_importance();
    984     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
    985         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
    986 
    987     virtual void read( CvFileStorage* fs, CvFileNode* node );
    988     virtual void write( CvFileStorage* fs, const char* name );
    989 
    990     CvMat* get_active_var_mask();
    991     CvRNG* get_rng();
    992 
    993     int get_tree_count() const;
    994     CvForestTree* get_tree(int i) const;
    995 
    996 protected:
    997 
    998     bool grow_forest( const CvTermCriteria term_crit );
    999 
   1000     // array of the trees of the forest
   1001     CvForestTree** trees;
   1002     CvDTreeTrainData* data;
   1003     int ntrees;
   1004     int nclasses;
   1005     double oob_error;
   1006     CvMat* var_importance;
   1007     int nsamples;
   1008 
   1009     CvRNG rng;
   1010     CvMat* active_var_mask;
   1011 };
   1012 
   1013 
   1014 /****************************************************************************************\
   1015 *                                   Boosted tree classifier                              *
   1016 \****************************************************************************************/
   1017 
   1018 struct CV_EXPORTS CvBoostParams : public CvDTreeParams
   1019 {
   1020     int boost_type;
   1021     int weak_count;
   1022     int split_criteria;
   1023     double weight_trim_rate;
   1024 
   1025     CvBoostParams();
   1026     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
   1027                    int max_depth, bool use_surrogates, const float* priors );
   1028 };
   1029 
   1030 
   1031 class CvBoost;
   1032 
   1033 class CV_EXPORTS CvBoostTree: public CvDTree
   1034 {
   1035 public:
   1036     CvBoostTree();
   1037     virtual ~CvBoostTree();
   1038 
   1039     virtual bool train( CvDTreeTrainData* _train_data,
   1040                         const CvMat* subsample_idx, CvBoost* ensemble );
   1041 
   1042     virtual void scale( double s );
   1043     virtual void read( CvFileStorage* fs, CvFileNode* node,
   1044                        CvBoost* ensemble, CvDTreeTrainData* _data );
   1045     virtual void clear();
   1046 
   1047     /* dummy methods to avoid warnings: BEGIN */
   1048     virtual bool train( const CvMat* _train_data, int _tflag,
   1049                         const CvMat* _responses, const CvMat* _var_idx=0,
   1050                         const CvMat* _sample_idx=0, const CvMat* _var_type=0,
   1051                         const CvMat* _missing_mask=0,
   1052                         CvDTreeParams params=CvDTreeParams() );
   1053 
   1054     virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
   1055     virtual void read( CvFileStorage* fs, CvFileNode* node );
   1056     virtual void read( CvFileStorage* fs, CvFileNode* node,
   1057                        CvDTreeTrainData* data );
   1058     /* dummy methods to avoid warnings: END */
   1059 
   1060 protected:
   1061 
   1062     virtual void try_split_node( CvDTreeNode* n );
   1063     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
   1064     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
   1065     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
   1066     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
   1067     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
   1068     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
   1069     virtual void calc_node_value( CvDTreeNode* n );
   1070     virtual double calc_node_dir( CvDTreeNode* n );
   1071 
   1072     CvBoost* ensemble;
   1073 };
   1074 
   1075 
   1076 class CV_EXPORTS CvBoost : public CvStatModel
   1077 {
   1078 public:
   1079     // Boosting type
   1080     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
   1081 
   1082     // Splitting criteria
   1083     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
   1084 
   1085     CvBoost();
   1086     virtual ~CvBoost();
   1087 
   1088     CvBoost( const CvMat* _train_data, int _tflag,
   1089              const CvMat* _responses, const CvMat* _var_idx=0,
   1090              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
   1091              const CvMat* _missing_mask=0,
   1092              CvBoostParams params=CvBoostParams() );
   1093 
   1094     virtual bool train( const CvMat* _train_data, int _tflag,
   1095              const CvMat* _responses, const CvMat* _var_idx=0,
   1096              const CvMat* _sample_idx=0, const CvMat* _var_type=0,
   1097              const CvMat* _missing_mask=0,
   1098              CvBoostParams params=CvBoostParams(),
   1099              bool update=false );
   1100 
   1101     virtual float predict( const CvMat* _sample, const CvMat* _missing=0,
   1102                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
   1103                            bool raw_mode=false ) const;
   1104 
   1105     virtual void prune( CvSlice slice );
   1106 
   1107     virtual void clear();
   1108 
   1109     virtual void write( CvFileStorage* storage, const char* name );
   1110     virtual void read( CvFileStorage* storage, CvFileNode* node );
   1111 
   1112     CvSeq* get_weak_predictors();
   1113 
   1114     CvMat* get_weights();
   1115     CvMat* get_subtree_weights();
   1116     CvMat* get_weak_response();
   1117     const CvBoostParams& get_params() const;
   1118 
   1119 protected:
   1120 
   1121     virtual bool set_params( const CvBoostParams& _params );
   1122     virtual void update_weights( CvBoostTree* tree );
   1123     virtual void trim_weights();
   1124     virtual void write_params( CvFileStorage* fs );
   1125     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
   1126 
   1127     CvDTreeTrainData* data;
   1128     CvBoostParams params;
   1129     CvSeq* weak;
   1130 
   1131     CvMat* orig_response;
   1132     CvMat* sum_response;
   1133     CvMat* weak_eval;
   1134     CvMat* subsample_mask;
   1135     CvMat* weights;
   1136     CvMat* subtree_weights;
   1137     bool have_subsample;
   1138 };
   1139 
   1140 
   1141 /****************************************************************************************\
   1142 *                              Artificial Neural Networks (ANN)                          *
   1143 \****************************************************************************************/
   1144 
   1145 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
   1146 
   1147 struct CV_EXPORTS CvANN_MLP_TrainParams
   1148 {
   1149     CvANN_MLP_TrainParams();
   1150     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
   1151                            double param1, double param2=0 );
   1152     ~CvANN_MLP_TrainParams();
   1153 
   1154     enum { BACKPROP=0, RPROP=1 };
   1155 
   1156     CvTermCriteria term_crit;
   1157     int train_method;
   1158 
   1159     // backpropagation parameters
   1160     double bp_dw_scale, bp_moment_scale;
   1161 
   1162     // rprop parameters
   1163     double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
   1164 };
   1165 
   1166 
   1167 class CV_EXPORTS CvANN_MLP : public CvStatModel
   1168 {
   1169 public:
   1170     CvANN_MLP();
   1171     CvANN_MLP( const CvMat* _layer_sizes,
   1172                int _activ_func=SIGMOID_SYM,
   1173                double _f_param1=0, double _f_param2=0 );
   1174 
   1175     virtual ~CvANN_MLP();
   1176 
   1177     virtual void create( const CvMat* _layer_sizes,
   1178                          int _activ_func=SIGMOID_SYM,
   1179                          double _f_param1=0, double _f_param2=0 );
   1180 
   1181     virtual int train( const CvMat* _inputs, const CvMat* _outputs,
   1182                        const CvMat* _sample_weights, const CvMat* _sample_idx=0,
   1183                        CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(),
   1184                        int flags=0 );
   1185     virtual float predict( const CvMat* _inputs,
   1186                            CvMat* _outputs ) const;
   1187 
   1188     virtual void clear();
   1189 
   1190     // possible activation functions
   1191     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
   1192 
   1193     // available training flags
   1194     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
   1195 
   1196     virtual void read( CvFileStorage* fs, CvFileNode* node );
   1197     virtual void write( CvFileStorage* storage, const char* name );
   1198 
   1199     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
   1200     const CvMat* get_layer_sizes() { return layer_sizes; }
   1201     double* get_weights(int layer)
   1202     {
   1203         return layer_sizes && weights &&
   1204             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
   1205     }
   1206 
   1207 protected:
   1208 
   1209     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
   1210             const CvMat* _sample_weights, const CvMat* _sample_idx,
   1211             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
   1212 
   1213     // sequential random backpropagation
   1214     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
   1215 
   1216     // RPROP algorithm
   1217     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
   1218 
   1219     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
   1220     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
   1221     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
   1222                                  double _f_param1=0, double _f_param2=0 );
   1223     virtual void init_weights();
   1224     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
   1225     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
   1226     virtual void calc_input_scale( const CvVectors* vecs, int flags );
   1227     virtual void calc_output_scale( const CvVectors* vecs, int flags );
   1228 
   1229     virtual void write_params( CvFileStorage* fs );
   1230     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
   1231 
   1232     CvMat* layer_sizes;
   1233     CvMat* wbuf;
   1234     CvMat* sample_weights;
   1235     double** weights;
   1236     double f_param1, f_param2;
   1237     double min_val, max_val, min_val1, max_val1;
   1238     int activ_func;
   1239     int max_count, max_buf_sz;
   1240     CvANN_MLP_TrainParams params;
   1241     CvRNG rng;
   1242 };
   1243 
   1244 #if 0
   1245 /****************************************************************************************\
   1246 *                            Convolutional Neural Network                                *
   1247 \****************************************************************************************/
   1248 typedef struct CvCNNLayer CvCNNLayer;
   1249 typedef struct CvCNNetwork CvCNNetwork;
   1250 
   1251 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY  1
   1252 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV        2
   1253 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV         3
   1254 
   1255 #define CV_CNN_GRAD_ESTIM_RANDOM        0
   1256 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG  1
   1257 
   1258 #define ICV_CNN_LAYER                0x55550000
   1259 #define ICV_CNN_CONVOLUTION_LAYER    0x00001111
   1260 #define ICV_CNN_SUBSAMPLING_LAYER    0x00002222
   1261 #define ICV_CNN_FULLCONNECT_LAYER    0x00003333
   1262 
   1263 #define ICV_IS_CNN_LAYER( layer )                                          \
   1264     ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\
   1265         == ICV_CNN_LAYER ))
   1266 
   1267 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer )                              \
   1268     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
   1269         & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER )
   1270 
   1271 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer )                              \
   1272     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
   1273         & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER )
   1274 
   1275 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer )                              \
   1276     ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags       \
   1277         & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER )
   1278 
   1279 typedef void (CV_CDECL *CvCNNLayerForward)
   1280     ( CvCNNLayer* layer, const CvMat* input, CvMat* output );
   1281 
   1282 typedef void (CV_CDECL *CvCNNLayerBackward)
   1283     ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX );
   1284 
   1285 typedef void (CV_CDECL *CvCNNLayerRelease)
   1286     (CvCNNLayer** layer);
   1287 
   1288 typedef void (CV_CDECL *CvCNNetworkAddLayer)
   1289     (CvCNNetwork* network, CvCNNLayer* layer);
   1290 
   1291 typedef void (CV_CDECL *CvCNNetworkRelease)
   1292     (CvCNNetwork** network);
   1293 
   1294 #define CV_CNN_LAYER_FIELDS()           \
   1295     /* Indicator of the layer's type */ \
   1296     int flags;                          \
   1297                                         \
   1298     /* Number of input images */        \
   1299     int n_input_planes;                 \
   1300     /* Height of each input image */    \
   1301     int input_height;                   \
   1302     /* Width of each input image */     \
   1303     int input_width;                    \
   1304                                         \
   1305     /* Number of output images */       \
   1306     int n_output_planes;                \
   1307     /* Height of each output image */   \
   1308     int output_height;                  \
   1309     /* Width of each output image */    \
   1310     int output_width;                   \
   1311                                         \
   1312     /* Learning rate at the first iteration */                      \
   1313     float init_learn_rate;                                          \
   1314     /* Dynamics of learning rate decreasing */                      \
   1315     int learn_rate_decrease_type;                                   \
   1316     /* Trainable weights of the layer (including bias) */           \
   1317     /* i-th row is a set of weights of the i-th output plane */     \
   1318     CvMat* weights;                                                 \
   1319                                                                     \
   1320     CvCNNLayerForward  forward;                                     \
   1321     CvCNNLayerBackward backward;                                    \
   1322     CvCNNLayerRelease  release;                                     \
   1323     /* Pointers to the previous and next layers in the network */   \
   1324     CvCNNLayer* prev_layer;                                         \
   1325     CvCNNLayer* next_layer
   1326 
   1327 typedef struct CvCNNLayer
   1328 {
   1329     CV_CNN_LAYER_FIELDS();
   1330 }CvCNNLayer;
   1331 
   1332 typedef struct CvCNNConvolutionLayer
   1333 {
   1334     CV_CNN_LAYER_FIELDS();
   1335     // Kernel size (height and width) for convolution.
   1336     int K;
   1337     // connections matrix, (i,j)-th element is 1 iff there is a connection between
   1338     // i-th plane of the current layer and j-th plane of the previous layer;
   1339     // (i,j)-th element is equal to 0 otherwise
   1340     CvMat *connect_mask;
   1341     // value of the learning rate for updating weights at the first iteration
   1342 }CvCNNConvolutionLayer;
   1343 
   1344 typedef struct CvCNNSubSamplingLayer
   1345 {
   1346     CV_CNN_LAYER_FIELDS();
   1347     // ratio between the heights (or widths - ratios are supposed to be equal)
   1348     // of the input and output planes
   1349     int sub_samp_scale;
   1350     // amplitude of sigmoid activation function
   1351     float a;
   1352     // scale parameter of sigmoid activation function
   1353     float s;
   1354     // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X
   1355     // - is the vector used in computing of the activation function in backward
   1356     CvMat* exp2ssumWX;
   1357     // (x1+x2+x3+x4), where x1,...x4 are some elements of X
   1358     // - is the vector used in computing of the activation function in backward
   1359     CvMat* sumX;
   1360 }CvCNNSubSamplingLayer;
   1361 
   1362 // Structure of the last layer.
   1363 typedef struct CvCNNFullConnectLayer
   1364 {
   1365     CV_CNN_LAYER_FIELDS();
   1366     // amplitude of sigmoid activation function
   1367     float a;
   1368     // scale parameter of sigmoid activation function
   1369     float s;
   1370     // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the
   1371     // activation function and it's derivative by the formulae
   1372     // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1)
   1373     // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2
   1374     CvMat* exp2ssumWX;
   1375 }CvCNNFullConnectLayer;
   1376 
   1377 typedef struct CvCNNetwork
   1378 {
   1379     int n_layers;
   1380     CvCNNLayer* layers;
   1381     CvCNNetworkAddLayer add_layer;
   1382     CvCNNetworkRelease release;
   1383 }CvCNNetwork;
   1384 
   1385 typedef struct CvCNNStatModel
   1386 {
   1387     CV_STAT_MODEL_FIELDS();
   1388     CvCNNetwork* network;
   1389     // etalons are allocated as rows, the i-th etalon has label cls_labeles[i]
   1390     CvMat* etalons;
   1391     // classes labels
   1392     CvMat* cls_labels;
   1393 }CvCNNStatModel;
   1394 
   1395 typedef struct CvCNNStatModelParams
   1396 {
   1397     CV_STAT_MODEL_PARAM_FIELDS();
   1398     // network must be created by the functions cvCreateCNNetwork and <add_layer>
   1399     CvCNNetwork* network;
   1400     CvMat* etalons;
   1401     // termination criteria
   1402     int max_iter;
   1403     int start_iter;
   1404     int grad_estim_type;
   1405 }CvCNNStatModelParams;
   1406 
   1407 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer(
   1408     int n_input_planes, int input_height, int input_width,
   1409     int n_output_planes, int K,
   1410     float init_learn_rate, int learn_rate_decrease_type,
   1411     CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) );
   1412 
   1413 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer(
   1414     int n_input_planes, int input_height, int input_width,
   1415     int sub_samp_scale, float a, float s,
   1416     float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) );
   1417 
   1418 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer(
   1419     int n_inputs, int n_outputs, float a, float s,
   1420     float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) );
   1421 
   1422 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer );
   1423 
   1424 CVAPI(CvStatModel*) cvTrainCNNClassifier(
   1425             const CvMat* train_data, int tflag,
   1426             const CvMat* responses,
   1427             const CvStatModelParams* params,
   1428             const CvMat* CV_DEFAULT(0),
   1429             const CvMat* sample_idx CV_DEFAULT(0),
   1430             const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) );
   1431 
   1432 /****************************************************************************************\
   1433 *                               Estimate classifiers algorithms                          *
   1434 \****************************************************************************************/
   1435 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat)
   1436                     ( const CvStatModel* estimateModel );
   1437 
   1438 typedef int (CV_CDECL *CvStatModelEstimateNextStep)
   1439                     ( CvStatModel* estimateModel );
   1440 
   1441 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier)
   1442                     ( CvStatModel* estimateModel,
   1443                 const CvStatModel* model,
   1444                 const CvMat*       features,
   1445                       int          sample_t_flag,
   1446                 const CvMat*       responses );
   1447 
   1448 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy)
   1449                     ( CvStatModel* estimateModel,
   1450                 const CvStatModel* model );
   1451 
   1452 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult)
   1453                     ( const CvStatModel* estimateModel,
   1454                             float*       correlation );
   1455 
   1456 typedef void (CV_CDECL *CvStatModelEstimateReset)
   1457                     ( CvStatModel* estimateModel );
   1458 
   1459 //-------------------------------- Cross-validation --------------------------------------
   1460 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS()    \
   1461     CV_STAT_MODEL_PARAM_FIELDS();                                 \
   1462     int     k_fold;                                               \
   1463     int     is_regression;                                        \
   1464     CvRNG*  rng
   1465 
   1466 typedef struct CvCrossValidationParams
   1467 {
   1468     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS();
   1469 } CvCrossValidationParams;
   1470 
   1471 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS()    \
   1472     CvStatModelEstimateGetMat               getTrainIdxMat; \
   1473     CvStatModelEstimateGetMat               getCheckIdxMat; \
   1474     CvStatModelEstimateNextStep             nextStep;       \
   1475     CvStatModelEstimateCheckClassifier      check;          \
   1476     CvStatModelEstimateGetCurrentResult     getResult;      \
   1477     CvStatModelEstimateReset                reset;          \
   1478     int     is_regression;                                  \
   1479     int     folds_all;                                      \
   1480     int     samples_all;                                    \
   1481     int*    sampleIdxAll;                                   \
   1482     int*    folds;                                          \
   1483     int     max_fold_size;                                  \
   1484     int         current_fold;                               \
   1485     int         is_checked;                                 \
   1486     CvMat*      sampleIdxTrain;                             \
   1487     CvMat*      sampleIdxEval;                              \
   1488     CvMat*      predict_results;                            \
   1489     int     correct_results;                                \
   1490     int     all_results;                                    \
   1491     double  sq_error;                                       \
   1492     double  sum_correct;                                    \
   1493     double  sum_predict;                                    \
   1494     double  sum_cc;                                         \
   1495     double  sum_pp;                                         \
   1496     double  sum_cp
   1497 
   1498 typedef struct CvCrossValidationModel
   1499 {
   1500     CV_STAT_MODEL_FIELDS();
   1501     CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS();
   1502 } CvCrossValidationModel;
   1503 
   1504 CVAPI(CvStatModel*)
   1505 cvCreateCrossValidationEstimateModel
   1506            ( int                samples_all,
   1507        const CvStatModelParams* estimateParams CV_DEFAULT(0),
   1508        const CvMat*             sampleIdx CV_DEFAULT(0) );
   1509 
   1510 CVAPI(float)
   1511 cvCrossValidation( const CvMat*             trueData,
   1512                          int                tflag,
   1513                    const CvMat*             trueClasses,
   1514                          CvStatModel*     (*createClassifier)( const CvMat*,
   1515                                                                      int,
   1516                                                                const CvMat*,
   1517                                                                const CvStatModelParams*,
   1518                                                                const CvMat*,
   1519                                                                const CvMat*,
   1520                                                                const CvMat*,
   1521                                                                const CvMat* ),
   1522                    const CvStatModelParams* estimateParams CV_DEFAULT(0),
   1523                    const CvStatModelParams* trainParams CV_DEFAULT(0),
   1524                    const CvMat*             compIdx CV_DEFAULT(0),
   1525                    const CvMat*             sampleIdx CV_DEFAULT(0),
   1526                          CvStatModel**      pCrValModel CV_DEFAULT(0),
   1527                    const CvMat*             typeMask CV_DEFAULT(0),
   1528                    const CvMat*             missedMeasurementMask CV_DEFAULT(0) );
   1529 #endif
   1530 
   1531 /****************************************************************************************\
   1532 *                           Auxilary functions declarations                              *
   1533 \****************************************************************************************/
   1534 
   1535 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
   1536    average row vector, <cov> - symmetric covariation matrix */
   1537 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
   1538                            CvRNG* rng CV_DEFAULT(0) );
   1539 
   1540 /* Generates sample from gaussian mixture distribution */
   1541 CVAPI(void) cvRandGaussMixture( CvMat* means[],
   1542                                CvMat* covs[],
   1543                                float weights[],
   1544                                int clsnum,
   1545                                CvMat* sample,
   1546                                CvMat* sampClasses CV_DEFAULT(0) );
   1547 
   1548 #define CV_TS_CONCENTRIC_SPHERES 0
   1549 
   1550 /* creates test set */
   1551 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
   1552                  int num_samples,
   1553                  int num_features,
   1554                  CvMat** responses,
   1555                  int num_classes, ... );
   1556 
   1557 /* Aij <- Aji for i > j if lower_to_upper != 0
   1558               for i < j if lower_to_upper = 0 */
   1559 CVAPI(void) cvCompleteSymm( CvMat* matrix, int lower_to_upper );
   1560 
   1561 #endif
   1562 
   1563 #endif /*__ML_H__*/
   1564 /* End of file. */
   1565