1 /*M/////////////////////////////////////////////////////////////////////////////////////// 2 // 3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4 // 5 // By downloading, copying, installing or using the software you agree to this license. 6 // If you do not agree to this license, do not download, install, 7 // copy or use the software. 8 // 9 // 10 // Intel License Agreement 11 // 12 // Copyright (C) 2000, Intel Corporation, all rights reserved. 13 // Third party copyrights are property of their respective owners. 14 // 15 // Redistribution and use in source and binary forms, with or without modification, 16 // are permitted provided that the following conditions are met: 17 // 18 // * Redistribution's of source code must retain the above copyright notice, 19 // this list of conditions and the following disclaimer. 20 // 21 // * Redistribution's in binary form must reproduce the above copyright notice, 22 // this list of conditions and the following disclaimer in the documentation 23 // and/or other materials provided with the distribution. 24 // 25 // * The name of Intel Corporation may not be used to endorse or promote products 26 // derived from this software without specific prior written permission. 27 // 28 // This software is provided by the copyright holders and contributors "as is" and 29 // any express or implied warranties, including, but not limited to, the implied 30 // warranties of merchantability and fitness for a particular purpose are disclaimed. 31 // In no event shall the Intel Corporation or contributors be liable for any direct, 32 // indirect, incidental, special, exemplary, or consequential damages 33 // (including, but not limited to, procurement of substitute goods or services; 34 // loss of use, data, or profits; or business interruption) however caused 35 // and on any theory of liability, whether in contract, strict liability, 36 // or tort (including negligence or otherwise) arising in any way out of 37 // the use of this software, even if advised of the possibility of such damage. 38 // 39 //M*/ 40 41 #ifndef __OPENCV_ML_HPP__ 42 #define __OPENCV_ML_HPP__ 43 44 #ifdef __cplusplus 45 # include "opencv2/core.hpp" 46 #endif 47 48 #include "opencv2/core/core_c.h" 49 #include <limits.h> 50 51 #ifdef __cplusplus 52 53 #include <map> 54 #include <iostream> 55 56 // Apple defines a check() macro somewhere in the debug headers 57 // that interferes with a method definiton in this header 58 #undef check 59 60 /****************************************************************************************\ 61 * Main struct definitions * 62 \****************************************************************************************/ 63 64 /* log(2*PI) */ 65 #define CV_LOG2PI (1.8378770664093454835606594728112) 66 67 /* columns of <trainData> matrix are training samples */ 68 #define CV_COL_SAMPLE 0 69 70 /* rows of <trainData> matrix are training samples */ 71 #define CV_ROW_SAMPLE 1 72 73 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE) 74 75 struct CvVectors 76 { 77 int type; 78 int dims, count; 79 CvVectors* next; 80 union 81 { 82 uchar** ptr; 83 float** fl; 84 double** db; 85 } data; 86 }; 87 88 #if 0 89 /* A structure, representing the lattice range of statmodel parameters. 90 It is used for optimizing statmodel parameters by cross-validation method. 91 The lattice is logarithmic, so <step> must be greater then 1. */ 92 typedef struct CvParamLattice 93 { 94 double min_val; 95 double max_val; 96 double step; 97 } 98 CvParamLattice; 99 100 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val, 101 double log_step ) 102 { 103 CvParamLattice pl; 104 pl.min_val = MIN( min_val, max_val ); 105 pl.max_val = MAX( min_val, max_val ); 106 pl.step = MAX( log_step, 1. ); 107 return pl; 108 } 109 110 CV_INLINE CvParamLattice cvDefaultParamLattice( void ) 111 { 112 CvParamLattice pl = {0,0,0}; 113 return pl; 114 } 115 #endif 116 117 /* Variable type */ 118 #define CV_VAR_NUMERICAL 0 119 #define CV_VAR_ORDERED 0 120 #define CV_VAR_CATEGORICAL 1 121 122 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm" 123 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn" 124 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian" 125 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree" 126 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree" 127 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp" 128 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn" 129 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees" 130 #define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees" 131 #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees" 132 133 #define CV_TRAIN_ERROR 0 134 #define CV_TEST_ERROR 1 135 136 class CvStatModel 137 { 138 public: 139 CvStatModel(); 140 virtual ~CvStatModel(); 141 142 virtual void clear(); 143 144 CV_WRAP virtual void save( const char* filename, const char* name=0 ) const; 145 CV_WRAP virtual void load( const char* filename, const char* name=0 ); 146 147 virtual void write( CvFileStorage* storage, const char* name ) const; 148 virtual void read( CvFileStorage* storage, CvFileNode* node ); 149 150 protected: 151 const char* default_model_name; 152 }; 153 154 /****************************************************************************************\ 155 * Normal Bayes Classifier * 156 \****************************************************************************************/ 157 158 /* The structure, representing the grid range of statmodel parameters. 159 It is used for optimizing statmodel accuracy by varying model parameters, 160 the accuracy estimate being computed by cross-validation. 161 The grid is logarithmic, so <step> must be greater then 1. */ 162 163 class CvMLData; 164 165 struct CvParamGrid 166 { 167 // SVM params type 168 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 }; 169 170 CvParamGrid() 171 { 172 min_val = max_val = step = 0; 173 } 174 175 CvParamGrid( double min_val, double max_val, double log_step ); 176 //CvParamGrid( int param_id ); 177 bool check() const; 178 179 CV_PROP_RW double min_val; 180 CV_PROP_RW double max_val; 181 CV_PROP_RW double step; 182 }; 183 184 inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step ) 185 { 186 min_val = _min_val; 187 max_val = _max_val; 188 step = _log_step; 189 } 190 191 class CvNormalBayesClassifier : public CvStatModel 192 { 193 public: 194 CV_WRAP CvNormalBayesClassifier(); 195 virtual ~CvNormalBayesClassifier(); 196 197 CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses, 198 const CvMat* varIdx=0, const CvMat* sampleIdx=0 ); 199 200 virtual bool train( const CvMat* trainData, const CvMat* responses, 201 const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false ); 202 203 virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const; 204 CV_WRAP virtual void clear(); 205 206 CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses, 207 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() ); 208 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 209 const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 210 bool update=false ); 211 CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const; 212 213 virtual void write( CvFileStorage* storage, const char* name ) const; 214 virtual void read( CvFileStorage* storage, CvFileNode* node ); 215 216 protected: 217 int var_count, var_all; 218 CvMat* var_idx; 219 CvMat* cls_labels; 220 CvMat** count; 221 CvMat** sum; 222 CvMat** productsum; 223 CvMat** avg; 224 CvMat** inv_eigen_values; 225 CvMat** cov_rotate_mats; 226 CvMat* c; 227 }; 228 229 230 /****************************************************************************************\ 231 * K-Nearest Neighbour Classifier * 232 \****************************************************************************************/ 233 234 // k Nearest Neighbors 235 class CvKNearest : public CvStatModel 236 { 237 public: 238 239 CV_WRAP CvKNearest(); 240 virtual ~CvKNearest(); 241 242 CvKNearest( const CvMat* trainData, const CvMat* responses, 243 const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 ); 244 245 virtual bool train( const CvMat* trainData, const CvMat* responses, 246 const CvMat* sampleIdx=0, bool is_regression=false, 247 int maxK=32, bool updateBase=false ); 248 249 virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0, 250 const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const; 251 252 CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses, 253 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 ); 254 255 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 256 const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, 257 int maxK=32, bool updateBase=false ); 258 259 virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0, 260 const float** neighbors=0, cv::Mat* neighborResponses=0, 261 cv::Mat* dist=0 ) const; 262 CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results, 263 CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const; 264 265 virtual void clear(); 266 int get_max_k() const; 267 int get_var_count() const; 268 int get_sample_count() const; 269 bool is_regression() const; 270 271 virtual float write_results( int k, int k1, int start, int end, 272 const float* neighbor_responses, const float* dist, CvMat* _results, 273 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const; 274 275 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end, 276 float* neighbor_responses, const float** neighbors, float* dist ) const; 277 278 protected: 279 280 int max_k, var_count; 281 int total; 282 bool regression; 283 CvVectors* samples; 284 }; 285 286 /****************************************************************************************\ 287 * Support Vector Machines * 288 \****************************************************************************************/ 289 290 // SVM training parameters 291 struct CvSVMParams 292 { 293 CvSVMParams(); 294 CvSVMParams( int svm_type, int kernel_type, 295 double degree, double gamma, double coef0, 296 double Cvalue, double nu, double p, 297 CvMat* class_weights, CvTermCriteria term_crit ); 298 299 CV_PROP_RW int svm_type; 300 CV_PROP_RW int kernel_type; 301 CV_PROP_RW double degree; // for poly 302 CV_PROP_RW double gamma; // for poly/rbf/sigmoid/chi2 303 CV_PROP_RW double coef0; // for poly/sigmoid 304 305 CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR 306 CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR 307 CV_PROP_RW double p; // for CV_SVM_EPS_SVR 308 CvMat* class_weights; // for CV_SVM_C_SVC 309 CV_PROP_RW CvTermCriteria term_crit; // termination criteria 310 }; 311 312 313 struct CvSVMKernel 314 { 315 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs, 316 const float* another, float* results ); 317 CvSVMKernel(); 318 CvSVMKernel( const CvSVMParams* params, Calc _calc_func ); 319 virtual bool create( const CvSVMParams* params, Calc _calc_func ); 320 virtual ~CvSVMKernel(); 321 322 virtual void clear(); 323 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results ); 324 325 const CvSVMParams* params; 326 Calc calc_func; 327 328 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs, 329 const float* another, float* results, 330 double alpha, double beta ); 331 virtual void calc_intersec( int vcount, int var_count, const float** vecs, 332 const float* another, float* results ); 333 virtual void calc_chi2( int vec_count, int vec_size, const float** vecs, 334 const float* another, float* results ); 335 virtual void calc_linear( int vec_count, int vec_size, const float** vecs, 336 const float* another, float* results ); 337 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs, 338 const float* another, float* results ); 339 virtual void calc_poly( int vec_count, int vec_size, const float** vecs, 340 const float* another, float* results ); 341 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs, 342 const float* another, float* results ); 343 }; 344 345 346 struct CvSVMKernelRow 347 { 348 CvSVMKernelRow* prev; 349 CvSVMKernelRow* next; 350 float* data; 351 }; 352 353 354 struct CvSVMSolutionInfo 355 { 356 double obj; 357 double rho; 358 double upper_bound_p; 359 double upper_bound_n; 360 double r; // for Solver_NU 361 }; 362 363 class CvSVMSolver 364 { 365 public: 366 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j ); 367 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed ); 368 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r ); 369 370 CvSVMSolver(); 371 372 CvSVMSolver( int count, int var_count, const float** samples, schar* y, 373 int alpha_count, double* alpha, double Cp, double Cn, 374 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 375 SelectWorkingSet select_working_set, CalcRho calc_rho ); 376 virtual bool create( int count, int var_count, const float** samples, schar* y, 377 int alpha_count, double* alpha, double Cp, double Cn, 378 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 379 SelectWorkingSet select_working_set, CalcRho calc_rho ); 380 virtual ~CvSVMSolver(); 381 382 virtual void clear(); 383 virtual bool solve_generic( CvSVMSolutionInfo& si ); 384 385 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y, 386 double Cp, double Cn, CvMemStorage* storage, 387 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si ); 388 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y, 389 CvMemStorage* storage, CvSVMKernel* kernel, 390 double* alpha, CvSVMSolutionInfo& si ); 391 virtual bool solve_one_class( int count, int var_count, const float** samples, 392 CvMemStorage* storage, CvSVMKernel* kernel, 393 double* alpha, CvSVMSolutionInfo& si ); 394 395 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y, 396 CvMemStorage* storage, CvSVMKernel* kernel, 397 double* alpha, CvSVMSolutionInfo& si ); 398 399 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y, 400 CvMemStorage* storage, CvSVMKernel* kernel, 401 double* alpha, CvSVMSolutionInfo& si ); 402 403 virtual float* get_row_base( int i, bool* _existed ); 404 virtual float* get_row( int i, float* dst ); 405 406 int sample_count; 407 int var_count; 408 int cache_size; 409 int cache_line_size; 410 const float** samples; 411 const CvSVMParams* params; 412 CvMemStorage* storage; 413 CvSVMKernelRow lru_list; 414 CvSVMKernelRow* rows; 415 416 int alpha_count; 417 418 double* G; 419 double* alpha; 420 421 // -1 - lower bound, 0 - free, 1 - upper bound 422 schar* alpha_status; 423 424 schar* y; 425 double* b; 426 float* buf[2]; 427 double eps; 428 int max_iter; 429 double C[2]; // C[0] == Cn, C[1] == Cp 430 CvSVMKernel* kernel; 431 432 SelectWorkingSet select_working_set_func; 433 CalcRho calc_rho_func; 434 GetRow get_row_func; 435 436 virtual bool select_working_set( int& i, int& j ); 437 virtual bool select_working_set_nu_svm( int& i, int& j ); 438 virtual void calc_rho( double& rho, double& r ); 439 virtual void calc_rho_nu_svm( double& rho, double& r ); 440 441 virtual float* get_row_svc( int i, float* row, float* dst, bool existed ); 442 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed ); 443 virtual float* get_row_svr( int i, float* row, float* dst, bool existed ); 444 }; 445 446 447 struct CvSVMDecisionFunc 448 { 449 double rho; 450 int sv_count; 451 double* alpha; 452 int* sv_index; 453 }; 454 455 456 // SVM model 457 class CvSVM : public CvStatModel 458 { 459 public: 460 // SVM type 461 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 }; 462 463 // SVM kernel type 464 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3, CHI2=4, INTER=5 }; 465 466 // SVM params type 467 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 }; 468 469 CV_WRAP CvSVM(); 470 virtual ~CvSVM(); 471 472 CvSVM( const CvMat* trainData, const CvMat* responses, 473 const CvMat* varIdx=0, const CvMat* sampleIdx=0, 474 CvSVMParams params=CvSVMParams() ); 475 476 virtual bool train( const CvMat* trainData, const CvMat* responses, 477 const CvMat* varIdx=0, const CvMat* sampleIdx=0, 478 CvSVMParams params=CvSVMParams() ); 479 480 virtual bool train_auto( const CvMat* trainData, const CvMat* responses, 481 const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params, 482 int kfold = 10, 483 CvParamGrid Cgrid = get_default_grid(CvSVM::C), 484 CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA), 485 CvParamGrid pGrid = get_default_grid(CvSVM::P), 486 CvParamGrid nuGrid = get_default_grid(CvSVM::NU), 487 CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF), 488 CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE), 489 bool balanced=false ); 490 491 virtual float predict( const CvMat* sample, bool returnDFVal=false ) const; 492 virtual float predict( const CvMat* samples, CV_OUT CvMat* results, bool returnDFVal=false ) const; 493 494 CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses, 495 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 496 CvSVMParams params=CvSVMParams() ); 497 498 CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses, 499 const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(), 500 CvSVMParams params=CvSVMParams() ); 501 502 CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses, 503 const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params, 504 int k_fold = 10, 505 CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C), 506 CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA), 507 CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P), 508 CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU), 509 CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF), 510 CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE), 511 bool balanced=false); 512 CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const; 513 CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const; 514 515 CV_WRAP virtual int get_support_vector_count() const; 516 virtual const float* get_support_vector(int i) const; 517 virtual CvSVMParams get_params() const { return params; } 518 CV_WRAP virtual void clear(); 519 520 virtual const CvSVMDecisionFunc* get_decision_function() const { return decision_func; } 521 522 static CvParamGrid get_default_grid( int param_id ); 523 524 virtual void write( CvFileStorage* storage, const char* name ) const; 525 virtual void read( CvFileStorage* storage, CvFileNode* node ); 526 CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; } 527 528 protected: 529 530 virtual bool set_params( const CvSVMParams& params ); 531 virtual bool train1( int sample_count, int var_count, const float** samples, 532 const void* responses, double Cp, double Cn, 533 CvMemStorage* _storage, double* alpha, double& rho ); 534 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples, 535 const CvMat* responses, CvMemStorage* _storage, double* alpha ); 536 virtual void create_kernel(); 537 virtual void create_solver(); 538 539 virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const; 540 541 virtual void write_params( CvFileStorage* fs ) const; 542 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 543 544 void optimize_linear_svm(); 545 546 CvSVMParams params; 547 CvMat* class_labels; 548 int var_all; 549 float** sv; 550 int sv_total; 551 CvMat* var_idx; 552 CvMat* class_weights; 553 CvSVMDecisionFunc* decision_func; 554 CvMemStorage* storage; 555 556 CvSVMSolver* solver; 557 CvSVMKernel* kernel; 558 559 private: 560 CvSVM(const CvSVM&); 561 CvSVM& operator = (const CvSVM&); 562 }; 563 564 /****************************************************************************************\ 565 * Decision Tree * 566 \****************************************************************************************/\ 567 struct CvPair16u32s 568 { 569 unsigned short* u; 570 int* i; 571 }; 572 573 574 #define CV_DTREE_CAT_DIR(idx,subset) \ 575 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1) 576 577 struct CvDTreeSplit 578 { 579 int var_idx; 580 int condensed_idx; 581 int inversed; 582 float quality; 583 CvDTreeSplit* next; 584 union 585 { 586 int subset[2]; 587 struct 588 { 589 float c; 590 int split_point; 591 } 592 ord; 593 }; 594 }; 595 596 struct CvDTreeNode 597 { 598 int class_idx; 599 int Tn; 600 double value; 601 602 CvDTreeNode* parent; 603 CvDTreeNode* left; 604 CvDTreeNode* right; 605 606 CvDTreeSplit* split; 607 608 int sample_count; 609 int depth; 610 int* num_valid; 611 int offset; 612 int buf_idx; 613 double maxlr; 614 615 // global pruning data 616 int complexity; 617 double alpha; 618 double node_risk, tree_risk, tree_error; 619 620 // cross-validation pruning data 621 int* cv_Tn; 622 double* cv_node_risk; 623 double* cv_node_error; 624 625 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; } 626 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; } 627 }; 628 629 630 struct CvDTreeParams 631 { 632 CV_PROP_RW int max_categories; 633 CV_PROP_RW int max_depth; 634 CV_PROP_RW int min_sample_count; 635 CV_PROP_RW int cv_folds; 636 CV_PROP_RW bool use_surrogates; 637 CV_PROP_RW bool use_1se_rule; 638 CV_PROP_RW bool truncate_pruned_tree; 639 CV_PROP_RW float regression_accuracy; 640 const float* priors; 641 642 CvDTreeParams(); 643 CvDTreeParams( int max_depth, int min_sample_count, 644 float regression_accuracy, bool use_surrogates, 645 int max_categories, int cv_folds, 646 bool use_1se_rule, bool truncate_pruned_tree, 647 const float* priors ); 648 }; 649 650 651 struct CvDTreeTrainData 652 { 653 CvDTreeTrainData(); 654 CvDTreeTrainData( const CvMat* trainData, int tflag, 655 const CvMat* responses, const CvMat* varIdx=0, 656 const CvMat* sampleIdx=0, const CvMat* varType=0, 657 const CvMat* missingDataMask=0, 658 const CvDTreeParams& params=CvDTreeParams(), 659 bool _shared=false, bool _add_labels=false ); 660 virtual ~CvDTreeTrainData(); 661 662 virtual void set_data( const CvMat* trainData, int tflag, 663 const CvMat* responses, const CvMat* varIdx=0, 664 const CvMat* sampleIdx=0, const CvMat* varType=0, 665 const CvMat* missingDataMask=0, 666 const CvDTreeParams& params=CvDTreeParams(), 667 bool _shared=false, bool _add_labels=false, 668 bool _update_data=false ); 669 virtual void do_responses_copy(); 670 671 virtual void get_vectors( const CvMat* _subsample_idx, 672 float* values, uchar* missing, float* responses, bool get_class_idx=false ); 673 674 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 675 676 virtual void write_params( CvFileStorage* fs ) const; 677 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 678 679 // release all the data 680 virtual void clear(); 681 682 int get_num_classes() const; 683 int get_var_type(int vi) const; 684 int get_work_var_count() const {return work_var_count;} 685 686 virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf ); 687 virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf ); 688 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 689 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 690 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 691 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf, 692 const float** ord_values, const int** sorted_indices, int* sample_indices_buf ); 693 virtual int get_child_buf_idx( CvDTreeNode* n ); 694 695 //////////////////////////////////// 696 697 virtual bool set_params( const CvDTreeParams& params ); 698 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count, 699 int storage_idx, int offset ); 700 701 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val, 702 int split_point, int inversed, float quality ); 703 virtual CvDTreeSplit* new_split_cat( int vi, float quality ); 704 virtual void free_node_data( CvDTreeNode* node ); 705 virtual void free_train_data(); 706 virtual void free_node( CvDTreeNode* node ); 707 708 int sample_count, var_all, var_count, max_c_count; 709 int ord_var_count, cat_var_count, work_var_count; 710 bool have_labels, have_priors; 711 bool is_classifier; 712 int tflag; 713 714 const CvMat* train_data; 715 const CvMat* responses; 716 CvMat* responses_copy; // used in Boosting 717 718 int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead 719 bool shared; 720 int is_buf_16u; 721 722 CvMat* cat_count; 723 CvMat* cat_ofs; 724 CvMat* cat_map; 725 726 CvMat* counts; 727 CvMat* buf; 728 inline size_t get_length_subbuf() const 729 { 730 size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count; 731 return res; 732 } 733 734 CvMat* direction; 735 CvMat* split_buf; 736 737 CvMat* var_idx; 738 CvMat* var_type; // i-th element = 739 // k<0 - ordered 740 // k>=0 - categorical, see k-th element of cat_* arrays 741 CvMat* priors; 742 CvMat* priors_mult; 743 744 CvDTreeParams params; 745 746 CvMemStorage* tree_storage; 747 CvMemStorage* temp_storage; 748 749 CvDTreeNode* data_root; 750 751 CvSet* node_heap; 752 CvSet* split_heap; 753 CvSet* cv_heap; 754 CvSet* nv_heap; 755 756 cv::RNG* rng; 757 }; 758 759 class CvDTree; 760 class CvForestTree; 761 762 namespace cv 763 { 764 struct DTreeBestSplitFinder; 765 struct ForestTreeBestSplitFinder; 766 } 767 768 class CvDTree : public CvStatModel 769 { 770 public: 771 CV_WRAP CvDTree(); 772 virtual ~CvDTree(); 773 774 virtual bool train( const CvMat* trainData, int tflag, 775 const CvMat* responses, const CvMat* varIdx=0, 776 const CvMat* sampleIdx=0, const CvMat* varType=0, 777 const CvMat* missingDataMask=0, 778 CvDTreeParams params=CvDTreeParams() ); 779 780 virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() ); 781 782 // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 783 virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 ); 784 785 virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx ); 786 787 virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0, 788 bool preprocessedInput=false ) const; 789 790 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 791 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 792 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 793 const cv::Mat& missingDataMask=cv::Mat(), 794 CvDTreeParams params=CvDTreeParams() ); 795 796 CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(), 797 bool preprocessedInput=false ) const; 798 CV_WRAP virtual cv::Mat getVarImportance(); 799 800 virtual const CvMat* get_var_importance(); 801 CV_WRAP virtual void clear(); 802 803 virtual void read( CvFileStorage* fs, CvFileNode* node ); 804 virtual void write( CvFileStorage* fs, const char* name ) const; 805 806 // special read & write methods for trees in the tree ensembles 807 virtual void read( CvFileStorage* fs, CvFileNode* node, 808 CvDTreeTrainData* data ); 809 virtual void write( CvFileStorage* fs ) const; 810 811 const CvDTreeNode* get_root() const; 812 int get_pruned_tree_idx() const; 813 CvDTreeTrainData* get_data(); 814 815 protected: 816 friend struct cv::DTreeBestSplitFinder; 817 818 virtual bool do_train( const CvMat* _subsample_idx ); 819 820 virtual void try_split_node( CvDTreeNode* n ); 821 virtual void split_node_data( CvDTreeNode* n ); 822 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 823 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 824 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 825 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 826 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 827 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 828 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 829 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 830 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 831 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 832 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 833 virtual double calc_node_dir( CvDTreeNode* node ); 834 virtual void complete_node_dir( CvDTreeNode* node ); 835 virtual void cluster_categories( const int* vectors, int vector_count, 836 int var_count, int* sums, int k, int* cluster_labels ); 837 838 virtual void calc_node_value( CvDTreeNode* node ); 839 840 virtual void prune_cv(); 841 virtual double update_tree_rnc( int T, int fold ); 842 virtual int cut_tree( int T, int fold, double min_alpha ); 843 virtual void free_prune_data(bool cut_tree); 844 virtual void free_tree(); 845 846 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ) const; 847 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ) const; 848 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent ); 849 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node ); 850 virtual void write_tree_nodes( CvFileStorage* fs ) const; 851 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node ); 852 853 CvDTreeNode* root; 854 CvMat* var_importance; 855 CvDTreeTrainData* data; 856 CvMat train_data_hdr, responses_hdr; 857 cv::Mat train_data_mat, responses_mat; 858 859 public: 860 int pruned_tree_idx; 861 }; 862 863 864 /****************************************************************************************\ 865 * Random Trees Classifier * 866 \****************************************************************************************/ 867 868 class CvRTrees; 869 870 class CvForestTree: public CvDTree 871 { 872 public: 873 CvForestTree(); 874 virtual ~CvForestTree(); 875 876 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest ); 877 878 virtual int get_var_count() const {return data ? data->var_count : 0;} 879 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data ); 880 881 /* dummy methods to avoid warnings: BEGIN */ 882 virtual bool train( const CvMat* trainData, int tflag, 883 const CvMat* responses, const CvMat* varIdx=0, 884 const CvMat* sampleIdx=0, const CvMat* varType=0, 885 const CvMat* missingDataMask=0, 886 CvDTreeParams params=CvDTreeParams() ); 887 888 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx ); 889 virtual void read( CvFileStorage* fs, CvFileNode* node ); 890 virtual void read( CvFileStorage* fs, CvFileNode* node, 891 CvDTreeTrainData* data ); 892 /* dummy methods to avoid warnings: END */ 893 894 protected: 895 friend struct cv::ForestTreeBestSplitFinder; 896 897 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 898 CvRTrees* forest; 899 }; 900 901 902 struct CvRTParams : public CvDTreeParams 903 { 904 //Parameters for the forest 905 CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance 906 CV_PROP_RW int nactive_vars; 907 CV_PROP_RW CvTermCriteria term_crit; 908 909 CvRTParams(); 910 CvRTParams( int max_depth, int min_sample_count, 911 float regression_accuracy, bool use_surrogates, 912 int max_categories, const float* priors, bool calc_var_importance, 913 int nactive_vars, int max_num_of_trees_in_the_forest, 914 float forest_accuracy, int termcrit_type ); 915 }; 916 917 918 class CvRTrees : public CvStatModel 919 { 920 public: 921 CV_WRAP CvRTrees(); 922 virtual ~CvRTrees(); 923 virtual bool train( const CvMat* trainData, int tflag, 924 const CvMat* responses, const CvMat* varIdx=0, 925 const CvMat* sampleIdx=0, const CvMat* varType=0, 926 const CvMat* missingDataMask=0, 927 CvRTParams params=CvRTParams() ); 928 929 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 930 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const; 931 virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const; 932 933 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 934 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 935 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 936 const cv::Mat& missingDataMask=cv::Mat(), 937 CvRTParams params=CvRTParams() ); 938 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 939 CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const; 940 CV_WRAP virtual cv::Mat getVarImportance(); 941 942 CV_WRAP virtual void clear(); 943 944 virtual const CvMat* get_var_importance(); 945 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2, 946 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const; 947 948 virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 949 950 virtual float get_train_error(); 951 952 virtual void read( CvFileStorage* fs, CvFileNode* node ); 953 virtual void write( CvFileStorage* fs, const char* name ) const; 954 955 CvMat* get_active_var_mask(); 956 CvRNG* get_rng(); 957 958 int get_tree_count() const; 959 CvForestTree* get_tree(int i) const; 960 961 protected: 962 virtual cv::String getName() const; 963 964 virtual bool grow_forest( const CvTermCriteria term_crit ); 965 966 // array of the trees of the forest 967 CvForestTree** trees; 968 CvDTreeTrainData* data; 969 CvMat train_data_hdr, responses_hdr; 970 cv::Mat train_data_mat, responses_mat; 971 int ntrees; 972 int nclasses; 973 double oob_error; 974 CvMat* var_importance; 975 int nsamples; 976 977 cv::RNG* rng; 978 CvMat* active_var_mask; 979 }; 980 981 /****************************************************************************************\ 982 * Extremely randomized trees Classifier * 983 \****************************************************************************************/ 984 struct CvERTreeTrainData : public CvDTreeTrainData 985 { 986 virtual void set_data( const CvMat* trainData, int tflag, 987 const CvMat* responses, const CvMat* varIdx=0, 988 const CvMat* sampleIdx=0, const CvMat* varType=0, 989 const CvMat* missingDataMask=0, 990 const CvDTreeParams& params=CvDTreeParams(), 991 bool _shared=false, bool _add_labels=false, 992 bool _update_data=false ); 993 virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf, 994 const float** ord_values, const int** missing, int* sample_buf = 0 ); 995 virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf ); 996 virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf ); 997 virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf ); 998 virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing, 999 float* responses, bool get_class_idx=false ); 1000 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 1001 const CvMat* missing_mask; 1002 }; 1003 1004 class CvForestERTree : public CvForestTree 1005 { 1006 protected: 1007 virtual double calc_node_dir( CvDTreeNode* node ); 1008 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 1009 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1010 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 1011 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1012 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 1013 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1014 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 1015 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1016 virtual void split_node_data( CvDTreeNode* n ); 1017 }; 1018 1019 class CvERTrees : public CvRTrees 1020 { 1021 public: 1022 CV_WRAP CvERTrees(); 1023 virtual ~CvERTrees(); 1024 virtual bool train( const CvMat* trainData, int tflag, 1025 const CvMat* responses, const CvMat* varIdx=0, 1026 const CvMat* sampleIdx=0, const CvMat* varType=0, 1027 const CvMat* missingDataMask=0, 1028 CvRTParams params=CvRTParams()); 1029 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 1030 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 1031 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 1032 const cv::Mat& missingDataMask=cv::Mat(), 1033 CvRTParams params=CvRTParams()); 1034 virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() ); 1035 protected: 1036 virtual cv::String getName() const; 1037 virtual bool grow_forest( const CvTermCriteria term_crit ); 1038 }; 1039 1040 1041 /****************************************************************************************\ 1042 * Boosted tree classifier * 1043 \****************************************************************************************/ 1044 1045 struct CvBoostParams : public CvDTreeParams 1046 { 1047 CV_PROP_RW int boost_type; 1048 CV_PROP_RW int weak_count; 1049 CV_PROP_RW int split_criteria; 1050 CV_PROP_RW double weight_trim_rate; 1051 1052 CvBoostParams(); 1053 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate, 1054 int max_depth, bool use_surrogates, const float* priors ); 1055 }; 1056 1057 1058 class CvBoost; 1059 1060 class CvBoostTree: public CvDTree 1061 { 1062 public: 1063 CvBoostTree(); 1064 virtual ~CvBoostTree(); 1065 1066 virtual bool train( CvDTreeTrainData* trainData, 1067 const CvMat* subsample_idx, CvBoost* ensemble ); 1068 1069 virtual void scale( double s ); 1070 virtual void read( CvFileStorage* fs, CvFileNode* node, 1071 CvBoost* ensemble, CvDTreeTrainData* _data ); 1072 virtual void clear(); 1073 1074 /* dummy methods to avoid warnings: BEGIN */ 1075 virtual bool train( const CvMat* trainData, int tflag, 1076 const CvMat* responses, const CvMat* varIdx=0, 1077 const CvMat* sampleIdx=0, const CvMat* varType=0, 1078 const CvMat* missingDataMask=0, 1079 CvDTreeParams params=CvDTreeParams() ); 1080 virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx ); 1081 1082 virtual void read( CvFileStorage* fs, CvFileNode* node ); 1083 virtual void read( CvFileStorage* fs, CvFileNode* node, 1084 CvDTreeTrainData* data ); 1085 /* dummy methods to avoid warnings: END */ 1086 1087 protected: 1088 1089 virtual void try_split_node( CvDTreeNode* n ); 1090 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 1091 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 ); 1092 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi, 1093 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1094 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi, 1095 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1096 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi, 1097 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1098 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi, 1099 float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 ); 1100 virtual void calc_node_value( CvDTreeNode* n ); 1101 virtual double calc_node_dir( CvDTreeNode* n ); 1102 1103 CvBoost* ensemble; 1104 }; 1105 1106 1107 class CvBoost : public CvStatModel 1108 { 1109 public: 1110 // Boosting type 1111 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 }; 1112 1113 // Splitting criteria 1114 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 }; 1115 1116 CV_WRAP CvBoost(); 1117 virtual ~CvBoost(); 1118 1119 CvBoost( const CvMat* trainData, int tflag, 1120 const CvMat* responses, const CvMat* varIdx=0, 1121 const CvMat* sampleIdx=0, const CvMat* varType=0, 1122 const CvMat* missingDataMask=0, 1123 CvBoostParams params=CvBoostParams() ); 1124 1125 virtual bool train( const CvMat* trainData, int tflag, 1126 const CvMat* responses, const CvMat* varIdx=0, 1127 const CvMat* sampleIdx=0, const CvMat* varType=0, 1128 const CvMat* missingDataMask=0, 1129 CvBoostParams params=CvBoostParams(), 1130 bool update=false ); 1131 1132 virtual bool train( CvMLData* data, 1133 CvBoostParams params=CvBoostParams(), 1134 bool update=false ); 1135 1136 virtual float predict( const CvMat* sample, const CvMat* missing=0, 1137 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ, 1138 bool raw_mode=false, bool return_sum=false ) const; 1139 1140 CV_WRAP CvBoost( const cv::Mat& trainData, int tflag, 1141 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 1142 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 1143 const cv::Mat& missingDataMask=cv::Mat(), 1144 CvBoostParams params=CvBoostParams() ); 1145 1146 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 1147 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 1148 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 1149 const cv::Mat& missingDataMask=cv::Mat(), 1150 CvBoostParams params=CvBoostParams(), 1151 bool update=false ); 1152 1153 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(), 1154 const cv::Range& slice=cv::Range::all(), bool rawMode=false, 1155 bool returnSum=false ) const; 1156 1157 virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR} 1158 1159 CV_WRAP virtual void prune( CvSlice slice ); 1160 1161 CV_WRAP virtual void clear(); 1162 1163 virtual void write( CvFileStorage* storage, const char* name ) const; 1164 virtual void read( CvFileStorage* storage, CvFileNode* node ); 1165 virtual const CvMat* get_active_vars(bool absolute_idx=true); 1166 1167 CvSeq* get_weak_predictors(); 1168 1169 CvMat* get_weights(); 1170 CvMat* get_subtree_weights(); 1171 CvMat* get_weak_response(); 1172 const CvBoostParams& get_params() const; 1173 const CvDTreeTrainData* get_data() const; 1174 1175 protected: 1176 1177 virtual bool set_params( const CvBoostParams& params ); 1178 virtual void update_weights( CvBoostTree* tree ); 1179 virtual void trim_weights(); 1180 virtual void write_params( CvFileStorage* fs ) const; 1181 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 1182 1183 virtual void initialize_weights(double (&p)[2]); 1184 1185 CvDTreeTrainData* data; 1186 CvMat train_data_hdr, responses_hdr; 1187 cv::Mat train_data_mat, responses_mat; 1188 CvBoostParams params; 1189 CvSeq* weak; 1190 1191 CvMat* active_vars; 1192 CvMat* active_vars_abs; 1193 bool have_active_cat_vars; 1194 1195 CvMat* orig_response; 1196 CvMat* sum_response; 1197 CvMat* weak_eval; 1198 CvMat* subsample_mask; 1199 CvMat* weights; 1200 CvMat* subtree_weights; 1201 bool have_subsample; 1202 }; 1203 1204 1205 /****************************************************************************************\ 1206 * Gradient Boosted Trees * 1207 \****************************************************************************************/ 1208 1209 // DataType: STRUCT CvGBTreesParams 1210 // Parameters of GBT (Gradient Boosted trees model), including single 1211 // tree settings and ensemble parameters. 1212 // 1213 // weak_count - count of trees in the ensemble 1214 // loss_function_type - loss function used for ensemble training 1215 // subsample_portion - portion of whole training set used for 1216 // every single tree training. 1217 // subsample_portion value is in (0.0, 1.0]. 1218 // subsample_portion == 1.0 when whole dataset is 1219 // used on each step. Count of sample used on each 1220 // step is computed as 1221 // int(total_samples_count * subsample_portion). 1222 // shrinkage - regularization parameter. 1223 // Each tree prediction is multiplied on shrinkage value. 1224 1225 1226 struct CvGBTreesParams : public CvDTreeParams 1227 { 1228 CV_PROP_RW int weak_count; 1229 CV_PROP_RW int loss_function_type; 1230 CV_PROP_RW float subsample_portion; 1231 CV_PROP_RW float shrinkage; 1232 1233 CvGBTreesParams(); 1234 CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage, 1235 float subsample_portion, int max_depth, bool use_surrogates ); 1236 }; 1237 1238 // DataType: CLASS CvGBTrees 1239 // Gradient Boosting Trees (GBT) algorithm implementation. 1240 // 1241 // data - training dataset 1242 // params - parameters of the CvGBTrees 1243 // weak - array[0..(class_count-1)] of CvSeq 1244 // for storing tree ensembles 1245 // orig_response - original responses of the training set samples 1246 // sum_response - predicitons of the current model on the training dataset. 1247 // this matrix is updated on every iteration. 1248 // sum_response_tmp - predicitons of the model on the training set on the next 1249 // step. On every iteration values of sum_responses_tmp are 1250 // computed via sum_responses values. When the current 1251 // step is complete sum_response values become equal to 1252 // sum_responses_tmp. 1253 // sampleIdx - indices of samples used for training the ensemble. 1254 // CvGBTrees training procedure takes a set of samples 1255 // (train_data) and a set of responses (responses). 1256 // Only pairs (train_data[i], responses[i]), where i is 1257 // in sample_idx are used for training the ensemble. 1258 // subsample_train - indices of samples used for training a single decision 1259 // tree on the current step. This indices are countered 1260 // relatively to the sample_idx, so that pairs 1261 // (train_data[sample_idx[i]], responses[sample_idx[i]]) 1262 // are used for training a decision tree. 1263 // Training set is randomly splited 1264 // in two parts (subsample_train and subsample_test) 1265 // on every iteration accordingly to the portion parameter. 1266 // subsample_test - relative indices of samples from the training set, 1267 // which are not used for training a tree on the current 1268 // step. 1269 // missing - mask of the missing values in the training set. This 1270 // matrix has the same size as train_data. 1 - missing 1271 // value, 0 - not a missing value. 1272 // class_labels - output class labels map. 1273 // rng - random number generator. Used for spliting the 1274 // training set. 1275 // class_count - count of output classes. 1276 // class_count == 1 in the case of regression, 1277 // and > 1 in the case of classification. 1278 // delta - Huber loss function parameter. 1279 // base_value - start point of the gradient descent procedure. 1280 // model prediction is 1281 // f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where 1282 // f_0 is the base value. 1283 1284 1285 1286 class CvGBTrees : public CvStatModel 1287 { 1288 public: 1289 1290 /* 1291 // DataType: ENUM 1292 // Loss functions implemented in CvGBTrees. 1293 // 1294 // SQUARED_LOSS 1295 // problem: regression 1296 // loss = (x - x')^2 1297 // 1298 // ABSOLUTE_LOSS 1299 // problem: regression 1300 // loss = abs(x - x') 1301 // 1302 // HUBER_LOSS 1303 // problem: regression 1304 // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta 1305 // 1/2*(x - x')^2, if abs(x - x') <= delta, 1306 // where delta is the alpha-quantile of pseudo responses from 1307 // the training set. 1308 // 1309 // DEVIANCE_LOSS 1310 // problem: classification 1311 // 1312 */ 1313 enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS}; 1314 1315 1316 /* 1317 // Default constructor. Creates a model only (without training). 1318 // Should be followed by one form of the train(...) function. 1319 // 1320 // API 1321 // CvGBTrees(); 1322 1323 // INPUT 1324 // OUTPUT 1325 // RESULT 1326 */ 1327 CV_WRAP CvGBTrees(); 1328 1329 1330 /* 1331 // Full form constructor. Creates a gradient boosting model and does the 1332 // train. 1333 // 1334 // API 1335 // CvGBTrees( const CvMat* trainData, int tflag, 1336 const CvMat* responses, const CvMat* varIdx=0, 1337 const CvMat* sampleIdx=0, const CvMat* varType=0, 1338 const CvMat* missingDataMask=0, 1339 CvGBTreesParams params=CvGBTreesParams() ); 1340 1341 // INPUT 1342 // trainData - a set of input feature vectors. 1343 // size of matrix is 1344 // <count of samples> x <variables count> 1345 // or <variables count> x <count of samples> 1346 // depending on the tflag parameter. 1347 // matrix values are float. 1348 // tflag - a flag showing how do samples stored in the 1349 // trainData matrix row by row (tflag=CV_ROW_SAMPLE) 1350 // or column by column (tflag=CV_COL_SAMPLE). 1351 // responses - a vector of responses corresponding to the samples 1352 // in trainData. 1353 // varIdx - indices of used variables. zero value means that all 1354 // variables are active. 1355 // sampleIdx - indices of used samples. zero value means that all 1356 // samples from trainData are in the training set. 1357 // varType - vector of <variables count> length. gives every 1358 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED. 1359 // varType = 0 means all variables are numerical. 1360 // missingDataMask - a mask of misiing values in trainData. 1361 // missingDataMask = 0 means that there are no missing 1362 // values. 1363 // params - parameters of GTB algorithm. 1364 // OUTPUT 1365 // RESULT 1366 */ 1367 CvGBTrees( const CvMat* trainData, int tflag, 1368 const CvMat* responses, const CvMat* varIdx=0, 1369 const CvMat* sampleIdx=0, const CvMat* varType=0, 1370 const CvMat* missingDataMask=0, 1371 CvGBTreesParams params=CvGBTreesParams() ); 1372 1373 1374 /* 1375 // Destructor. 1376 */ 1377 virtual ~CvGBTrees(); 1378 1379 1380 /* 1381 // Gradient tree boosting model training 1382 // 1383 // API 1384 // virtual bool train( const CvMat* trainData, int tflag, 1385 const CvMat* responses, const CvMat* varIdx=0, 1386 const CvMat* sampleIdx=0, const CvMat* varType=0, 1387 const CvMat* missingDataMask=0, 1388 CvGBTreesParams params=CvGBTreesParams(), 1389 bool update=false ); 1390 1391 // INPUT 1392 // trainData - a set of input feature vectors. 1393 // size of matrix is 1394 // <count of samples> x <variables count> 1395 // or <variables count> x <count of samples> 1396 // depending on the tflag parameter. 1397 // matrix values are float. 1398 // tflag - a flag showing how do samples stored in the 1399 // trainData matrix row by row (tflag=CV_ROW_SAMPLE) 1400 // or column by column (tflag=CV_COL_SAMPLE). 1401 // responses - a vector of responses corresponding to the samples 1402 // in trainData. 1403 // varIdx - indices of used variables. zero value means that all 1404 // variables are active. 1405 // sampleIdx - indices of used samples. zero value means that all 1406 // samples from trainData are in the training set. 1407 // varType - vector of <variables count> length. gives every 1408 // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED. 1409 // varType = 0 means all variables are numerical. 1410 // missingDataMask - a mask of misiing values in trainData. 1411 // missingDataMask = 0 means that there are no missing 1412 // values. 1413 // params - parameters of GTB algorithm. 1414 // update - is not supported now. (!) 1415 // OUTPUT 1416 // RESULT 1417 // Error state. 1418 */ 1419 virtual bool train( const CvMat* trainData, int tflag, 1420 const CvMat* responses, const CvMat* varIdx=0, 1421 const CvMat* sampleIdx=0, const CvMat* varType=0, 1422 const CvMat* missingDataMask=0, 1423 CvGBTreesParams params=CvGBTreesParams(), 1424 bool update=false ); 1425 1426 1427 /* 1428 // Gradient tree boosting model training 1429 // 1430 // API 1431 // virtual bool train( CvMLData* data, 1432 CvGBTreesParams params=CvGBTreesParams(), 1433 bool update=false ) {return false;} 1434 1435 // INPUT 1436 // data - training set. 1437 // params - parameters of GTB algorithm. 1438 // update - is not supported now. (!) 1439 // OUTPUT 1440 // RESULT 1441 // Error state. 1442 */ 1443 virtual bool train( CvMLData* data, 1444 CvGBTreesParams params=CvGBTreesParams(), 1445 bool update=false ); 1446 1447 1448 /* 1449 // Response value prediction 1450 // 1451 // API 1452 // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0, 1453 CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ, 1454 int k=-1 ) const; 1455 1456 // INPUT 1457 // sample - input sample of the same type as in the training set. 1458 // missing - missing values mask. missing=0 if there are no 1459 // missing values in sample vector. 1460 // weak_responses - predictions of all of the trees. 1461 // not implemented (!) 1462 // slice - part of the ensemble used for prediction. 1463 // slice = CV_WHOLE_SEQ when all trees are used. 1464 // k - number of ensemble used. 1465 // k is in {-1,0,1,..,<count of output classes-1>}. 1466 // in the case of classification problem 1467 // <count of output classes-1> ensembles are built. 1468 // If k = -1 ordinary prediction is the result, 1469 // otherwise function gives the prediction of the 1470 // k-th ensemble only. 1471 // OUTPUT 1472 // RESULT 1473 // Predicted value. 1474 */ 1475 virtual float predict_serial( const CvMat* sample, const CvMat* missing=0, 1476 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ, 1477 int k=-1 ) const; 1478 1479 /* 1480 // Response value prediction. 1481 // Parallel version (in the case of TBB existence) 1482 // 1483 // API 1484 // virtual float predict( const CvMat* sample, const CvMat* missing=0, 1485 CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ, 1486 int k=-1 ) const; 1487 1488 // INPUT 1489 // sample - input sample of the same type as in the training set. 1490 // missing - missing values mask. missing=0 if there are no 1491 // missing values in sample vector. 1492 // weak_responses - predictions of all of the trees. 1493 // not implemented (!) 1494 // slice - part of the ensemble used for prediction. 1495 // slice = CV_WHOLE_SEQ when all trees are used. 1496 // k - number of ensemble used. 1497 // k is in {-1,0,1,..,<count of output classes-1>}. 1498 // in the case of classification problem 1499 // <count of output classes-1> ensembles are built. 1500 // If k = -1 ordinary prediction is the result, 1501 // otherwise function gives the prediction of the 1502 // k-th ensemble only. 1503 // OUTPUT 1504 // RESULT 1505 // Predicted value. 1506 */ 1507 virtual float predict( const CvMat* sample, const CvMat* missing=0, 1508 CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ, 1509 int k=-1 ) const; 1510 1511 /* 1512 // Deletes all the data. 1513 // 1514 // API 1515 // virtual void clear(); 1516 1517 // INPUT 1518 // OUTPUT 1519 // delete data, weak, orig_response, sum_response, 1520 // weak_eval, subsample_train, subsample_test, 1521 // sample_idx, missing, lass_labels 1522 // delta = 0.0 1523 // RESULT 1524 */ 1525 CV_WRAP virtual void clear(); 1526 1527 /* 1528 // Compute error on the train/test set. 1529 // 1530 // API 1531 // virtual float calc_error( CvMLData* _data, int type, 1532 // std::vector<float> *resp = 0 ); 1533 // 1534 // INPUT 1535 // data - dataset 1536 // type - defines which error is to compute: train (CV_TRAIN_ERROR) or 1537 // test (CV_TEST_ERROR). 1538 // OUTPUT 1539 // resp - vector of predicitons 1540 // RESULT 1541 // Error value. 1542 */ 1543 virtual float calc_error( CvMLData* _data, int type, 1544 std::vector<float> *resp = 0 ); 1545 1546 /* 1547 // 1548 // Write parameters of the gtb model and data. Write learned model. 1549 // 1550 // API 1551 // virtual void write( CvFileStorage* fs, const char* name ) const; 1552 // 1553 // INPUT 1554 // fs - file storage to read parameters from. 1555 // name - model name. 1556 // OUTPUT 1557 // RESULT 1558 */ 1559 virtual void write( CvFileStorage* fs, const char* name ) const; 1560 1561 1562 /* 1563 // 1564 // Read parameters of the gtb model and data. Read learned model. 1565 // 1566 // API 1567 // virtual void read( CvFileStorage* fs, CvFileNode* node ); 1568 // 1569 // INPUT 1570 // fs - file storage to read parameters from. 1571 // node - file node. 1572 // OUTPUT 1573 // RESULT 1574 */ 1575 virtual void read( CvFileStorage* fs, CvFileNode* node ); 1576 1577 1578 // new-style C++ interface 1579 CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag, 1580 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 1581 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 1582 const cv::Mat& missingDataMask=cv::Mat(), 1583 CvGBTreesParams params=CvGBTreesParams() ); 1584 1585 CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag, 1586 const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(), 1587 const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(), 1588 const cv::Mat& missingDataMask=cv::Mat(), 1589 CvGBTreesParams params=CvGBTreesParams(), 1590 bool update=false ); 1591 1592 CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(), 1593 const cv::Range& slice = cv::Range::all(), 1594 int k=-1 ) const; 1595 1596 protected: 1597 1598 /* 1599 // Compute the gradient vector components. 1600 // 1601 // API 1602 // virtual void find_gradient( const int k = 0); 1603 1604 // INPUT 1605 // k - used for classification problem, determining current 1606 // tree ensemble. 1607 // OUTPUT 1608 // changes components of data->responses 1609 // which correspond to samples used for training 1610 // on the current step. 1611 // RESULT 1612 */ 1613 virtual void find_gradient( const int k = 0); 1614 1615 1616 /* 1617 // 1618 // Change values in tree leaves according to the used loss function. 1619 // 1620 // API 1621 // virtual void change_values(CvDTree* tree, const int k = 0); 1622 // 1623 // INPUT 1624 // tree - decision tree to change. 1625 // k - used for classification problem, determining current 1626 // tree ensemble. 1627 // OUTPUT 1628 // changes 'value' fields of the trees' leaves. 1629 // changes sum_response_tmp. 1630 // RESULT 1631 */ 1632 virtual void change_values(CvDTree* tree, const int k = 0); 1633 1634 1635 /* 1636 // 1637 // Find optimal constant prediction value according to the used loss 1638 // function. 1639 // The goal is to find a constant which gives the minimal summary loss 1640 // on the _Idx samples. 1641 // 1642 // API 1643 // virtual float find_optimal_value( const CvMat* _Idx ); 1644 // 1645 // INPUT 1646 // _Idx - indices of the samples from the training set. 1647 // OUTPUT 1648 // RESULT 1649 // optimal constant value. 1650 */ 1651 virtual float find_optimal_value( const CvMat* _Idx ); 1652 1653 1654 /* 1655 // 1656 // Randomly split the whole training set in two parts according 1657 // to params.portion. 1658 // 1659 // API 1660 // virtual void do_subsample(); 1661 // 1662 // INPUT 1663 // OUTPUT 1664 // subsample_train - indices of samples used for training 1665 // subsample_test - indices of samples used for test 1666 // RESULT 1667 */ 1668 virtual void do_subsample(); 1669 1670 1671 /* 1672 // 1673 // Internal recursive function giving an array of subtree tree leaves. 1674 // 1675 // API 1676 // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node ); 1677 // 1678 // INPUT 1679 // node - current leaf. 1680 // OUTPUT 1681 // count - count of leaves in the subtree. 1682 // leaves - array of pointers to leaves. 1683 // RESULT 1684 */ 1685 void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node ); 1686 1687 1688 /* 1689 // 1690 // Get leaves of the tree. 1691 // 1692 // API 1693 // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len ); 1694 // 1695 // INPUT 1696 // dtree - decision tree. 1697 // OUTPUT 1698 // len - count of the leaves. 1699 // RESULT 1700 // CvDTreeNode** - array of pointers to leaves. 1701 */ 1702 CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len ); 1703 1704 1705 /* 1706 // 1707 // Is it a regression or a classification. 1708 // 1709 // API 1710 // bool problem_type(); 1711 // 1712 // INPUT 1713 // OUTPUT 1714 // RESULT 1715 // false if it is a classification problem, 1716 // true - if regression. 1717 */ 1718 virtual bool problem_type() const; 1719 1720 1721 /* 1722 // 1723 // Write parameters of the gtb model. 1724 // 1725 // API 1726 // virtual void write_params( CvFileStorage* fs ) const; 1727 // 1728 // INPUT 1729 // fs - file storage to write parameters to. 1730 // OUTPUT 1731 // RESULT 1732 */ 1733 virtual void write_params( CvFileStorage* fs ) const; 1734 1735 1736 /* 1737 // 1738 // Read parameters of the gtb model and data. 1739 // 1740 // API 1741 // virtual void read_params( CvFileStorage* fs ); 1742 // 1743 // INPUT 1744 // fs - file storage to read parameters from. 1745 // OUTPUT 1746 // params - parameters of the gtb model. 1747 // data - contains information about the structure 1748 // of the data set (count of variables, 1749 // their types, etc.). 1750 // class_labels - output class labels map. 1751 // RESULT 1752 */ 1753 virtual void read_params( CvFileStorage* fs, CvFileNode* fnode ); 1754 int get_len(const CvMat* mat) const; 1755 1756 1757 CvDTreeTrainData* data; 1758 CvGBTreesParams params; 1759 1760 CvSeq** weak; 1761 CvMat* orig_response; 1762 CvMat* sum_response; 1763 CvMat* sum_response_tmp; 1764 CvMat* sample_idx; 1765 CvMat* subsample_train; 1766 CvMat* subsample_test; 1767 CvMat* missing; 1768 CvMat* class_labels; 1769 1770 cv::RNG* rng; 1771 1772 int class_count; 1773 float delta; 1774 float base_value; 1775 1776 }; 1777 1778 1779 1780 /****************************************************************************************\ 1781 * Artificial Neural Networks (ANN) * 1782 \****************************************************************************************/ 1783 1784 /////////////////////////////////// Multi-Layer Perceptrons ////////////////////////////// 1785 1786 struct CvANN_MLP_TrainParams 1787 { 1788 CvANN_MLP_TrainParams(); 1789 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method, 1790 double param1, double param2=0 ); 1791 ~CvANN_MLP_TrainParams(); 1792 1793 enum { BACKPROP=0, RPROP=1 }; 1794 1795 CV_PROP_RW CvTermCriteria term_crit; 1796 CV_PROP_RW int train_method; 1797 1798 // backpropagation parameters 1799 CV_PROP_RW double bp_dw_scale, bp_moment_scale; 1800 1801 // rprop parameters 1802 CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max; 1803 }; 1804 1805 1806 class CvANN_MLP : public CvStatModel 1807 { 1808 public: 1809 CV_WRAP CvANN_MLP(); 1810 CvANN_MLP( const CvMat* layerSizes, 1811 int activateFunc=CvANN_MLP::SIGMOID_SYM, 1812 double fparam1=0, double fparam2=0 ); 1813 1814 virtual ~CvANN_MLP(); 1815 1816 virtual void create( const CvMat* layerSizes, 1817 int activateFunc=CvANN_MLP::SIGMOID_SYM, 1818 double fparam1=0, double fparam2=0 ); 1819 1820 virtual int train( const CvMat* inputs, const CvMat* outputs, 1821 const CvMat* sampleWeights, const CvMat* sampleIdx=0, 1822 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(), 1823 int flags=0 ); 1824 virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const; 1825 1826 CV_WRAP CvANN_MLP( const cv::Mat& layerSizes, 1827 int activateFunc=CvANN_MLP::SIGMOID_SYM, 1828 double fparam1=0, double fparam2=0 ); 1829 1830 CV_WRAP virtual void create( const cv::Mat& layerSizes, 1831 int activateFunc=CvANN_MLP::SIGMOID_SYM, 1832 double fparam1=0, double fparam2=0 ); 1833 1834 CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs, 1835 const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(), 1836 CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(), 1837 int flags=0 ); 1838 1839 CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const; 1840 1841 CV_WRAP virtual void clear(); 1842 1843 // possible activation functions 1844 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 }; 1845 1846 // available training flags 1847 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 }; 1848 1849 virtual void read( CvFileStorage* fs, CvFileNode* node ); 1850 virtual void write( CvFileStorage* storage, const char* name ) const; 1851 1852 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; } 1853 const CvMat* get_layer_sizes() { return layer_sizes; } 1854 double* get_weights(int layer) 1855 { 1856 return layer_sizes && weights && 1857 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0; 1858 } 1859 1860 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const; 1861 1862 protected: 1863 1864 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs, 1865 const CvMat* _sample_weights, const CvMat* sampleIdx, 1866 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags ); 1867 1868 // sequential random backpropagation 1869 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 1870 1871 // RPROP algorithm 1872 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 1873 1874 virtual void calc_activ_func( CvMat* xf, const double* bias ) const; 1875 virtual void set_activ_func( int _activ_func=SIGMOID_SYM, 1876 double _f_param1=0, double _f_param2=0 ); 1877 virtual void init_weights(); 1878 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const; 1879 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const; 1880 virtual void calc_input_scale( const CvVectors* vecs, int flags ); 1881 virtual void calc_output_scale( const CvVectors* vecs, int flags ); 1882 1883 virtual void write_params( CvFileStorage* fs ) const; 1884 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 1885 1886 CvMat* layer_sizes; 1887 CvMat* wbuf; 1888 CvMat* sample_weights; 1889 double** weights; 1890 double f_param1, f_param2; 1891 double min_val, max_val, min_val1, max_val1; 1892 int activ_func; 1893 int max_count, max_buf_sz; 1894 CvANN_MLP_TrainParams params; 1895 cv::RNG* rng; 1896 }; 1897 1898 /****************************************************************************************\ 1899 * Auxilary functions declarations * 1900 \****************************************************************************************/ 1901 1902 /* Generates <sample> from multivariate normal distribution, where <mean> - is an 1903 average row vector, <cov> - symmetric covariation matrix */ 1904 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample, 1905 CvRNG* rng CV_DEFAULT(0) ); 1906 1907 /* Generates sample from gaussian mixture distribution */ 1908 CVAPI(void) cvRandGaussMixture( CvMat* means[], 1909 CvMat* covs[], 1910 float weights[], 1911 int clsnum, 1912 CvMat* sample, 1913 CvMat* sampClasses CV_DEFAULT(0) ); 1914 1915 #define CV_TS_CONCENTRIC_SPHERES 0 1916 1917 /* creates test set */ 1918 CVAPI(void) cvCreateTestSet( int type, CvMat** samples, 1919 int num_samples, 1920 int num_features, 1921 CvMat** responses, 1922 int num_classes, ... ); 1923 1924 /****************************************************************************************\ 1925 * Data * 1926 \****************************************************************************************/ 1927 1928 #define CV_COUNT 0 1929 #define CV_PORTION 1 1930 1931 struct CvTrainTestSplit 1932 { 1933 CvTrainTestSplit(); 1934 CvTrainTestSplit( int train_sample_count, bool mix = true); 1935 CvTrainTestSplit( float train_sample_portion, bool mix = true); 1936 1937 union 1938 { 1939 int count; 1940 float portion; 1941 } train_sample_part; 1942 int train_sample_part_mode; 1943 1944 bool mix; 1945 }; 1946 1947 class CvMLData 1948 { 1949 public: 1950 CvMLData(); 1951 virtual ~CvMLData(); 1952 1953 // returns: 1954 // 0 - OK 1955 // -1 - file can not be opened or is not correct 1956 int read_csv( const char* filename ); 1957 1958 const CvMat* get_values() const; 1959 const CvMat* get_responses(); 1960 const CvMat* get_missing() const; 1961 1962 void set_header_lines_number( int n ); 1963 int get_header_lines_number() const; 1964 1965 void set_response_idx( int idx ); // old response become predictors, new response_idx = idx 1966 // if idx < 0 there will be no response 1967 int get_response_idx() const; 1968 1969 void set_train_test_split( const CvTrainTestSplit * spl ); 1970 const CvMat* get_train_sample_idx() const; 1971 const CvMat* get_test_sample_idx() const; 1972 void mix_train_and_test_idx(); 1973 1974 const CvMat* get_var_idx(); 1975 void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability), 1976 // use change_var_idx 1977 void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor 1978 1979 const CvMat* get_var_types(); 1980 int get_var_type( int var_idx ) const; 1981 // following 2 methods enable to change vars type 1982 // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable 1983 // with numerical labels; in the other cases var types are correctly determined automatically 1984 void set_var_types( const char* str ); // str examples: 1985 // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]", 1986 // "cat", "ord" (all vars are categorical/ordered) 1987 void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL } 1988 1989 void set_delimiter( char ch ); 1990 char get_delimiter() const; 1991 1992 void set_miss_ch( char ch ); 1993 char get_miss_ch() const; 1994 1995 const std::map<cv::String, int>& get_class_labels_map() const; 1996 1997 protected: 1998 virtual void clear(); 1999 2000 void str_to_flt_elem( const char* token, float& flt_elem, int& type); 2001 void free_train_test_idx(); 2002 2003 char delimiter; 2004 char miss_ch; 2005 //char flt_separator; 2006 2007 CvMat* values; 2008 CvMat* missing; 2009 CvMat* var_types; 2010 CvMat* var_idx_mask; 2011 2012 CvMat* response_out; // header 2013 CvMat* var_idx_out; // mat 2014 CvMat* var_types_out; // mat 2015 2016 int header_lines_number; 2017 2018 int response_idx; 2019 2020 int train_sample_count; 2021 bool mix; 2022 2023 int total_class_count; 2024 std::map<cv::String, int> class_map; 2025 2026 CvMat* train_sample_idx; 2027 CvMat* test_sample_idx; 2028 int* sample_idx; // data of train_sample_idx and test_sample_idx 2029 2030 cv::RNG* rng; 2031 }; 2032 2033 2034 namespace cv 2035 { 2036 2037 typedef CvStatModel StatModel; 2038 typedef CvParamGrid ParamGrid; 2039 typedef CvNormalBayesClassifier NormalBayesClassifier; 2040 typedef CvKNearest KNearest; 2041 typedef CvSVMParams SVMParams; 2042 typedef CvSVMKernel SVMKernel; 2043 typedef CvSVMSolver SVMSolver; 2044 typedef CvSVM SVM; 2045 typedef CvDTreeParams DTreeParams; 2046 typedef CvMLData TrainData; 2047 typedef CvDTree DecisionTree; 2048 typedef CvForestTree ForestTree; 2049 typedef CvRTParams RandomTreeParams; 2050 typedef CvRTrees RandomTrees; 2051 typedef CvERTreeTrainData ERTreeTRainData; 2052 typedef CvForestERTree ERTree; 2053 typedef CvERTrees ERTrees; 2054 typedef CvBoostParams BoostParams; 2055 typedef CvBoostTree BoostTree; 2056 typedef CvBoost Boost; 2057 typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams; 2058 typedef CvANN_MLP NeuralNet_MLP; 2059 typedef CvGBTreesParams GradientBoostingTreeParams; 2060 typedef CvGBTrees GradientBoostingTrees; 2061 2062 template<> void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const; 2063 } 2064 2065 #endif // __cplusplus 2066 #endif // __OPENCV_ML_HPP__ 2067 2068 /* End of file. */ 2069