1 /*M/////////////////////////////////////////////////////////////////////////////////////// 2 // 3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4 // 5 // By downloading, copying, installing or using the software you agree to this license. 6 // If you do not agree to this license, do not download, install, 7 // copy or use the software. 8 // 9 // 10 // Intel License Agreement 11 // 12 // Copyright (C) 2000, Intel Corporation, all rights reserved. 13 // Third party copyrights are property of their respective owners. 14 // 15 // Redistribution and use in source and binary forms, with or without modification, 16 // are permitted provided that the following conditions are met: 17 // 18 // * Redistribution's of source code must retain the above copyright notice, 19 // this list of conditions and the following disclaimer. 20 // 21 // * Redistribution's in binary form must reproduce the above copyright notice, 22 // this list of conditions and the following disclaimer in the documentation 23 // and/or other materials provided with the distribution. 24 // 25 // * The name of Intel Corporation may not be used to endorse or promote products 26 // derived from this software without specific prior written permission. 27 // 28 // This software is provided by the copyright holders and contributors "as is" and 29 // any express or implied warranties, including, but not limited to, the implied 30 // warranties of merchantability and fitness for a particular purpose are disclaimed. 31 // In no event shall the Intel Corporation or contributors be liable for any direct, 32 // indirect, incidental, special, exemplary, or consequential damages 33 // (including, but not limited to, procurement of substitute goods or services; 34 // loss of use, data, or profits; or business interruption) however caused 35 // and on any theory of liability, whether in contract, strict liability, 36 // or tort (including negligence or otherwise) arising in any way out of 37 // the use of this software, even if advised of the possibility of such damage. 38 // 39 //M*/ 40 41 #include "old_ml_precomp.hpp" 42 #include <ctype.h> 43 44 using namespace cv; 45 46 static const float ord_nan = FLT_MAX*0.5f; 47 static const int min_block_size = 1 << 16; 48 static const int block_size_delta = 1 << 10; 49 50 CvDTreeTrainData::CvDTreeTrainData() 51 { 52 var_idx = var_type = cat_count = cat_ofs = cat_map = 53 priors = priors_mult = counts = direction = split_buf = responses_copy = 0; 54 buf = 0; 55 tree_storage = temp_storage = 0; 56 57 clear(); 58 } 59 60 61 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag, 62 const CvMat* _responses, const CvMat* _var_idx, 63 const CvMat* _sample_idx, const CvMat* _var_type, 64 const CvMat* _missing_mask, const CvDTreeParams& _params, 65 bool _shared, bool _add_labels ) 66 { 67 var_idx = var_type = cat_count = cat_ofs = cat_map = 68 priors = priors_mult = counts = direction = split_buf = responses_copy = 0; 69 buf = 0; 70 71 tree_storage = temp_storage = 0; 72 73 set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx, 74 _var_type, _missing_mask, _params, _shared, _add_labels ); 75 } 76 77 78 CvDTreeTrainData::~CvDTreeTrainData() 79 { 80 clear(); 81 } 82 83 84 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params ) 85 { 86 bool ok = false; 87 88 CV_FUNCNAME( "CvDTreeTrainData::set_params" ); 89 90 __BEGIN__; 91 92 // set parameters 93 params = _params; 94 95 if( params.max_categories < 2 ) 96 CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" ); 97 params.max_categories = MIN( params.max_categories, 15 ); 98 99 if( params.max_depth < 0 ) 100 CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" ); 101 params.max_depth = MIN( params.max_depth, 25 ); 102 103 params.min_sample_count = MAX(params.min_sample_count,1); 104 105 if( params.cv_folds < 0 ) 106 CV_ERROR( CV_StsOutOfRange, 107 "params.cv_folds should be =0 (the tree is not pruned) " 108 "or n>0 (tree is pruned using n-fold cross-validation)" ); 109 110 if( params.cv_folds == 1 ) 111 params.cv_folds = 0; 112 113 if( params.regression_accuracy < 0 ) 114 CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" ); 115 116 ok = true; 117 118 __END__; 119 120 return ok; 121 } 122 123 template<typename T> 124 class LessThanPtr 125 { 126 public: 127 bool operator()(T* a, T* b) const { return *a < *b; } 128 }; 129 130 template<typename T, typename Idx> 131 class LessThanIdx 132 { 133 public: 134 LessThanIdx( const T* _arr ) : arr(_arr) {} 135 bool operator()(Idx a, Idx b) const { return arr[a] < arr[b]; } 136 const T* arr; 137 }; 138 139 class LessThanPairs 140 { 141 public: 142 bool operator()(const CvPair16u32s& a, const CvPair16u32s& b) const { return *a.i < *b.i; } 143 }; 144 145 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag, 146 const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx, 147 const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params, 148 bool _shared, bool _add_labels, bool _update_data ) 149 { 150 CvMat* sample_indices = 0; 151 CvMat* var_type0 = 0; 152 CvMat* tmp_map = 0; 153 int** int_ptr = 0; 154 CvPair16u32s* pair16u32s_ptr = 0; 155 CvDTreeTrainData* data = 0; 156 float *_fdst = 0; 157 int *_idst = 0; 158 unsigned short* udst = 0; 159 int* idst = 0; 160 161 CV_FUNCNAME( "CvDTreeTrainData::set_data" ); 162 163 __BEGIN__; 164 165 int sample_all = 0, r_type, cv_n; 166 int total_c_count = 0; 167 int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0; 168 int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step 169 int vi, i, size; 170 char err[100]; 171 const int *sidx = 0, *vidx = 0; 172 173 uint64 effective_buf_size = 0; 174 int effective_buf_height = 0, effective_buf_width = 0; 175 176 if( _update_data && data_root ) 177 { 178 data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx, 179 _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels ); 180 181 // compare new and old train data 182 if( !(data->var_count == var_count && 183 cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON && 184 cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON && 185 cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) ) 186 CV_ERROR( CV_StsBadArg, 187 "The new training data must have the same types and the input and output variables " 188 "and the same categories for categorical variables" ); 189 190 cvReleaseMat( &priors ); 191 cvReleaseMat( &priors_mult ); 192 cvReleaseMat( &buf ); 193 cvReleaseMat( &direction ); 194 cvReleaseMat( &split_buf ); 195 cvReleaseMemStorage( &temp_storage ); 196 197 priors = data->priors; data->priors = 0; 198 priors_mult = data->priors_mult; data->priors_mult = 0; 199 buf = data->buf; data->buf = 0; 200 buf_count = data->buf_count; buf_size = data->buf_size; 201 sample_count = data->sample_count; 202 203 direction = data->direction; data->direction = 0; 204 split_buf = data->split_buf; data->split_buf = 0; 205 temp_storage = data->temp_storage; data->temp_storage = 0; 206 nv_heap = data->nv_heap; cv_heap = data->cv_heap; 207 208 data_root = new_node( 0, sample_count, 0, 0 ); 209 EXIT; 210 } 211 212 clear(); 213 214 var_all = 0; 215 rng = &cv::theRNG(); 216 217 CV_CALL( set_params( _params )); 218 219 // check parameter types and sizes 220 CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all )); 221 222 train_data = _train_data; 223 responses = _responses; 224 225 if( _tflag == CV_ROW_SAMPLE ) 226 { 227 ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type); 228 dv_step = 1; 229 if( _missing_mask ) 230 ms_step = _missing_mask->step, mv_step = 1; 231 } 232 else 233 { 234 dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type); 235 ds_step = 1; 236 if( _missing_mask ) 237 mv_step = _missing_mask->step, ms_step = 1; 238 } 239 tflag = _tflag; 240 241 sample_count = sample_all; 242 var_count = var_all; 243 244 if( _sample_idx ) 245 { 246 CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all )); 247 sidx = sample_indices->data.i; 248 sample_count = sample_indices->rows + sample_indices->cols - 1; 249 } 250 251 if( _var_idx ) 252 { 253 CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all )); 254 vidx = var_idx->data.i; 255 var_count = var_idx->rows + var_idx->cols - 1; 256 } 257 258 is_buf_16u = false; 259 if ( sample_count < 65536 ) 260 is_buf_16u = true; 261 262 if( !CV_IS_MAT(_responses) || 263 (CV_MAT_TYPE(_responses->type) != CV_32SC1 && 264 CV_MAT_TYPE(_responses->type) != CV_32FC1) || 265 (_responses->rows != 1 && _responses->cols != 1) || 266 _responses->rows + _responses->cols - 1 != sample_all ) 267 CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or " 268 "floating-point vector containing as many elements as " 269 "the total number of samples in the training data matrix" ); 270 271 r_type = CV_VAR_CATEGORICAL; 272 if( _var_type ) 273 CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type )); 274 275 CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 )); 276 277 cat_var_count = 0; 278 ord_var_count = -1; 279 280 is_classifier = r_type == CV_VAR_CATEGORICAL; 281 282 // step 0. calc the number of categorical vars 283 for( vi = 0; vi < var_count; vi++ ) 284 { 285 char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED; 286 var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--; 287 } 288 289 ord_var_count = ~ord_var_count; 290 cv_n = params.cv_folds; 291 // set the two last elements of var_type array to be able 292 // to locate responses and cross-validation labels using 293 // the corresponding get_* functions. 294 var_type->data.i[var_count] = cat_var_count; 295 var_type->data.i[var_count+1] = cat_var_count+1; 296 297 // in case of single ordered predictor we need dummy cv_labels 298 // for safe split_node_data() operation 299 have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels; 300 301 work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels 302 + (have_labels ? 1 : 0); // for cv_labels 303 304 shared = _shared; 305 buf_count = shared ? 2 : 1; 306 307 buf_size = -1; // the member buf_size is obsolete 308 309 effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated 310 effective_buf_width = sample_count; 311 effective_buf_height = work_var_count+1; 312 313 if (effective_buf_width >= effective_buf_height) 314 effective_buf_height *= buf_count; 315 else 316 effective_buf_width *= buf_count; 317 318 if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size) 319 { 320 CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit"); 321 } 322 323 324 325 if ( is_buf_16u ) 326 { 327 CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 )); 328 CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) )); 329 } 330 else 331 { 332 CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 )); 333 CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) )); 334 } 335 336 size = is_classifier ? (cat_var_count+1) : cat_var_count; 337 size = !size ? 1 : size; 338 CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 )); 339 CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 )); 340 341 size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories; 342 size = !size ? 1 : size; 343 CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 )); 344 345 // now calculate the maximum size of split, 346 // create memory storage that will keep nodes and splits of the decision tree 347 // allocate root node and the buffer for the whole training data 348 max_split_size = cvAlign(sizeof(CvDTreeSplit) + 349 (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*)); 350 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size); 351 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size); 352 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size )); 353 CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage )); 354 355 nv_size = var_count*sizeof(int); 356 nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*)); 357 358 temp_block_size = nv_size; 359 360 if( cv_n ) 361 { 362 if( sample_count < cv_n*MAX(params.min_sample_count,10) ) 363 CV_ERROR( CV_StsOutOfRange, 364 "The many folds in cross-validation for such a small dataset" ); 365 366 cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) ); 367 temp_block_size = MAX(temp_block_size, cv_size); 368 } 369 370 temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size ); 371 CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size )); 372 CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage )); 373 if( cv_size ) 374 CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage )); 375 376 CV_CALL( data_root = new_node( 0, sample_count, 0, 0 )); 377 378 max_c_count = 1; 379 380 _fdst = 0; 381 _idst = 0; 382 if (ord_var_count) 383 _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0])); 384 if (is_buf_16u && (cat_var_count || is_classifier)) 385 _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0])); 386 387 // transform the training data to convenient representation 388 for( vi = 0; vi <= var_count; vi++ ) 389 { 390 int ci; 391 const uchar* mask = 0; 392 int64 m_step = 0, step; 393 const int* idata = 0; 394 const float* fdata = 0; 395 int num_valid = 0; 396 397 if( vi < var_count ) // analyze i-th input variable 398 { 399 int vi0 = vidx ? vidx[vi] : vi; 400 ci = get_var_type(vi); 401 step = ds_step; m_step = ms_step; 402 if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 ) 403 idata = _train_data->data.i + vi0*dv_step; 404 else 405 fdata = _train_data->data.fl + vi0*dv_step; 406 if( _missing_mask ) 407 mask = _missing_mask->data.ptr + vi0*mv_step; 408 } 409 else // analyze _responses 410 { 411 ci = cat_var_count; 412 step = CV_IS_MAT_CONT(_responses->type) ? 413 1 : _responses->step / CV_ELEM_SIZE(_responses->type); 414 if( CV_MAT_TYPE(_responses->type) == CV_32SC1 ) 415 idata = _responses->data.i; 416 else 417 fdata = _responses->data.fl; 418 } 419 420 if( (vi < var_count && ci>=0) || 421 (vi == var_count && is_classifier) ) // process categorical variable or response 422 { 423 int c_count, prev_label; 424 int* c_map; 425 426 if (is_buf_16u) 427 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count); 428 else 429 idst = buf->data.i + (size_t)vi*sample_count; 430 431 // copy data 432 for( i = 0; i < sample_count; i++ ) 433 { 434 int val = INT_MAX, si = sidx ? sidx[i] : i; 435 if( !mask || !mask[(size_t)si*m_step] ) 436 { 437 if( idata ) 438 val = idata[(size_t)si*step]; 439 else 440 { 441 float t = fdata[(size_t)si*step]; 442 val = cvRound(t); 443 if( fabs(t - val) > FLT_EPSILON ) 444 { 445 sprintf( err, "%d-th value of %d-th (categorical) " 446 "variable is not an integer", i, vi ); 447 CV_ERROR( CV_StsBadArg, err ); 448 } 449 } 450 451 if( val == INT_MAX ) 452 { 453 sprintf( err, "%d-th value of %d-th (categorical) " 454 "variable is too large", i, vi ); 455 CV_ERROR( CV_StsBadArg, err ); 456 } 457 num_valid++; 458 } 459 if (is_buf_16u) 460 { 461 _idst[i] = val; 462 pair16u32s_ptr[i].u = udst + i; 463 pair16u32s_ptr[i].i = _idst + i; 464 } 465 else 466 { 467 idst[i] = val; 468 int_ptr[i] = idst + i; 469 } 470 } 471 472 c_count = num_valid > 0; 473 if (is_buf_16u) 474 { 475 std::sort(pair16u32s_ptr, pair16u32s_ptr + sample_count, LessThanPairs()); 476 // count the categories 477 for( i = 1; i < num_valid; i++ ) 478 if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i) 479 c_count ++ ; 480 } 481 else 482 { 483 std::sort(int_ptr, int_ptr + sample_count, LessThanPtr<int>()); 484 // count the categories 485 for( i = 1; i < num_valid; i++ ) 486 c_count += *int_ptr[i] != *int_ptr[i-1]; 487 } 488 489 if( vi > 0 ) 490 max_c_count = MAX( max_c_count, c_count ); 491 cat_count->data.i[ci] = c_count; 492 cat_ofs->data.i[ci] = total_c_count; 493 494 // resize cat_map, if need 495 if( cat_map->cols < total_c_count + c_count ) 496 { 497 tmp_map = cat_map; 498 CV_CALL( cat_map = cvCreateMat( 1, 499 MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 )); 500 for( i = 0; i < total_c_count; i++ ) 501 cat_map->data.i[i] = tmp_map->data.i[i]; 502 cvReleaseMat( &tmp_map ); 503 } 504 505 c_map = cat_map->data.i + total_c_count; 506 total_c_count += c_count; 507 508 c_count = -1; 509 if (is_buf_16u) 510 { 511 // compact the class indices and build the map 512 prev_label = ~*pair16u32s_ptr[0].i; 513 for( i = 0; i < num_valid; i++ ) 514 { 515 int cur_label = *pair16u32s_ptr[i].i; 516 if( cur_label != prev_label ) 517 c_map[++c_count] = prev_label = cur_label; 518 *pair16u32s_ptr[i].u = (unsigned short)c_count; 519 } 520 // replace labels for missing values with -1 521 for( ; i < sample_count; i++ ) 522 *pair16u32s_ptr[i].u = 65535; 523 } 524 else 525 { 526 // compact the class indices and build the map 527 prev_label = ~*int_ptr[0]; 528 for( i = 0; i < num_valid; i++ ) 529 { 530 int cur_label = *int_ptr[i]; 531 if( cur_label != prev_label ) 532 c_map[++c_count] = prev_label = cur_label; 533 *int_ptr[i] = c_count; 534 } 535 // replace labels for missing values with -1 536 for( ; i < sample_count; i++ ) 537 *int_ptr[i] = -1; 538 } 539 } 540 else if( ci < 0 ) // process ordered variable 541 { 542 if (is_buf_16u) 543 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count); 544 else 545 idst = buf->data.i + (size_t)vi*sample_count; 546 547 for( i = 0; i < sample_count; i++ ) 548 { 549 float val = ord_nan; 550 int si = sidx ? sidx[i] : i; 551 if( !mask || !mask[(size_t)si*m_step] ) 552 { 553 if( idata ) 554 val = (float)idata[(size_t)si*step]; 555 else 556 val = fdata[(size_t)si*step]; 557 558 if( fabs(val) >= ord_nan ) 559 { 560 sprintf( err, "%d-th value of %d-th (ordered) " 561 "variable (=%g) is too large", i, vi, val ); 562 CV_ERROR( CV_StsBadArg, err ); 563 } 564 num_valid++; 565 } 566 567 if (is_buf_16u) 568 udst[i] = (unsigned short)i; // TODO: memory corruption may be here 569 else 570 idst[i] = i; 571 _fdst[i] = val; 572 573 } 574 if (is_buf_16u) 575 std::sort(udst, udst + sample_count, LessThanIdx<float, unsigned short>(_fdst)); 576 else 577 std::sort(idst, idst + sample_count, LessThanIdx<float, int>(_fdst)); 578 } 579 580 if( vi < var_count ) 581 data_root->set_num_valid(vi, num_valid); 582 } 583 584 // set sample labels 585 if (is_buf_16u) 586 udst = (unsigned short*)(buf->data.s + (size_t)work_var_count*sample_count); 587 else 588 idst = buf->data.i + (size_t)work_var_count*sample_count; 589 590 for (i = 0; i < sample_count; i++) 591 { 592 if (udst) 593 udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i; 594 else 595 idst[i] = sidx ? sidx[i] : i; 596 } 597 598 if( cv_n ) 599 { 600 unsigned short* usdst = 0; 601 int* idst2 = 0; 602 603 if (is_buf_16u) 604 { 605 usdst = (unsigned short*)(buf->data.s + (size_t)(get_work_var_count()-1)*sample_count); 606 for( i = vi = 0; i < sample_count; i++ ) 607 { 608 usdst[i] = (unsigned short)vi++; 609 vi &= vi < cv_n ? -1 : 0; 610 } 611 612 for( i = 0; i < sample_count; i++ ) 613 { 614 int a = (*rng)(sample_count); 615 int b = (*rng)(sample_count); 616 unsigned short unsh = (unsigned short)vi; 617 CV_SWAP( usdst[a], usdst[b], unsh ); 618 } 619 } 620 else 621 { 622 idst2 = buf->data.i + (size_t)(get_work_var_count()-1)*sample_count; 623 for( i = vi = 0; i < sample_count; i++ ) 624 { 625 idst2[i] = vi++; 626 vi &= vi < cv_n ? -1 : 0; 627 } 628 629 for( i = 0; i < sample_count; i++ ) 630 { 631 int a = (*rng)(sample_count); 632 int b = (*rng)(sample_count); 633 CV_SWAP( idst2[a], idst2[b], vi ); 634 } 635 } 636 } 637 638 if ( cat_map ) 639 cat_map->cols = MAX( total_c_count, 1 ); 640 641 max_split_size = cvAlign(sizeof(CvDTreeSplit) + 642 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); 643 CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage )); 644 645 have_priors = is_classifier && params.priors; 646 if( is_classifier ) 647 { 648 int m = get_num_classes(); 649 double sum = 0; 650 CV_CALL( priors = cvCreateMat( 1, m, CV_64F )); 651 for( i = 0; i < m; i++ ) 652 { 653 double val = have_priors ? params.priors[i] : 1.; 654 if( val <= 0 ) 655 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" ); 656 priors->data.db[i] = val; 657 sum += val; 658 } 659 660 // normalize weights 661 if( have_priors ) 662 cvScale( priors, priors, 1./sum ); 663 664 CV_CALL( priors_mult = cvCloneMat( priors )); 665 CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 )); 666 } 667 668 669 CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 )); 670 CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 )); 671 672 __END__; 673 674 if( data ) 675 delete data; 676 677 if (_fdst) 678 cvFree( &_fdst ); 679 if (_idst) 680 cvFree( &_idst ); 681 cvFree( &int_ptr ); 682 cvFree( &pair16u32s_ptr); 683 cvReleaseMat( &var_type0 ); 684 cvReleaseMat( &sample_indices ); 685 cvReleaseMat( &tmp_map ); 686 } 687 688 void CvDTreeTrainData::do_responses_copy() 689 { 690 responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type ); 691 cvCopy( responses, responses_copy); 692 responses = responses_copy; 693 } 694 695 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx ) 696 { 697 CvDTreeNode* root = 0; 698 CvMat* isubsample_idx = 0; 699 CvMat* subsample_co = 0; 700 701 bool isMakeRootCopy = true; 702 703 CV_FUNCNAME( "CvDTreeTrainData::subsample_data" ); 704 705 __BEGIN__; 706 707 if( !data_root ) 708 CV_ERROR( CV_StsError, "No training data has been set" ); 709 710 if( _subsample_idx ) 711 { 712 CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); 713 714 if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count ) 715 { 716 const int* sidx = isubsample_idx->data.i; 717 for( int i = 0; i < sample_count; i++ ) 718 { 719 if( sidx[i] != i ) 720 { 721 isMakeRootCopy = false; 722 break; 723 } 724 } 725 } 726 else 727 isMakeRootCopy = false; 728 } 729 730 if( isMakeRootCopy ) 731 { 732 // make a copy of the root node 733 CvDTreeNode temp; 734 int i; 735 root = new_node( 0, 1, 0, 0 ); 736 temp = *root; 737 *root = *data_root; 738 root->num_valid = temp.num_valid; 739 if( root->num_valid ) 740 { 741 for( i = 0; i < var_count; i++ ) 742 root->num_valid[i] = data_root->num_valid[i]; 743 } 744 root->cv_Tn = temp.cv_Tn; 745 root->cv_node_risk = temp.cv_node_risk; 746 root->cv_node_error = temp.cv_node_error; 747 } 748 else 749 { 750 int* sidx = isubsample_idx->data.i; 751 // co - array of count/offset pairs (to handle duplicated values in _subsample_idx) 752 int* co, cur_ofs = 0; 753 int vi, i; 754 int workVarCount = get_work_var_count(); 755 int count = isubsample_idx->rows + isubsample_idx->cols - 1; 756 757 root = new_node( 0, count, 1, 0 ); 758 759 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 )); 760 cvZero( subsample_co ); 761 co = subsample_co->data.i; 762 for( i = 0; i < count; i++ ) 763 co[sidx[i]*2]++; 764 for( i = 0; i < sample_count; i++ ) 765 { 766 if( co[i*2] ) 767 { 768 co[i*2+1] = cur_ofs; 769 cur_ofs += co[i*2]; 770 } 771 else 772 co[i*2+1] = -1; 773 } 774 775 cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float))); 776 for( vi = 0; vi < workVarCount; vi++ ) 777 { 778 int ci = get_var_type(vi); 779 780 if( ci >= 0 || vi >= var_count ) 781 { 782 int num_valid = 0; 783 const int* src = CvDTreeTrainData::get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf ); 784 785 if (is_buf_16u) 786 { 787 unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + 788 (size_t)vi*sample_count + root->offset); 789 for( i = 0; i < count; i++ ) 790 { 791 int val = src[sidx[i]]; 792 udst[i] = (unsigned short)val; 793 num_valid += val >= 0; 794 } 795 } 796 else 797 { 798 int* idst = buf->data.i + root->buf_idx*get_length_subbuf() + 799 (size_t)vi*sample_count + root->offset; 800 for( i = 0; i < count; i++ ) 801 { 802 int val = src[sidx[i]]; 803 idst[i] = val; 804 num_valid += val >= 0; 805 } 806 } 807 808 if( vi < var_count ) 809 root->set_num_valid(vi, num_valid); 810 } 811 else 812 { 813 int *src_idx_buf = (int*)(uchar*)inn_buf; 814 float *src_val_buf = (float*)(src_idx_buf + sample_count); 815 int* sample_indices_buf = (int*)(src_val_buf + sample_count); 816 const int* src_idx = 0; 817 const float* src_val = 0; 818 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf ); 819 int j = 0, idx, count_i; 820 int num_valid = data_root->get_num_valid(vi); 821 822 if (is_buf_16u) 823 { 824 unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + 825 (size_t)vi*sample_count + data_root->offset); 826 for( i = 0; i < num_valid; i++ ) 827 { 828 idx = src_idx[i]; 829 count_i = co[idx*2]; 830 if( count_i ) 831 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 832 udst_idx[j] = (unsigned short)cur_ofs; 833 } 834 835 root->set_num_valid(vi, j); 836 837 for( ; i < sample_count; i++ ) 838 { 839 idx = src_idx[i]; 840 count_i = co[idx*2]; 841 if( count_i ) 842 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 843 udst_idx[j] = (unsigned short)cur_ofs; 844 } 845 } 846 else 847 { 848 int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() + 849 (size_t)vi*sample_count + root->offset; 850 for( i = 0; i < num_valid; i++ ) 851 { 852 idx = src_idx[i]; 853 count_i = co[idx*2]; 854 if( count_i ) 855 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 856 idst_idx[j] = cur_ofs; 857 } 858 859 root->set_num_valid(vi, j); 860 861 for( ; i < sample_count; i++ ) 862 { 863 idx = src_idx[i]; 864 count_i = co[idx*2]; 865 if( count_i ) 866 for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) 867 idst_idx[j] = cur_ofs; 868 } 869 } 870 } 871 } 872 // sample indices subsampling 873 const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf); 874 if (is_buf_16u) 875 { 876 unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + 877 (size_t)workVarCount*sample_count + root->offset); 878 for (i = 0; i < count; i++) 879 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]]; 880 } 881 else 882 { 883 int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() + 884 (size_t)workVarCount*sample_count + root->offset; 885 for (i = 0; i < count; i++) 886 sample_idx_dst[i] = sample_idx_src[sidx[i]]; 887 } 888 } 889 890 __END__; 891 892 cvReleaseMat( &isubsample_idx ); 893 cvReleaseMat( &subsample_co ); 894 895 return root; 896 } 897 898 899 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx, 900 float* values, uchar* missing, 901 float* _responses, bool get_class_idx ) 902 { 903 CvMat* subsample_idx = 0; 904 CvMat* subsample_co = 0; 905 906 CV_FUNCNAME( "CvDTreeTrainData::get_vectors" ); 907 908 __BEGIN__; 909 910 int i, vi, total = sample_count, count = total, cur_ofs = 0; 911 int* sidx = 0; 912 int* co = 0; 913 914 cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float))); 915 if( _subsample_idx ) 916 { 917 CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )); 918 sidx = subsample_idx->data.i; 919 CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 )); 920 co = subsample_co->data.i; 921 cvZero( subsample_co ); 922 count = subsample_idx->cols + subsample_idx->rows - 1; 923 for( i = 0; i < count; i++ ) 924 co[sidx[i]*2]++; 925 for( i = 0; i < total; i++ ) 926 { 927 int count_i = co[i*2]; 928 if( count_i ) 929 { 930 co[i*2+1] = cur_ofs*var_count; 931 cur_ofs += count_i; 932 } 933 } 934 } 935 936 if( missing ) 937 memset( missing, 1, count*var_count ); 938 939 for( vi = 0; vi < var_count; vi++ ) 940 { 941 int ci = get_var_type(vi); 942 if( ci >= 0 ) // categorical 943 { 944 float* dst = values + vi; 945 uchar* m = missing ? missing + vi : 0; 946 const int* src = get_cat_var_data(data_root, vi, (int*)(uchar*)inn_buf); 947 948 for( i = 0; i < count; i++, dst += var_count ) 949 { 950 int idx = sidx ? sidx[i] : i; 951 int val = src[idx]; 952 *dst = (float)val; 953 if( m ) 954 { 955 *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535)); 956 m += var_count; 957 } 958 } 959 } 960 else // ordered 961 { 962 float* dst = values + vi; 963 uchar* m = missing ? missing + vi : 0; 964 int count1 = data_root->get_num_valid(vi); 965 float *src_val_buf = (float*)(uchar*)inn_buf; 966 int* src_idx_buf = (int*)(src_val_buf + sample_count); 967 int* sample_indices_buf = src_idx_buf + sample_count; 968 const float *src_val = 0; 969 const int* src_idx = 0; 970 get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf); 971 972 for( i = 0; i < count1; i++ ) 973 { 974 int idx = src_idx[i]; 975 int count_i = 1; 976 if( co ) 977 { 978 count_i = co[idx*2]; 979 cur_ofs = co[idx*2+1]; 980 } 981 else 982 cur_ofs = idx*var_count; 983 if( count_i ) 984 { 985 float val = src_val[i]; 986 for( ; count_i > 0; count_i--, cur_ofs += var_count ) 987 { 988 dst[cur_ofs] = val; 989 if( m ) 990 m[cur_ofs] = 0; 991 } 992 } 993 } 994 } 995 } 996 997 // copy responses 998 if( _responses ) 999 { 1000 if( is_classifier ) 1001 { 1002 const int* src = get_class_labels(data_root, (int*)(uchar*)inn_buf); 1003 for( i = 0; i < count; i++ ) 1004 { 1005 int idx = sidx ? sidx[i] : i; 1006 int val = get_class_idx ? src[idx] : 1007 cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]]; 1008 _responses[i] = (float)val; 1009 } 1010 } 1011 else 1012 { 1013 float* val_buf = (float*)(uchar*)inn_buf; 1014 int* sample_idx_buf = (int*)(val_buf + sample_count); 1015 const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf); 1016 for( i = 0; i < count; i++ ) 1017 { 1018 int idx = sidx ? sidx[i] : i; 1019 _responses[i] = _values[idx]; 1020 } 1021 } 1022 } 1023 1024 __END__; 1025 1026 cvReleaseMat( &subsample_idx ); 1027 cvReleaseMat( &subsample_co ); 1028 } 1029 1030 1031 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count, 1032 int storage_idx, int offset ) 1033 { 1034 CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap ); 1035 1036 node->sample_count = count; 1037 node->depth = parent ? parent->depth + 1 : 0; 1038 node->parent = parent; 1039 node->left = node->right = 0; 1040 node->split = 0; 1041 node->value = 0; 1042 node->class_idx = 0; 1043 node->maxlr = 0.; 1044 1045 node->buf_idx = storage_idx; 1046 node->offset = offset; 1047 if( nv_heap ) 1048 node->num_valid = (int*)cvSetNew( nv_heap ); 1049 else 1050 node->num_valid = 0; 1051 node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.; 1052 node->complexity = 0; 1053 1054 if( params.cv_folds > 0 && cv_heap ) 1055 { 1056 int cv_n = params.cv_folds; 1057 node->Tn = INT_MAX; 1058 node->cv_Tn = (int*)cvSetNew( cv_heap ); 1059 node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double)); 1060 node->cv_node_error = node->cv_node_risk + cv_n; 1061 } 1062 else 1063 { 1064 node->Tn = 0; 1065 node->cv_Tn = 0; 1066 node->cv_node_risk = 0; 1067 node->cv_node_error = 0; 1068 } 1069 1070 return node; 1071 } 1072 1073 1074 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val, 1075 int split_point, int inversed, float quality ) 1076 { 1077 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap ); 1078 split->var_idx = vi; 1079 split->condensed_idx = INT_MIN; 1080 split->ord.c = cmp_val; 1081 split->ord.split_point = split_point; 1082 split->inversed = inversed; 1083 split->quality = quality; 1084 split->next = 0; 1085 1086 return split; 1087 } 1088 1089 1090 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality ) 1091 { 1092 CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap ); 1093 int i, n = (max_c_count + 31)/32; 1094 1095 split->var_idx = vi; 1096 split->condensed_idx = INT_MIN; 1097 split->inversed = 0; 1098 split->quality = quality; 1099 for( i = 0; i < n; i++ ) 1100 split->subset[i] = 0; 1101 split->next = 0; 1102 1103 return split; 1104 } 1105 1106 1107 void CvDTreeTrainData::free_node( CvDTreeNode* node ) 1108 { 1109 CvDTreeSplit* split = node->split; 1110 free_node_data( node ); 1111 while( split ) 1112 { 1113 CvDTreeSplit* next = split->next; 1114 cvSetRemoveByPtr( split_heap, split ); 1115 split = next; 1116 } 1117 node->split = 0; 1118 cvSetRemoveByPtr( node_heap, node ); 1119 } 1120 1121 1122 void CvDTreeTrainData::free_node_data( CvDTreeNode* node ) 1123 { 1124 if( node->num_valid ) 1125 { 1126 cvSetRemoveByPtr( nv_heap, node->num_valid ); 1127 node->num_valid = 0; 1128 } 1129 // do not free cv_* fields, as all the cross-validation related data is released at once. 1130 } 1131 1132 1133 void CvDTreeTrainData::free_train_data() 1134 { 1135 cvReleaseMat( &counts ); 1136 cvReleaseMat( &buf ); 1137 cvReleaseMat( &direction ); 1138 cvReleaseMat( &split_buf ); 1139 cvReleaseMemStorage( &temp_storage ); 1140 cvReleaseMat( &responses_copy ); 1141 cv_heap = nv_heap = 0; 1142 } 1143 1144 1145 void CvDTreeTrainData::clear() 1146 { 1147 free_train_data(); 1148 1149 cvReleaseMemStorage( &tree_storage ); 1150 1151 cvReleaseMat( &var_idx ); 1152 cvReleaseMat( &var_type ); 1153 cvReleaseMat( &cat_count ); 1154 cvReleaseMat( &cat_ofs ); 1155 cvReleaseMat( &cat_map ); 1156 cvReleaseMat( &priors ); 1157 cvReleaseMat( &priors_mult ); 1158 1159 node_heap = split_heap = 0; 1160 1161 sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0; 1162 have_labels = have_priors = is_classifier = false; 1163 1164 buf_count = buf_size = 0; 1165 shared = false; 1166 1167 data_root = 0; 1168 1169 rng = &cv::theRNG(); 1170 } 1171 1172 1173 int CvDTreeTrainData::get_num_classes() const 1174 { 1175 return is_classifier ? cat_count->data.i[cat_var_count] : 0; 1176 } 1177 1178 1179 int CvDTreeTrainData::get_var_type(int vi) const 1180 { 1181 return var_type->data.i[vi]; 1182 } 1183 1184 void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf, 1185 const float** ord_values, const int** sorted_indices, int* sample_indices_buf ) 1186 { 1187 int vidx = var_idx ? var_idx->data.i[vi] : vi; 1188 int node_sample_count = n->sample_count; 1189 int td_step = train_data->step/CV_ELEM_SIZE(train_data->type); 1190 1191 const int* sample_indices = get_sample_indices(n, sample_indices_buf); 1192 1193 if( !is_buf_16u ) 1194 *sorted_indices = buf->data.i + n->buf_idx*get_length_subbuf() + 1195 (size_t)vi*sample_count + n->offset; 1196 else { 1197 const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() + 1198 (size_t)vi*sample_count + n->offset ); 1199 for( int i = 0; i < node_sample_count; i++ ) 1200 sorted_indices_buf[i] = short_indices[i]; 1201 *sorted_indices = sorted_indices_buf; 1202 } 1203 1204 if( tflag == CV_ROW_SAMPLE ) 1205 { 1206 for( int i = 0; i < node_sample_count && 1207 ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ ) 1208 { 1209 int idx = (*sorted_indices)[i]; 1210 idx = sample_indices[idx]; 1211 ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx); 1212 } 1213 } 1214 else 1215 for( int i = 0; i < node_sample_count && 1216 ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ ) 1217 { 1218 int idx = (*sorted_indices)[i]; 1219 idx = sample_indices[idx]; 1220 ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx); 1221 } 1222 1223 *ord_values = ord_values_buf; 1224 } 1225 1226 1227 const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf ) 1228 { 1229 if (is_classifier) 1230 return get_cat_var_data( n, var_count, labels_buf); 1231 return 0; 1232 } 1233 1234 const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf ) 1235 { 1236 return get_cat_var_data( n, get_work_var_count(), indices_buf ); 1237 } 1238 1239 const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf ) 1240 { 1241 int _sample_count = n->sample_count; 1242 int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type); 1243 const int* indices = get_sample_indices(n, sample_indices_buf); 1244 1245 for( int i = 0; i < _sample_count && 1246 (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ ) 1247 { 1248 int idx = indices[i]; 1249 values_buf[i] = *(responses->data.fl + idx * r_step); 1250 } 1251 1252 return values_buf; 1253 } 1254 1255 1256 const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf ) 1257 { 1258 if (have_labels) 1259 return get_cat_var_data( n, get_work_var_count()- 1, labels_buf); 1260 return 0; 1261 } 1262 1263 1264 const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf) 1265 { 1266 const int* cat_values = 0; 1267 if( !is_buf_16u ) 1268 cat_values = buf->data.i + n->buf_idx*get_length_subbuf() + 1269 (size_t)vi*sample_count + n->offset; 1270 else { 1271 const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() + 1272 (size_t)vi*sample_count + n->offset); 1273 for( int i = 0; i < n->sample_count; i++ ) 1274 cat_values_buf[i] = short_values[i]; 1275 cat_values = cat_values_buf; 1276 } 1277 return cat_values; 1278 } 1279 1280 1281 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n ) 1282 { 1283 int idx = n->buf_idx + 1; 1284 if( idx >= buf_count ) 1285 idx = shared ? 1 : 0; 1286 return idx; 1287 } 1288 1289 1290 void CvDTreeTrainData::write_params( CvFileStorage* fs ) const 1291 { 1292 CV_FUNCNAME( "CvDTreeTrainData::write_params" ); 1293 1294 __BEGIN__; 1295 1296 int vi, vcount = var_count; 1297 1298 cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 ); 1299 cvWriteInt( fs, "var_all", var_all ); 1300 cvWriteInt( fs, "var_count", var_count ); 1301 cvWriteInt( fs, "ord_var_count", ord_var_count ); 1302 cvWriteInt( fs, "cat_var_count", cat_var_count ); 1303 1304 cvStartWriteStruct( fs, "training_params", CV_NODE_MAP ); 1305 cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 ); 1306 1307 if( is_classifier ) 1308 { 1309 cvWriteInt( fs, "max_categories", params.max_categories ); 1310 } 1311 else 1312 { 1313 cvWriteReal( fs, "regression_accuracy", params.regression_accuracy ); 1314 } 1315 1316 cvWriteInt( fs, "max_depth", params.max_depth ); 1317 cvWriteInt( fs, "min_sample_count", params.min_sample_count ); 1318 cvWriteInt( fs, "cross_validation_folds", params.cv_folds ); 1319 1320 if( params.cv_folds > 1 ) 1321 { 1322 cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 ); 1323 cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 ); 1324 } 1325 1326 if( priors ) 1327 cvWrite( fs, "priors", priors ); 1328 1329 cvEndWriteStruct( fs ); 1330 1331 if( var_idx ) 1332 cvWrite( fs, "var_idx", var_idx ); 1333 1334 cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW ); 1335 1336 for( vi = 0; vi < vcount; vi++ ) 1337 cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 ); 1338 1339 cvEndWriteStruct( fs ); 1340 1341 if( cat_count && (cat_var_count > 0 || is_classifier) ) 1342 { 1343 CV_ASSERT( cat_count != 0 ); 1344 cvWrite( fs, "cat_count", cat_count ); 1345 cvWrite( fs, "cat_map", cat_map ); 1346 } 1347 1348 __END__; 1349 } 1350 1351 1352 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node ) 1353 { 1354 CV_FUNCNAME( "CvDTreeTrainData::read_params" ); 1355 1356 __BEGIN__; 1357 1358 CvFileNode *tparams_node, *vartype_node; 1359 CvSeqReader reader; 1360 int vi, max_split_size, tree_block_size; 1361 1362 is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0); 1363 var_all = cvReadIntByName( fs, node, "var_all" ); 1364 var_count = cvReadIntByName( fs, node, "var_count", var_all ); 1365 cat_var_count = cvReadIntByName( fs, node, "cat_var_count" ); 1366 ord_var_count = cvReadIntByName( fs, node, "ord_var_count" ); 1367 1368 tparams_node = cvGetFileNodeByName( fs, node, "training_params" ); 1369 1370 if( tparams_node ) // training parameters are not necessary 1371 { 1372 params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0; 1373 1374 if( is_classifier ) 1375 { 1376 params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" ); 1377 } 1378 else 1379 { 1380 params.regression_accuracy = 1381 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" ); 1382 } 1383 1384 params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" ); 1385 params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" ); 1386 params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" ); 1387 1388 if( params.cv_folds > 1 ) 1389 { 1390 params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0; 1391 params.truncate_pruned_tree = 1392 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0; 1393 } 1394 1395 priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" ); 1396 if( priors ) 1397 { 1398 if( !CV_IS_MAT(priors) ) 1399 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" ); 1400 priors_mult = cvCloneMat( priors ); 1401 } 1402 } 1403 1404 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" )); 1405 if( var_idx ) 1406 { 1407 if( !CV_IS_MAT(var_idx) || 1408 (var_idx->cols != 1 && var_idx->rows != 1) || 1409 var_idx->cols + var_idx->rows - 1 != var_count || 1410 CV_MAT_TYPE(var_idx->type) != CV_32SC1 ) 1411 CV_ERROR( CV_StsParseError, 1412 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" ); 1413 1414 for( vi = 0; vi < var_count; vi++ ) 1415 if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all ) 1416 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" ); 1417 } 1418 1419 ////// read var type 1420 CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 )); 1421 1422 cat_var_count = 0; 1423 ord_var_count = -1; 1424 vartype_node = cvGetFileNodeByName( fs, node, "var_type" ); 1425 1426 if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 ) 1427 var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--; 1428 else 1429 { 1430 if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ || 1431 vartype_node->data.seq->total != var_count ) 1432 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" ); 1433 1434 cvStartReadSeq( vartype_node->data.seq, &reader ); 1435 1436 for( vi = 0; vi < var_count; vi++ ) 1437 { 1438 CvFileNode* n = (CvFileNode*)reader.ptr; 1439 if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) ) 1440 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" ); 1441 var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--; 1442 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 1443 } 1444 } 1445 var_type->data.i[var_count] = cat_var_count; 1446 1447 ord_var_count = ~ord_var_count; 1448 ////// 1449 1450 if( cat_var_count > 0 || is_classifier ) 1451 { 1452 int ccount, total_c_count = 0; 1453 CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" )); 1454 CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" )); 1455 1456 if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) || 1457 (cat_count->cols != 1 && cat_count->rows != 1) || 1458 CV_MAT_TYPE(cat_count->type) != CV_32SC1 || 1459 cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier || 1460 (cat_map->cols != 1 && cat_map->rows != 1) || 1461 CV_MAT_TYPE(cat_map->type) != CV_32SC1 ) 1462 CV_ERROR( CV_StsParseError, 1463 "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" ); 1464 1465 ccount = cat_var_count + is_classifier; 1466 1467 CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 )); 1468 cat_ofs->data.i[0] = 0; 1469 max_c_count = 1; 1470 1471 for( vi = 0; vi < ccount; vi++ ) 1472 { 1473 int val = cat_count->data.i[vi]; 1474 if( val <= 0 ) 1475 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" ); 1476 max_c_count = MAX( max_c_count, val ); 1477 cat_ofs->data.i[vi+1] = total_c_count += val; 1478 } 1479 1480 if( cat_map->cols + cat_map->rows - 1 != total_c_count ) 1481 CV_ERROR( CV_StsBadSize, 1482 "cat_map vector length is not equal to the total number of categories in all categorical vars" ); 1483 } 1484 1485 max_split_size = cvAlign(sizeof(CvDTreeSplit) + 1486 (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); 1487 1488 tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size); 1489 tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size); 1490 CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size )); 1491 CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]), 1492 sizeof(CvDTreeNode), tree_storage )); 1493 CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]), 1494 max_split_size, tree_storage )); 1495 1496 __END__; 1497 } 1498 1499 /////////////////////// Decision Tree ///////////////////////// 1500 CvDTreeParams::CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10), 1501 cv_folds(10), use_surrogates(true), use_1se_rule(true), 1502 truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0) 1503 {} 1504 1505 CvDTreeParams::CvDTreeParams( int _max_depth, int _min_sample_count, 1506 float _regression_accuracy, bool _use_surrogates, 1507 int _max_categories, int _cv_folds, 1508 bool _use_1se_rule, bool _truncate_pruned_tree, 1509 const float* _priors ) : 1510 max_categories(_max_categories), max_depth(_max_depth), 1511 min_sample_count(_min_sample_count), cv_folds (_cv_folds), 1512 use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule), 1513 truncate_pruned_tree(_truncate_pruned_tree), 1514 regression_accuracy(_regression_accuracy), 1515 priors(_priors) 1516 {} 1517 1518 CvDTree::CvDTree() 1519 { 1520 data = 0; 1521 var_importance = 0; 1522 default_model_name = "my_tree"; 1523 1524 clear(); 1525 } 1526 1527 1528 void CvDTree::clear() 1529 { 1530 cvReleaseMat( &var_importance ); 1531 if( data ) 1532 { 1533 if( !data->shared ) 1534 delete data; 1535 else 1536 free_tree(); 1537 data = 0; 1538 } 1539 root = 0; 1540 pruned_tree_idx = -1; 1541 } 1542 1543 1544 CvDTree::~CvDTree() 1545 { 1546 clear(); 1547 } 1548 1549 1550 const CvDTreeNode* CvDTree::get_root() const 1551 { 1552 return root; 1553 } 1554 1555 1556 int CvDTree::get_pruned_tree_idx() const 1557 { 1558 return pruned_tree_idx; 1559 } 1560 1561 1562 CvDTreeTrainData* CvDTree::get_data() 1563 { 1564 return data; 1565 } 1566 1567 1568 bool CvDTree::train( const CvMat* _train_data, int _tflag, 1569 const CvMat* _responses, const CvMat* _var_idx, 1570 const CvMat* _sample_idx, const CvMat* _var_type, 1571 const CvMat* _missing_mask, CvDTreeParams _params ) 1572 { 1573 bool result = false; 1574 1575 CV_FUNCNAME( "CvDTree::train" ); 1576 1577 __BEGIN__; 1578 1579 clear(); 1580 data = new CvDTreeTrainData( _train_data, _tflag, _responses, 1581 _var_idx, _sample_idx, _var_type, 1582 _missing_mask, _params, false ); 1583 CV_CALL( result = do_train(0) ); 1584 1585 __END__; 1586 1587 return result; 1588 } 1589 1590 bool CvDTree::train( const Mat& _train_data, int _tflag, 1591 const Mat& _responses, const Mat& _var_idx, 1592 const Mat& _sample_idx, const Mat& _var_type, 1593 const Mat& _missing_mask, CvDTreeParams _params ) 1594 { 1595 train_data_hdr = _train_data; 1596 train_data_mat = _train_data; 1597 responses_hdr = _responses; 1598 responses_mat = _responses; 1599 1600 CvMat vidx=_var_idx, sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask; 1601 1602 return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0, 1603 vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params); 1604 } 1605 1606 1607 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params ) 1608 { 1609 bool result = false; 1610 1611 CV_FUNCNAME( "CvDTree::train" ); 1612 1613 __BEGIN__; 1614 1615 const CvMat* values = _data->get_values(); 1616 const CvMat* response = _data->get_responses(); 1617 const CvMat* missing = _data->get_missing(); 1618 const CvMat* var_types = _data->get_var_types(); 1619 const CvMat* train_sidx = _data->get_train_sample_idx(); 1620 const CvMat* var_idx = _data->get_var_idx(); 1621 1622 CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx, 1623 train_sidx, var_types, missing, _params ) ); 1624 1625 __END__; 1626 1627 return result; 1628 } 1629 1630 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx ) 1631 { 1632 bool result = false; 1633 1634 CV_FUNCNAME( "CvDTree::train" ); 1635 1636 __BEGIN__; 1637 1638 clear(); 1639 data = _data; 1640 data->shared = true; 1641 CV_CALL( result = do_train(_subsample_idx)); 1642 1643 __END__; 1644 1645 return result; 1646 } 1647 1648 1649 bool CvDTree::do_train( const CvMat* _subsample_idx ) 1650 { 1651 bool result = false; 1652 1653 CV_FUNCNAME( "CvDTree::do_train" ); 1654 1655 __BEGIN__; 1656 1657 root = data->subsample_data( _subsample_idx ); 1658 1659 CV_CALL( try_split_node(root)); 1660 1661 if( root->split ) 1662 { 1663 CV_Assert( root->left ); 1664 CV_Assert( root->right ); 1665 1666 if( data->params.cv_folds > 0 ) 1667 CV_CALL( prune_cv() ); 1668 1669 if( !data->shared ) 1670 data->free_train_data(); 1671 1672 result = true; 1673 } 1674 1675 __END__; 1676 1677 return result; 1678 } 1679 1680 1681 void CvDTree::try_split_node( CvDTreeNode* node ) 1682 { 1683 CvDTreeSplit* best_split = 0; 1684 int i, n = node->sample_count, vi; 1685 bool can_split = true; 1686 double quality_scale; 1687 1688 calc_node_value( node ); 1689 1690 if( node->sample_count <= data->params.min_sample_count || 1691 node->depth >= data->params.max_depth ) 1692 can_split = false; 1693 1694 if( can_split && data->is_classifier ) 1695 { 1696 // check if we have a "pure" node, 1697 // we assume that cls_count is filled by calc_node_value() 1698 int* cls_count = data->counts->data.i; 1699 int nz = 0, m = data->get_num_classes(); 1700 for( i = 0; i < m; i++ ) 1701 nz += cls_count[i] != 0; 1702 if( nz == 1 ) // there is only one class 1703 can_split = false; 1704 } 1705 else if( can_split ) 1706 { 1707 if( sqrt(node->node_risk)/n < data->params.regression_accuracy ) 1708 can_split = false; 1709 } 1710 1711 if( can_split ) 1712 { 1713 best_split = find_best_split(node); 1714 // TODO: check the split quality ... 1715 node->split = best_split; 1716 } 1717 if( !can_split || !best_split ) 1718 { 1719 data->free_node_data(node); 1720 return; 1721 } 1722 1723 quality_scale = calc_node_dir( node ); 1724 if( data->params.use_surrogates ) 1725 { 1726 // find all the surrogate splits 1727 // and sort them by their similarity to the primary one 1728 for( vi = 0; vi < data->var_count; vi++ ) 1729 { 1730 CvDTreeSplit* split; 1731 int ci = data->get_var_type(vi); 1732 1733 if( vi == best_split->var_idx ) 1734 continue; 1735 1736 if( ci >= 0 ) 1737 split = find_surrogate_split_cat( node, vi ); 1738 else 1739 split = find_surrogate_split_ord( node, vi ); 1740 1741 if( split ) 1742 { 1743 // insert the split 1744 CvDTreeSplit* prev_split = node->split; 1745 split->quality = (float)(split->quality*quality_scale); 1746 1747 while( prev_split->next && 1748 prev_split->next->quality > split->quality ) 1749 prev_split = prev_split->next; 1750 split->next = prev_split->next; 1751 prev_split->next = split; 1752 } 1753 } 1754 } 1755 split_node_data( node ); 1756 try_split_node( node->left ); 1757 try_split_node( node->right ); 1758 } 1759 1760 1761 // calculate direction (left(-1),right(1),missing(0)) 1762 // for each sample using the best split 1763 // the function returns scale coefficients for surrogate split quality factors. 1764 // the scale is applied to normalize surrogate split quality relatively to the 1765 // best (primary) split quality. That is, if a surrogate split is absolutely 1766 // identical to the primary split, its quality will be set to the maximum value = 1767 // quality of the primary split; otherwise, it will be lower. 1768 // besides, the function compute node->maxlr, 1769 // minimum possible quality (w/o considering the above mentioned scale) 1770 // for a surrogate split. Surrogate splits with quality less than node->maxlr 1771 // are not discarded. 1772 double CvDTree::calc_node_dir( CvDTreeNode* node ) 1773 { 1774 char* dir = (char*)data->direction->data.ptr; 1775 int i, n = node->sample_count, vi = node->split->var_idx; 1776 double L, R; 1777 1778 assert( !node->split->inversed ); 1779 1780 if( data->get_var_type(vi) >= 0 ) // split on categorical var 1781 { 1782 cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2)); 1783 int* labels_buf = (int*)inn_buf; 1784 const int* labels = data->get_cat_var_data( node, vi, labels_buf ); 1785 const int* subset = node->split->subset; 1786 if( !data->have_priors ) 1787 { 1788 int sum = 0, sum_abs = 0; 1789 1790 for( i = 0; i < n; i++ ) 1791 { 1792 int idx = labels[i]; 1793 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ? 1794 CV_DTREE_CAT_DIR(idx,subset) : 0; 1795 sum += d; sum_abs += d & 1; 1796 dir[i] = (char)d; 1797 } 1798 1799 R = (sum_abs + sum) >> 1; 1800 L = (sum_abs - sum) >> 1; 1801 } 1802 else 1803 { 1804 const double* priors = data->priors_mult->data.db; 1805 double sum = 0, sum_abs = 0; 1806 int* responses_buf = labels_buf + n; 1807 const int* responses = data->get_class_labels(node, responses_buf); 1808 1809 for( i = 0; i < n; i++ ) 1810 { 1811 int idx = labels[i]; 1812 double w = priors[responses[i]]; 1813 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0; 1814 sum += d*w; sum_abs += (d & 1)*w; 1815 dir[i] = (char)d; 1816 } 1817 1818 R = (sum_abs + sum) * 0.5; 1819 L = (sum_abs - sum) * 0.5; 1820 } 1821 } 1822 else // split on ordered var 1823 { 1824 int split_point = node->split->ord.split_point; 1825 int n1 = node->get_num_valid(vi); 1826 cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float))); 1827 float* val_buf = (float*)(uchar*)inn_buf; 1828 int* sorted_buf = (int*)(val_buf + n); 1829 int* sample_idx_buf = sorted_buf + n; 1830 const float* val = 0; 1831 const int* sorted = 0; 1832 data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf); 1833 1834 assert( 0 <= split_point && split_point < n1-1 ); 1835 1836 if( !data->have_priors ) 1837 { 1838 for( i = 0; i <= split_point; i++ ) 1839 dir[sorted[i]] = (char)-1; 1840 for( ; i < n1; i++ ) 1841 dir[sorted[i]] = (char)1; 1842 for( ; i < n; i++ ) 1843 dir[sorted[i]] = (char)0; 1844 1845 L = split_point-1; 1846 R = n1 - split_point + 1; 1847 } 1848 else 1849 { 1850 const double* priors = data->priors_mult->data.db; 1851 int* responses_buf = sample_idx_buf + n; 1852 const int* responses = data->get_class_labels(node, responses_buf); 1853 L = R = 0; 1854 1855 for( i = 0; i <= split_point; i++ ) 1856 { 1857 int idx = sorted[i]; 1858 double w = priors[responses[idx]]; 1859 dir[idx] = (char)-1; 1860 L += w; 1861 } 1862 1863 for( ; i < n1; i++ ) 1864 { 1865 int idx = sorted[i]; 1866 double w = priors[responses[idx]]; 1867 dir[idx] = (char)1; 1868 R += w; 1869 } 1870 1871 for( ; i < n; i++ ) 1872 dir[sorted[i]] = (char)0; 1873 } 1874 } 1875 node->maxlr = MAX( L, R ); 1876 return node->split->quality/(L + R); 1877 } 1878 1879 1880 namespace cv 1881 { 1882 1883 template<> CV_EXPORTS void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const 1884 { 1885 fastFree(obj); 1886 } 1887 1888 DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node) 1889 { 1890 tree = _tree; 1891 node = _node; 1892 splitSize = tree->get_data()->split_heap->elem_size; 1893 1894 bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize)); 1895 memset(bestSplit.get(), 0, splitSize); 1896 bestSplit->quality = -1; 1897 bestSplit->condensed_idx = INT_MIN; 1898 split.reset((CvDTreeSplit*)fastMalloc(splitSize)); 1899 memset(split.get(), 0, splitSize); 1900 //haveSplit = false; 1901 } 1902 1903 DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split ) 1904 { 1905 tree = finder.tree; 1906 node = finder.node; 1907 splitSize = tree->get_data()->split_heap->elem_size; 1908 1909 bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize)); 1910 memcpy(bestSplit.get(), finder.bestSplit.get(), splitSize); 1911 split.reset((CvDTreeSplit*)fastMalloc(splitSize)); 1912 memset(split.get(), 0, splitSize); 1913 } 1914 1915 void DTreeBestSplitFinder::operator()(const BlockedRange& range) 1916 { 1917 int vi, vi1 = range.begin(), vi2 = range.end(); 1918 int n = node->sample_count; 1919 CvDTreeTrainData* data = tree->get_data(); 1920 AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float))); 1921 1922 for( vi = vi1; vi < vi2; vi++ ) 1923 { 1924 CvDTreeSplit *res; 1925 int ci = data->get_var_type(vi); 1926 if( node->get_num_valid(vi) <= 1 ) 1927 continue; 1928 1929 if( data->is_classifier ) 1930 { 1931 if( ci >= 0 ) 1932 res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf ); 1933 else 1934 res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf ); 1935 } 1936 else 1937 { 1938 if( ci >= 0 ) 1939 res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf ); 1940 else 1941 res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf ); 1942 } 1943 1944 if( res && bestSplit->quality < split->quality ) 1945 memcpy( bestSplit.get(), split.get(), splitSize ); 1946 } 1947 } 1948 1949 void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs ) 1950 { 1951 if( bestSplit->quality < rhs.bestSplit->quality ) 1952 memcpy( bestSplit.get(), rhs.bestSplit.get(), splitSize ); 1953 } 1954 } 1955 1956 1957 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node ) 1958 { 1959 DTreeBestSplitFinder finder( this, node ); 1960 1961 cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder); 1962 1963 CvDTreeSplit *bestSplit = 0; 1964 if( finder.bestSplit->quality > 0 ) 1965 { 1966 bestSplit = data->new_split_cat( 0, -1.0f ); 1967 memcpy( bestSplit, finder.bestSplit, finder.splitSize ); 1968 } 1969 1970 return bestSplit; 1971 } 1972 1973 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi, 1974 float init_quality, CvDTreeSplit* _split, uchar* _ext_buf ) 1975 { 1976 const float epsilon = FLT_EPSILON*2; 1977 int n = node->sample_count; 1978 int n1 = node->get_num_valid(vi); 1979 int m = data->get_num_classes(); 1980 1981 int base_size = 2*m*sizeof(int); 1982 cv::AutoBuffer<uchar> inn_buf(base_size); 1983 if( !_ext_buf ) 1984 inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float))); 1985 uchar* base_buf = (uchar*)inn_buf; 1986 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size; 1987 float* values_buf = (float*)ext_buf; 1988 int* sorted_indices_buf = (int*)(values_buf + n); 1989 int* sample_indices_buf = sorted_indices_buf + n; 1990 const float* values = 0; 1991 const int* sorted_indices = 0; 1992 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, 1993 &sorted_indices, sample_indices_buf ); 1994 int* responses_buf = sample_indices_buf + n; 1995 const int* responses = data->get_class_labels( node, responses_buf ); 1996 1997 const int* rc0 = data->counts->data.i; 1998 int* lc = (int*)base_buf; 1999 int* rc = lc + m; 2000 int i, best_i = -1; 2001 double lsum2 = 0, rsum2 = 0, best_val = init_quality; 2002 const double* priors = data->have_priors ? data->priors_mult->data.db : 0; 2003 2004 // init arrays of class instance counters on both sides of the split 2005 for( i = 0; i < m; i++ ) 2006 { 2007 lc[i] = 0; 2008 rc[i] = rc0[i]; 2009 } 2010 2011 // compensate for missing values 2012 for( i = n1; i < n; i++ ) 2013 { 2014 rc[responses[sorted_indices[i]]]--; 2015 } 2016 2017 if( !priors ) 2018 { 2019 int L = 0, R = n1; 2020 2021 for( i = 0; i < m; i++ ) 2022 rsum2 += (double)rc[i]*rc[i]; 2023 2024 for( i = 0; i < n1 - 1; i++ ) 2025 { 2026 int idx = responses[sorted_indices[i]]; 2027 int lv, rv; 2028 L++; R--; 2029 lv = lc[idx]; rv = rc[idx]; 2030 lsum2 += lv*2 + 1; 2031 rsum2 -= rv*2 - 1; 2032 lc[idx] = lv + 1; rc[idx] = rv - 1; 2033 2034 if( values[i] + epsilon < values[i+1] ) 2035 { 2036 double val = (lsum2*R + rsum2*L)/((double)L*R); 2037 if( best_val < val ) 2038 { 2039 best_val = val; 2040 best_i = i; 2041 } 2042 } 2043 } 2044 } 2045 else 2046 { 2047 double L = 0, R = 0; 2048 for( i = 0; i < m; i++ ) 2049 { 2050 double wv = rc[i]*priors[i]; 2051 R += wv; 2052 rsum2 += wv*wv; 2053 } 2054 2055 for( i = 0; i < n1 - 1; i++ ) 2056 { 2057 int idx = responses[sorted_indices[i]]; 2058 int lv, rv; 2059 double p = priors[idx], p2 = p*p; 2060 L += p; R -= p; 2061 lv = lc[idx]; rv = rc[idx]; 2062 lsum2 += p2*(lv*2 + 1); 2063 rsum2 -= p2*(rv*2 - 1); 2064 lc[idx] = lv + 1; rc[idx] = rv - 1; 2065 2066 if( values[i] + epsilon < values[i+1] ) 2067 { 2068 double val = (lsum2*R + rsum2*L)/((double)L*R); 2069 if( best_val < val ) 2070 { 2071 best_val = val; 2072 best_i = i; 2073 } 2074 } 2075 } 2076 } 2077 2078 CvDTreeSplit* split = 0; 2079 if( best_i >= 0 ) 2080 { 2081 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f ); 2082 split->var_idx = vi; 2083 split->ord.c = (values[best_i] + values[best_i+1])*0.5f; 2084 split->ord.split_point = best_i; 2085 split->inversed = 0; 2086 split->quality = (float)best_val; 2087 } 2088 return split; 2089 } 2090 2091 2092 void CvDTree::cluster_categories( const int* vectors, int n, int m, 2093 int* csums, int k, int* labels ) 2094 { 2095 // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm 2096 int iters = 0, max_iters = 100; 2097 int i, j, idx; 2098 cv::AutoBuffer<double> buf(n + k); 2099 double *v_weights = buf, *c_weights = buf + n; 2100 bool modified = true; 2101 RNG* r = data->rng; 2102 2103 // assign labels randomly 2104 for( i = 0; i < n; i++ ) 2105 { 2106 int sum = 0; 2107 const int* v = vectors + i*m; 2108 labels[i] = i < k ? i : r->uniform(0, k); 2109 2110 // compute weight of each vector 2111 for( j = 0; j < m; j++ ) 2112 sum += v[j]; 2113 v_weights[i] = sum ? 1./sum : 0.; 2114 } 2115 2116 for( i = 0; i < n; i++ ) 2117 { 2118 int i1 = (*r)(n); 2119 int i2 = (*r)(n); 2120 CV_SWAP( labels[i1], labels[i2], j ); 2121 } 2122 2123 for( iters = 0; iters <= max_iters; iters++ ) 2124 { 2125 // calculate csums 2126 for( i = 0; i < k; i++ ) 2127 { 2128 for( j = 0; j < m; j++ ) 2129 csums[i*m + j] = 0; 2130 } 2131 2132 for( i = 0; i < n; i++ ) 2133 { 2134 const int* v = vectors + i*m; 2135 int* s = csums + labels[i]*m; 2136 for( j = 0; j < m; j++ ) 2137 s[j] += v[j]; 2138 } 2139 2140 // exit the loop here, when we have up-to-date csums 2141 if( iters == max_iters || !modified ) 2142 break; 2143 2144 modified = false; 2145 2146 // calculate weight of each cluster 2147 for( i = 0; i < k; i++ ) 2148 { 2149 const int* s = csums + i*m; 2150 int sum = 0; 2151 for( j = 0; j < m; j++ ) 2152 sum += s[j]; 2153 c_weights[i] = sum ? 1./sum : 0; 2154 } 2155 2156 // now for each vector determine the closest cluster 2157 for( i = 0; i < n; i++ ) 2158 { 2159 const int* v = vectors + i*m; 2160 double alpha = v_weights[i]; 2161 double min_dist2 = DBL_MAX; 2162 int min_idx = -1; 2163 2164 for( idx = 0; idx < k; idx++ ) 2165 { 2166 const int* s = csums + idx*m; 2167 double dist2 = 0., beta = c_weights[idx]; 2168 for( j = 0; j < m; j++ ) 2169 { 2170 double t = v[j]*alpha - s[j]*beta; 2171 dist2 += t*t; 2172 } 2173 if( min_dist2 > dist2 ) 2174 { 2175 min_dist2 = dist2; 2176 min_idx = idx; 2177 } 2178 } 2179 2180 if( min_idx != labels[i] ) 2181 modified = true; 2182 labels[i] = min_idx; 2183 } 2184 } 2185 } 2186 2187 2188 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality, 2189 CvDTreeSplit* _split, uchar* _ext_buf ) 2190 { 2191 int ci = data->get_var_type(vi); 2192 int n = node->sample_count; 2193 int m = data->get_num_classes(); 2194 int _mi = data->cat_count->data.i[ci], mi = _mi; 2195 2196 int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double); 2197 if( m > 2 && mi > data->params.max_categories ) 2198 base_size += (m*std::min(data->params.max_categories, n) + mi)*sizeof(int); 2199 else 2200 base_size += mi*sizeof(int*); 2201 cv::AutoBuffer<uchar> inn_buf(base_size); 2202 if( !_ext_buf ) 2203 inn_buf.allocate(base_size + 2*n*sizeof(int)); 2204 uchar* base_buf = (uchar*)inn_buf; 2205 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size; 2206 2207 int* lc = (int*)base_buf; 2208 int* rc = lc + m; 2209 int* _cjk = rc + m*2, *cjk = _cjk; 2210 double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double)); 2211 2212 int* labels_buf = (int*)ext_buf; 2213 const int* labels = data->get_cat_var_data(node, vi, labels_buf); 2214 int* responses_buf = labels_buf + n; 2215 const int* responses = data->get_class_labels(node, responses_buf); 2216 2217 int* cluster_labels = 0; 2218 int** int_ptr = 0; 2219 int i, j, k, idx; 2220 double L = 0, R = 0; 2221 double best_val = init_quality; 2222 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0; 2223 const double* priors = data->priors_mult->data.db; 2224 2225 // init array of counters: 2226 // c_{jk} - number of samples that have vi-th input variable = j and response = k. 2227 for( j = -1; j < mi; j++ ) 2228 for( k = 0; k < m; k++ ) 2229 cjk[j*m + k] = 0; 2230 2231 for( i = 0; i < n; i++ ) 2232 { 2233 j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i]; 2234 k = responses[i]; 2235 cjk[j*m + k]++; 2236 } 2237 2238 if( m > 2 ) 2239 { 2240 if( mi > data->params.max_categories ) 2241 { 2242 mi = MIN(data->params.max_categories, n); 2243 cjk = (int*)(c_weights + _mi); 2244 cluster_labels = cjk + m*mi; 2245 cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels ); 2246 } 2247 subset_i = 1; 2248 subset_n = 1 << mi; 2249 } 2250 else 2251 { 2252 assert( m == 2 ); 2253 int_ptr = (int**)(c_weights + _mi); 2254 for( j = 0; j < mi; j++ ) 2255 int_ptr[j] = cjk + j*2 + 1; 2256 std::sort(int_ptr, int_ptr + mi, LessThanPtr<int>()); 2257 subset_i = 0; 2258 subset_n = mi; 2259 } 2260 2261 for( k = 0; k < m; k++ ) 2262 { 2263 int sum = 0; 2264 for( j = 0; j < mi; j++ ) 2265 sum += cjk[j*m + k]; 2266 rc[k] = sum; 2267 lc[k] = 0; 2268 } 2269 2270 for( j = 0; j < mi; j++ ) 2271 { 2272 double sum = 0; 2273 for( k = 0; k < m; k++ ) 2274 sum += cjk[j*m + k]*priors[k]; 2275 c_weights[j] = sum; 2276 R += c_weights[j]; 2277 } 2278 2279 for( ; subset_i < subset_n; subset_i++ ) 2280 { 2281 double weight; 2282 int* crow; 2283 double lsum2 = 0, rsum2 = 0; 2284 2285 if( m == 2 ) 2286 idx = (int)(int_ptr[subset_i] - cjk)/2; 2287 else 2288 { 2289 int graycode = (subset_i>>1)^subset_i; 2290 int diff = graycode ^ prevcode; 2291 2292 // determine index of the changed bit. 2293 Cv32suf u; 2294 idx = diff >= (1 << 16) ? 16 : 0; 2295 u.f = (float)(((diff >> 16) | diff) & 65535); 2296 idx += (u.i >> 23) - 127; 2297 subtract = graycode < prevcode; 2298 prevcode = graycode; 2299 } 2300 2301 crow = cjk + idx*m; 2302 weight = c_weights[idx]; 2303 if( weight < FLT_EPSILON ) 2304 continue; 2305 2306 if( !subtract ) 2307 { 2308 for( k = 0; k < m; k++ ) 2309 { 2310 int t = crow[k]; 2311 int lval = lc[k] + t; 2312 int rval = rc[k] - t; 2313 double p = priors[k], p2 = p*p; 2314 lsum2 += p2*lval*lval; 2315 rsum2 += p2*rval*rval; 2316 lc[k] = lval; rc[k] = rval; 2317 } 2318 L += weight; 2319 R -= weight; 2320 } 2321 else 2322 { 2323 for( k = 0; k < m; k++ ) 2324 { 2325 int t = crow[k]; 2326 int lval = lc[k] - t; 2327 int rval = rc[k] + t; 2328 double p = priors[k], p2 = p*p; 2329 lsum2 += p2*lval*lval; 2330 rsum2 += p2*rval*rval; 2331 lc[k] = lval; rc[k] = rval; 2332 } 2333 L -= weight; 2334 R += weight; 2335 } 2336 2337 if( L > FLT_EPSILON && R > FLT_EPSILON ) 2338 { 2339 double val = (lsum2*R + rsum2*L)/((double)L*R); 2340 if( best_val < val ) 2341 { 2342 best_val = val; 2343 best_subset = subset_i; 2344 } 2345 } 2346 } 2347 2348 CvDTreeSplit* split = 0; 2349 if( best_subset >= 0 ) 2350 { 2351 split = _split ? _split : data->new_split_cat( 0, -1.0f ); 2352 split->var_idx = vi; 2353 split->quality = (float)best_val; 2354 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int)); 2355 if( m == 2 ) 2356 { 2357 for( i = 0; i <= best_subset; i++ ) 2358 { 2359 idx = (int)(int_ptr[i] - cjk) >> 1; 2360 split->subset[idx >> 5] |= 1 << (idx & 31); 2361 } 2362 } 2363 else 2364 { 2365 for( i = 0; i < _mi; i++ ) 2366 { 2367 idx = cluster_labels ? cluster_labels[i] : i; 2368 if( best_subset & (1 << idx) ) 2369 split->subset[i >> 5] |= 1 << (i & 31); 2370 } 2371 } 2372 } 2373 return split; 2374 } 2375 2376 2377 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf ) 2378 { 2379 const float epsilon = FLT_EPSILON*2; 2380 int n = node->sample_count; 2381 int n1 = node->get_num_valid(vi); 2382 2383 cv::AutoBuffer<uchar> inn_buf; 2384 if( !_ext_buf ) 2385 inn_buf.allocate(2*n*(sizeof(int) + sizeof(float))); 2386 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf; 2387 float* values_buf = (float*)ext_buf; 2388 int* sorted_indices_buf = (int*)(values_buf + n); 2389 int* sample_indices_buf = sorted_indices_buf + n; 2390 const float* values = 0; 2391 const int* sorted_indices = 0; 2392 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf ); 2393 float* responses_buf = (float*)(sample_indices_buf + n); 2394 const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf ); 2395 2396 int i, best_i = -1; 2397 double best_val = init_quality, lsum = 0, rsum = node->value*n; 2398 int L = 0, R = n1; 2399 2400 // compensate for missing values 2401 for( i = n1; i < n; i++ ) 2402 rsum -= responses[sorted_indices[i]]; 2403 2404 // find the optimal split 2405 for( i = 0; i < n1 - 1; i++ ) 2406 { 2407 float t = responses[sorted_indices[i]]; 2408 L++; R--; 2409 lsum += t; 2410 rsum -= t; 2411 2412 if( values[i] + epsilon < values[i+1] ) 2413 { 2414 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R); 2415 if( best_val < val ) 2416 { 2417 best_val = val; 2418 best_i = i; 2419 } 2420 } 2421 } 2422 2423 CvDTreeSplit* split = 0; 2424 if( best_i >= 0 ) 2425 { 2426 split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f ); 2427 split->var_idx = vi; 2428 split->ord.c = (values[best_i] + values[best_i+1])*0.5f; 2429 split->ord.split_point = best_i; 2430 split->inversed = 0; 2431 split->quality = (float)best_val; 2432 } 2433 return split; 2434 } 2435 2436 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf ) 2437 { 2438 int ci = data->get_var_type(vi); 2439 int n = node->sample_count; 2440 int mi = data->cat_count->data.i[ci]; 2441 2442 int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*)); 2443 cv::AutoBuffer<uchar> inn_buf(base_size); 2444 if( !_ext_buf ) 2445 inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float))); 2446 uchar* base_buf = (uchar*)inn_buf; 2447 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size; 2448 int* labels_buf = (int*)ext_buf; 2449 const int* labels = data->get_cat_var_data(node, vi, labels_buf); 2450 float* responses_buf = (float*)(labels_buf + n); 2451 int* sample_indices_buf = (int*)(responses_buf + n); 2452 const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf); 2453 2454 double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1; 2455 int* counts = (int*)(sum + mi) + 1; 2456 double** sum_ptr = (double**)(counts + mi); 2457 int i, L = 0, R = 0; 2458 double best_val = init_quality, lsum = 0, rsum = 0; 2459 int best_subset = -1, subset_i; 2460 2461 for( i = -1; i < mi; i++ ) 2462 sum[i] = counts[i] = 0; 2463 2464 // calculate sum response and weight of each category of the input var 2465 for( i = 0; i < n; i++ ) 2466 { 2467 int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i]; 2468 double s = sum[idx] + responses[i]; 2469 int nc = counts[idx] + 1; 2470 sum[idx] = s; 2471 counts[idx] = nc; 2472 } 2473 2474 // calculate average response in each category 2475 for( i = 0; i < mi; i++ ) 2476 { 2477 R += counts[i]; 2478 rsum += sum[i]; 2479 sum[i] /= MAX(counts[i],1); 2480 sum_ptr[i] = sum + i; 2481 } 2482 2483 std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>()); 2484 2485 // revert back to unnormalized sums 2486 // (there should be a very little loss of accuracy) 2487 for( i = 0; i < mi; i++ ) 2488 sum[i] *= counts[i]; 2489 2490 for( subset_i = 0; subset_i < mi-1; subset_i++ ) 2491 { 2492 int idx = (int)(sum_ptr[subset_i] - sum); 2493 int ni = counts[idx]; 2494 2495 if( ni ) 2496 { 2497 double s = sum[idx]; 2498 lsum += s; L += ni; 2499 rsum -= s; R -= ni; 2500 2501 if( L && R ) 2502 { 2503 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R); 2504 if( best_val < val ) 2505 { 2506 best_val = val; 2507 best_subset = subset_i; 2508 } 2509 } 2510 } 2511 } 2512 2513 CvDTreeSplit* split = 0; 2514 if( best_subset >= 0 ) 2515 { 2516 split = _split ? _split : data->new_split_cat( 0, -1.0f); 2517 split->var_idx = vi; 2518 split->quality = (float)best_val; 2519 memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int)); 2520 for( i = 0; i <= best_subset; i++ ) 2521 { 2522 int idx = (int)(sum_ptr[i] - sum); 2523 split->subset[idx >> 5] |= 1 << (idx & 31); 2524 } 2525 } 2526 return split; 2527 } 2528 2529 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf ) 2530 { 2531 const float epsilon = FLT_EPSILON*2; 2532 const char* dir = (char*)data->direction->data.ptr; 2533 int n = node->sample_count, n1 = node->get_num_valid(vi); 2534 cv::AutoBuffer<uchar> inn_buf; 2535 if( !_ext_buf ) 2536 inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) ); 2537 uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf; 2538 float* values_buf = (float*)ext_buf; 2539 int* sorted_indices_buf = (int*)(values_buf + n); 2540 int* sample_indices_buf = sorted_indices_buf + n; 2541 const float* values = 0; 2542 const int* sorted_indices = 0; 2543 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf ); 2544 // LL - number of samples that both the primary and the surrogate splits send to the left 2545 // LR - ... primary split sends to the left and the surrogate split sends to the right 2546 // RL - ... primary split sends to the right and the surrogate split sends to the left 2547 // RR - ... both send to the right 2548 int i, best_i = -1, best_inversed = 0; 2549 double best_val; 2550 2551 if( !data->have_priors ) 2552 { 2553 int LL = 0, RL = 0, LR, RR; 2554 int worst_val = cvFloor(node->maxlr), _best_val = worst_val; 2555 int sum = 0, sum_abs = 0; 2556 2557 for( i = 0; i < n1; i++ ) 2558 { 2559 int d = dir[sorted_indices[i]]; 2560 sum += d; sum_abs += d & 1; 2561 } 2562 2563 // sum_abs = R + L; sum = R - L 2564 RR = (sum_abs + sum) >> 1; 2565 LR = (sum_abs - sum) >> 1; 2566 2567 // initially all the samples are sent to the right by the surrogate split, 2568 // LR of them are sent to the left by primary split, and RR - to the right. 2569 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value. 2570 for( i = 0; i < n1 - 1; i++ ) 2571 { 2572 int d = dir[sorted_indices[i]]; 2573 2574 if( d < 0 ) 2575 { 2576 LL++; LR--; 2577 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] ) 2578 { 2579 best_val = LL + RR; 2580 best_i = i; best_inversed = 0; 2581 } 2582 } 2583 else if( d > 0 ) 2584 { 2585 RL++; RR--; 2586 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] ) 2587 { 2588 best_val = RL + LR; 2589 best_i = i; best_inversed = 1; 2590 } 2591 } 2592 } 2593 best_val = _best_val; 2594 } 2595 else 2596 { 2597 double LL = 0, RL = 0, LR, RR; 2598 double worst_val = node->maxlr; 2599 double sum = 0, sum_abs = 0; 2600 const double* priors = data->priors_mult->data.db; 2601 int* responses_buf = sample_indices_buf + n; 2602 const int* responses = data->get_class_labels(node, responses_buf); 2603 best_val = worst_val; 2604 2605 for( i = 0; i < n1; i++ ) 2606 { 2607 int idx = sorted_indices[i]; 2608 double w = priors[responses[idx]]; 2609 int d = dir[idx]; 2610 sum += d*w; sum_abs += (d & 1)*w; 2611 } 2612 2613 // sum_abs = R + L; sum = R - L 2614 RR = (sum_abs + sum)*0.5; 2615 LR = (sum_abs - sum)*0.5; 2616 2617 // initially all the samples are sent to the right by the surrogate split, 2618 // LR of them are sent to the left by primary split, and RR - to the right. 2619 // now iteratively compute LL, LR, RL and RR for every possible surrogate split value. 2620 for( i = 0; i < n1 - 1; i++ ) 2621 { 2622 int idx = sorted_indices[i]; 2623 double w = priors[responses[idx]]; 2624 int d = dir[idx]; 2625 2626 if( d < 0 ) 2627 { 2628 LL += w; LR -= w; 2629 if( LL + RR > best_val && values[i] + epsilon < values[i+1] ) 2630 { 2631 best_val = LL + RR; 2632 best_i = i; best_inversed = 0; 2633 } 2634 } 2635 else if( d > 0 ) 2636 { 2637 RL += w; RR -= w; 2638 if( RL + LR > best_val && values[i] + epsilon < values[i+1] ) 2639 { 2640 best_val = RL + LR; 2641 best_i = i; best_inversed = 1; 2642 } 2643 } 2644 } 2645 } 2646 return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi, 2647 (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0; 2648 } 2649 2650 2651 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf ) 2652 { 2653 const char* dir = (char*)data->direction->data.ptr; 2654 int n = node->sample_count; 2655 int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0; 2656 2657 int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0); 2658 cv::AutoBuffer<uchar> inn_buf(base_size); 2659 if( !_ext_buf ) 2660 inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0))); 2661 uchar* base_buf = (uchar*)inn_buf; 2662 uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size; 2663 2664 int* labels_buf = (int*)ext_buf; 2665 const int* labels = data->get_cat_var_data(node, vi, labels_buf); 2666 // LL - number of samples that both the primary and the surrogate splits send to the left 2667 // LR - ... primary split sends to the left and the surrogate split sends to the right 2668 // RL - ... primary split sends to the right and the surrogate split sends to the left 2669 // RR - ... both send to the right 2670 CvDTreeSplit* split = data->new_split_cat( vi, 0 ); 2671 double best_val = 0; 2672 double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1; 2673 double* rc = lc + mi + 1; 2674 2675 for( i = -1; i < mi; i++ ) 2676 lc[i] = rc[i] = 0; 2677 2678 // for each category calculate the weight of samples 2679 // sent to the left (lc) and to the right (rc) by the primary split 2680 if( !data->have_priors ) 2681 { 2682 int* _lc = (int*)rc + 1; 2683 int* _rc = _lc + mi + 1; 2684 2685 for( i = -1; i < mi; i++ ) 2686 _lc[i] = _rc[i] = 0; 2687 2688 for( i = 0; i < n; i++ ) 2689 { 2690 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i]; 2691 int d = dir[i]; 2692 int sum = _lc[idx] + d; 2693 int sum_abs = _rc[idx] + (d & 1); 2694 _lc[idx] = sum; _rc[idx] = sum_abs; 2695 } 2696 2697 for( i = 0; i < mi; i++ ) 2698 { 2699 int sum = _lc[i]; 2700 int sum_abs = _rc[i]; 2701 lc[i] = (sum_abs - sum) >> 1; 2702 rc[i] = (sum_abs + sum) >> 1; 2703 } 2704 } 2705 else 2706 { 2707 const double* priors = data->priors_mult->data.db; 2708 int* responses_buf = labels_buf + n; 2709 const int* responses = data->get_class_labels(node, responses_buf); 2710 2711 for( i = 0; i < n; i++ ) 2712 { 2713 int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i]; 2714 double w = priors[responses[i]]; 2715 int d = dir[i]; 2716 double sum = lc[idx] + d*w; 2717 double sum_abs = rc[idx] + (d & 1)*w; 2718 lc[idx] = sum; rc[idx] = sum_abs; 2719 } 2720 2721 for( i = 0; i < mi; i++ ) 2722 { 2723 double sum = lc[i]; 2724 double sum_abs = rc[i]; 2725 lc[i] = (sum_abs - sum) * 0.5; 2726 rc[i] = (sum_abs + sum) * 0.5; 2727 } 2728 } 2729 2730 // 2. now form the split. 2731 // in each category send all the samples to the same direction as majority 2732 for( i = 0; i < mi; i++ ) 2733 { 2734 double lval = lc[i], rval = rc[i]; 2735 if( lval > rval ) 2736 { 2737 split->subset[i >> 5] |= 1 << (i & 31); 2738 best_val += lval; 2739 l_win++; 2740 } 2741 else 2742 best_val += rval; 2743 } 2744 2745 split->quality = (float)best_val; 2746 if( split->quality <= node->maxlr || l_win == 0 || l_win == mi ) 2747 cvSetRemoveByPtr( data->split_heap, split ), split = 0; 2748 2749 return split; 2750 } 2751 2752 2753 void CvDTree::calc_node_value( CvDTreeNode* node ) 2754 { 2755 int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds; 2756 int m = data->get_num_classes(); 2757 2758 int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int); 2759 int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float))); 2760 cv::AutoBuffer<uchar> inn_buf(base_size + ext_size); 2761 uchar* base_buf = (uchar*)inn_buf; 2762 uchar* ext_buf = base_buf + base_size; 2763 2764 int* cv_labels_buf = (int*)ext_buf; 2765 const int* cv_labels = data->get_cv_labels(node, cv_labels_buf); 2766 2767 if( data->is_classifier ) 2768 { 2769 // in case of classification tree: 2770 // * node value is the label of the class that has the largest weight in the node. 2771 // * node risk is the weighted number of misclassified samples, 2772 // * j-th cross-validation fold value and risk are calculated as above, 2773 // but using the samples with cv_labels(*)!=j. 2774 // * j-th cross-validation fold error is calculated as the weighted number of 2775 // misclassified samples with cv_labels(*)==j. 2776 2777 // compute the number of instances of each class 2778 int* cls_count = data->counts->data.i; 2779 int* responses_buf = cv_labels_buf + n; 2780 const int* responses = data->get_class_labels(node, responses_buf); 2781 int* cv_cls_count = (int*)base_buf; 2782 double max_val = -1, total_weight = 0; 2783 int max_k = -1; 2784 double* priors = data->priors_mult->data.db; 2785 2786 for( k = 0; k < m; k++ ) 2787 cls_count[k] = 0; 2788 2789 if( cv_n == 0 ) 2790 { 2791 for( i = 0; i < n; i++ ) 2792 cls_count[responses[i]]++; 2793 } 2794 else 2795 { 2796 for( j = 0; j < cv_n; j++ ) 2797 for( k = 0; k < m; k++ ) 2798 cv_cls_count[j*m + k] = 0; 2799 2800 for( i = 0; i < n; i++ ) 2801 { 2802 j = cv_labels[i]; k = responses[i]; 2803 cv_cls_count[j*m + k]++; 2804 } 2805 2806 for( j = 0; j < cv_n; j++ ) 2807 for( k = 0; k < m; k++ ) 2808 cls_count[k] += cv_cls_count[j*m + k]; 2809 } 2810 2811 if( data->have_priors && node->parent == 0 ) 2812 { 2813 // compute priors_mult from priors, take the sample ratio into account. 2814 double sum = 0; 2815 for( k = 0; k < m; k++ ) 2816 { 2817 int n_k = cls_count[k]; 2818 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.); 2819 sum += priors[k]; 2820 } 2821 sum = 1./sum; 2822 for( k = 0; k < m; k++ ) 2823 priors[k] *= sum; 2824 } 2825 2826 for( k = 0; k < m; k++ ) 2827 { 2828 double val = cls_count[k]*priors[k]; 2829 total_weight += val; 2830 if( max_val < val ) 2831 { 2832 max_val = val; 2833 max_k = k; 2834 } 2835 } 2836 2837 node->class_idx = max_k; 2838 node->value = data->cat_map->data.i[ 2839 data->cat_ofs->data.i[data->cat_var_count] + max_k]; 2840 node->node_risk = total_weight - max_val; 2841 2842 for( j = 0; j < cv_n; j++ ) 2843 { 2844 double sum_k = 0, sum = 0, max_val_k = 0; 2845 max_val = -1; max_k = -1; 2846 2847 for( k = 0; k < m; k++ ) 2848 { 2849 double w = priors[k]; 2850 double val_k = cv_cls_count[j*m + k]*w; 2851 double val = cls_count[k]*w - val_k; 2852 sum_k += val_k; 2853 sum += val; 2854 if( max_val < val ) 2855 { 2856 max_val = val; 2857 max_val_k = val_k; 2858 max_k = k; 2859 } 2860 } 2861 2862 node->cv_Tn[j] = INT_MAX; 2863 node->cv_node_risk[j] = sum - max_val; 2864 node->cv_node_error[j] = sum_k - max_val_k; 2865 } 2866 } 2867 else 2868 { 2869 // in case of regression tree: 2870 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response, 2871 // n is the number of samples in the node. 2872 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2) 2873 // * j-th cross-validation fold value and risk are calculated as above, 2874 // but using the samples with cv_labels(*)!=j. 2875 // * j-th cross-validation fold error is calculated 2876 // using samples with cv_labels(*)==j as the test subset: 2877 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2), 2878 // where node_value_j is the node value calculated 2879 // as described in the previous bullet, and summation is done 2880 // over the samples with cv_labels(*)==j. 2881 2882 double sum = 0, sum2 = 0; 2883 float* values_buf = (float*)(cv_labels_buf + n); 2884 int* sample_indices_buf = (int*)(values_buf + n); 2885 const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf); 2886 double *cv_sum = 0, *cv_sum2 = 0; 2887 int* cv_count = 0; 2888 2889 if( cv_n == 0 ) 2890 { 2891 for( i = 0; i < n; i++ ) 2892 { 2893 double t = values[i]; 2894 sum += t; 2895 sum2 += t*t; 2896 } 2897 } 2898 else 2899 { 2900 cv_sum = (double*)base_buf; 2901 cv_sum2 = cv_sum + cv_n; 2902 cv_count = (int*)(cv_sum2 + cv_n); 2903 2904 for( j = 0; j < cv_n; j++ ) 2905 { 2906 cv_sum[j] = cv_sum2[j] = 0.; 2907 cv_count[j] = 0; 2908 } 2909 2910 for( i = 0; i < n; i++ ) 2911 { 2912 j = cv_labels[i]; 2913 double t = values[i]; 2914 double s = cv_sum[j] + t; 2915 double s2 = cv_sum2[j] + t*t; 2916 int nc = cv_count[j] + 1; 2917 cv_sum[j] = s; 2918 cv_sum2[j] = s2; 2919 cv_count[j] = nc; 2920 } 2921 2922 for( j = 0; j < cv_n; j++ ) 2923 { 2924 sum += cv_sum[j]; 2925 sum2 += cv_sum2[j]; 2926 } 2927 } 2928 2929 node->node_risk = sum2 - (sum/n)*sum; 2930 node->value = sum/n; 2931 2932 for( j = 0; j < cv_n; j++ ) 2933 { 2934 double s = cv_sum[j], si = sum - s; 2935 double s2 = cv_sum2[j], s2i = sum2 - s2; 2936 int c = cv_count[j], ci = n - c; 2937 double r = si/MAX(ci,1); 2938 node->cv_node_risk[j] = s2i - r*r*ci; 2939 node->cv_node_error[j] = s2 - 2*r*s + c*r*r; 2940 node->cv_Tn[j] = INT_MAX; 2941 } 2942 } 2943 } 2944 2945 2946 void CvDTree::complete_node_dir( CvDTreeNode* node ) 2947 { 2948 int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1; 2949 int nz = n - node->get_num_valid(node->split->var_idx); 2950 char* dir = (char*)data->direction->data.ptr; 2951 2952 // try to complete direction using surrogate splits 2953 if( nz && data->params.use_surrogates ) 2954 { 2955 cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float))); 2956 CvDTreeSplit* split = node->split->next; 2957 for( ; split != 0 && nz; split = split->next ) 2958 { 2959 int inversed_mask = split->inversed ? -1 : 0; 2960 vi = split->var_idx; 2961 2962 if( data->get_var_type(vi) >= 0 ) // split on categorical var 2963 { 2964 int* labels_buf = (int*)(uchar*)inn_buf; 2965 const int* labels = data->get_cat_var_data(node, vi, labels_buf); 2966 const int* subset = split->subset; 2967 2968 for( i = 0; i < n; i++ ) 2969 { 2970 int idx = labels[i]; 2971 if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) )) 2972 2973 { 2974 int d = CV_DTREE_CAT_DIR(idx,subset); 2975 dir[i] = (char)((d ^ inversed_mask) - inversed_mask); 2976 if( --nz ) 2977 break; 2978 } 2979 } 2980 } 2981 else // split on ordered var 2982 { 2983 float* values_buf = (float*)(uchar*)inn_buf; 2984 int* sorted_indices_buf = (int*)(values_buf + n); 2985 int* sample_indices_buf = sorted_indices_buf + n; 2986 const float* values = 0; 2987 const int* sorted_indices = 0; 2988 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf ); 2989 int split_point = split->ord.split_point; 2990 int n1 = node->get_num_valid(vi); 2991 2992 assert( 0 <= split_point && split_point < n-1 ); 2993 2994 for( i = 0; i < n1; i++ ) 2995 { 2996 int idx = sorted_indices[i]; 2997 if( !dir[idx] ) 2998 { 2999 int d = i <= split_point ? -1 : 1; 3000 dir[idx] = (char)((d ^ inversed_mask) - inversed_mask); 3001 if( --nz ) 3002 break; 3003 } 3004 } 3005 } 3006 } 3007 } 3008 3009 // find the default direction for the rest 3010 if( nz ) 3011 { 3012 for( i = nr = 0; i < n; i++ ) 3013 nr += dir[i] > 0; 3014 nl = n - nr - nz; 3015 d0 = nl > nr ? -1 : nr > nl; 3016 } 3017 3018 // make sure that every sample is directed either to the left or to the right 3019 for( i = 0; i < n; i++ ) 3020 { 3021 int d = dir[i]; 3022 if( !d ) 3023 { 3024 d = d0; 3025 if( !d ) 3026 d = d1, d1 = -d1; 3027 } 3028 d = d > 0; 3029 dir[i] = (char)d; // remap (-1,1) to (0,1) 3030 } 3031 } 3032 3033 3034 void CvDTree::split_node_data( CvDTreeNode* node ) 3035 { 3036 int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count; 3037 char* dir = (char*)data->direction->data.ptr; 3038 CvDTreeNode *left = 0, *right = 0; 3039 int* new_idx = data->split_buf->data.i; 3040 int new_buf_idx = data->get_child_buf_idx( node ); 3041 int work_var_count = data->get_work_var_count(); 3042 CvMat* buf = data->buf; 3043 size_t length_buf_row = data->get_length_subbuf(); 3044 cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float))); 3045 int* temp_buf = (int*)(uchar*)inn_buf; 3046 3047 complete_node_dir(node); 3048 3049 for( i = nl = nr = 0; i < n; i++ ) 3050 { 3051 int d = dir[i]; 3052 // initialize new indices for splitting ordered variables 3053 new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li 3054 nr += d; 3055 nl += d^1; 3056 } 3057 3058 bool split_input_data; 3059 node->left = left = data->new_node( node, nl, new_buf_idx, node->offset ); 3060 node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl ); 3061 3062 split_input_data = node->depth + 1 < data->params.max_depth && 3063 (node->left->sample_count > data->params.min_sample_count || 3064 node->right->sample_count > data->params.min_sample_count); 3065 3066 // split ordered variables, keep both halves sorted. 3067 for( vi = 0; vi < data->var_count; vi++ ) 3068 { 3069 int ci = data->get_var_type(vi); 3070 3071 if( ci >= 0 || !split_input_data ) 3072 continue; 3073 3074 int n1 = node->get_num_valid(vi); 3075 float* src_val_buf = (float*)(uchar*)(temp_buf + n); 3076 int* src_sorted_idx_buf = (int*)(src_val_buf + n); 3077 int* src_sample_idx_buf = src_sorted_idx_buf + n; 3078 const float* src_val = 0; 3079 const int* src_sorted_idx = 0; 3080 data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf); 3081 3082 for(i = 0; i < n; i++) 3083 temp_buf[i] = src_sorted_idx[i]; 3084 3085 if (data->is_buf_16u) 3086 { 3087 unsigned short *ldst, *rdst, *ldst0, *rdst0; 3088 //unsigned short tl, tr; 3089 ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row + 3090 vi*scount + left->offset); 3091 rdst0 = rdst = (unsigned short*)(ldst + nl); 3092 3093 // split sorted 3094 for( i = 0; i < n1; i++ ) 3095 { 3096 int idx = temp_buf[i]; 3097 int d = dir[idx]; 3098 idx = new_idx[idx]; 3099 if (d) 3100 { 3101 *rdst = (unsigned short)idx; 3102 rdst++; 3103 } 3104 else 3105 { 3106 *ldst = (unsigned short)idx; 3107 ldst++; 3108 } 3109 } 3110 3111 left->set_num_valid(vi, (int)(ldst - ldst0)); 3112 right->set_num_valid(vi, (int)(rdst - rdst0)); 3113 3114 // split missing 3115 for( ; i < n; i++ ) 3116 { 3117 int idx = temp_buf[i]; 3118 int d = dir[idx]; 3119 idx = new_idx[idx]; 3120 if (d) 3121 { 3122 *rdst = (unsigned short)idx; 3123 rdst++; 3124 } 3125 else 3126 { 3127 *ldst = (unsigned short)idx; 3128 ldst++; 3129 } 3130 } 3131 } 3132 else 3133 { 3134 int *ldst0, *ldst, *rdst0, *rdst; 3135 ldst0 = ldst = buf->data.i + left->buf_idx*length_buf_row + 3136 vi*scount + left->offset; 3137 rdst0 = rdst = buf->data.i + right->buf_idx*length_buf_row + 3138 vi*scount + right->offset; 3139 3140 // split sorted 3141 for( i = 0; i < n1; i++ ) 3142 { 3143 int idx = temp_buf[i]; 3144 int d = dir[idx]; 3145 idx = new_idx[idx]; 3146 if (d) 3147 { 3148 *rdst = idx; 3149 rdst++; 3150 } 3151 else 3152 { 3153 *ldst = idx; 3154 ldst++; 3155 } 3156 } 3157 3158 left->set_num_valid(vi, (int)(ldst - ldst0)); 3159 right->set_num_valid(vi, (int)(rdst - rdst0)); 3160 3161 // split missing 3162 for( ; i < n; i++ ) 3163 { 3164 int idx = temp_buf[i]; 3165 int d = dir[idx]; 3166 idx = new_idx[idx]; 3167 if (d) 3168 { 3169 *rdst = idx; 3170 rdst++; 3171 } 3172 else 3173 { 3174 *ldst = idx; 3175 ldst++; 3176 } 3177 } 3178 } 3179 } 3180 3181 // split categorical vars, responses and cv_labels using new_idx relocation table 3182 for( vi = 0; vi < work_var_count; vi++ ) 3183 { 3184 int ci = data->get_var_type(vi); 3185 int n1 = node->get_num_valid(vi), nr1 = 0; 3186 3187 if( ci < 0 || (vi < data->var_count && !split_input_data) ) 3188 continue; 3189 3190 int *src_lbls_buf = temp_buf + n; 3191 const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf); 3192 3193 for(i = 0; i < n; i++) 3194 temp_buf[i] = src_lbls[i]; 3195 3196 if (data->is_buf_16u) 3197 { 3198 unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row + 3199 vi*scount + left->offset); 3200 unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row + 3201 vi*scount + right->offset); 3202 3203 for( i = 0; i < n; i++ ) 3204 { 3205 int d = dir[i]; 3206 int idx = temp_buf[i]; 3207 if (d) 3208 { 3209 *rdst = (unsigned short)idx; 3210 rdst++; 3211 nr1 += (idx != 65535 )&d; 3212 } 3213 else 3214 { 3215 *ldst = (unsigned short)idx; 3216 ldst++; 3217 } 3218 } 3219 3220 if( vi < data->var_count ) 3221 { 3222 left->set_num_valid(vi, n1 - nr1); 3223 right->set_num_valid(vi, nr1); 3224 } 3225 } 3226 else 3227 { 3228 int *ldst = buf->data.i + left->buf_idx*length_buf_row + 3229 vi*scount + left->offset; 3230 int *rdst = buf->data.i + right->buf_idx*length_buf_row + 3231 vi*scount + right->offset; 3232 3233 for( i = 0; i < n; i++ ) 3234 { 3235 int d = dir[i]; 3236 int idx = temp_buf[i]; 3237 if (d) 3238 { 3239 *rdst = idx; 3240 rdst++; 3241 nr1 += (idx >= 0)&d; 3242 } 3243 else 3244 { 3245 *ldst = idx; 3246 ldst++; 3247 } 3248 3249 } 3250 3251 if( vi < data->var_count ) 3252 { 3253 left->set_num_valid(vi, n1 - nr1); 3254 right->set_num_valid(vi, nr1); 3255 } 3256 } 3257 } 3258 3259 3260 // split sample indices 3261 int *sample_idx_src_buf = temp_buf + n; 3262 const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf); 3263 3264 for(i = 0; i < n; i++) 3265 temp_buf[i] = sample_idx_src[i]; 3266 3267 int pos = data->get_work_var_count(); 3268 if (data->is_buf_16u) 3269 { 3270 unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row + 3271 pos*scount + left->offset); 3272 unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row + 3273 pos*scount + right->offset); 3274 for (i = 0; i < n; i++) 3275 { 3276 int d = dir[i]; 3277 unsigned short idx = (unsigned short)temp_buf[i]; 3278 if (d) 3279 { 3280 *rdst = idx; 3281 rdst++; 3282 } 3283 else 3284 { 3285 *ldst = idx; 3286 ldst++; 3287 } 3288 } 3289 } 3290 else 3291 { 3292 int* ldst = buf->data.i + left->buf_idx*length_buf_row + 3293 pos*scount + left->offset; 3294 int* rdst = buf->data.i + right->buf_idx*length_buf_row + 3295 pos*scount + right->offset; 3296 for (i = 0; i < n; i++) 3297 { 3298 int d = dir[i]; 3299 int idx = temp_buf[i]; 3300 if (d) 3301 { 3302 *rdst = idx; 3303 rdst++; 3304 } 3305 else 3306 { 3307 *ldst = idx; 3308 ldst++; 3309 } 3310 } 3311 } 3312 3313 // deallocate the parent node data that is not needed anymore 3314 data->free_node_data(node); 3315 } 3316 3317 float CvDTree::calc_error( CvMLData* _data, int type, std::vector<float> *resp ) 3318 { 3319 float err = 0; 3320 const CvMat* values = _data->get_values(); 3321 const CvMat* response = _data->get_responses(); 3322 const CvMat* missing = _data->get_missing(); 3323 const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx(); 3324 const CvMat* var_types = _data->get_var_types(); 3325 int* sidx = sample_idx ? sample_idx->data.i : 0; 3326 int r_step = CV_IS_MAT_CONT(response->type) ? 3327 1 : response->step / CV_ELEM_SIZE(response->type); 3328 bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL; 3329 int sample_count = sample_idx ? sample_idx->cols : 0; 3330 sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count; 3331 float* pred_resp = 0; 3332 if( resp && (sample_count > 0) ) 3333 { 3334 resp->resize( sample_count ); 3335 pred_resp = &((*resp)[0]); 3336 } 3337 3338 if ( is_classifier ) 3339 { 3340 for( int i = 0; i < sample_count; i++ ) 3341 { 3342 CvMat sample, miss; 3343 int si = sidx ? sidx[i] : i; 3344 cvGetRow( values, &sample, si ); 3345 if( missing ) 3346 cvGetRow( missing, &miss, si ); 3347 float r = (float)predict( &sample, missing ? &miss : 0 )->value; 3348 if( pred_resp ) 3349 pred_resp[i] = r; 3350 int d = fabs((double)r - response->data.fl[(size_t)si*r_step]) <= FLT_EPSILON ? 0 : 1; 3351 err += d; 3352 } 3353 err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX; 3354 } 3355 else 3356 { 3357 for( int i = 0; i < sample_count; i++ ) 3358 { 3359 CvMat sample, miss; 3360 int si = sidx ? sidx[i] : i; 3361 cvGetRow( values, &sample, si ); 3362 if( missing ) 3363 cvGetRow( missing, &miss, si ); 3364 float r = (float)predict( &sample, missing ? &miss : 0 )->value; 3365 if( pred_resp ) 3366 pred_resp[i] = r; 3367 float d = r - response->data.fl[(size_t)si*r_step]; 3368 err += d*d; 3369 } 3370 err = sample_count ? err / (float)sample_count : -FLT_MAX; 3371 } 3372 return err; 3373 } 3374 3375 void CvDTree::prune_cv() 3376 { 3377 CvMat* ab = 0; 3378 CvMat* temp = 0; 3379 CvMat* err_jk = 0; 3380 3381 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}. 3382 // 2. choose the best tree index (if need, apply 1SE rule). 3383 // 3. store the best index and cut the branches. 3384 3385 CV_FUNCNAME( "CvDTree::prune_cv" ); 3386 3387 __BEGIN__; 3388 3389 int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count; 3390 // currently, 1SE for regression is not implemented 3391 bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier; 3392 double* err; 3393 double min_err = 0, min_err_se = 0; 3394 int min_idx = -1; 3395 3396 CV_CALL( ab = cvCreateMat( 1, 256, CV_64F )); 3397 3398 // build the main tree sequence, calculate alpha's 3399 for(;;tree_count++) 3400 { 3401 double min_alpha = update_tree_rnc(tree_count, -1); 3402 if( cut_tree(tree_count, -1, min_alpha) ) 3403 break; 3404 3405 if( ab->cols <= tree_count ) 3406 { 3407 CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F )); 3408 for( ti = 0; ti < ab->cols; ti++ ) 3409 temp->data.db[ti] = ab->data.db[ti]; 3410 cvReleaseMat( &ab ); 3411 ab = temp; 3412 temp = 0; 3413 } 3414 3415 ab->data.db[tree_count] = min_alpha; 3416 } 3417 3418 ab->data.db[0] = 0.; 3419 3420 if( tree_count > 0 ) 3421 { 3422 for( ti = 1; ti < tree_count-1; ti++ ) 3423 ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]); 3424 ab->data.db[tree_count-1] = DBL_MAX*0.5; 3425 3426 CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F )); 3427 err = err_jk->data.db; 3428 3429 for( j = 0; j < cv_n; j++ ) 3430 { 3431 int tj = 0, tk = 0; 3432 for( ; tk < tree_count; tj++ ) 3433 { 3434 double min_alpha = update_tree_rnc(tj, j); 3435 if( cut_tree(tj, j, min_alpha) ) 3436 min_alpha = DBL_MAX; 3437 3438 for( ; tk < tree_count; tk++ ) 3439 { 3440 if( ab->data.db[tk] > min_alpha ) 3441 break; 3442 err[j*tree_count + tk] = root->tree_error; 3443 } 3444 } 3445 } 3446 3447 for( ti = 0; ti < tree_count; ti++ ) 3448 { 3449 double sum_err = 0; 3450 for( j = 0; j < cv_n; j++ ) 3451 sum_err += err[j*tree_count + ti]; 3452 if( ti == 0 || sum_err < min_err ) 3453 { 3454 min_err = sum_err; 3455 min_idx = ti; 3456 if( use_1se ) 3457 min_err_se = sqrt( sum_err*(n - sum_err) ); 3458 } 3459 else if( sum_err < min_err + min_err_se ) 3460 min_idx = ti; 3461 } 3462 } 3463 3464 pruned_tree_idx = min_idx; 3465 free_prune_data(data->params.truncate_pruned_tree != 0); 3466 3467 __END__; 3468 3469 cvReleaseMat( &err_jk ); 3470 cvReleaseMat( &ab ); 3471 cvReleaseMat( &temp ); 3472 } 3473 3474 3475 double CvDTree::update_tree_rnc( int T, int fold ) 3476 { 3477 CvDTreeNode* node = root; 3478 double min_alpha = DBL_MAX; 3479 3480 for(;;) 3481 { 3482 CvDTreeNode* parent; 3483 for(;;) 3484 { 3485 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn; 3486 if( t <= T || !node->left ) 3487 { 3488 node->complexity = 1; 3489 node->tree_risk = node->node_risk; 3490 node->tree_error = 0.; 3491 if( fold >= 0 ) 3492 { 3493 node->tree_risk = node->cv_node_risk[fold]; 3494 node->tree_error = node->cv_node_error[fold]; 3495 } 3496 break; 3497 } 3498 node = node->left; 3499 } 3500 3501 for( parent = node->parent; parent && parent->right == node; 3502 node = parent, parent = parent->parent ) 3503 { 3504 parent->complexity += node->complexity; 3505 parent->tree_risk += node->tree_risk; 3506 parent->tree_error += node->tree_error; 3507 3508 parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk) 3509 - parent->tree_risk)/(parent->complexity - 1); 3510 min_alpha = MIN( min_alpha, parent->alpha ); 3511 } 3512 3513 if( !parent ) 3514 break; 3515 3516 parent->complexity = node->complexity; 3517 parent->tree_risk = node->tree_risk; 3518 parent->tree_error = node->tree_error; 3519 node = parent->right; 3520 } 3521 3522 return min_alpha; 3523 } 3524 3525 3526 int CvDTree::cut_tree( int T, int fold, double min_alpha ) 3527 { 3528 CvDTreeNode* node = root; 3529 if( !node->left ) 3530 return 1; 3531 3532 for(;;) 3533 { 3534 CvDTreeNode* parent; 3535 for(;;) 3536 { 3537 int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn; 3538 if( t <= T || !node->left ) 3539 break; 3540 if( node->alpha <= min_alpha + FLT_EPSILON ) 3541 { 3542 if( fold >= 0 ) 3543 node->cv_Tn[fold] = T; 3544 else 3545 node->Tn = T; 3546 if( node == root ) 3547 return 1; 3548 break; 3549 } 3550 node = node->left; 3551 } 3552 3553 for( parent = node->parent; parent && parent->right == node; 3554 node = parent, parent = parent->parent ) 3555 ; 3556 3557 if( !parent ) 3558 break; 3559 3560 node = parent->right; 3561 } 3562 3563 return 0; 3564 } 3565 3566 3567 void CvDTree::free_prune_data(bool _cut_tree) 3568 { 3569 CvDTreeNode* node = root; 3570 3571 for(;;) 3572 { 3573 CvDTreeNode* parent; 3574 for(;;) 3575 { 3576 // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn ) 3577 // as we will clear the whole cross-validation heap at the end 3578 node->cv_Tn = 0; 3579 node->cv_node_error = node->cv_node_risk = 0; 3580 if( !node->left ) 3581 break; 3582 node = node->left; 3583 } 3584 3585 for( parent = node->parent; parent && parent->right == node; 3586 node = parent, parent = parent->parent ) 3587 { 3588 if( _cut_tree && parent->Tn <= pruned_tree_idx ) 3589 { 3590 data->free_node( parent->left ); 3591 data->free_node( parent->right ); 3592 parent->left = parent->right = 0; 3593 } 3594 } 3595 3596 if( !parent ) 3597 break; 3598 3599 node = parent->right; 3600 } 3601 3602 if( data->cv_heap ) 3603 cvClearSet( data->cv_heap ); 3604 } 3605 3606 3607 void CvDTree::free_tree() 3608 { 3609 if( root && data && data->shared ) 3610 { 3611 pruned_tree_idx = INT_MIN; 3612 free_prune_data(true); 3613 data->free_node(root); 3614 root = 0; 3615 } 3616 } 3617 3618 CvDTreeNode* CvDTree::predict( const CvMat* _sample, 3619 const CvMat* _missing, bool preprocessed_input ) const 3620 { 3621 cv::AutoBuffer<int> catbuf; 3622 3623 int i, mstep = 0; 3624 const uchar* m = 0; 3625 CvDTreeNode* node = root; 3626 3627 if( !node ) 3628 CV_Error( CV_StsError, "The tree has not been trained yet" ); 3629 3630 if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 || 3631 (_sample->cols != 1 && _sample->rows != 1) || 3632 (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) || 3633 (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) ) 3634 CV_Error( CV_StsBadArg, 3635 "the input sample must be 1d floating-point vector with the same " 3636 "number of elements as the total number of variables used for training" ); 3637 3638 const float* sample = _sample->data.fl; 3639 int step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]); 3640 3641 if( data->cat_count && !preprocessed_input ) // cache for categorical variables 3642 { 3643 int n = data->cat_count->cols; 3644 catbuf.allocate(n); 3645 for( i = 0; i < n; i++ ) 3646 catbuf[i] = -1; 3647 } 3648 3649 if( _missing ) 3650 { 3651 if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) || 3652 !CV_ARE_SIZES_EQ(_missing, _sample) ) 3653 CV_Error( CV_StsBadArg, 3654 "the missing data mask must be 8-bit vector of the same size as input sample" ); 3655 m = _missing->data.ptr; 3656 mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]); 3657 } 3658 3659 const int* vtype = data->var_type->data.i; 3660 const int* vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0; 3661 const int* cmap = data->cat_map ? data->cat_map->data.i : 0; 3662 const int* cofs = data->cat_ofs ? data->cat_ofs->data.i : 0; 3663 3664 while( node->Tn > pruned_tree_idx && node->left ) 3665 { 3666 CvDTreeSplit* split = node->split; 3667 int dir = 0; 3668 for( ; !dir && split != 0; split = split->next ) 3669 { 3670 int vi = split->var_idx; 3671 int ci = vtype[vi]; 3672 i = vidx ? vidx[vi] : vi; 3673 float val = sample[(size_t)i*step]; 3674 if( m && m[(size_t)i*mstep] ) 3675 continue; 3676 if( ci < 0 ) // ordered 3677 dir = val <= split->ord.c ? -1 : 1; 3678 else // categorical 3679 { 3680 int c; 3681 if( preprocessed_input ) 3682 c = cvRound(val); 3683 else 3684 { 3685 c = catbuf[ci]; 3686 if( c < 0 ) 3687 { 3688 int a = c = cofs[ci]; 3689 int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1]; 3690 3691 int ival = cvRound(val); 3692 if( ival != val ) 3693 CV_Error( CV_StsBadArg, 3694 "one of input categorical variable is not an integer" ); 3695 3696 int sh = 0; 3697 while( a < b ) 3698 { 3699 sh++; 3700 c = (a + b) >> 1; 3701 if( ival < cmap[c] ) 3702 b = c; 3703 else if( ival > cmap[c] ) 3704 a = c+1; 3705 else 3706 break; 3707 } 3708 3709 if( c < 0 || ival != cmap[c] ) 3710 continue; 3711 3712 catbuf[ci] = c -= cofs[ci]; 3713 } 3714 } 3715 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c; 3716 dir = CV_DTREE_CAT_DIR(c, split->subset); 3717 } 3718 3719 if( split->inversed ) 3720 dir = -dir; 3721 } 3722 3723 if( !dir ) 3724 { 3725 double diff = node->right->sample_count - node->left->sample_count; 3726 dir = diff < 0 ? -1 : 1; 3727 } 3728 node = dir < 0 ? node->left : node->right; 3729 } 3730 3731 return node; 3732 } 3733 3734 3735 CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const 3736 { 3737 CvMat sample = _sample, mmask = _missing; 3738 return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input); 3739 } 3740 3741 3742 const CvMat* CvDTree::get_var_importance() 3743 { 3744 if( !var_importance ) 3745 { 3746 CvDTreeNode* node = root; 3747 double* importance; 3748 if( !node ) 3749 return 0; 3750 var_importance = cvCreateMat( 1, data->var_count, CV_64F ); 3751 cvZero( var_importance ); 3752 importance = var_importance->data.db; 3753 3754 for(;;) 3755 { 3756 CvDTreeNode* parent; 3757 for( ;; node = node->left ) 3758 { 3759 CvDTreeSplit* split = node->split; 3760 3761 if( !node->left || node->Tn <= pruned_tree_idx ) 3762 break; 3763 3764 for( ; split != 0; split = split->next ) 3765 importance[split->var_idx] += split->quality; 3766 } 3767 3768 for( parent = node->parent; parent && parent->right == node; 3769 node = parent, parent = parent->parent ) 3770 ; 3771 3772 if( !parent ) 3773 break; 3774 3775 node = parent->right; 3776 } 3777 3778 cvNormalize( var_importance, var_importance, 1., 0, CV_L1 ); 3779 } 3780 3781 return var_importance; 3782 } 3783 3784 3785 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const 3786 { 3787 int ci; 3788 3789 cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW ); 3790 cvWriteInt( fs, "var", split->var_idx ); 3791 cvWriteReal( fs, "quality", split->quality ); 3792 3793 ci = data->get_var_type(split->var_idx); 3794 if( ci >= 0 ) // split on a categorical var 3795 { 3796 int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir; 3797 for( i = 0; i < n; i++ ) 3798 to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0; 3799 3800 // ad-hoc rule when to use inverse categorical split notation 3801 // to achieve more compact and clear representation 3802 default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1; 3803 3804 cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ? 3805 "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW ); 3806 3807 for( i = 0; i < n; i++ ) 3808 { 3809 int dir = CV_DTREE_CAT_DIR(i,split->subset); 3810 if( dir*default_dir < 0 ) 3811 cvWriteInt( fs, 0, i ); 3812 } 3813 cvEndWriteStruct( fs ); 3814 } 3815 else 3816 cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c ); 3817 3818 cvEndWriteStruct( fs ); 3819 } 3820 3821 3822 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const 3823 { 3824 CvDTreeSplit* split; 3825 3826 cvStartWriteStruct( fs, 0, CV_NODE_MAP ); 3827 3828 cvWriteInt( fs, "depth", node->depth ); 3829 cvWriteInt( fs, "sample_count", node->sample_count ); 3830 cvWriteReal( fs, "value", node->value ); 3831 3832 if( data->is_classifier ) 3833 cvWriteInt( fs, "norm_class_idx", node->class_idx ); 3834 3835 cvWriteInt( fs, "Tn", node->Tn ); 3836 cvWriteInt( fs, "complexity", node->complexity ); 3837 cvWriteReal( fs, "alpha", node->alpha ); 3838 cvWriteReal( fs, "node_risk", node->node_risk ); 3839 cvWriteReal( fs, "tree_risk", node->tree_risk ); 3840 cvWriteReal( fs, "tree_error", node->tree_error ); 3841 3842 if( node->left ) 3843 { 3844 cvStartWriteStruct( fs, "splits", CV_NODE_SEQ ); 3845 3846 for( split = node->split; split != 0; split = split->next ) 3847 write_split( fs, split ); 3848 3849 cvEndWriteStruct( fs ); 3850 } 3851 3852 cvEndWriteStruct( fs ); 3853 } 3854 3855 3856 void CvDTree::write_tree_nodes( CvFileStorage* fs ) const 3857 { 3858 //CV_FUNCNAME( "CvDTree::write_tree_nodes" ); 3859 3860 __BEGIN__; 3861 3862 CvDTreeNode* node = root; 3863 3864 // traverse the tree and save all the nodes in depth-first order 3865 for(;;) 3866 { 3867 CvDTreeNode* parent; 3868 for(;;) 3869 { 3870 write_node( fs, node ); 3871 if( !node->left ) 3872 break; 3873 node = node->left; 3874 } 3875 3876 for( parent = node->parent; parent && parent->right == node; 3877 node = parent, parent = parent->parent ) 3878 ; 3879 3880 if( !parent ) 3881 break; 3882 3883 node = parent->right; 3884 } 3885 3886 __END__; 3887 } 3888 3889 3890 void CvDTree::write( CvFileStorage* fs, const char* name ) const 3891 { 3892 //CV_FUNCNAME( "CvDTree::write" ); 3893 3894 __BEGIN__; 3895 3896 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE ); 3897 3898 //get_var_importance(); 3899 data->write_params( fs ); 3900 //if( var_importance ) 3901 //cvWrite( fs, "var_importance", var_importance ); 3902 write( fs ); 3903 3904 cvEndWriteStruct( fs ); 3905 3906 __END__; 3907 } 3908 3909 3910 void CvDTree::write( CvFileStorage* fs ) const 3911 { 3912 //CV_FUNCNAME( "CvDTree::write" ); 3913 3914 __BEGIN__; 3915 3916 cvWriteInt( fs, "best_tree_idx", pruned_tree_idx ); 3917 3918 cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ ); 3919 write_tree_nodes( fs ); 3920 cvEndWriteStruct( fs ); 3921 3922 __END__; 3923 } 3924 3925 3926 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode ) 3927 { 3928 CvDTreeSplit* split = 0; 3929 3930 CV_FUNCNAME( "CvDTree::read_split" ); 3931 3932 __BEGIN__; 3933 3934 int vi, ci; 3935 3936 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP ) 3937 CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" ); 3938 3939 vi = cvReadIntByName( fs, fnode, "var", -1 ); 3940 if( (unsigned)vi >= (unsigned)data->var_count ) 3941 CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" ); 3942 3943 ci = data->get_var_type(vi); 3944 if( ci >= 0 ) // split on categorical var 3945 { 3946 int i, n = data->cat_count->data.i[ci], inversed = 0, val; 3947 CvSeqReader reader; 3948 CvFileNode* inseq; 3949 split = data->new_split_cat( vi, 0 ); 3950 inseq = cvGetFileNodeByName( fs, fnode, "in" ); 3951 if( !inseq ) 3952 { 3953 inseq = cvGetFileNodeByName( fs, fnode, "not_in" ); 3954 inversed = 1; 3955 } 3956 if( !inseq || 3957 (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT)) 3958 CV_ERROR( CV_StsParseError, 3959 "Either 'in' or 'not_in' tags should be inside a categorical split data" ); 3960 3961 if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT ) 3962 { 3963 val = inseq->data.i; 3964 if( (unsigned)val >= (unsigned)n ) 3965 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" ); 3966 3967 split->subset[val >> 5] |= 1 << (val & 31); 3968 } 3969 else 3970 { 3971 cvStartReadSeq( inseq->data.seq, &reader ); 3972 3973 for( i = 0; i < reader.seq->total; i++ ) 3974 { 3975 CvFileNode* inode = (CvFileNode*)reader.ptr; 3976 val = inode->data.i; 3977 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n ) 3978 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" ); 3979 3980 split->subset[val >> 5] |= 1 << (val & 31); 3981 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 3982 } 3983 } 3984 3985 // for categorical splits we do not use inversed splits, 3986 // instead we inverse the variable set in the split 3987 if( inversed ) 3988 for( i = 0; i < (n + 31) >> 5; i++ ) 3989 split->subset[i] ^= -1; 3990 } 3991 else 3992 { 3993 CvFileNode* cmp_node; 3994 split = data->new_split_ord( vi, 0, 0, 0, 0 ); 3995 3996 cmp_node = cvGetFileNodeByName( fs, fnode, "le" ); 3997 if( !cmp_node ) 3998 { 3999 cmp_node = cvGetFileNodeByName( fs, fnode, "gt" ); 4000 split->inversed = 1; 4001 } 4002 4003 split->ord.c = (float)cvReadReal( cmp_node ); 4004 } 4005 4006 split->quality = (float)cvReadRealByName( fs, fnode, "quality" ); 4007 4008 __END__; 4009 4010 return split; 4011 } 4012 4013 4014 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent ) 4015 { 4016 CvDTreeNode* node = 0; 4017 4018 CV_FUNCNAME( "CvDTree::read_node" ); 4019 4020 __BEGIN__; 4021 4022 CvFileNode* splits; 4023 int i, depth; 4024 4025 if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP ) 4026 CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" ); 4027 4028 CV_CALL( node = data->new_node( parent, 0, 0, 0 )); 4029 depth = cvReadIntByName( fs, fnode, "depth", -1 ); 4030 if( depth != node->depth ) 4031 CV_ERROR( CV_StsParseError, "incorrect node depth" ); 4032 4033 node->sample_count = cvReadIntByName( fs, fnode, "sample_count" ); 4034 node->value = cvReadRealByName( fs, fnode, "value" ); 4035 if( data->is_classifier ) 4036 node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" ); 4037 4038 node->Tn = cvReadIntByName( fs, fnode, "Tn" ); 4039 node->complexity = cvReadIntByName( fs, fnode, "complexity" ); 4040 node->alpha = cvReadRealByName( fs, fnode, "alpha" ); 4041 node->node_risk = cvReadRealByName( fs, fnode, "node_risk" ); 4042 node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" ); 4043 node->tree_error = cvReadRealByName( fs, fnode, "tree_error" ); 4044 4045 splits = cvGetFileNodeByName( fs, fnode, "splits" ); 4046 if( splits ) 4047 { 4048 CvSeqReader reader; 4049 CvDTreeSplit* last_split = 0; 4050 4051 if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ ) 4052 CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" ); 4053 4054 cvStartReadSeq( splits->data.seq, &reader ); 4055 for( i = 0; i < reader.seq->total; i++ ) 4056 { 4057 CvDTreeSplit* split; 4058 CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr )); 4059 if( !last_split ) 4060 node->split = last_split = split; 4061 else 4062 last_split = last_split->next = split; 4063 4064 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 4065 } 4066 } 4067 4068 __END__; 4069 4070 return node; 4071 } 4072 4073 4074 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode ) 4075 { 4076 CV_FUNCNAME( "CvDTree::read_tree_nodes" ); 4077 4078 __BEGIN__; 4079 4080 CvSeqReader reader; 4081 CvDTreeNode _root; 4082 CvDTreeNode* parent = &_root; 4083 int i; 4084 parent->left = parent->right = parent->parent = 0; 4085 4086 cvStartReadSeq( fnode->data.seq, &reader ); 4087 4088 for( i = 0; i < reader.seq->total; i++ ) 4089 { 4090 CvDTreeNode* node; 4091 4092 CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 )); 4093 if( !parent->left ) 4094 parent->left = node; 4095 else 4096 parent->right = node; 4097 if( node->split ) 4098 parent = node; 4099 else 4100 { 4101 while( parent && parent->right ) 4102 parent = parent->parent; 4103 } 4104 4105 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader ); 4106 } 4107 4108 root = _root.left; 4109 4110 __END__; 4111 } 4112 4113 4114 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode ) 4115 { 4116 CvDTreeTrainData* _data = new CvDTreeTrainData(); 4117 _data->read_params( fs, fnode ); 4118 4119 read( fs, fnode, _data ); 4120 get_var_importance(); 4121 } 4122 4123 4124 // a special entry point for reading weak decision trees from the tree ensembles 4125 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data ) 4126 { 4127 CV_FUNCNAME( "CvDTree::read" ); 4128 4129 __BEGIN__; 4130 4131 CvFileNode* tree_nodes; 4132 4133 clear(); 4134 data = _data; 4135 4136 tree_nodes = cvGetFileNodeByName( fs, node, "nodes" ); 4137 if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ ) 4138 CV_ERROR( CV_StsParseError, "nodes tag is missing" ); 4139 4140 pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 ); 4141 read_tree_nodes( fs, tree_nodes ); 4142 4143 __END__; 4144 } 4145 4146 Mat CvDTree::getVarImportance() 4147 { 4148 return cvarrToMat(get_var_importance()); 4149 } 4150 4151 /* End of file. */ 4152