1 /*M/////////////////////////////////////////////////////////////////////////////////////// 2 // 3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4 // 5 // By downloading, copying, installing or using the software you agree to this license. 6 // If you do not agree to this license, do not download, install, 7 // copy or use the software. 8 // 9 // 10 // Intel License Agreement 11 // 12 // Copyright (C) 2000, Intel Corporation, all rights reserved. 13 // Third party copyrights are property of their respective owners. 14 // 15 // Redistribution and use in source and binary forms, with or without modification, 16 // are permitted provided that the following conditions are met: 17 // 18 // * Redistribution's of source code must retain the above copyright notice, 19 // this list of conditions and the following disclaimer. 20 // 21 // * Redistribution's in binary form must reproduce the above copyright notice, 22 // this list of conditions and the following disclaimer in the documentation 23 // and/or other materials provided with the distribution. 24 // 25 // * The name of Intel Corporation may not be used to endorse or promote products 26 // derived from this software without specific prior written permission. 27 // 28 // This software is provided by the copyright holders and contributors "as is" and 29 // any express or implied warranties, including, but not limited to, the implied 30 // warranties of merchantability and fitness for a particular purpose are disclaimed. 31 // In no event shall the Intel Corporation or contributors be liable for any direct, 32 // indirect, incidental, special, exemplary, or consequential damages 33 // (including, but not limited to, procurement of substitute goods or services; 34 // loss of use, data, or profits; or business interruption) however caused 35 // and on any theory of liability, whether in contract, strict liability, 36 // or tort (including negligence or otherwise) arising in any way out of 37 // the use of this software, even if advised of the possibility of such damage. 38 // 39 //M*/ 40 41 #ifndef __ML_H__ 42 #define __ML_H__ 43 44 // disable deprecation warning which appears in VisualStudio 8.0 45 #if _MSC_VER >= 1400 46 #pragma warning( disable : 4996 ) 47 #endif 48 49 #ifndef SKIP_INCLUDES 50 51 #include "cxcore.h" 52 #include <limits.h> 53 54 #if defined WIN32 || defined WIN64 55 #include <windows.h> 56 #endif 57 58 #else // SKIP_INCLUDES 59 60 #if defined WIN32 || defined WIN64 61 #define CV_CDECL __cdecl 62 #define CV_STDCALL __stdcall 63 #else 64 #define CV_CDECL 65 #define CV_STDCALL 66 #endif 67 68 #ifndef CV_EXTERN_C 69 #ifdef __cplusplus 70 #define CV_EXTERN_C extern "C" 71 #define CV_DEFAULT(val) = val 72 #else 73 #define CV_EXTERN_C 74 #define CV_DEFAULT(val) 75 #endif 76 #endif 77 78 #ifndef CV_EXTERN_C_FUNCPTR 79 #ifdef __cplusplus 80 #define CV_EXTERN_C_FUNCPTR(x) extern "C" { typedef x; } 81 #else 82 #define CV_EXTERN_C_FUNCPTR(x) typedef x 83 #endif 84 #endif 85 86 #ifndef CV_INLINE 87 #if defined __cplusplus 88 #define CV_INLINE inline 89 #elif (defined WIN32 || defined WIN64) && !defined __GNUC__ 90 #define CV_INLINE __inline 91 #else 92 #define CV_INLINE static 93 #endif 94 #endif /* CV_INLINE */ 95 96 #if (defined WIN32 || defined WIN64) && defined CVAPI_EXPORTS 97 #define CV_EXPORTS __declspec(dllexport) 98 #else 99 #define CV_EXPORTS 100 #endif 101 102 #ifndef CVAPI 103 #define CVAPI(rettype) CV_EXTERN_C CV_EXPORTS rettype CV_CDECL 104 #endif 105 106 #endif // SKIP_INCLUDES 107 108 109 #ifdef __cplusplus 110 111 // Apple defines a check() macro somewhere in the debug headers 112 // that interferes with a method definiton in this header 113 #undef check 114 115 /****************************************************************************************\ 116 * Main struct definitions * 117 \****************************************************************************************/ 118 119 /* log(2*PI) */ 120 #define CV_LOG2PI (1.8378770664093454835606594728112) 121 122 /* columns of <trainData> matrix are training samples */ 123 #define CV_COL_SAMPLE 0 124 125 /* rows of <trainData> matrix are training samples */ 126 #define CV_ROW_SAMPLE 1 127 128 #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE) 129 130 struct CvVectors 131 { 132 int type; 133 int dims, count; 134 CvVectors* next; 135 union 136 { 137 uchar** ptr; 138 float** fl; 139 double** db; 140 } data; 141 }; 142 143 #if 0 144 /* A structure, representing the lattice range of statmodel parameters. 145 It is used for optimizing statmodel parameters by cross-validation method. 146 The lattice is logarithmic, so <step> must be greater then 1. */ 147 typedef struct CvParamLattice 148 { 149 double min_val; 150 double max_val; 151 double step; 152 } 153 CvParamLattice; 154 155 CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val, 156 double log_step ) 157 { 158 CvParamLattice pl; 159 pl.min_val = MIN( min_val, max_val ); 160 pl.max_val = MAX( min_val, max_val ); 161 pl.step = MAX( log_step, 1. ); 162 return pl; 163 } 164 165 CV_INLINE CvParamLattice cvDefaultParamLattice( void ) 166 { 167 CvParamLattice pl = {0,0,0}; 168 return pl; 169 } 170 #endif 171 172 /* Variable type */ 173 #define CV_VAR_NUMERICAL 0 174 #define CV_VAR_ORDERED 0 175 #define CV_VAR_CATEGORICAL 1 176 177 #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm" 178 #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn" 179 #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian" 180 #define CV_TYPE_NAME_ML_EM "opencv-ml-em" 181 #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree" 182 #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree" 183 #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp" 184 #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn" 185 #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees" 186 187 class CV_EXPORTS CvStatModel 188 { 189 public: 190 CvStatModel(); 191 virtual ~CvStatModel(); 192 193 virtual void clear(); 194 195 virtual void save( const char* filename, const char* name=0 ); 196 virtual void load( const char* filename, const char* name=0 ); 197 198 virtual void write( CvFileStorage* storage, const char* name ); 199 virtual void read( CvFileStorage* storage, CvFileNode* node ); 200 201 protected: 202 const char* default_model_name; 203 }; 204 205 206 /****************************************************************************************\ 207 * Normal Bayes Classifier * 208 \****************************************************************************************/ 209 210 /* The structure, representing the grid range of statmodel parameters. 211 It is used for optimizing statmodel accuracy by varying model parameters, 212 the accuracy estimate being computed by cross-validation. 213 The grid is logarithmic, so <step> must be greater then 1. */ 214 struct CV_EXPORTS CvParamGrid 215 { 216 // SVM params type 217 enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 }; 218 219 CvParamGrid() 220 { 221 min_val = max_val = step = 0; 222 } 223 224 CvParamGrid( double _min_val, double _max_val, double log_step ) 225 { 226 min_val = _min_val; 227 max_val = _max_val; 228 step = log_step; 229 } 230 //CvParamGrid( int param_id ); 231 bool check() const; 232 233 double min_val; 234 double max_val; 235 double step; 236 }; 237 238 class CV_EXPORTS CvNormalBayesClassifier : public CvStatModel 239 { 240 public: 241 CvNormalBayesClassifier(); 242 virtual ~CvNormalBayesClassifier(); 243 244 CvNormalBayesClassifier( const CvMat* _train_data, const CvMat* _responses, 245 const CvMat* _var_idx=0, const CvMat* _sample_idx=0 ); 246 247 virtual bool train( const CvMat* _train_data, const CvMat* _responses, 248 const CvMat* _var_idx = 0, const CvMat* _sample_idx=0, bool update=false ); 249 250 virtual float predict( const CvMat* _samples, CvMat* results=0 ) const; 251 virtual void clear(); 252 253 virtual void write( CvFileStorage* storage, const char* name ); 254 virtual void read( CvFileStorage* storage, CvFileNode* node ); 255 256 protected: 257 int var_count, var_all; 258 CvMat* var_idx; 259 CvMat* cls_labels; 260 CvMat** count; 261 CvMat** sum; 262 CvMat** productsum; 263 CvMat** avg; 264 CvMat** inv_eigen_values; 265 CvMat** cov_rotate_mats; 266 CvMat* c; 267 }; 268 269 270 /****************************************************************************************\ 271 * K-Nearest Neighbour Classifier * 272 \****************************************************************************************/ 273 274 // k Nearest Neighbors 275 class CV_EXPORTS CvKNearest : public CvStatModel 276 { 277 public: 278 279 CvKNearest(); 280 virtual ~CvKNearest(); 281 282 CvKNearest( const CvMat* _train_data, const CvMat* _responses, 283 const CvMat* _sample_idx=0, bool _is_regression=false, int max_k=32 ); 284 285 virtual bool train( const CvMat* _train_data, const CvMat* _responses, 286 const CvMat* _sample_idx=0, bool is_regression=false, 287 int _max_k=32, bool _update_base=false ); 288 289 virtual float find_nearest( const CvMat* _samples, int k, CvMat* results=0, 290 const float** neighbors=0, CvMat* neighbor_responses=0, CvMat* dist=0 ) const; 291 292 virtual void clear(); 293 int get_max_k() const; 294 int get_var_count() const; 295 int get_sample_count() const; 296 bool is_regression() const; 297 298 protected: 299 300 virtual float write_results( int k, int k1, int start, int end, 301 const float* neighbor_responses, const float* dist, CvMat* _results, 302 CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const; 303 304 virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end, 305 float* neighbor_responses, const float** neighbors, float* dist ) const; 306 307 308 int max_k, var_count; 309 int total; 310 bool regression; 311 CvVectors* samples; 312 }; 313 314 /****************************************************************************************\ 315 * Support Vector Machines * 316 \****************************************************************************************/ 317 318 // SVM training parameters 319 struct CV_EXPORTS CvSVMParams 320 { 321 CvSVMParams(); 322 CvSVMParams( int _svm_type, int _kernel_type, 323 double _degree, double _gamma, double _coef0, 324 double _C, double _nu, double _p, 325 CvMat* _class_weights, CvTermCriteria _term_crit ); 326 327 int svm_type; 328 int kernel_type; 329 double degree; // for poly 330 double gamma; // for poly/rbf/sigmoid 331 double coef0; // for poly/sigmoid 332 333 double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR 334 double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR 335 double p; // for CV_SVM_EPS_SVR 336 CvMat* class_weights; // for CV_SVM_C_SVC 337 CvTermCriteria term_crit; // termination criteria 338 }; 339 340 341 struct CV_EXPORTS CvSVMKernel 342 { 343 typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs, 344 const float* another, float* results ); 345 CvSVMKernel(); 346 CvSVMKernel( const CvSVMParams* _params, Calc _calc_func ); 347 virtual bool create( const CvSVMParams* _params, Calc _calc_func ); 348 virtual ~CvSVMKernel(); 349 350 virtual void clear(); 351 virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results ); 352 353 const CvSVMParams* params; 354 Calc calc_func; 355 356 virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs, 357 const float* another, float* results, 358 double alpha, double beta ); 359 360 virtual void calc_linear( int vec_count, int vec_size, const float** vecs, 361 const float* another, float* results ); 362 virtual void calc_rbf( int vec_count, int vec_size, const float** vecs, 363 const float* another, float* results ); 364 virtual void calc_poly( int vec_count, int vec_size, const float** vecs, 365 const float* another, float* results ); 366 virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs, 367 const float* another, float* results ); 368 }; 369 370 371 struct CvSVMKernelRow 372 { 373 CvSVMKernelRow* prev; 374 CvSVMKernelRow* next; 375 float* data; 376 }; 377 378 379 struct CvSVMSolutionInfo 380 { 381 double obj; 382 double rho; 383 double upper_bound_p; 384 double upper_bound_n; 385 double r; // for Solver_NU 386 }; 387 388 class CV_EXPORTS CvSVMSolver 389 { 390 public: 391 typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j ); 392 typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed ); 393 typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r ); 394 395 CvSVMSolver(); 396 397 CvSVMSolver( int count, int var_count, const float** samples, schar* y, 398 int alpha_count, double* alpha, double Cp, double Cn, 399 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 400 SelectWorkingSet select_working_set, CalcRho calc_rho ); 401 virtual bool create( int count, int var_count, const float** samples, schar* y, 402 int alpha_count, double* alpha, double Cp, double Cn, 403 CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row, 404 SelectWorkingSet select_working_set, CalcRho calc_rho ); 405 virtual ~CvSVMSolver(); 406 407 virtual void clear(); 408 virtual bool solve_generic( CvSVMSolutionInfo& si ); 409 410 virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y, 411 double Cp, double Cn, CvMemStorage* storage, 412 CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si ); 413 virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y, 414 CvMemStorage* storage, CvSVMKernel* kernel, 415 double* alpha, CvSVMSolutionInfo& si ); 416 virtual bool solve_one_class( int count, int var_count, const float** samples, 417 CvMemStorage* storage, CvSVMKernel* kernel, 418 double* alpha, CvSVMSolutionInfo& si ); 419 420 virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y, 421 CvMemStorage* storage, CvSVMKernel* kernel, 422 double* alpha, CvSVMSolutionInfo& si ); 423 424 virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y, 425 CvMemStorage* storage, CvSVMKernel* kernel, 426 double* alpha, CvSVMSolutionInfo& si ); 427 428 virtual float* get_row_base( int i, bool* _existed ); 429 virtual float* get_row( int i, float* dst ); 430 431 int sample_count; 432 int var_count; 433 int cache_size; 434 int cache_line_size; 435 const float** samples; 436 const CvSVMParams* params; 437 CvMemStorage* storage; 438 CvSVMKernelRow lru_list; 439 CvSVMKernelRow* rows; 440 441 int alpha_count; 442 443 double* G; 444 double* alpha; 445 446 // -1 - lower bound, 0 - free, 1 - upper bound 447 schar* alpha_status; 448 449 schar* y; 450 double* b; 451 float* buf[2]; 452 double eps; 453 int max_iter; 454 double C[2]; // C[0] == Cn, C[1] == Cp 455 CvSVMKernel* kernel; 456 457 SelectWorkingSet select_working_set_func; 458 CalcRho calc_rho_func; 459 GetRow get_row_func; 460 461 virtual bool select_working_set( int& i, int& j ); 462 virtual bool select_working_set_nu_svm( int& i, int& j ); 463 virtual void calc_rho( double& rho, double& r ); 464 virtual void calc_rho_nu_svm( double& rho, double& r ); 465 466 virtual float* get_row_svc( int i, float* row, float* dst, bool existed ); 467 virtual float* get_row_one_class( int i, float* row, float* dst, bool existed ); 468 virtual float* get_row_svr( int i, float* row, float* dst, bool existed ); 469 }; 470 471 472 struct CvSVMDecisionFunc 473 { 474 double rho; 475 int sv_count; 476 double* alpha; 477 int* sv_index; 478 }; 479 480 481 // SVM model 482 class CV_EXPORTS CvSVM : public CvStatModel 483 { 484 public: 485 // SVM type 486 enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 }; 487 488 // SVM kernel type 489 enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3 }; 490 491 // SVM params type 492 enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 }; 493 494 CvSVM(); 495 virtual ~CvSVM(); 496 497 CvSVM( const CvMat* _train_data, const CvMat* _responses, 498 const CvMat* _var_idx=0, const CvMat* _sample_idx=0, 499 CvSVMParams _params=CvSVMParams() ); 500 501 virtual bool train( const CvMat* _train_data, const CvMat* _responses, 502 const CvMat* _var_idx=0, const CvMat* _sample_idx=0, 503 CvSVMParams _params=CvSVMParams() ); 504 virtual bool train_auto( const CvMat* _train_data, const CvMat* _responses, 505 const CvMat* _var_idx, const CvMat* _sample_idx, CvSVMParams _params, 506 int k_fold = 10, 507 CvParamGrid C_grid = get_default_grid(CvSVM::C), 508 CvParamGrid gamma_grid = get_default_grid(CvSVM::GAMMA), 509 CvParamGrid p_grid = get_default_grid(CvSVM::P), 510 CvParamGrid nu_grid = get_default_grid(CvSVM::NU), 511 CvParamGrid coef_grid = get_default_grid(CvSVM::COEF), 512 CvParamGrid degree_grid = get_default_grid(CvSVM::DEGREE) ); 513 514 virtual float predict( const CvMat* _sample ) const; 515 516 virtual int get_support_vector_count() const; 517 virtual const float* get_support_vector(int i) const; 518 virtual CvSVMParams get_params() const { return params; }; 519 virtual void clear(); 520 521 static CvParamGrid get_default_grid( int param_id ); 522 523 virtual void write( CvFileStorage* storage, const char* name ); 524 virtual void read( CvFileStorage* storage, CvFileNode* node ); 525 int get_var_count() const { return var_idx ? var_idx->cols : var_all; } 526 527 protected: 528 529 virtual bool set_params( const CvSVMParams& _params ); 530 virtual bool train1( int sample_count, int var_count, const float** samples, 531 const void* _responses, double Cp, double Cn, 532 CvMemStorage* _storage, double* alpha, double& rho ); 533 virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples, 534 const CvMat* _responses, CvMemStorage* _storage, double* alpha ); 535 virtual void create_kernel(); 536 virtual void create_solver(); 537 538 virtual void write_params( CvFileStorage* fs ); 539 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 540 541 CvSVMParams params; 542 CvMat* class_labels; 543 int var_all; 544 float** sv; 545 int sv_total; 546 CvMat* var_idx; 547 CvMat* class_weights; 548 CvSVMDecisionFunc* decision_func; 549 CvMemStorage* storage; 550 551 CvSVMSolver* solver; 552 CvSVMKernel* kernel; 553 }; 554 555 /****************************************************************************************\ 556 * Expectation - Maximization * 557 \****************************************************************************************/ 558 559 struct CV_EXPORTS CvEMParams 560 { 561 CvEMParams() : nclusters(10), cov_mat_type(1/*CvEM::COV_MAT_DIAGONAL*/), 562 start_step(0/*CvEM::START_AUTO_STEP*/), probs(0), weights(0), means(0), covs(0) 563 { 564 term_crit=cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON ); 565 } 566 567 CvEMParams( int _nclusters, int _cov_mat_type=1/*CvEM::COV_MAT_DIAGONAL*/, 568 int _start_step=0/*CvEM::START_AUTO_STEP*/, 569 CvTermCriteria _term_crit=cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 100, FLT_EPSILON), 570 const CvMat* _probs=0, const CvMat* _weights=0, const CvMat* _means=0, const CvMat** _covs=0 ) : 571 nclusters(_nclusters), cov_mat_type(_cov_mat_type), start_step(_start_step), 572 probs(_probs), weights(_weights), means(_means), covs(_covs), term_crit(_term_crit) 573 {} 574 575 int nclusters; 576 int cov_mat_type; 577 int start_step; 578 const CvMat* probs; 579 const CvMat* weights; 580 const CvMat* means; 581 const CvMat** covs; 582 CvTermCriteria term_crit; 583 }; 584 585 586 class CV_EXPORTS CvEM : public CvStatModel 587 { 588 public: 589 // Type of covariation matrices 590 enum { COV_MAT_SPHERICAL=0, COV_MAT_DIAGONAL=1, COV_MAT_GENERIC=2 }; 591 592 // The initial step 593 enum { START_E_STEP=1, START_M_STEP=2, START_AUTO_STEP=0 }; 594 595 CvEM(); 596 CvEM( const CvMat* samples, const CvMat* sample_idx=0, 597 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 598 599 virtual ~CvEM(); 600 601 virtual bool train( const CvMat* samples, const CvMat* sample_idx=0, 602 CvEMParams params=CvEMParams(), CvMat* labels=0 ); 603 604 virtual float predict( const CvMat* sample, CvMat* probs ) const; 605 virtual void clear(); 606 607 int get_nclusters() const; 608 const CvMat* get_means() const; 609 const CvMat** get_covs() const; 610 const CvMat* get_weights() const; 611 const CvMat* get_probs() const; 612 613 inline double get_log_likelihood () const { return log_likelihood; }; 614 615 protected: 616 617 virtual void set_params( const CvEMParams& params, 618 const CvVectors& train_data ); 619 virtual void init_em( const CvVectors& train_data ); 620 virtual double run_em( const CvVectors& train_data ); 621 virtual void init_auto( const CvVectors& samples ); 622 virtual void kmeans( const CvVectors& train_data, int nclusters, 623 CvMat* labels, CvTermCriteria criteria, 624 const CvMat* means ); 625 CvEMParams params; 626 double log_likelihood; 627 628 CvMat* means; 629 CvMat** covs; 630 CvMat* weights; 631 CvMat* probs; 632 633 CvMat* log_weight_div_det; 634 CvMat* inv_eigen_values; 635 CvMat** cov_rotate_mats; 636 }; 637 638 /****************************************************************************************\ 639 * Decision Tree * 640 \****************************************************************************************/ 641 642 struct CvPair32s32f 643 { 644 int i; 645 float val; 646 }; 647 648 649 #define CV_DTREE_CAT_DIR(idx,subset) \ 650 (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1) 651 652 struct CvDTreeSplit 653 { 654 int var_idx; 655 int inversed; 656 float quality; 657 CvDTreeSplit* next; 658 union 659 { 660 int subset[2]; 661 struct 662 { 663 float c; 664 int split_point; 665 } 666 ord; 667 }; 668 }; 669 670 671 struct CvDTreeNode 672 { 673 int class_idx; 674 int Tn; 675 double value; 676 677 CvDTreeNode* parent; 678 CvDTreeNode* left; 679 CvDTreeNode* right; 680 681 CvDTreeSplit* split; 682 683 int sample_count; 684 int depth; 685 int* num_valid; 686 int offset; 687 int buf_idx; 688 double maxlr; 689 690 // global pruning data 691 int complexity; 692 double alpha; 693 double node_risk, tree_risk, tree_error; 694 695 // cross-validation pruning data 696 int* cv_Tn; 697 double* cv_node_risk; 698 double* cv_node_error; 699 700 int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; } 701 void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; } 702 }; 703 704 705 struct CV_EXPORTS CvDTreeParams 706 { 707 int max_categories; 708 int max_depth; 709 int min_sample_count; 710 int cv_folds; 711 bool use_surrogates; 712 bool use_1se_rule; 713 bool truncate_pruned_tree; 714 float regression_accuracy; 715 const float* priors; 716 717 CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10), 718 cv_folds(10), use_surrogates(true), use_1se_rule(true), 719 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0) 720 {} 721 722 CvDTreeParams( int _max_depth, int _min_sample_count, 723 float _regression_accuracy, bool _use_surrogates, 724 int _max_categories, int _cv_folds, 725 bool _use_1se_rule, bool _truncate_pruned_tree, 726 const float* _priors ) : 727 max_categories(_max_categories), max_depth(_max_depth), 728 min_sample_count(_min_sample_count), cv_folds (_cv_folds), 729 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule), 730 truncate_pruned_tree(_truncate_pruned_tree), 731 regression_accuracy(_regression_accuracy), 732 priors(_priors) 733 {} 734 }; 735 736 737 struct CV_EXPORTS CvDTreeTrainData 738 { 739 CvDTreeTrainData(); 740 CvDTreeTrainData( const CvMat* _train_data, int _tflag, 741 const CvMat* _responses, const CvMat* _var_idx=0, 742 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 743 const CvMat* _missing_mask=0, 744 const CvDTreeParams& _params=CvDTreeParams(), 745 bool _shared=false, bool _add_labels=false ); 746 virtual ~CvDTreeTrainData(); 747 748 virtual void set_data( const CvMat* _train_data, int _tflag, 749 const CvMat* _responses, const CvMat* _var_idx=0, 750 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 751 const CvMat* _missing_mask=0, 752 const CvDTreeParams& _params=CvDTreeParams(), 753 bool _shared=false, bool _add_labels=false, 754 bool _update_data=false ); 755 756 virtual void get_vectors( const CvMat* _subsample_idx, 757 float* values, uchar* missing, float* responses, bool get_class_idx=false ); 758 759 virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); 760 761 virtual void write_params( CvFileStorage* fs ); 762 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 763 764 // release all the data 765 virtual void clear(); 766 767 int get_num_classes() const; 768 int get_var_type(int vi) const; 769 int get_work_var_count() const; 770 771 virtual int* get_class_labels( CvDTreeNode* n ); 772 virtual float* get_ord_responses( CvDTreeNode* n ); 773 virtual int* get_labels( CvDTreeNode* n ); 774 virtual int* get_cat_var_data( CvDTreeNode* n, int vi ); 775 virtual CvPair32s32f* get_ord_var_data( CvDTreeNode* n, int vi ); 776 virtual int get_child_buf_idx( CvDTreeNode* n ); 777 778 //////////////////////////////////// 779 780 virtual bool set_params( const CvDTreeParams& params ); 781 virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count, 782 int storage_idx, int offset ); 783 784 virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val, 785 int split_point, int inversed, float quality ); 786 virtual CvDTreeSplit* new_split_cat( int vi, float quality ); 787 virtual void free_node_data( CvDTreeNode* node ); 788 virtual void free_train_data(); 789 virtual void free_node( CvDTreeNode* node ); 790 791 int sample_count, var_all, var_count, max_c_count; 792 int ord_var_count, cat_var_count; 793 bool have_labels, have_priors; 794 bool is_classifier; 795 796 int buf_count, buf_size; 797 bool shared; 798 799 CvMat* cat_count; 800 CvMat* cat_ofs; 801 CvMat* cat_map; 802 803 CvMat* counts; 804 CvMat* buf; 805 CvMat* direction; 806 CvMat* split_buf; 807 808 CvMat* var_idx; 809 CvMat* var_type; // i-th element = 810 // k<0 - ordered 811 // k>=0 - categorical, see k-th element of cat_* arrays 812 CvMat* priors; 813 CvMat* priors_mult; 814 815 CvDTreeParams params; 816 817 CvMemStorage* tree_storage; 818 CvMemStorage* temp_storage; 819 820 CvDTreeNode* data_root; 821 822 CvSet* node_heap; 823 CvSet* split_heap; 824 CvSet* cv_heap; 825 CvSet* nv_heap; 826 827 CvRNG rng; 828 }; 829 830 831 class CV_EXPORTS CvDTree : public CvStatModel 832 { 833 public: 834 CvDTree(); 835 virtual ~CvDTree(); 836 837 virtual bool train( const CvMat* _train_data, int _tflag, 838 const CvMat* _responses, const CvMat* _var_idx=0, 839 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 840 const CvMat* _missing_mask=0, 841 CvDTreeParams params=CvDTreeParams() ); 842 843 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx ); 844 845 virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0, 846 bool preprocessed_input=false ) const; 847 virtual const CvMat* get_var_importance(); 848 virtual void clear(); 849 850 virtual void read( CvFileStorage* fs, CvFileNode* node ); 851 virtual void write( CvFileStorage* fs, const char* name ); 852 853 // special read & write methods for trees in the tree ensembles 854 virtual void read( CvFileStorage* fs, CvFileNode* node, 855 CvDTreeTrainData* data ); 856 virtual void write( CvFileStorage* fs ); 857 858 const CvDTreeNode* get_root() const; 859 int get_pruned_tree_idx() const; 860 CvDTreeTrainData* get_data(); 861 862 protected: 863 864 virtual bool do_train( const CvMat* _subsample_idx ); 865 866 virtual void try_split_node( CvDTreeNode* n ); 867 virtual void split_node_data( CvDTreeNode* n ); 868 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 869 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi ); 870 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi ); 871 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi ); 872 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi ); 873 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi ); 874 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi ); 875 virtual double calc_node_dir( CvDTreeNode* node ); 876 virtual void complete_node_dir( CvDTreeNode* node ); 877 virtual void cluster_categories( const int* vectors, int vector_count, 878 int var_count, int* sums, int k, int* cluster_labels ); 879 880 virtual void calc_node_value( CvDTreeNode* node ); 881 882 virtual void prune_cv(); 883 virtual double update_tree_rnc( int T, int fold ); 884 virtual int cut_tree( int T, int fold, double min_alpha ); 885 virtual void free_prune_data(bool cut_tree); 886 virtual void free_tree(); 887 888 virtual void write_node( CvFileStorage* fs, CvDTreeNode* node ); 889 virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split ); 890 virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent ); 891 virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node ); 892 virtual void write_tree_nodes( CvFileStorage* fs ); 893 virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node ); 894 895 CvDTreeNode* root; 896 897 int pruned_tree_idx; 898 CvMat* var_importance; 899 900 CvDTreeTrainData* data; 901 }; 902 903 904 /****************************************************************************************\ 905 * Random Trees Classifier * 906 \****************************************************************************************/ 907 908 class CvRTrees; 909 910 class CV_EXPORTS CvForestTree: public CvDTree 911 { 912 public: 913 CvForestTree(); 914 virtual ~CvForestTree(); 915 916 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx, CvRTrees* forest ); 917 918 virtual int get_var_count() const {return data ? data->var_count : 0;} 919 virtual void read( CvFileStorage* fs, CvFileNode* node, CvRTrees* forest, CvDTreeTrainData* _data ); 920 921 /* dummy methods to avoid warnings: BEGIN */ 922 virtual bool train( const CvMat* _train_data, int _tflag, 923 const CvMat* _responses, const CvMat* _var_idx=0, 924 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 925 const CvMat* _missing_mask=0, 926 CvDTreeParams params=CvDTreeParams() ); 927 928 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx ); 929 virtual void read( CvFileStorage* fs, CvFileNode* node ); 930 virtual void read( CvFileStorage* fs, CvFileNode* node, 931 CvDTreeTrainData* data ); 932 /* dummy methods to avoid warnings: END */ 933 934 protected: 935 virtual CvDTreeSplit* find_best_split( CvDTreeNode* n ); 936 CvRTrees* forest; 937 }; 938 939 940 struct CV_EXPORTS CvRTParams : public CvDTreeParams 941 { 942 //Parameters for the forest 943 bool calc_var_importance; // true <=> RF processes variable importance 944 int nactive_vars; 945 CvTermCriteria term_crit; 946 947 CvRTParams() : CvDTreeParams( 5, 10, 0, false, 10, 0, false, false, 0 ), 948 calc_var_importance(false), nactive_vars(0) 949 { 950 term_crit = cvTermCriteria( CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, 50, 0.1 ); 951 } 952 953 CvRTParams( int _max_depth, int _min_sample_count, 954 float _regression_accuracy, bool _use_surrogates, 955 int _max_categories, const float* _priors, bool _calc_var_importance, 956 int _nactive_vars, int max_num_of_trees_in_the_forest, 957 float forest_accuracy, int termcrit_type ) : 958 CvDTreeParams( _max_depth, _min_sample_count, _regression_accuracy, 959 _use_surrogates, _max_categories, 0, 960 false, false, _priors ), 961 calc_var_importance(_calc_var_importance), 962 nactive_vars(_nactive_vars) 963 { 964 term_crit = cvTermCriteria(termcrit_type, 965 max_num_of_trees_in_the_forest, forest_accuracy); 966 } 967 }; 968 969 970 class CV_EXPORTS CvRTrees : public CvStatModel 971 { 972 public: 973 CvRTrees(); 974 virtual ~CvRTrees(); 975 virtual bool train( const CvMat* _train_data, int _tflag, 976 const CvMat* _responses, const CvMat* _var_idx=0, 977 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 978 const CvMat* _missing_mask=0, 979 CvRTParams params=CvRTParams() ); 980 virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const; 981 virtual void clear(); 982 983 virtual const CvMat* get_var_importance(); 984 virtual float get_proximity( const CvMat* sample1, const CvMat* sample2, 985 const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const; 986 987 virtual void read( CvFileStorage* fs, CvFileNode* node ); 988 virtual void write( CvFileStorage* fs, const char* name ); 989 990 CvMat* get_active_var_mask(); 991 CvRNG* get_rng(); 992 993 int get_tree_count() const; 994 CvForestTree* get_tree(int i) const; 995 996 protected: 997 998 bool grow_forest( const CvTermCriteria term_crit ); 999 1000 // array of the trees of the forest 1001 CvForestTree** trees; 1002 CvDTreeTrainData* data; 1003 int ntrees; 1004 int nclasses; 1005 double oob_error; 1006 CvMat* var_importance; 1007 int nsamples; 1008 1009 CvRNG rng; 1010 CvMat* active_var_mask; 1011 }; 1012 1013 1014 /****************************************************************************************\ 1015 * Boosted tree classifier * 1016 \****************************************************************************************/ 1017 1018 struct CV_EXPORTS CvBoostParams : public CvDTreeParams 1019 { 1020 int boost_type; 1021 int weak_count; 1022 int split_criteria; 1023 double weight_trim_rate; 1024 1025 CvBoostParams(); 1026 CvBoostParams( int boost_type, int weak_count, double weight_trim_rate, 1027 int max_depth, bool use_surrogates, const float* priors ); 1028 }; 1029 1030 1031 class CvBoost; 1032 1033 class CV_EXPORTS CvBoostTree: public CvDTree 1034 { 1035 public: 1036 CvBoostTree(); 1037 virtual ~CvBoostTree(); 1038 1039 virtual bool train( CvDTreeTrainData* _train_data, 1040 const CvMat* subsample_idx, CvBoost* ensemble ); 1041 1042 virtual void scale( double s ); 1043 virtual void read( CvFileStorage* fs, CvFileNode* node, 1044 CvBoost* ensemble, CvDTreeTrainData* _data ); 1045 virtual void clear(); 1046 1047 /* dummy methods to avoid warnings: BEGIN */ 1048 virtual bool train( const CvMat* _train_data, int _tflag, 1049 const CvMat* _responses, const CvMat* _var_idx=0, 1050 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 1051 const CvMat* _missing_mask=0, 1052 CvDTreeParams params=CvDTreeParams() ); 1053 1054 virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx ); 1055 virtual void read( CvFileStorage* fs, CvFileNode* node ); 1056 virtual void read( CvFileStorage* fs, CvFileNode* node, 1057 CvDTreeTrainData* data ); 1058 /* dummy methods to avoid warnings: END */ 1059 1060 protected: 1061 1062 virtual void try_split_node( CvDTreeNode* n ); 1063 virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi ); 1064 virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi ); 1065 virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi ); 1066 virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi ); 1067 virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi ); 1068 virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi ); 1069 virtual void calc_node_value( CvDTreeNode* n ); 1070 virtual double calc_node_dir( CvDTreeNode* n ); 1071 1072 CvBoost* ensemble; 1073 }; 1074 1075 1076 class CV_EXPORTS CvBoost : public CvStatModel 1077 { 1078 public: 1079 // Boosting type 1080 enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 }; 1081 1082 // Splitting criteria 1083 enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 }; 1084 1085 CvBoost(); 1086 virtual ~CvBoost(); 1087 1088 CvBoost( const CvMat* _train_data, int _tflag, 1089 const CvMat* _responses, const CvMat* _var_idx=0, 1090 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 1091 const CvMat* _missing_mask=0, 1092 CvBoostParams params=CvBoostParams() ); 1093 1094 virtual bool train( const CvMat* _train_data, int _tflag, 1095 const CvMat* _responses, const CvMat* _var_idx=0, 1096 const CvMat* _sample_idx=0, const CvMat* _var_type=0, 1097 const CvMat* _missing_mask=0, 1098 CvBoostParams params=CvBoostParams(), 1099 bool update=false ); 1100 1101 virtual float predict( const CvMat* _sample, const CvMat* _missing=0, 1102 CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ, 1103 bool raw_mode=false ) const; 1104 1105 virtual void prune( CvSlice slice ); 1106 1107 virtual void clear(); 1108 1109 virtual void write( CvFileStorage* storage, const char* name ); 1110 virtual void read( CvFileStorage* storage, CvFileNode* node ); 1111 1112 CvSeq* get_weak_predictors(); 1113 1114 CvMat* get_weights(); 1115 CvMat* get_subtree_weights(); 1116 CvMat* get_weak_response(); 1117 const CvBoostParams& get_params() const; 1118 1119 protected: 1120 1121 virtual bool set_params( const CvBoostParams& _params ); 1122 virtual void update_weights( CvBoostTree* tree ); 1123 virtual void trim_weights(); 1124 virtual void write_params( CvFileStorage* fs ); 1125 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 1126 1127 CvDTreeTrainData* data; 1128 CvBoostParams params; 1129 CvSeq* weak; 1130 1131 CvMat* orig_response; 1132 CvMat* sum_response; 1133 CvMat* weak_eval; 1134 CvMat* subsample_mask; 1135 CvMat* weights; 1136 CvMat* subtree_weights; 1137 bool have_subsample; 1138 }; 1139 1140 1141 /****************************************************************************************\ 1142 * Artificial Neural Networks (ANN) * 1143 \****************************************************************************************/ 1144 1145 /////////////////////////////////// Multi-Layer Perceptrons ////////////////////////////// 1146 1147 struct CV_EXPORTS CvANN_MLP_TrainParams 1148 { 1149 CvANN_MLP_TrainParams(); 1150 CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method, 1151 double param1, double param2=0 ); 1152 ~CvANN_MLP_TrainParams(); 1153 1154 enum { BACKPROP=0, RPROP=1 }; 1155 1156 CvTermCriteria term_crit; 1157 int train_method; 1158 1159 // backpropagation parameters 1160 double bp_dw_scale, bp_moment_scale; 1161 1162 // rprop parameters 1163 double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max; 1164 }; 1165 1166 1167 class CV_EXPORTS CvANN_MLP : public CvStatModel 1168 { 1169 public: 1170 CvANN_MLP(); 1171 CvANN_MLP( const CvMat* _layer_sizes, 1172 int _activ_func=SIGMOID_SYM, 1173 double _f_param1=0, double _f_param2=0 ); 1174 1175 virtual ~CvANN_MLP(); 1176 1177 virtual void create( const CvMat* _layer_sizes, 1178 int _activ_func=SIGMOID_SYM, 1179 double _f_param1=0, double _f_param2=0 ); 1180 1181 virtual int train( const CvMat* _inputs, const CvMat* _outputs, 1182 const CvMat* _sample_weights, const CvMat* _sample_idx=0, 1183 CvANN_MLP_TrainParams _params = CvANN_MLP_TrainParams(), 1184 int flags=0 ); 1185 virtual float predict( const CvMat* _inputs, 1186 CvMat* _outputs ) const; 1187 1188 virtual void clear(); 1189 1190 // possible activation functions 1191 enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 }; 1192 1193 // available training flags 1194 enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 }; 1195 1196 virtual void read( CvFileStorage* fs, CvFileNode* node ); 1197 virtual void write( CvFileStorage* storage, const char* name ); 1198 1199 int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; } 1200 const CvMat* get_layer_sizes() { return layer_sizes; } 1201 double* get_weights(int layer) 1202 { 1203 return layer_sizes && weights && 1204 (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0; 1205 } 1206 1207 protected: 1208 1209 virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs, 1210 const CvMat* _sample_weights, const CvMat* _sample_idx, 1211 CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags ); 1212 1213 // sequential random backpropagation 1214 virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 1215 1216 // RPROP algorithm 1217 virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw ); 1218 1219 virtual void calc_activ_func( CvMat* xf, const double* bias ) const; 1220 virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const; 1221 virtual void set_activ_func( int _activ_func=SIGMOID_SYM, 1222 double _f_param1=0, double _f_param2=0 ); 1223 virtual void init_weights(); 1224 virtual void scale_input( const CvMat* _src, CvMat* _dst ) const; 1225 virtual void scale_output( const CvMat* _src, CvMat* _dst ) const; 1226 virtual void calc_input_scale( const CvVectors* vecs, int flags ); 1227 virtual void calc_output_scale( const CvVectors* vecs, int flags ); 1228 1229 virtual void write_params( CvFileStorage* fs ); 1230 virtual void read_params( CvFileStorage* fs, CvFileNode* node ); 1231 1232 CvMat* layer_sizes; 1233 CvMat* wbuf; 1234 CvMat* sample_weights; 1235 double** weights; 1236 double f_param1, f_param2; 1237 double min_val, max_val, min_val1, max_val1; 1238 int activ_func; 1239 int max_count, max_buf_sz; 1240 CvANN_MLP_TrainParams params; 1241 CvRNG rng; 1242 }; 1243 1244 #if 0 1245 /****************************************************************************************\ 1246 * Convolutional Neural Network * 1247 \****************************************************************************************/ 1248 typedef struct CvCNNLayer CvCNNLayer; 1249 typedef struct CvCNNetwork CvCNNetwork; 1250 1251 #define CV_CNN_LEARN_RATE_DECREASE_HYPERBOLICALLY 1 1252 #define CV_CNN_LEARN_RATE_DECREASE_SQRT_INV 2 1253 #define CV_CNN_LEARN_RATE_DECREASE_LOG_INV 3 1254 1255 #define CV_CNN_GRAD_ESTIM_RANDOM 0 1256 #define CV_CNN_GRAD_ESTIM_BY_WORST_IMG 1 1257 1258 #define ICV_CNN_LAYER 0x55550000 1259 #define ICV_CNN_CONVOLUTION_LAYER 0x00001111 1260 #define ICV_CNN_SUBSAMPLING_LAYER 0x00002222 1261 #define ICV_CNN_FULLCONNECT_LAYER 0x00003333 1262 1263 #define ICV_IS_CNN_LAYER( layer ) \ 1264 ( ((layer) != NULL) && ((((CvCNNLayer*)(layer))->flags & CV_MAGIC_MASK)\ 1265 == ICV_CNN_LAYER )) 1266 1267 #define ICV_IS_CNN_CONVOLUTION_LAYER( layer ) \ 1268 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \ 1269 & ~CV_MAGIC_MASK) == ICV_CNN_CONVOLUTION_LAYER ) 1270 1271 #define ICV_IS_CNN_SUBSAMPLING_LAYER( layer ) \ 1272 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \ 1273 & ~CV_MAGIC_MASK) == ICV_CNN_SUBSAMPLING_LAYER ) 1274 1275 #define ICV_IS_CNN_FULLCONNECT_LAYER( layer ) \ 1276 ( (ICV_IS_CNN_LAYER( layer )) && (((CvCNNLayer*) (layer))->flags \ 1277 & ~CV_MAGIC_MASK) == ICV_CNN_FULLCONNECT_LAYER ) 1278 1279 typedef void (CV_CDECL *CvCNNLayerForward) 1280 ( CvCNNLayer* layer, const CvMat* input, CvMat* output ); 1281 1282 typedef void (CV_CDECL *CvCNNLayerBackward) 1283 ( CvCNNLayer* layer, int t, const CvMat* X, const CvMat* dE_dY, CvMat* dE_dX ); 1284 1285 typedef void (CV_CDECL *CvCNNLayerRelease) 1286 (CvCNNLayer** layer); 1287 1288 typedef void (CV_CDECL *CvCNNetworkAddLayer) 1289 (CvCNNetwork* network, CvCNNLayer* layer); 1290 1291 typedef void (CV_CDECL *CvCNNetworkRelease) 1292 (CvCNNetwork** network); 1293 1294 #define CV_CNN_LAYER_FIELDS() \ 1295 /* Indicator of the layer's type */ \ 1296 int flags; \ 1297 \ 1298 /* Number of input images */ \ 1299 int n_input_planes; \ 1300 /* Height of each input image */ \ 1301 int input_height; \ 1302 /* Width of each input image */ \ 1303 int input_width; \ 1304 \ 1305 /* Number of output images */ \ 1306 int n_output_planes; \ 1307 /* Height of each output image */ \ 1308 int output_height; \ 1309 /* Width of each output image */ \ 1310 int output_width; \ 1311 \ 1312 /* Learning rate at the first iteration */ \ 1313 float init_learn_rate; \ 1314 /* Dynamics of learning rate decreasing */ \ 1315 int learn_rate_decrease_type; \ 1316 /* Trainable weights of the layer (including bias) */ \ 1317 /* i-th row is a set of weights of the i-th output plane */ \ 1318 CvMat* weights; \ 1319 \ 1320 CvCNNLayerForward forward; \ 1321 CvCNNLayerBackward backward; \ 1322 CvCNNLayerRelease release; \ 1323 /* Pointers to the previous and next layers in the network */ \ 1324 CvCNNLayer* prev_layer; \ 1325 CvCNNLayer* next_layer 1326 1327 typedef struct CvCNNLayer 1328 { 1329 CV_CNN_LAYER_FIELDS(); 1330 }CvCNNLayer; 1331 1332 typedef struct CvCNNConvolutionLayer 1333 { 1334 CV_CNN_LAYER_FIELDS(); 1335 // Kernel size (height and width) for convolution. 1336 int K; 1337 // connections matrix, (i,j)-th element is 1 iff there is a connection between 1338 // i-th plane of the current layer and j-th plane of the previous layer; 1339 // (i,j)-th element is equal to 0 otherwise 1340 CvMat *connect_mask; 1341 // value of the learning rate for updating weights at the first iteration 1342 }CvCNNConvolutionLayer; 1343 1344 typedef struct CvCNNSubSamplingLayer 1345 { 1346 CV_CNN_LAYER_FIELDS(); 1347 // ratio between the heights (or widths - ratios are supposed to be equal) 1348 // of the input and output planes 1349 int sub_samp_scale; 1350 // amplitude of sigmoid activation function 1351 float a; 1352 // scale parameter of sigmoid activation function 1353 float s; 1354 // exp2ssumWX = exp(2<s>*(bias+w*(x1+...+x4))), where x1,...x4 are some elements of X 1355 // - is the vector used in computing of the activation function in backward 1356 CvMat* exp2ssumWX; 1357 // (x1+x2+x3+x4), where x1,...x4 are some elements of X 1358 // - is the vector used in computing of the activation function in backward 1359 CvMat* sumX; 1360 }CvCNNSubSamplingLayer; 1361 1362 // Structure of the last layer. 1363 typedef struct CvCNNFullConnectLayer 1364 { 1365 CV_CNN_LAYER_FIELDS(); 1366 // amplitude of sigmoid activation function 1367 float a; 1368 // scale parameter of sigmoid activation function 1369 float s; 1370 // exp2ssumWX = exp(2*<s>*(W*X)) - is the vector used in computing of the 1371 // activation function and it's derivative by the formulae 1372 // activ.func. = <a>(exp(2<s>WX)-1)/(exp(2<s>WX)+1) == <a> - 2<a>/(<exp2ssumWX> + 1) 1373 // (activ.func.)' = 4<a><s>exp(2<s>WX)/(exp(2<s>WX)+1)^2 1374 CvMat* exp2ssumWX; 1375 }CvCNNFullConnectLayer; 1376 1377 typedef struct CvCNNetwork 1378 { 1379 int n_layers; 1380 CvCNNLayer* layers; 1381 CvCNNetworkAddLayer add_layer; 1382 CvCNNetworkRelease release; 1383 }CvCNNetwork; 1384 1385 typedef struct CvCNNStatModel 1386 { 1387 CV_STAT_MODEL_FIELDS(); 1388 CvCNNetwork* network; 1389 // etalons are allocated as rows, the i-th etalon has label cls_labeles[i] 1390 CvMat* etalons; 1391 // classes labels 1392 CvMat* cls_labels; 1393 }CvCNNStatModel; 1394 1395 typedef struct CvCNNStatModelParams 1396 { 1397 CV_STAT_MODEL_PARAM_FIELDS(); 1398 // network must be created by the functions cvCreateCNNetwork and <add_layer> 1399 CvCNNetwork* network; 1400 CvMat* etalons; 1401 // termination criteria 1402 int max_iter; 1403 int start_iter; 1404 int grad_estim_type; 1405 }CvCNNStatModelParams; 1406 1407 CVAPI(CvCNNLayer*) cvCreateCNNConvolutionLayer( 1408 int n_input_planes, int input_height, int input_width, 1409 int n_output_planes, int K, 1410 float init_learn_rate, int learn_rate_decrease_type, 1411 CvMat* connect_mask CV_DEFAULT(0), CvMat* weights CV_DEFAULT(0) ); 1412 1413 CVAPI(CvCNNLayer*) cvCreateCNNSubSamplingLayer( 1414 int n_input_planes, int input_height, int input_width, 1415 int sub_samp_scale, float a, float s, 1416 float init_learn_rate, int learn_rate_decrease_type, CvMat* weights CV_DEFAULT(0) ); 1417 1418 CVAPI(CvCNNLayer*) cvCreateCNNFullConnectLayer( 1419 int n_inputs, int n_outputs, float a, float s, 1420 float init_learn_rate, int learning_type, CvMat* weights CV_DEFAULT(0) ); 1421 1422 CVAPI(CvCNNetwork*) cvCreateCNNetwork( CvCNNLayer* first_layer ); 1423 1424 CVAPI(CvStatModel*) cvTrainCNNClassifier( 1425 const CvMat* train_data, int tflag, 1426 const CvMat* responses, 1427 const CvStatModelParams* params, 1428 const CvMat* CV_DEFAULT(0), 1429 const CvMat* sample_idx CV_DEFAULT(0), 1430 const CvMat* CV_DEFAULT(0), const CvMat* CV_DEFAULT(0) ); 1431 1432 /****************************************************************************************\ 1433 * Estimate classifiers algorithms * 1434 \****************************************************************************************/ 1435 typedef const CvMat* (CV_CDECL *CvStatModelEstimateGetMat) 1436 ( const CvStatModel* estimateModel ); 1437 1438 typedef int (CV_CDECL *CvStatModelEstimateNextStep) 1439 ( CvStatModel* estimateModel ); 1440 1441 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifier) 1442 ( CvStatModel* estimateModel, 1443 const CvStatModel* model, 1444 const CvMat* features, 1445 int sample_t_flag, 1446 const CvMat* responses ); 1447 1448 typedef void (CV_CDECL *CvStatModelEstimateCheckClassifierEasy) 1449 ( CvStatModel* estimateModel, 1450 const CvStatModel* model ); 1451 1452 typedef float (CV_CDECL *CvStatModelEstimateGetCurrentResult) 1453 ( const CvStatModel* estimateModel, 1454 float* correlation ); 1455 1456 typedef void (CV_CDECL *CvStatModelEstimateReset) 1457 ( CvStatModel* estimateModel ); 1458 1459 //-------------------------------- Cross-validation -------------------------------------- 1460 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS() \ 1461 CV_STAT_MODEL_PARAM_FIELDS(); \ 1462 int k_fold; \ 1463 int is_regression; \ 1464 CvRNG* rng 1465 1466 typedef struct CvCrossValidationParams 1467 { 1468 CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_PARAM_FIELDS(); 1469 } CvCrossValidationParams; 1470 1471 #define CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS() \ 1472 CvStatModelEstimateGetMat getTrainIdxMat; \ 1473 CvStatModelEstimateGetMat getCheckIdxMat; \ 1474 CvStatModelEstimateNextStep nextStep; \ 1475 CvStatModelEstimateCheckClassifier check; \ 1476 CvStatModelEstimateGetCurrentResult getResult; \ 1477 CvStatModelEstimateReset reset; \ 1478 int is_regression; \ 1479 int folds_all; \ 1480 int samples_all; \ 1481 int* sampleIdxAll; \ 1482 int* folds; \ 1483 int max_fold_size; \ 1484 int current_fold; \ 1485 int is_checked; \ 1486 CvMat* sampleIdxTrain; \ 1487 CvMat* sampleIdxEval; \ 1488 CvMat* predict_results; \ 1489 int correct_results; \ 1490 int all_results; \ 1491 double sq_error; \ 1492 double sum_correct; \ 1493 double sum_predict; \ 1494 double sum_cc; \ 1495 double sum_pp; \ 1496 double sum_cp 1497 1498 typedef struct CvCrossValidationModel 1499 { 1500 CV_STAT_MODEL_FIELDS(); 1501 CV_CROSS_VALIDATION_ESTIMATE_CLASSIFIER_FIELDS(); 1502 } CvCrossValidationModel; 1503 1504 CVAPI(CvStatModel*) 1505 cvCreateCrossValidationEstimateModel 1506 ( int samples_all, 1507 const CvStatModelParams* estimateParams CV_DEFAULT(0), 1508 const CvMat* sampleIdx CV_DEFAULT(0) ); 1509 1510 CVAPI(float) 1511 cvCrossValidation( const CvMat* trueData, 1512 int tflag, 1513 const CvMat* trueClasses, 1514 CvStatModel* (*createClassifier)( const CvMat*, 1515 int, 1516 const CvMat*, 1517 const CvStatModelParams*, 1518 const CvMat*, 1519 const CvMat*, 1520 const CvMat*, 1521 const CvMat* ), 1522 const CvStatModelParams* estimateParams CV_DEFAULT(0), 1523 const CvStatModelParams* trainParams CV_DEFAULT(0), 1524 const CvMat* compIdx CV_DEFAULT(0), 1525 const CvMat* sampleIdx CV_DEFAULT(0), 1526 CvStatModel** pCrValModel CV_DEFAULT(0), 1527 const CvMat* typeMask CV_DEFAULT(0), 1528 const CvMat* missedMeasurementMask CV_DEFAULT(0) ); 1529 #endif 1530 1531 /****************************************************************************************\ 1532 * Auxilary functions declarations * 1533 \****************************************************************************************/ 1534 1535 /* Generates <sample> from multivariate normal distribution, where <mean> - is an 1536 average row vector, <cov> - symmetric covariation matrix */ 1537 CVAPI(void) cvRandMVNormal( CvMat* mean, CvMat* cov, CvMat* sample, 1538 CvRNG* rng CV_DEFAULT(0) ); 1539 1540 /* Generates sample from gaussian mixture distribution */ 1541 CVAPI(void) cvRandGaussMixture( CvMat* means[], 1542 CvMat* covs[], 1543 float weights[], 1544 int clsnum, 1545 CvMat* sample, 1546 CvMat* sampClasses CV_DEFAULT(0) ); 1547 1548 #define CV_TS_CONCENTRIC_SPHERES 0 1549 1550 /* creates test set */ 1551 CVAPI(void) cvCreateTestSet( int type, CvMat** samples, 1552 int num_samples, 1553 int num_features, 1554 CvMat** responses, 1555 int num_classes, ... ); 1556 1557 /* Aij <- Aji for i > j if lower_to_upper != 0 1558 for i < j if lower_to_upper = 0 */ 1559 CVAPI(void) cvCompleteSymm( CvMat* matrix, int lower_to_upper ); 1560 1561 #endif 1562 1563 #endif /*__ML_H__*/ 1564 /* End of file. */ 1565