Home | History | Annotate | Download | only in traincascade
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #ifndef __OPENCV_ML_HPP__
     42 #define __OPENCV_ML_HPP__
     43 
     44 #ifdef __cplusplus
     45 #  include "opencv2/core.hpp"
     46 #endif
     47 
     48 #include "opencv2/core/core_c.h"
     49 #include <limits.h>
     50 
     51 #ifdef __cplusplus
     52 
     53 #include <map>
     54 #include <iostream>
     55 
     56 // Apple defines a check() macro somewhere in the debug headers
     57 // that interferes with a method definiton in this header
     58 #undef check
     59 
     60 /****************************************************************************************\
     61 *                               Main struct definitions                                  *
     62 \****************************************************************************************/
     63 
     64 /* log(2*PI) */
     65 #define CV_LOG2PI (1.8378770664093454835606594728112)
     66 
     67 /* columns of <trainData> matrix are training samples */
     68 #define CV_COL_SAMPLE 0
     69 
     70 /* rows of <trainData> matrix are training samples */
     71 #define CV_ROW_SAMPLE 1
     72 
     73 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
     74 
     75 struct CvVectors
     76 {
     77     int type;
     78     int dims, count;
     79     CvVectors* next;
     80     union
     81     {
     82         uchar** ptr;
     83         float** fl;
     84         double** db;
     85     } data;
     86 };
     87 
     88 #if 0
     89 /* A structure, representing the lattice range of statmodel parameters.
     90    It is used for optimizing statmodel parameters by cross-validation method.
     91    The lattice is logarithmic, so <step> must be greater then 1. */
     92 typedef struct CvParamLattice
     93 {
     94     double min_val;
     95     double max_val;
     96     double step;
     97 }
     98 CvParamLattice;
     99 
    100 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
    101                                          double log_step )
    102 {
    103     CvParamLattice pl;
    104     pl.min_val = MIN( min_val, max_val );
    105     pl.max_val = MAX( min_val, max_val );
    106     pl.step = MAX( log_step, 1. );
    107     return pl;
    108 }
    109 
    110 CV_INLINE CvParamLattice cvDefaultParamLattice( void )
    111 {
    112     CvParamLattice pl = {0,0,0};
    113     return pl;
    114 }
    115 #endif
    116 
    117 /* Variable type */
    118 #define CV_VAR_NUMERICAL    0
    119 #define CV_VAR_ORDERED      0
    120 #define CV_VAR_CATEGORICAL  1
    121 
    122 #define CV_TYPE_NAME_ML_SVM         "opencv-ml-svm"
    123 #define CV_TYPE_NAME_ML_KNN         "opencv-ml-knn"
    124 #define CV_TYPE_NAME_ML_NBAYES      "opencv-ml-bayesian"
    125 #define CV_TYPE_NAME_ML_BOOSTING    "opencv-ml-boost-tree"
    126 #define CV_TYPE_NAME_ML_TREE        "opencv-ml-tree"
    127 #define CV_TYPE_NAME_ML_ANN_MLP     "opencv-ml-ann-mlp"
    128 #define CV_TYPE_NAME_ML_CNN         "opencv-ml-cnn"
    129 #define CV_TYPE_NAME_ML_RTREES      "opencv-ml-random-trees"
    130 #define CV_TYPE_NAME_ML_ERTREES     "opencv-ml-extremely-randomized-trees"
    131 #define CV_TYPE_NAME_ML_GBT         "opencv-ml-gradient-boosting-trees"
    132 
    133 #define CV_TRAIN_ERROR  0
    134 #define CV_TEST_ERROR   1
    135 
    136 class CvStatModel
    137 {
    138 public:
    139     CvStatModel();
    140     virtual ~CvStatModel();
    141 
    142     virtual void clear();
    143 
    144     CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
    145     CV_WRAP virtual void load( const char* filename, const char* name=0 );
    146 
    147     virtual void write( CvFileStorage* storage, const char* name ) const;
    148     virtual void read( CvFileStorage* storage, CvFileNode* node );
    149 
    150 protected:
    151     const char* default_model_name;
    152 };
    153 
    154 /****************************************************************************************\
    155 *                                 Normal Bayes Classifier                                *
    156 \****************************************************************************************/
    157 
    158 /* The structure, representing the grid range of statmodel parameters.
    159    It is used for optimizing statmodel accuracy by varying model parameters,
    160    the accuracy estimate being computed by cross-validation.
    161    The grid is logarithmic, so <step> must be greater then 1. */
    162 
    163 class CvMLData;
    164 
    165 struct CvParamGrid
    166 {
    167     // SVM params type
    168     enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
    169 
    170     CvParamGrid()
    171     {
    172         min_val = max_val = step = 0;
    173     }
    174 
    175     CvParamGrid( double min_val, double max_val, double log_step );
    176     //CvParamGrid( int param_id );
    177     bool check() const;
    178 
    179     CV_PROP_RW double min_val;
    180     CV_PROP_RW double max_val;
    181     CV_PROP_RW double step;
    182 };
    183 
    184 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
    185 {
    186     min_val = _min_val;
    187     max_val = _max_val;
    188     step = _log_step;
    189 }
    190 
    191 class CvNormalBayesClassifier : public CvStatModel
    192 {
    193 public:
    194     CV_WRAP CvNormalBayesClassifier();
    195     virtual ~CvNormalBayesClassifier();
    196 
    197     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
    198         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
    199 
    200     virtual bool train( const CvMat* trainData, const CvMat* responses,
    201         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
    202 
    203     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
    204     CV_WRAP virtual void clear();
    205 
    206     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
    207                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
    208     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
    209                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
    210                        bool update=false );
    211     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
    212 
    213     virtual void write( CvFileStorage* storage, const char* name ) const;
    214     virtual void read( CvFileStorage* storage, CvFileNode* node );
    215 
    216 protected:
    217     int     var_count, var_all;
    218     CvMat*  var_idx;
    219     CvMat*  cls_labels;
    220     CvMat** count;
    221     CvMat** sum;
    222     CvMat** productsum;
    223     CvMat** avg;
    224     CvMat** inv_eigen_values;
    225     CvMat** cov_rotate_mats;
    226     CvMat*  c;
    227 };
    228 
    229 
    230 /****************************************************************************************\
    231 *                          K-Nearest Neighbour Classifier                                *
    232 \****************************************************************************************/
    233 
    234 // k Nearest Neighbors
    235 class CvKNearest : public CvStatModel
    236 {
    237 public:
    238 
    239     CV_WRAP CvKNearest();
    240     virtual ~CvKNearest();
    241 
    242     CvKNearest( const CvMat* trainData, const CvMat* responses,
    243                 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
    244 
    245     virtual bool train( const CvMat* trainData, const CvMat* responses,
    246                         const CvMat* sampleIdx=0, bool is_regression=false,
    247                         int maxK=32, bool updateBase=false );
    248 
    249     virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
    250         const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
    251 
    252     CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
    253                const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
    254 
    255     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
    256                        const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
    257                        int maxK=32, bool updateBase=false );
    258 
    259     virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
    260                                 const float** neighbors=0, cv::Mat* neighborResponses=0,
    261                                 cv::Mat* dist=0 ) const;
    262     CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
    263                                         CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
    264 
    265     virtual void clear();
    266     int get_max_k() const;
    267     int get_var_count() const;
    268     int get_sample_count() const;
    269     bool is_regression() const;
    270 
    271     virtual float write_results( int k, int k1, int start, int end,
    272         const float* neighbor_responses, const float* dist, CvMat* _results,
    273         CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
    274 
    275     virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
    276         float* neighbor_responses, const float** neighbors, float* dist ) const;
    277 
    278 protected:
    279 
    280     int max_k, var_count;
    281     int total;
    282     bool regression;
    283     CvVectors* samples;
    284 };
    285 
    286 /****************************************************************************************\
    287 *                                   Support Vector Machines                              *
    288 \****************************************************************************************/
    289 
    290 // SVM training parameters
    291 struct CvSVMParams
    292 {
    293     CvSVMParams();
    294     CvSVMParams( int svm_type, int kernel_type,
    295                  double degree, double gamma, double coef0,
    296                  double Cvalue, double nu, double p,
    297                  CvMat* class_weights, CvTermCriteria term_crit );
    298 
    299     CV_PROP_RW int         svm_type;
    300     CV_PROP_RW int         kernel_type;
    301     CV_PROP_RW double      degree; // for poly
    302     CV_PROP_RW double      gamma;  // for poly/rbf/sigmoid/chi2
    303     CV_PROP_RW double      coef0;  // for poly/sigmoid
    304 
    305     CV_PROP_RW double      C;  // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
    306     CV_PROP_RW double      nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
    307     CV_PROP_RW double      p; // for CV_SVM_EPS_SVR
    308     CvMat*      class_weights; // for CV_SVM_C_SVC
    309     CV_PROP_RW CvTermCriteria term_crit; // termination criteria
    310 };
    311 
    312 
    313 struct CvSVMKernel
    314 {
    315     typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
    316                                        const float* another, float* results );
    317     CvSVMKernel();
    318     CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
    319     virtual bool create( const CvSVMParams* params, Calc _calc_func );
    320     virtual ~CvSVMKernel();
    321 
    322     virtual void clear();
    323     virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
    324 
    325     const CvSVMParams* params;
    326     Calc calc_func;
    327 
    328     virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
    329                                     const float* another, float* results,
    330                                     double alpha, double beta );
    331     virtual void calc_intersec( int vcount, int var_count, const float** vecs,
    332                             const float* another, float* results );
    333     virtual void calc_chi2( int vec_count, int vec_size, const float** vecs,
    334                               const float* another, float* results );
    335     virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
    336                               const float* another, float* results );
    337     virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
    338                            const float* another, float* results );
    339     virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
    340                             const float* another, float* results );
    341     virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
    342                                const float* another, float* results );
    343 };
    344 
    345 
    346 struct CvSVMKernelRow
    347 {
    348     CvSVMKernelRow* prev;
    349     CvSVMKernelRow* next;
    350     float* data;
    351 };
    352 
    353 
    354 struct CvSVMSolutionInfo
    355 {
    356     double obj;
    357     double rho;
    358     double upper_bound_p;
    359     double upper_bound_n;
    360     double r;   // for Solver_NU
    361 };
    362 
    363 class CvSVMSolver
    364 {
    365 public:
    366     typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
    367     typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
    368     typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
    369 
    370     CvSVMSolver();
    371 
    372     CvSVMSolver( int count, int var_count, const float** samples, schar* y,
    373                  int alpha_count, double* alpha, double Cp, double Cn,
    374                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
    375                  SelectWorkingSet select_working_set, CalcRho calc_rho );
    376     virtual bool create( int count, int var_count, const float** samples, schar* y,
    377                  int alpha_count, double* alpha, double Cp, double Cn,
    378                  CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
    379                  SelectWorkingSet select_working_set, CalcRho calc_rho );
    380     virtual ~CvSVMSolver();
    381 
    382     virtual void clear();
    383     virtual bool solve_generic( CvSVMSolutionInfo& si );
    384 
    385     virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
    386                               double Cp, double Cn, CvMemStorage* storage,
    387                               CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
    388     virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
    389                                CvMemStorage* storage, CvSVMKernel* kernel,
    390                                double* alpha, CvSVMSolutionInfo& si );
    391     virtual bool solve_one_class( int count, int var_count, const float** samples,
    392                                   CvMemStorage* storage, CvSVMKernel* kernel,
    393                                   double* alpha, CvSVMSolutionInfo& si );
    394 
    395     virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
    396                                 CvMemStorage* storage, CvSVMKernel* kernel,
    397                                 double* alpha, CvSVMSolutionInfo& si );
    398 
    399     virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
    400                                CvMemStorage* storage, CvSVMKernel* kernel,
    401                                double* alpha, CvSVMSolutionInfo& si );
    402 
    403     virtual float* get_row_base( int i, bool* _existed );
    404     virtual float* get_row( int i, float* dst );
    405 
    406     int sample_count;
    407     int var_count;
    408     int cache_size;
    409     int cache_line_size;
    410     const float** samples;
    411     const CvSVMParams* params;
    412     CvMemStorage* storage;
    413     CvSVMKernelRow lru_list;
    414     CvSVMKernelRow* rows;
    415 
    416     int alpha_count;
    417 
    418     double* G;
    419     double* alpha;
    420 
    421     // -1 - lower bound, 0 - free, 1 - upper bound
    422     schar* alpha_status;
    423 
    424     schar* y;
    425     double* b;
    426     float* buf[2];
    427     double eps;
    428     int max_iter;
    429     double C[2];  // C[0] == Cn, C[1] == Cp
    430     CvSVMKernel* kernel;
    431 
    432     SelectWorkingSet select_working_set_func;
    433     CalcRho calc_rho_func;
    434     GetRow get_row_func;
    435 
    436     virtual bool select_working_set( int& i, int& j );
    437     virtual bool select_working_set_nu_svm( int& i, int& j );
    438     virtual void calc_rho( double& rho, double& r );
    439     virtual void calc_rho_nu_svm( double& rho, double& r );
    440 
    441     virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
    442     virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
    443     virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
    444 };
    445 
    446 
    447 struct CvSVMDecisionFunc
    448 {
    449     double rho;
    450     int sv_count;
    451     double* alpha;
    452     int* sv_index;
    453 };
    454 
    455 
    456 // SVM model
    457 class CvSVM : public CvStatModel
    458 {
    459 public:
    460     // SVM type
    461     enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
    462 
    463     // SVM kernel type
    464     enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3, CHI2=4, INTER=5 };
    465 
    466     // SVM params type
    467     enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
    468 
    469     CV_WRAP CvSVM();
    470     virtual ~CvSVM();
    471 
    472     CvSVM( const CvMat* trainData, const CvMat* responses,
    473            const CvMat* varIdx=0, const CvMat* sampleIdx=0,
    474            CvSVMParams params=CvSVMParams() );
    475 
    476     virtual bool train( const CvMat* trainData, const CvMat* responses,
    477                         const CvMat* varIdx=0, const CvMat* sampleIdx=0,
    478                         CvSVMParams params=CvSVMParams() );
    479 
    480     virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
    481         const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
    482         int kfold = 10,
    483         CvParamGrid Cgrid      = get_default_grid(CvSVM::C),
    484         CvParamGrid gammaGrid  = get_default_grid(CvSVM::GAMMA),
    485         CvParamGrid pGrid      = get_default_grid(CvSVM::P),
    486         CvParamGrid nuGrid     = get_default_grid(CvSVM::NU),
    487         CvParamGrid coeffGrid  = get_default_grid(CvSVM::COEF),
    488         CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
    489         bool balanced=false );
    490 
    491     virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
    492     virtual float predict( const CvMat* samples, CV_OUT CvMat* results, bool returnDFVal=false ) const;
    493 
    494     CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
    495           const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
    496           CvSVMParams params=CvSVMParams() );
    497 
    498     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
    499                        const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
    500                        CvSVMParams params=CvSVMParams() );
    501 
    502     CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
    503                             const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
    504                             int k_fold = 10,
    505                             CvParamGrid Cgrid      = CvSVM::get_default_grid(CvSVM::C),
    506                             CvParamGrid gammaGrid  = CvSVM::get_default_grid(CvSVM::GAMMA),
    507                             CvParamGrid pGrid      = CvSVM::get_default_grid(CvSVM::P),
    508                             CvParamGrid nuGrid     = CvSVM::get_default_grid(CvSVM::NU),
    509                             CvParamGrid coeffGrid  = CvSVM::get_default_grid(CvSVM::COEF),
    510                             CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
    511                             bool balanced=false);
    512     CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
    513     CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
    514 
    515     CV_WRAP virtual int get_support_vector_count() const;
    516     virtual const float* get_support_vector(int i) const;
    517     virtual CvSVMParams get_params() const { return params; }
    518     CV_WRAP virtual void clear();
    519 
    520     virtual const CvSVMDecisionFunc* get_decision_function() const { return decision_func; }
    521 
    522     static CvParamGrid get_default_grid( int param_id );
    523 
    524     virtual void write( CvFileStorage* storage, const char* name ) const;
    525     virtual void read( CvFileStorage* storage, CvFileNode* node );
    526     CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
    527 
    528 protected:
    529 
    530     virtual bool set_params( const CvSVMParams& params );
    531     virtual bool train1( int sample_count, int var_count, const float** samples,
    532                     const void* responses, double Cp, double Cn,
    533                     CvMemStorage* _storage, double* alpha, double& rho );
    534     virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
    535                     const CvMat* responses, CvMemStorage* _storage, double* alpha );
    536     virtual void create_kernel();
    537     virtual void create_solver();
    538 
    539     virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
    540 
    541     virtual void write_params( CvFileStorage* fs ) const;
    542     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
    543 
    544     void optimize_linear_svm();
    545 
    546     CvSVMParams params;
    547     CvMat* class_labels;
    548     int var_all;
    549     float** sv;
    550     int sv_total;
    551     CvMat* var_idx;
    552     CvMat* class_weights;
    553     CvSVMDecisionFunc* decision_func;
    554     CvMemStorage* storage;
    555 
    556     CvSVMSolver* solver;
    557     CvSVMKernel* kernel;
    558 
    559 private:
    560     CvSVM(const CvSVM&);
    561     CvSVM& operator = (const CvSVM&);
    562 };
    563 
    564 /****************************************************************************************\
    565 *                                      Decision Tree                                     *
    566 \****************************************************************************************/\
    567 struct CvPair16u32s
    568 {
    569     unsigned short* u;
    570     int* i;
    571 };
    572 
    573 
    574 #define CV_DTREE_CAT_DIR(idx,subset) \
    575     (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
    576 
    577 struct CvDTreeSplit
    578 {
    579     int var_idx;
    580     int condensed_idx;
    581     int inversed;
    582     float quality;
    583     CvDTreeSplit* next;
    584     union
    585     {
    586         int subset[2];
    587         struct
    588         {
    589             float c;
    590             int split_point;
    591         }
    592         ord;
    593     };
    594 };
    595 
    596 struct CvDTreeNode
    597 {
    598     int class_idx;
    599     int Tn;
    600     double value;
    601 
    602     CvDTreeNode* parent;
    603     CvDTreeNode* left;
    604     CvDTreeNode* right;
    605 
    606     CvDTreeSplit* split;
    607 
    608     int sample_count;
    609     int depth;
    610     int* num_valid;
    611     int offset;
    612     int buf_idx;
    613     double maxlr;
    614 
    615     // global pruning data
    616     int complexity;
    617     double alpha;
    618     double node_risk, tree_risk, tree_error;
    619 
    620     // cross-validation pruning data
    621     int* cv_Tn;
    622     double* cv_node_risk;
    623     double* cv_node_error;
    624 
    625     int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
    626     void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
    627 };
    628 
    629 
    630 struct CvDTreeParams
    631 {
    632     CV_PROP_RW int   max_categories;
    633     CV_PROP_RW int   max_depth;
    634     CV_PROP_RW int   min_sample_count;
    635     CV_PROP_RW int   cv_folds;
    636     CV_PROP_RW bool  use_surrogates;
    637     CV_PROP_RW bool  use_1se_rule;
    638     CV_PROP_RW bool  truncate_pruned_tree;
    639     CV_PROP_RW float regression_accuracy;
    640     const float* priors;
    641 
    642     CvDTreeParams();
    643     CvDTreeParams( int max_depth, int min_sample_count,
    644                    float regression_accuracy, bool use_surrogates,
    645                    int max_categories, int cv_folds,
    646                    bool use_1se_rule, bool truncate_pruned_tree,
    647                    const float* priors );
    648 };
    649 
    650 
    651 struct CvDTreeTrainData
    652 {
    653     CvDTreeTrainData();
    654     CvDTreeTrainData( const CvMat* trainData, int tflag,
    655                       const CvMat* responses, const CvMat* varIdx=0,
    656                       const CvMat* sampleIdx=0, const CvMat* varType=0,
    657                       const CvMat* missingDataMask=0,
    658                       const CvDTreeParams& params=CvDTreeParams(),
    659                       bool _shared=false, bool _add_labels=false );
    660     virtual ~CvDTreeTrainData();
    661 
    662     virtual void set_data( const CvMat* trainData, int tflag,
    663                           const CvMat* responses, const CvMat* varIdx=0,
    664                           const CvMat* sampleIdx=0, const CvMat* varType=0,
    665                           const CvMat* missingDataMask=0,
    666                           const CvDTreeParams& params=CvDTreeParams(),
    667                           bool _shared=false, bool _add_labels=false,
    668                           bool _update_data=false );
    669     virtual void do_responses_copy();
    670 
    671     virtual void get_vectors( const CvMat* _subsample_idx,
    672          float* values, uchar* missing, float* responses, bool get_class_idx=false );
    673 
    674     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
    675 
    676     virtual void write_params( CvFileStorage* fs ) const;
    677     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
    678 
    679     // release all the data
    680     virtual void clear();
    681 
    682     int get_num_classes() const;
    683     int get_var_type(int vi) const;
    684     int get_work_var_count() const {return work_var_count;}
    685 
    686     virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
    687     virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
    688     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
    689     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
    690     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
    691     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
    692                                    const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
    693     virtual int get_child_buf_idx( CvDTreeNode* n );
    694 
    695     ////////////////////////////////////
    696 
    697     virtual bool set_params( const CvDTreeParams& params );
    698     virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
    699                                    int storage_idx, int offset );
    700 
    701     virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
    702                 int split_point, int inversed, float quality );
    703     virtual CvDTreeSplit* new_split_cat( int vi, float quality );
    704     virtual void free_node_data( CvDTreeNode* node );
    705     virtual void free_train_data();
    706     virtual void free_node( CvDTreeNode* node );
    707 
    708     int sample_count, var_all, var_count, max_c_count;
    709     int ord_var_count, cat_var_count, work_var_count;
    710     bool have_labels, have_priors;
    711     bool is_classifier;
    712     int tflag;
    713 
    714     const CvMat* train_data;
    715     const CvMat* responses;
    716     CvMat* responses_copy; // used in Boosting
    717 
    718     int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
    719     bool shared;
    720     int is_buf_16u;
    721 
    722     CvMat* cat_count;
    723     CvMat* cat_ofs;
    724     CvMat* cat_map;
    725 
    726     CvMat* counts;
    727     CvMat* buf;
    728     inline size_t get_length_subbuf() const
    729     {
    730         size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
    731         return res;
    732     }
    733 
    734     CvMat* direction;
    735     CvMat* split_buf;
    736 
    737     CvMat* var_idx;
    738     CvMat* var_type; // i-th element =
    739                      //   k<0  - ordered
    740                      //   k>=0 - categorical, see k-th element of cat_* arrays
    741     CvMat* priors;
    742     CvMat* priors_mult;
    743 
    744     CvDTreeParams params;
    745 
    746     CvMemStorage* tree_storage;
    747     CvMemStorage* temp_storage;
    748 
    749     CvDTreeNode* data_root;
    750 
    751     CvSet* node_heap;
    752     CvSet* split_heap;
    753     CvSet* cv_heap;
    754     CvSet* nv_heap;
    755 
    756     cv::RNG* rng;
    757 };
    758 
    759 class CvDTree;
    760 class CvForestTree;
    761 
    762 namespace cv
    763 {
    764     struct DTreeBestSplitFinder;
    765     struct ForestTreeBestSplitFinder;
    766 }
    767 
    768 class CvDTree : public CvStatModel
    769 {
    770 public:
    771     CV_WRAP CvDTree();
    772     virtual ~CvDTree();
    773 
    774     virtual bool train( const CvMat* trainData, int tflag,
    775                         const CvMat* responses, const CvMat* varIdx=0,
    776                         const CvMat* sampleIdx=0, const CvMat* varType=0,
    777                         const CvMat* missingDataMask=0,
    778                         CvDTreeParams params=CvDTreeParams() );
    779 
    780     virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
    781 
    782     // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
    783     virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
    784 
    785     virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
    786 
    787     virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
    788                                   bool preprocessedInput=false ) const;
    789 
    790     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
    791                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
    792                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
    793                        const cv::Mat& missingDataMask=cv::Mat(),
    794                        CvDTreeParams params=CvDTreeParams() );
    795 
    796     CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
    797                                   bool preprocessedInput=false ) const;
    798     CV_WRAP virtual cv::Mat getVarImportance();
    799 
    800     virtual const CvMat* get_var_importance();
    801     CV_WRAP virtual void clear();
    802 
    803     virtual void read( CvFileStorage* fs, CvFileNode* node );
    804     virtual void write( CvFileStorage* fs, const char* name ) const;
    805 
    806     // special read & write methods for trees in the tree ensembles
    807     virtual void read( CvFileStorage* fs, CvFileNode* node,
    808                        CvDTreeTrainData* data );
    809     virtual void write( CvFileStorage* fs ) const;
    810 
    811     const CvDTreeNode* get_root() const;
    812     int get_pruned_tree_idx() const;
    813     CvDTreeTrainData* get_data();
    814 
    815 protected:
    816     friend struct cv::DTreeBestSplitFinder;
    817 
    818     virtual bool do_train( const CvMat* _subsample_idx );
    819 
    820     virtual void try_split_node( CvDTreeNode* n );
    821     virtual void split_node_data( CvDTreeNode* n );
    822     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
    823     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
    824                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    825     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
    826                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    827     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
    828                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    829     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
    830                             float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
    831     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
    832     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
    833     virtual double calc_node_dir( CvDTreeNode* node );
    834     virtual void complete_node_dir( CvDTreeNode* node );
    835     virtual void cluster_categories( const int* vectors, int vector_count,
    836         int var_count, int* sums, int k, int* cluster_labels );
    837 
    838     virtual void calc_node_value( CvDTreeNode* node );
    839 
    840     virtual void prune_cv();
    841     virtual double update_tree_rnc( int T, int fold );
    842     virtual int cut_tree( int T, int fold, double min_alpha );
    843     virtual void free_prune_data(bool cut_tree);
    844     virtual void free_tree();
    845 
    846     virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const;
    847     virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const;
    848     virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
    849     virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
    850     virtual void write_tree_nodes( CvFileStorage* fs ) const;
    851     virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );
    852 
    853     CvDTreeNode* root;
    854     CvMat* var_importance;
    855     CvDTreeTrainData* data;
    856     CvMat train_data_hdr, responses_hdr;
    857     cv::Mat train_data_mat, responses_mat;
    858 
    859 public:
    860     int pruned_tree_idx;
    861 };
    862 
    863 
    864 /****************************************************************************************\
    865 *                                   Random Trees Classifier                              *
    866 \****************************************************************************************/
    867 
    868 class CvRTrees;
    869 
    870 class CvForestTree: public CvDTree
    871 {
    872 public:
    873     CvForestTree();
    874     virtual ~CvForestTree();
    875 
    876     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
    877 
    878     virtual int get_var_count() const {return data ? data->var_count : 0;}
    879     virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data );
    880 
    881     /* dummy methods to avoid warnings: BEGIN */
    882     virtual bool train( const CvMat* trainData, int tflag,
    883                         const CvMat* responses, const CvMat* varIdx=0,
    884                         const CvMat* sampleIdx=0, const CvMat* varType=0,
    885                         const CvMat* missingDataMask=0,
    886                         CvDTreeParams params=CvDTreeParams() );
    887 
    888     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
    889     virtual void read( CvFileStorage* fs, CvFileNode* node );
    890     virtual void read( CvFileStorage* fs, CvFileNode* node,
    891                        CvDTreeTrainData* data );
    892     /* dummy methods to avoid warnings: END */
    893 
    894 protected:
    895     friend struct cv::ForestTreeBestSplitFinder;
    896 
    897     virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
    898     CvRTrees* forest;
    899 };
    900 
    901 
    902 struct CvRTParams : public CvDTreeParams
    903 {
    904     //Parameters for the forest
    905     CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
    906     CV_PROP_RW int nactive_vars;
    907     CV_PROP_RW CvTermCriteria term_crit;
    908 
    909     CvRTParams();
    910     CvRTParams( int max_depth, int min_sample_count,
    911                 float regression_accuracy, bool use_surrogates,
    912                 int max_categories, const float* priors, bool calc_var_importance,
    913                 int nactive_vars, int max_num_of_trees_in_the_forest,
    914                 float forest_accuracy, int termcrit_type );
    915 };
    916 
    917 
    918 class CvRTrees : public CvStatModel
    919 {
    920 public:
    921     CV_WRAP CvRTrees();
    922     virtual ~CvRTrees();
    923     virtual bool train( const CvMat* trainData, int tflag,
    924                         const CvMat* responses, const CvMat* varIdx=0,
    925                         const CvMat* sampleIdx=0, const CvMat* varType=0,
    926                         const CvMat* missingDataMask=0,
    927                         CvRTParams params=CvRTParams() );
    928 
    929     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
    930     virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
    931     virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
    932 
    933     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
    934                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
    935                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
    936                        const cv::Mat& missingDataMask=cv::Mat(),
    937                        CvRTParams params=CvRTParams() );
    938     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
    939     CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
    940     CV_WRAP virtual cv::Mat getVarImportance();
    941 
    942     CV_WRAP virtual void clear();
    943 
    944     virtual const CvMat* get_var_importance();
    945     virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
    946         const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
    947 
    948     virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
    949 
    950     virtual float get_train_error();
    951 
    952     virtual void read( CvFileStorage* fs, CvFileNode* node );
    953     virtual void write( CvFileStorage* fs, const char* name ) const;
    954 
    955     CvMat* get_active_var_mask();
    956     CvRNG* get_rng();
    957 
    958     int get_tree_count() const;
    959     CvForestTree* get_tree(int i) const;
    960 
    961 protected:
    962     virtual cv::String getName() const;
    963 
    964     virtual bool grow_forest( const CvTermCriteria term_crit );
    965 
    966     // array of the trees of the forest
    967     CvForestTree** trees;
    968     CvDTreeTrainData* data;
    969     CvMat train_data_hdr, responses_hdr;
    970     cv::Mat train_data_mat, responses_mat;
    971     int ntrees;
    972     int nclasses;
    973     double oob_error;
    974     CvMat* var_importance;
    975     int nsamples;
    976 
    977     cv::RNG* rng;
    978     CvMat* active_var_mask;
    979 };
    980 
    981 /****************************************************************************************\
    982 *                           Extremely randomized trees Classifier                        *
    983 \****************************************************************************************/
    984 struct CvERTreeTrainData : public CvDTreeTrainData
    985 {
    986     virtual void set_data( const CvMat* trainData, int tflag,
    987                           const CvMat* responses, const CvMat* varIdx=0,
    988                           const CvMat* sampleIdx=0, const CvMat* varType=0,
    989                           const CvMat* missingDataMask=0,
    990                           const CvDTreeParams& params=CvDTreeParams(),
    991                           bool _shared=false, bool _add_labels=false,
    992                           bool _update_data=false );
    993     virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
    994                                    const float** ord_values, const int** missing, int* sample_buf = 0 );
    995     virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
    996     virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
    997     virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
    998     virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
    999                               float* responses, bool get_class_idx=false );
   1000     virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
   1001     const CvMat* missing_mask;
   1002 };
   1003 
   1004 class CvForestERTree : public CvForestTree
   1005 {
   1006 protected:
   1007     virtual double calc_node_dir( CvDTreeNode* node );
   1008     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
   1009         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1010     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
   1011         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1012     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
   1013         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1014     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
   1015         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1016     virtual void split_node_data( CvDTreeNode* n );
   1017 };
   1018 
   1019 class CvERTrees : public CvRTrees
   1020 {
   1021 public:
   1022     CV_WRAP CvERTrees();
   1023     virtual ~CvERTrees();
   1024     virtual bool train( const CvMat* trainData, int tflag,
   1025                         const CvMat* responses, const CvMat* varIdx=0,
   1026                         const CvMat* sampleIdx=0, const CvMat* varType=0,
   1027                         const CvMat* missingDataMask=0,
   1028                         CvRTParams params=CvRTParams());
   1029     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
   1030                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
   1031                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
   1032                        const cv::Mat& missingDataMask=cv::Mat(),
   1033                        CvRTParams params=CvRTParams());
   1034     virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
   1035 protected:
   1036     virtual cv::String getName() const;
   1037     virtual bool grow_forest( const CvTermCriteria term_crit );
   1038 };
   1039 
   1040 
   1041 /****************************************************************************************\
   1042 *                                   Boosted tree classifier                              *
   1043 \****************************************************************************************/
   1044 
   1045 struct CvBoostParams : public CvDTreeParams
   1046 {
   1047     CV_PROP_RW int boost_type;
   1048     CV_PROP_RW int weak_count;
   1049     CV_PROP_RW int split_criteria;
   1050     CV_PROP_RW double weight_trim_rate;
   1051 
   1052     CvBoostParams();
   1053     CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
   1054                    int max_depth, bool use_surrogates, const float* priors );
   1055 };
   1056 
   1057 
   1058 class CvBoost;
   1059 
   1060 class CvBoostTree: public CvDTree
   1061 {
   1062 public:
   1063     CvBoostTree();
   1064     virtual ~CvBoostTree();
   1065 
   1066     virtual bool train( CvDTreeTrainData* trainData,
   1067                         const CvMat* subsample_idx, CvBoost* ensemble );
   1068 
   1069     virtual void scale( double s );
   1070     virtual void read( CvFileStorage* fs, CvFileNode* node,
   1071                        CvBoost* ensemble, CvDTreeTrainData* _data );
   1072     virtual void clear();
   1073 
   1074     /* dummy methods to avoid warnings: BEGIN */
   1075     virtual bool train( const CvMat* trainData, int tflag,
   1076                         const CvMat* responses, const CvMat* varIdx=0,
   1077                         const CvMat* sampleIdx=0, const CvMat* varType=0,
   1078                         const CvMat* missingDataMask=0,
   1079                         CvDTreeParams params=CvDTreeParams() );
   1080     virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
   1081 
   1082     virtual void read( CvFileStorage* fs, CvFileNode* node );
   1083     virtual void read( CvFileStorage* fs, CvFileNode* node,
   1084                        CvDTreeTrainData* data );
   1085     /* dummy methods to avoid warnings: END */
   1086 
   1087 protected:
   1088 
   1089     virtual void try_split_node( CvDTreeNode* n );
   1090     virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
   1091     virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
   1092     virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
   1093         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1094     virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
   1095         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1096     virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
   1097         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1098     virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
   1099         float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
   1100     virtual void calc_node_value( CvDTreeNode* n );
   1101     virtual double calc_node_dir( CvDTreeNode* n );
   1102 
   1103     CvBoost* ensemble;
   1104 };
   1105 
   1106 
   1107 class CvBoost : public CvStatModel
   1108 {
   1109 public:
   1110     // Boosting type
   1111     enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
   1112 
   1113     // Splitting criteria
   1114     enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
   1115 
   1116     CV_WRAP CvBoost();
   1117     virtual ~CvBoost();
   1118 
   1119     CvBoost( const CvMat* trainData, int tflag,
   1120              const CvMat* responses, const CvMat* varIdx=0,
   1121              const CvMat* sampleIdx=0, const CvMat* varType=0,
   1122              const CvMat* missingDataMask=0,
   1123              CvBoostParams params=CvBoostParams() );
   1124 
   1125     virtual bool train( const CvMat* trainData, int tflag,
   1126              const CvMat* responses, const CvMat* varIdx=0,
   1127              const CvMat* sampleIdx=0, const CvMat* varType=0,
   1128              const CvMat* missingDataMask=0,
   1129              CvBoostParams params=CvBoostParams(),
   1130              bool update=false );
   1131 
   1132     virtual bool train( CvMLData* data,
   1133              CvBoostParams params=CvBoostParams(),
   1134              bool update=false );
   1135 
   1136     virtual float predict( const CvMat* sample, const CvMat* missing=0,
   1137                            CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
   1138                            bool raw_mode=false, bool return_sum=false ) const;
   1139 
   1140     CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
   1141             const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
   1142             const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
   1143             const cv::Mat& missingDataMask=cv::Mat(),
   1144             CvBoostParams params=CvBoostParams() );
   1145 
   1146     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
   1147                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
   1148                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
   1149                        const cv::Mat& missingDataMask=cv::Mat(),
   1150                        CvBoostParams params=CvBoostParams(),
   1151                        bool update=false );
   1152 
   1153     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
   1154                                    const cv::Range& slice=cv::Range::all(), bool rawMode=false,
   1155                                    bool returnSum=false ) const;
   1156 
   1157     virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
   1158 
   1159     CV_WRAP virtual void prune( CvSlice slice );
   1160 
   1161     CV_WRAP virtual void clear();
   1162 
   1163     virtual void write( CvFileStorage* storage, const char* name ) const;
   1164     virtual void read( CvFileStorage* storage, CvFileNode* node );
   1165     virtual const CvMat* get_active_vars(bool absolute_idx=true);
   1166 
   1167     CvSeq* get_weak_predictors();
   1168 
   1169     CvMat* get_weights();
   1170     CvMat* get_subtree_weights();
   1171     CvMat* get_weak_response();
   1172     const CvBoostParams& get_params() const;
   1173     const CvDTreeTrainData* get_data() const;
   1174 
   1175 protected:
   1176 
   1177     virtual bool set_params( const CvBoostParams& params );
   1178     virtual void update_weights( CvBoostTree* tree );
   1179     virtual void trim_weights();
   1180     virtual void write_params( CvFileStorage* fs ) const;
   1181     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
   1182 
   1183     virtual void initialize_weights(double (&p)[2]);
   1184 
   1185     CvDTreeTrainData* data;
   1186     CvMat train_data_hdr, responses_hdr;
   1187     cv::Mat train_data_mat, responses_mat;
   1188     CvBoostParams params;
   1189     CvSeq* weak;
   1190 
   1191     CvMat* active_vars;
   1192     CvMat* active_vars_abs;
   1193     bool have_active_cat_vars;
   1194 
   1195     CvMat* orig_response;
   1196     CvMat* sum_response;
   1197     CvMat* weak_eval;
   1198     CvMat* subsample_mask;
   1199     CvMat* weights;
   1200     CvMat* subtree_weights;
   1201     bool have_subsample;
   1202 };
   1203 
   1204 
   1205 /****************************************************************************************\
   1206 *                                   Gradient Boosted Trees                               *
   1207 \****************************************************************************************/
   1208 
   1209 // DataType: STRUCT CvGBTreesParams
   1210 // Parameters of GBT (Gradient Boosted trees model), including single
   1211 // tree settings and ensemble parameters.
   1212 //
   1213 // weak_count          - count of trees in the ensemble
   1214 // loss_function_type  - loss function used for ensemble training
   1215 // subsample_portion   - portion of whole training set used for
   1216 //                       every single tree training.
   1217 //                       subsample_portion value is in (0.0, 1.0].
   1218 //                       subsample_portion == 1.0 when whole dataset is
   1219 //                       used on each step. Count of sample used on each
   1220 //                       step is computed as
   1221 //                       int(total_samples_count * subsample_portion).
   1222 // shrinkage           - regularization parameter.
   1223 //                       Each tree prediction is multiplied on shrinkage value.
   1224 
   1225 
   1226 struct CvGBTreesParams : public CvDTreeParams
   1227 {
   1228     CV_PROP_RW int weak_count;
   1229     CV_PROP_RW int loss_function_type;
   1230     CV_PROP_RW float subsample_portion;
   1231     CV_PROP_RW float shrinkage;
   1232 
   1233     CvGBTreesParams();
   1234     CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
   1235         float subsample_portion, int max_depth, bool use_surrogates );
   1236 };
   1237 
   1238 // DataType: CLASS CvGBTrees
   1239 // Gradient Boosting Trees (GBT) algorithm implementation.
   1240 //
   1241 // data             - training dataset
   1242 // params           - parameters of the CvGBTrees
   1243 // weak             - array[0..(class_count-1)] of CvSeq
   1244 //                    for storing tree ensembles
   1245 // orig_response    - original responses of the training set samples
   1246 // sum_response     - predicitons of the current model on the training dataset.
   1247 //                    this matrix is updated on every iteration.
   1248 // sum_response_tmp - predicitons of the model on the training set on the next
   1249 //                    step. On every iteration values of sum_responses_tmp are
   1250 //                    computed via sum_responses values. When the current
   1251 //                    step is complete sum_response values become equal to
   1252 //                    sum_responses_tmp.
   1253 // sampleIdx       - indices of samples used for training the ensemble.
   1254 //                    CvGBTrees training procedure takes a set of samples
   1255 //                    (train_data) and a set of responses (responses).
   1256 //                    Only pairs (train_data[i], responses[i]), where i is
   1257 //                    in sample_idx are used for training the ensemble.
   1258 // subsample_train  - indices of samples used for training a single decision
   1259 //                    tree on the current step. This indices are countered
   1260 //                    relatively to the sample_idx, so that pairs
   1261 //                    (train_data[sample_idx[i]], responses[sample_idx[i]])
   1262 //                    are used for training a decision tree.
   1263 //                    Training set is randomly splited
   1264 //                    in two parts (subsample_train and subsample_test)
   1265 //                    on every iteration accordingly to the portion parameter.
   1266 // subsample_test   - relative indices of samples from the training set,
   1267 //                    which are not used for training a tree on the current
   1268 //                    step.
   1269 // missing          - mask of the missing values in the training set. This
   1270 //                    matrix has the same size as train_data. 1 - missing
   1271 //                    value, 0 - not a missing value.
   1272 // class_labels     - output class labels map.
   1273 // rng              - random number generator. Used for spliting the
   1274 //                    training set.
   1275 // class_count      - count of output classes.
   1276 //                    class_count == 1 in the case of regression,
   1277 //                    and > 1 in the case of classification.
   1278 // delta            - Huber loss function parameter.
   1279 // base_value       - start point of the gradient descent procedure.
   1280 //                    model prediction is
   1281 //                    f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
   1282 //                    f_0 is the base value.
   1283 
   1284 
   1285 
   1286 class CvGBTrees : public CvStatModel
   1287 {
   1288 public:
   1289 
   1290     /*
   1291     // DataType: ENUM
   1292     // Loss functions implemented in CvGBTrees.
   1293     //
   1294     // SQUARED_LOSS
   1295     // problem: regression
   1296     // loss = (x - x')^2
   1297     //
   1298     // ABSOLUTE_LOSS
   1299     // problem: regression
   1300     // loss = abs(x - x')
   1301     //
   1302     // HUBER_LOSS
   1303     // problem: regression
   1304     // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
   1305     //           1/2*(x - x')^2, if abs(x - x') <= delta,
   1306     //           where delta is the alpha-quantile of pseudo responses from
   1307     //           the training set.
   1308     //
   1309     // DEVIANCE_LOSS
   1310     // problem: classification
   1311     //
   1312     */
   1313     enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
   1314 
   1315 
   1316     /*
   1317     // Default constructor. Creates a model only (without training).
   1318     // Should be followed by one form of the train(...) function.
   1319     //
   1320     // API
   1321     // CvGBTrees();
   1322 
   1323     // INPUT
   1324     // OUTPUT
   1325     // RESULT
   1326     */
   1327     CV_WRAP CvGBTrees();
   1328 
   1329 
   1330     /*
   1331     // Full form constructor. Creates a gradient boosting model and does the
   1332     // train.
   1333     //
   1334     // API
   1335     // CvGBTrees( const CvMat* trainData, int tflag,
   1336              const CvMat* responses, const CvMat* varIdx=0,
   1337              const CvMat* sampleIdx=0, const CvMat* varType=0,
   1338              const CvMat* missingDataMask=0,
   1339              CvGBTreesParams params=CvGBTreesParams() );
   1340 
   1341     // INPUT
   1342     // trainData    - a set of input feature vectors.
   1343     //                  size of matrix is
   1344     //                  <count of samples> x <variables count>
   1345     //                  or <variables count> x <count of samples>
   1346     //                  depending on the tflag parameter.
   1347     //                  matrix values are float.
   1348     // tflag         - a flag showing how do samples stored in the
   1349     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
   1350     //                  or column by column (tflag=CV_COL_SAMPLE).
   1351     // responses     - a vector of responses corresponding to the samples
   1352     //                  in trainData.
   1353     // varIdx       - indices of used variables. zero value means that all
   1354     //                  variables are active.
   1355     // sampleIdx    - indices of used samples. zero value means that all
   1356     //                  samples from trainData are in the training set.
   1357     // varType      - vector of <variables count> length. gives every
   1358     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
   1359     //                  varType = 0 means all variables are numerical.
   1360     // missingDataMask  - a mask of misiing values in trainData.
   1361     //                  missingDataMask = 0 means that there are no missing
   1362     //                  values.
   1363     // params         - parameters of GTB algorithm.
   1364     // OUTPUT
   1365     // RESULT
   1366     */
   1367     CvGBTrees( const CvMat* trainData, int tflag,
   1368              const CvMat* responses, const CvMat* varIdx=0,
   1369              const CvMat* sampleIdx=0, const CvMat* varType=0,
   1370              const CvMat* missingDataMask=0,
   1371              CvGBTreesParams params=CvGBTreesParams() );
   1372 
   1373 
   1374     /*
   1375     // Destructor.
   1376     */
   1377     virtual ~CvGBTrees();
   1378 
   1379 
   1380     /*
   1381     // Gradient tree boosting model training
   1382     //
   1383     // API
   1384     // virtual bool train( const CvMat* trainData, int tflag,
   1385              const CvMat* responses, const CvMat* varIdx=0,
   1386              const CvMat* sampleIdx=0, const CvMat* varType=0,
   1387              const CvMat* missingDataMask=0,
   1388              CvGBTreesParams params=CvGBTreesParams(),
   1389              bool update=false );
   1390 
   1391     // INPUT
   1392     // trainData    - a set of input feature vectors.
   1393     //                  size of matrix is
   1394     //                  <count of samples> x <variables count>
   1395     //                  or <variables count> x <count of samples>
   1396     //                  depending on the tflag parameter.
   1397     //                  matrix values are float.
   1398     // tflag         - a flag showing how do samples stored in the
   1399     //                  trainData matrix row by row (tflag=CV_ROW_SAMPLE)
   1400     //                  or column by column (tflag=CV_COL_SAMPLE).
   1401     // responses     - a vector of responses corresponding to the samples
   1402     //                  in trainData.
   1403     // varIdx       - indices of used variables. zero value means that all
   1404     //                  variables are active.
   1405     // sampleIdx    - indices of used samples. zero value means that all
   1406     //                  samples from trainData are in the training set.
   1407     // varType      - vector of <variables count> length. gives every
   1408     //                  variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
   1409     //                  varType = 0 means all variables are numerical.
   1410     // missingDataMask  - a mask of misiing values in trainData.
   1411     //                  missingDataMask = 0 means that there are no missing
   1412     //                  values.
   1413     // params         - parameters of GTB algorithm.
   1414     // update         - is not supported now. (!)
   1415     // OUTPUT
   1416     // RESULT
   1417     // Error state.
   1418     */
   1419     virtual bool train( const CvMat* trainData, int tflag,
   1420              const CvMat* responses, const CvMat* varIdx=0,
   1421              const CvMat* sampleIdx=0, const CvMat* varType=0,
   1422              const CvMat* missingDataMask=0,
   1423              CvGBTreesParams params=CvGBTreesParams(),
   1424              bool update=false );
   1425 
   1426 
   1427     /*
   1428     // Gradient tree boosting model training
   1429     //
   1430     // API
   1431     // virtual bool train( CvMLData* data,
   1432              CvGBTreesParams params=CvGBTreesParams(),
   1433              bool update=false ) {return false;}
   1434 
   1435     // INPUT
   1436     // data          - training set.
   1437     // params        - parameters of GTB algorithm.
   1438     // update        - is not supported now. (!)
   1439     // OUTPUT
   1440     // RESULT
   1441     // Error state.
   1442     */
   1443     virtual bool train( CvMLData* data,
   1444              CvGBTreesParams params=CvGBTreesParams(),
   1445              bool update=false );
   1446 
   1447 
   1448     /*
   1449     // Response value prediction
   1450     //
   1451     // API
   1452     // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
   1453              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
   1454              int k=-1 ) const;
   1455 
   1456     // INPUT
   1457     // sample         - input sample of the same type as in the training set.
   1458     // missing        - missing values mask. missing=0 if there are no
   1459     //                   missing values in sample vector.
   1460     // weak_responses  - predictions of all of the trees.
   1461     //                   not implemented (!)
   1462     // slice           - part of the ensemble used for prediction.
   1463     //                   slice = CV_WHOLE_SEQ when all trees are used.
   1464     // k               - number of ensemble used.
   1465     //                   k is in {-1,0,1,..,<count of output classes-1>}.
   1466     //                   in the case of classification problem
   1467     //                   <count of output classes-1> ensembles are built.
   1468     //                   If k = -1 ordinary prediction is the result,
   1469     //                   otherwise function gives the prediction of the
   1470     //                   k-th ensemble only.
   1471     // OUTPUT
   1472     // RESULT
   1473     // Predicted value.
   1474     */
   1475     virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
   1476             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
   1477             int k=-1 ) const;
   1478 
   1479     /*
   1480     // Response value prediction.
   1481     // Parallel version (in the case of TBB existence)
   1482     //
   1483     // API
   1484     // virtual float predict( const CvMat* sample, const CvMat* missing=0,
   1485              CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
   1486              int k=-1 ) const;
   1487 
   1488     // INPUT
   1489     // sample         - input sample of the same type as in the training set.
   1490     // missing        - missing values mask. missing=0 if there are no
   1491     //                   missing values in sample vector.
   1492     // weak_responses  - predictions of all of the trees.
   1493     //                   not implemented (!)
   1494     // slice           - part of the ensemble used for prediction.
   1495     //                   slice = CV_WHOLE_SEQ when all trees are used.
   1496     // k               - number of ensemble used.
   1497     //                   k is in {-1,0,1,..,<count of output classes-1>}.
   1498     //                   in the case of classification problem
   1499     //                   <count of output classes-1> ensembles are built.
   1500     //                   If k = -1 ordinary prediction is the result,
   1501     //                   otherwise function gives the prediction of the
   1502     //                   k-th ensemble only.
   1503     // OUTPUT
   1504     // RESULT
   1505     // Predicted value.
   1506     */
   1507     virtual float predict( const CvMat* sample, const CvMat* missing=0,
   1508             CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
   1509             int k=-1 ) const;
   1510 
   1511     /*
   1512     // Deletes all the data.
   1513     //
   1514     // API
   1515     // virtual void clear();
   1516 
   1517     // INPUT
   1518     // OUTPUT
   1519     // delete data, weak, orig_response, sum_response,
   1520     //        weak_eval, subsample_train, subsample_test,
   1521     //        sample_idx, missing, lass_labels
   1522     // delta = 0.0
   1523     // RESULT
   1524     */
   1525     CV_WRAP virtual void clear();
   1526 
   1527     /*
   1528     // Compute error on the train/test set.
   1529     //
   1530     // API
   1531     // virtual float calc_error( CvMLData* _data, int type,
   1532     //        std::vector<float> *resp = 0 );
   1533     //
   1534     // INPUT
   1535     // data  - dataset
   1536     // type  - defines which error is to compute: train (CV_TRAIN_ERROR) or
   1537     //         test (CV_TEST_ERROR).
   1538     // OUTPUT
   1539     // resp  - vector of predicitons
   1540     // RESULT
   1541     // Error value.
   1542     */
   1543     virtual float calc_error( CvMLData* _data, int type,
   1544             std::vector<float> *resp = 0 );
   1545 
   1546     /*
   1547     //
   1548     // Write parameters of the gtb model and data. Write learned model.
   1549     //
   1550     // API
   1551     // virtual void write( CvFileStorage* fs, const char* name ) const;
   1552     //
   1553     // INPUT
   1554     // fs     - file storage to read parameters from.
   1555     // name   - model name.
   1556     // OUTPUT
   1557     // RESULT
   1558     */
   1559     virtual void write( CvFileStorage* fs, const char* name ) const;
   1560 
   1561 
   1562     /*
   1563     //
   1564     // Read parameters of the gtb model and data. Read learned model.
   1565     //
   1566     // API
   1567     // virtual void read( CvFileStorage* fs, CvFileNode* node );
   1568     //
   1569     // INPUT
   1570     // fs     - file storage to read parameters from.
   1571     // node   - file node.
   1572     // OUTPUT
   1573     // RESULT
   1574     */
   1575     virtual void read( CvFileStorage* fs, CvFileNode* node );
   1576 
   1577 
   1578     // new-style C++ interface
   1579     CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
   1580               const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
   1581               const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
   1582               const cv::Mat& missingDataMask=cv::Mat(),
   1583               CvGBTreesParams params=CvGBTreesParams() );
   1584 
   1585     CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
   1586                        const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
   1587                        const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
   1588                        const cv::Mat& missingDataMask=cv::Mat(),
   1589                        CvGBTreesParams params=CvGBTreesParams(),
   1590                        bool update=false );
   1591 
   1592     CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
   1593                            const cv::Range& slice = cv::Range::all(),
   1594                            int k=-1 ) const;
   1595 
   1596 protected:
   1597 
   1598     /*
   1599     // Compute the gradient vector components.
   1600     //
   1601     // API
   1602     // virtual void find_gradient( const int k = 0);
   1603 
   1604     // INPUT
   1605     // k        - used for classification problem, determining current
   1606     //            tree ensemble.
   1607     // OUTPUT
   1608     // changes components of data->responses
   1609     // which correspond to samples used for training
   1610     // on the current step.
   1611     // RESULT
   1612     */
   1613     virtual void find_gradient( const int k = 0);
   1614 
   1615 
   1616     /*
   1617     //
   1618     // Change values in tree leaves according to the used loss function.
   1619     //
   1620     // API
   1621     // virtual void change_values(CvDTree* tree, const int k = 0);
   1622     //
   1623     // INPUT
   1624     // tree      - decision tree to change.
   1625     // k         - used for classification problem, determining current
   1626     //             tree ensemble.
   1627     // OUTPUT
   1628     // changes 'value' fields of the trees' leaves.
   1629     // changes sum_response_tmp.
   1630     // RESULT
   1631     */
   1632     virtual void change_values(CvDTree* tree, const int k = 0);
   1633 
   1634 
   1635     /*
   1636     //
   1637     // Find optimal constant prediction value according to the used loss
   1638     // function.
   1639     // The goal is to find a constant which gives the minimal summary loss
   1640     // on the _Idx samples.
   1641     //
   1642     // API
   1643     // virtual float find_optimal_value( const CvMat* _Idx );
   1644     //
   1645     // INPUT
   1646     // _Idx        - indices of the samples from the training set.
   1647     // OUTPUT
   1648     // RESULT
   1649     // optimal constant value.
   1650     */
   1651     virtual float find_optimal_value( const CvMat* _Idx );
   1652 
   1653 
   1654     /*
   1655     //
   1656     // Randomly split the whole training set in two parts according
   1657     // to params.portion.
   1658     //
   1659     // API
   1660     // virtual void do_subsample();
   1661     //
   1662     // INPUT
   1663     // OUTPUT
   1664     // subsample_train - indices of samples used for training
   1665     // subsample_test  - indices of samples used for test
   1666     // RESULT
   1667     */
   1668     virtual void do_subsample();
   1669 
   1670 
   1671     /*
   1672     //
   1673     // Internal recursive function giving an array of subtree tree leaves.
   1674     //
   1675     // API
   1676     // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
   1677     //
   1678     // INPUT
   1679     // node         - current leaf.
   1680     // OUTPUT
   1681     // count        - count of leaves in the subtree.
   1682     // leaves       - array of pointers to leaves.
   1683     // RESULT
   1684     */
   1685     void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
   1686 
   1687 
   1688     /*
   1689     //
   1690     // Get leaves of the tree.
   1691     //
   1692     // API
   1693     // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
   1694     //
   1695     // INPUT
   1696     // dtree            - decision tree.
   1697     // OUTPUT
   1698     // len              - count of the leaves.
   1699     // RESULT
   1700     // CvDTreeNode**    - array of pointers to leaves.
   1701     */
   1702     CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
   1703 
   1704 
   1705     /*
   1706     //
   1707     // Is it a regression or a classification.
   1708     //
   1709     // API
   1710     // bool problem_type();
   1711     //
   1712     // INPUT
   1713     // OUTPUT
   1714     // RESULT
   1715     // false if it is a classification problem,
   1716     // true - if regression.
   1717     */
   1718     virtual bool problem_type() const;
   1719 
   1720 
   1721     /*
   1722     //
   1723     // Write parameters of the gtb model.
   1724     //
   1725     // API
   1726     // virtual void write_params( CvFileStorage* fs ) const;
   1727     //
   1728     // INPUT
   1729     // fs           - file storage to write parameters to.
   1730     // OUTPUT
   1731     // RESULT
   1732     */
   1733     virtual void write_params( CvFileStorage* fs ) const;
   1734 
   1735 
   1736     /*
   1737     //
   1738     // Read parameters of the gtb model and data.
   1739     //
   1740     // API
   1741     // virtual void read_params( CvFileStorage* fs );
   1742     //
   1743     // INPUT
   1744     // fs           - file storage to read parameters from.
   1745     // OUTPUT
   1746     // params       - parameters of the gtb model.
   1747     // data         - contains information about the structure
   1748     //                of the data set (count of variables,
   1749     //                their types, etc.).
   1750     // class_labels - output class labels map.
   1751     // RESULT
   1752     */
   1753     virtual void read_params( CvFileStorage* fs, CvFileNode* fnode );
   1754     int get_len(const CvMat* mat) const;
   1755 
   1756 
   1757     CvDTreeTrainData* data;
   1758     CvGBTreesParams params;
   1759 
   1760     CvSeq** weak;
   1761     CvMat* orig_response;
   1762     CvMat* sum_response;
   1763     CvMat* sum_response_tmp;
   1764     CvMat* sample_idx;
   1765     CvMat* subsample_train;
   1766     CvMat* subsample_test;
   1767     CvMat* missing;
   1768     CvMat* class_labels;
   1769 
   1770     cv::RNG* rng;
   1771 
   1772     int class_count;
   1773     float delta;
   1774     float base_value;
   1775 
   1776 };
   1777 
   1778 
   1779 
   1780 /****************************************************************************************\
   1781 *                              Artificial Neural Networks (ANN)                          *
   1782 \****************************************************************************************/
   1783 
   1784 /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
   1785 
   1786 struct CvANN_MLP_TrainParams
   1787 {
   1788     CvANN_MLP_TrainParams();
   1789     CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
   1790                            double param1, double param2=0 );
   1791     ~CvANN_MLP_TrainParams();
   1792 
   1793     enum { BACKPROP=0, RPROP=1 };
   1794 
   1795     CV_PROP_RW CvTermCriteria term_crit;
   1796     CV_PROP_RW int train_method;
   1797 
   1798     // backpropagation parameters
   1799     CV_PROP_RW double bp_dw_scale, bp_moment_scale;
   1800 
   1801     // rprop parameters
   1802     CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
   1803 };
   1804 
   1805 
   1806 class CvANN_MLP : public CvStatModel
   1807 {
   1808 public:
   1809     CV_WRAP CvANN_MLP();
   1810     CvANN_MLP( const CvMat* layerSizes,
   1811                int activateFunc=CvANN_MLP::SIGMOID_SYM,
   1812                double fparam1=0, double fparam2=0 );
   1813 
   1814     virtual ~CvANN_MLP();
   1815 
   1816     virtual void create( const CvMat* layerSizes,
   1817                          int activateFunc=CvANN_MLP::SIGMOID_SYM,
   1818                          double fparam1=0, double fparam2=0 );
   1819 
   1820     virtual int train( const CvMat* inputs, const CvMat* outputs,
   1821                        const CvMat* sampleWeights, const CvMat* sampleIdx=0,
   1822                        CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
   1823                        int flags=0 );
   1824     virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
   1825 
   1826     CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
   1827               int activateFunc=CvANN_MLP::SIGMOID_SYM,
   1828               double fparam1=0, double fparam2=0 );
   1829 
   1830     CV_WRAP virtual void create( const cv::Mat& layerSizes,
   1831                         int activateFunc=CvANN_MLP::SIGMOID_SYM,
   1832                         double fparam1=0, double fparam2=0 );
   1833 
   1834     CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
   1835                       const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
   1836                       CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
   1837                       int flags=0 );
   1838 
   1839     CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
   1840 
   1841     CV_WRAP virtual void clear();
   1842 
   1843     // possible activation functions
   1844     enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
   1845 
   1846     // available training flags
   1847     enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
   1848 
   1849     virtual void read( CvFileStorage* fs, CvFileNode* node );
   1850     virtual void write( CvFileStorage* storage, const char* name ) const;
   1851 
   1852     int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
   1853     const CvMat* get_layer_sizes() { return layer_sizes; }
   1854     double* get_weights(int layer)
   1855     {
   1856         return layer_sizes && weights &&
   1857             (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
   1858     }
   1859 
   1860     virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
   1861 
   1862 protected:
   1863 
   1864     virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
   1865             const CvMat* _sample_weights, const CvMat* sampleIdx,
   1866             CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
   1867 
   1868     // sequential random backpropagation
   1869     virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
   1870 
   1871     // RPROP algorithm
   1872     virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
   1873 
   1874     virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
   1875     virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
   1876                                  double _f_param1=0, double _f_param2=0 );
   1877     virtual void init_weights();
   1878     virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
   1879     virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
   1880     virtual void calc_input_scale( const CvVectors* vecs, int flags );
   1881     virtual void calc_output_scale( const CvVectors* vecs, int flags );
   1882 
   1883     virtual void write_params( CvFileStorage* fs ) const;
   1884     virtual void read_params( CvFileStorage* fs, CvFileNode* node );
   1885 
   1886     CvMat* layer_sizes;
   1887     CvMat* wbuf;
   1888     CvMat* sample_weights;
   1889     double** weights;
   1890     double f_param1, f_param2;
   1891     double min_val, max_val, min_val1, max_val1;
   1892     int activ_func;
   1893     int max_count, max_buf_sz;
   1894     CvANN_MLP_TrainParams params;
   1895     cv::RNG* rng;
   1896 };
   1897 
   1898 /****************************************************************************************\
   1899 *                           Auxilary functions declarations                              *
   1900 \****************************************************************************************/
   1901 
   1902 /* Generates <sample> from multivariate normal distribution, where <mean> - is an
   1903    average row vector, <cov> - symmetric covariation matrix */
   1904 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample,
   1905                            CvRNG* rng CV_DEFAULT(0) );
   1906 
   1907 /* Generates sample from gaussian mixture distribution */
   1908 CVAPI(void) cvRandGaussMixture( CvMat* means[],
   1909                                CvMat* covs[],
   1910                                float weights[],
   1911                                int clsnum,
   1912                                CvMat* sample,
   1913                                CvMat* sampClasses CV_DEFAULT(0) );
   1914 
   1915 #define CV_TS_CONCENTRIC_SPHERES 0
   1916 
   1917 /* creates test set */
   1918 CVAPI(void) cvCreateTestSet( int type, CvMat** samples,
   1919                  int num_samples,
   1920                  int num_features,
   1921                  CvMat** responses,
   1922                  int num_classes, ... );
   1923 
   1924 /****************************************************************************************\
   1925 *                                      Data                                             *
   1926 \****************************************************************************************/
   1927 
   1928 #define CV_COUNT     0
   1929 #define CV_PORTION   1
   1930 
   1931 struct CvTrainTestSplit
   1932 {
   1933     CvTrainTestSplit();
   1934     CvTrainTestSplit( int train_sample_count, bool mix = true);
   1935     CvTrainTestSplit( float train_sample_portion, bool mix = true);
   1936 
   1937     union
   1938     {
   1939         int count;
   1940         float portion;
   1941     } train_sample_part;
   1942     int train_sample_part_mode;
   1943 
   1944     bool mix;
   1945 };
   1946 
   1947 class CvMLData
   1948 {
   1949 public:
   1950     CvMLData();
   1951     virtual ~CvMLData();
   1952 
   1953     // returns:
   1954     // 0 - OK
   1955     // -1 - file can not be opened or is not correct
   1956     int read_csv( const char* filename );
   1957 
   1958     const CvMat* get_values() const;
   1959     const CvMat* get_responses();
   1960     const CvMat* get_missing() const;
   1961 
   1962     void set_header_lines_number( int n );
   1963     int get_header_lines_number() const;
   1964 
   1965     void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
   1966                                       // if idx < 0 there will be no response
   1967     int get_response_idx() const;
   1968 
   1969     void set_train_test_split( const CvTrainTestSplit * spl );
   1970     const CvMat* get_train_sample_idx() const;
   1971     const CvMat* get_test_sample_idx() const;
   1972     void mix_train_and_test_idx();
   1973 
   1974     const CvMat* get_var_idx();
   1975     void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
   1976                                                // use change_var_idx
   1977     void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
   1978 
   1979     const CvMat* get_var_types();
   1980     int get_var_type( int var_idx ) const;
   1981     // following 2 methods enable to change vars type
   1982     // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
   1983     // with numerical labels; in the other cases var types are correctly determined automatically
   1984     void set_var_types( const char* str );  // str examples:
   1985                                             // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
   1986                                             // "cat", "ord" (all vars are categorical/ordered)
   1987     void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
   1988 
   1989     void set_delimiter( char ch );
   1990     char get_delimiter() const;
   1991 
   1992     void set_miss_ch( char ch );
   1993     char get_miss_ch() const;
   1994 
   1995     const std::map<cv::String, int>& get_class_labels_map() const;
   1996 
   1997 protected:
   1998     virtual void clear();
   1999 
   2000     void str_to_flt_elem( const char* token, float& flt_elem, int& type);
   2001     void free_train_test_idx();
   2002 
   2003     char delimiter;
   2004     char miss_ch;
   2005     //char flt_separator;
   2006 
   2007     CvMat* values;
   2008     CvMat* missing;
   2009     CvMat* var_types;
   2010     CvMat* var_idx_mask;
   2011 
   2012     CvMat* response_out; // header
   2013     CvMat* var_idx_out; // mat
   2014     CvMat* var_types_out; // mat
   2015 
   2016     int header_lines_number;
   2017 
   2018     int response_idx;
   2019 
   2020     int train_sample_count;
   2021     bool mix;
   2022 
   2023     int total_class_count;
   2024     std::map<cv::String, int> class_map;
   2025 
   2026     CvMat* train_sample_idx;
   2027     CvMat* test_sample_idx;
   2028     int* sample_idx; // data of train_sample_idx and test_sample_idx
   2029 
   2030     cv::RNG* rng;
   2031 };
   2032 
   2033 
   2034 namespace cv
   2035 {
   2036 
   2037 typedef CvStatModel StatModel;
   2038 typedef CvParamGrid ParamGrid;
   2039 typedef CvNormalBayesClassifier NormalBayesClassifier;
   2040 typedef CvKNearest KNearest;
   2041 typedef CvSVMParams SVMParams;
   2042 typedef CvSVMKernel SVMKernel;
   2043 typedef CvSVMSolver SVMSolver;
   2044 typedef CvSVM SVM;
   2045 typedef CvDTreeParams DTreeParams;
   2046 typedef CvMLData TrainData;
   2047 typedef CvDTree DecisionTree;
   2048 typedef CvForestTree ForestTree;
   2049 typedef CvRTParams RandomTreeParams;
   2050 typedef CvRTrees RandomTrees;
   2051 typedef CvERTreeTrainData ERTreeTRainData;
   2052 typedef CvForestERTree ERTree;
   2053 typedef CvERTrees ERTrees;
   2054 typedef CvBoostParams BoostParams;
   2055 typedef CvBoostTree BoostTree;
   2056 typedef CvBoost Boost;
   2057 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
   2058 typedef CvANN_MLP NeuralNet_MLP;
   2059 typedef CvGBTreesParams GradientBoostingTreeParams;
   2060 typedef CvGBTrees GradientBoostingTrees;
   2061 
   2062 template<> void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const;
   2063 }
   2064 
   2065 #endif // __cplusplus
   2066 #endif // __OPENCV_ML_HPP__
   2067 
   2068 /* End of file. */
   2069