1 /*M/////////////////////////////////////////////////////////////////////////////////////// 2 // 3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4 // 5 // By downloading, copying, installing or using the software you agree to this license. 6 // If you do not agree to this license, do not download, install, 7 // copy or use the software. 8 // 9 // 10 // License Agreement 11 // For Open Source Computer Vision Library 12 // 13 // Copyright (C) 2000, Intel Corporation, all rights reserved. 14 // Copyright (C) 2014, Itseez Inc, all rights reserved. 15 // Third party copyrights are property of their respective owners. 16 // 17 // Redistribution and use in source and binary forms, with or without modification, 18 // are permitted provided that the following conditions are met: 19 // 20 // * Redistribution's of source code must retain the above copyright notice, 21 // this list of conditions and the following disclaimer. 22 // 23 // * Redistribution's in binary form must reproduce the above copyright notice, 24 // this list of conditions and the following disclaimer in the documentation 25 // and/or other materials provided with the distribution. 26 // 27 // * The name of the copyright holders may not be used to endorse or promote products 28 // derived from this software without specific prior written permission. 29 // 30 // This software is provided by the copyright holders and contributors "as is" and 31 // any express or implied warranties, including, but not limited to, the implied 32 // warranties of merchantability and fitness for a particular purpose are disclaimed. 33 // In no event shall the Intel Corporation or contributors be liable for any direct, 34 // indirect, incidental, special, exemplary, or consequential damages 35 // (including, but not limited to, procurement of substitute goods or services; 36 // loss of use, data, or profits; or business interruption) however caused 37 // and on any theory of liability, whether in contract, strict liability, 38 // or tort (including negligence or otherwise) arising in any way out of 39 // the use of this software, even if advised of the possibility of such damage. 40 // 41 //M*/ 42 43 #include "precomp.hpp" 44 #include <ctype.h> 45 46 namespace cv { 47 namespace ml { 48 49 using std::vector; 50 51 TreeParams::TreeParams() 52 { 53 maxDepth = INT_MAX; 54 minSampleCount = 10; 55 regressionAccuracy = 0.01f; 56 useSurrogates = false; 57 maxCategories = 10; 58 CVFolds = 10; 59 use1SERule = true; 60 truncatePrunedTree = true; 61 priors = Mat(); 62 } 63 64 TreeParams::TreeParams(int _maxDepth, int _minSampleCount, 65 double _regressionAccuracy, bool _useSurrogates, 66 int _maxCategories, int _CVFolds, 67 bool _use1SERule, bool _truncatePrunedTree, 68 const Mat& _priors) 69 { 70 maxDepth = _maxDepth; 71 minSampleCount = _minSampleCount; 72 regressionAccuracy = (float)_regressionAccuracy; 73 useSurrogates = _useSurrogates; 74 maxCategories = _maxCategories; 75 CVFolds = _CVFolds; 76 use1SERule = _use1SERule; 77 truncatePrunedTree = _truncatePrunedTree; 78 priors = _priors; 79 } 80 81 DTrees::Node::Node() 82 { 83 classIdx = 0; 84 value = 0; 85 parent = left = right = split = defaultDir = -1; 86 } 87 88 DTrees::Split::Split() 89 { 90 varIdx = 0; 91 inversed = false; 92 quality = 0.f; 93 next = -1; 94 c = 0.f; 95 subsetOfs = 0; 96 } 97 98 99 DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data) 100 { 101 data = _data; 102 vector<int> subsampleIdx; 103 Mat sidx0 = _data->getTrainSampleIdx(); 104 if( !sidx0.empty() ) 105 { 106 sidx0.copyTo(sidx); 107 std::sort(sidx.begin(), sidx.end()); 108 } 109 else 110 { 111 int n = _data->getNSamples(); 112 setRangeVector(sidx, n); 113 } 114 115 maxSubsetSize = 0; 116 } 117 118 DTreesImpl::DTreesImpl() {} 119 DTreesImpl::~DTreesImpl() {} 120 void DTreesImpl::clear() 121 { 122 varIdx.clear(); 123 compVarIdx.clear(); 124 varType.clear(); 125 catOfs.clear(); 126 catMap.clear(); 127 roots.clear(); 128 nodes.clear(); 129 splits.clear(); 130 subsets.clear(); 131 classLabels.clear(); 132 133 w.release(); 134 _isClassifier = false; 135 } 136 137 void DTreesImpl::startTraining( const Ptr<TrainData>& data, int ) 138 { 139 clear(); 140 w = makePtr<WorkData>(data); 141 142 Mat vtype = data->getVarType(); 143 vtype.copyTo(varType); 144 145 data->getCatOfs().copyTo(catOfs); 146 data->getCatMap().copyTo(catMap); 147 data->getDefaultSubstValues().copyTo(missingSubst); 148 149 int nallvars = data->getNAllVars(); 150 151 Mat vidx0 = data->getVarIdx(); 152 if( !vidx0.empty() ) 153 vidx0.copyTo(varIdx); 154 else 155 setRangeVector(varIdx, nallvars); 156 157 initCompVarIdx(); 158 159 w->maxSubsetSize = 0; 160 161 int i, nvars = (int)varIdx.size(); 162 for( i = 0; i < nvars; i++ ) 163 w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i])); 164 165 w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1); 166 167 data->getSampleWeights().copyTo(w->sample_weights); 168 169 _isClassifier = data->getResponseType() == VAR_CATEGORICAL; 170 171 if( _isClassifier ) 172 { 173 data->getNormCatResponses().copyTo(w->cat_responses); 174 data->getClassLabels().copyTo(classLabels); 175 int nclasses = (int)classLabels.size(); 176 177 Mat class_weights = params.priors; 178 if( !class_weights.empty() ) 179 { 180 if( class_weights.type() != CV_64F || !class_weights.isContinuous() ) 181 { 182 Mat temp; 183 class_weights.convertTo(temp, CV_64F); 184 class_weights = temp; 185 } 186 CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses ); 187 188 int nsamples = (int)w->cat_responses.size(); 189 const double* cw = class_weights.ptr<double>(); 190 CV_Assert( (int)w->sample_weights.size() == nsamples ); 191 192 for( i = 0; i < nsamples; i++ ) 193 { 194 int ci = w->cat_responses[i]; 195 CV_Assert( 0 <= ci && ci < nclasses ); 196 w->sample_weights[i] *= cw[ci]; 197 } 198 } 199 } 200 else 201 data->getResponses().copyTo(w->ord_responses); 202 } 203 204 205 void DTreesImpl::initCompVarIdx() 206 { 207 int nallvars = (int)varType.size(); 208 compVarIdx.assign(nallvars, -1); 209 int i, nvars = (int)varIdx.size(), prevIdx = -1; 210 for( i = 0; i < nvars; i++ ) 211 { 212 int vi = varIdx[i]; 213 CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx ); 214 prevIdx = vi; 215 compVarIdx[vi] = i; 216 } 217 } 218 219 void DTreesImpl::endTraining() 220 { 221 w.release(); 222 } 223 224 bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags ) 225 { 226 startTraining(trainData, flags); 227 bool ok = addTree( w->sidx ) >= 0; 228 w.release(); 229 endTraining(); 230 return ok; 231 } 232 233 const vector<int>& DTreesImpl::getActiveVars() 234 { 235 return varIdx; 236 } 237 238 int DTreesImpl::addTree(const vector<int>& sidx ) 239 { 240 size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size(); 241 242 w->wnodes.reserve(n); 243 w->wsplits.reserve(n); 244 w->wsubsets.reserve(n*w->maxSubsetSize); 245 w->wnodes.clear(); 246 w->wsplits.clear(); 247 w->wsubsets.clear(); 248 249 int cv_n = params.getCVFolds(); 250 251 if( cv_n > 0 ) 252 { 253 w->cv_Tn.resize(n*cv_n); 254 w->cv_node_error.resize(n*cv_n); 255 w->cv_node_risk.resize(n*cv_n); 256 } 257 258 // build the tree recursively 259 int w_root = addNodeAndTrySplit(-1, sidx); 260 int maxdepth = INT_MAX;//pruneCV(root); 261 262 int w_nidx = w_root, pidx = -1, depth = 0; 263 int root = (int)nodes.size(); 264 265 for(;;) 266 { 267 const WNode& wnode = w->wnodes[w_nidx]; 268 Node node; 269 node.parent = pidx; 270 node.classIdx = wnode.class_idx; 271 node.value = wnode.value; 272 node.defaultDir = wnode.defaultDir; 273 274 int wsplit_idx = wnode.split; 275 if( wsplit_idx >= 0 ) 276 { 277 const WSplit& wsplit = w->wsplits[wsplit_idx]; 278 Split split; 279 split.c = wsplit.c; 280 split.quality = wsplit.quality; 281 split.inversed = wsplit.inversed; 282 split.varIdx = wsplit.varIdx; 283 split.subsetOfs = -1; 284 if( wsplit.subsetOfs >= 0 ) 285 { 286 int ssize = getSubsetSize(split.varIdx); 287 split.subsetOfs = (int)subsets.size(); 288 subsets.resize(split.subsetOfs + ssize); 289 // This check verifies that subsets index is in the correct range 290 // as in case ssize == 0 no real resize performed. 291 // Thus memory kept safe. 292 // Also this skips useless memcpy call when size parameter is zero 293 if(ssize > 0) 294 { 295 memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int)); 296 } 297 } 298 node.split = (int)splits.size(); 299 splits.push_back(split); 300 } 301 int nidx = (int)nodes.size(); 302 nodes.push_back(node); 303 if( pidx >= 0 ) 304 { 305 int w_pidx = w->wnodes[w_nidx].parent; 306 if( w->wnodes[w_pidx].left == w_nidx ) 307 { 308 nodes[pidx].left = nidx; 309 } 310 else 311 { 312 CV_Assert(w->wnodes[w_pidx].right == w_nidx); 313 nodes[pidx].right = nidx; 314 } 315 } 316 317 if( wnode.left >= 0 && depth+1 < maxdepth ) 318 { 319 w_nidx = wnode.left; 320 pidx = nidx; 321 depth++; 322 } 323 else 324 { 325 int w_pidx = wnode.parent; 326 while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx ) 327 { 328 w_nidx = w_pidx; 329 w_pidx = w->wnodes[w_pidx].parent; 330 nidx = pidx; 331 pidx = nodes[pidx].parent; 332 depth--; 333 } 334 335 if( w_pidx < 0 ) 336 break; 337 338 w_nidx = w->wnodes[w_pidx].right; 339 CV_Assert( w_nidx >= 0 ); 340 } 341 } 342 roots.push_back(root); 343 return root; 344 } 345 346 void DTreesImpl::setDParams(const TreeParams& _params) 347 { 348 params = _params; 349 } 350 351 int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx ) 352 { 353 w->wnodes.push_back(WNode()); 354 int nidx = (int)(w->wnodes.size() - 1); 355 WNode& node = w->wnodes.back(); 356 357 node.parent = parent; 358 node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0; 359 int nfolds = params.getCVFolds(); 360 361 if( nfolds > 0 ) 362 { 363 w->cv_Tn.resize((nidx+1)*nfolds); 364 w->cv_node_error.resize((nidx+1)*nfolds); 365 w->cv_node_risk.resize((nidx+1)*nfolds); 366 } 367 368 int i, n = node.sample_count = (int)sidx.size(); 369 bool can_split = true; 370 vector<int> sleft, sright; 371 372 calcValue( nidx, sidx ); 373 374 if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() ) 375 can_split = false; 376 else if( _isClassifier ) 377 { 378 const int* responses = &w->cat_responses[0]; 379 const int* s = &sidx[0]; 380 int first = responses[s[0]]; 381 for( i = 1; i < n; i++ ) 382 if( responses[s[i]] != first ) 383 break; 384 if( i == n ) 385 can_split = false; 386 } 387 else 388 { 389 if( sqrt(node.node_risk) < params.getRegressionAccuracy() ) 390 can_split = false; 391 } 392 393 if( can_split ) 394 node.split = findBestSplit( sidx ); 395 396 //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk); 397 398 if( node.split >= 0 ) 399 { 400 node.defaultDir = calcDir( node.split, sidx, sleft, sright ); 401 if( params.useSurrogates ) 402 CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet"); 403 404 int left = addNodeAndTrySplit( nidx, sleft ); 405 int right = addNodeAndTrySplit( nidx, sright ); 406 w->wnodes[nidx].left = left; 407 w->wnodes[nidx].right = right; 408 CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 ); 409 } 410 411 return nidx; 412 } 413 414 int DTreesImpl::findBestSplit( const vector<int>& _sidx ) 415 { 416 const vector<int>& activeVars = getActiveVars(); 417 int splitidx = -1; 418 int vi_, nv = (int)activeVars.size(); 419 AutoBuffer<int> buf(w->maxSubsetSize*2); 420 int *subset = buf, *best_subset = subset + w->maxSubsetSize; 421 WSplit split, best_split; 422 best_split.quality = 0.; 423 424 for( vi_ = 0; vi_ < nv; vi_++ ) 425 { 426 int vi = activeVars[vi_]; 427 if( varType[vi] == VAR_CATEGORICAL ) 428 { 429 if( _isClassifier ) 430 split = findSplitCatClass(vi, _sidx, 0, subset); 431 else 432 split = findSplitCatReg(vi, _sidx, 0, subset); 433 } 434 else 435 { 436 if( _isClassifier ) 437 split = findSplitOrdClass(vi, _sidx, 0); 438 else 439 split = findSplitOrdReg(vi, _sidx, 0); 440 } 441 if( split.quality > best_split.quality ) 442 { 443 best_split = split; 444 std::swap(subset, best_subset); 445 } 446 } 447 448 if( best_split.quality > 0 ) 449 { 450 int best_vi = best_split.varIdx; 451 CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 ); 452 int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi); 453 w->wsubsets.resize(prevsz + ssize); 454 for( i = 0; i < ssize; i++ ) 455 w->wsubsets[prevsz + i] = best_subset[i]; 456 best_split.subsetOfs = prevsz; 457 w->wsplits.push_back(best_split); 458 splitidx = (int)(w->wsplits.size()-1); 459 } 460 461 return splitidx; 462 } 463 464 void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx ) 465 { 466 WNode* node = &w->wnodes[nidx]; 467 int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds(); 468 int m = (int)classLabels.size(); 469 470 cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1)); 471 472 if( cv_n > 0 ) 473 { 474 size_t sz = w->cv_Tn.size(); 475 w->cv_Tn.resize(sz + cv_n); 476 w->cv_node_risk.resize(sz + cv_n); 477 w->cv_node_error.resize(sz + cv_n); 478 } 479 480 if( _isClassifier ) 481 { 482 // in case of classification tree: 483 // * node value is the label of the class that has the largest weight in the node. 484 // * node risk is the weighted number of misclassified samples, 485 // * j-th cross-validation fold value and risk are calculated as above, 486 // but using the samples with cv_labels(*)!=j. 487 // * j-th cross-validation fold error is calculated as the weighted number of 488 // misclassified samples with cv_labels(*)==j. 489 490 // compute the number of instances of each class 491 double* cls_count = buf; 492 double* cv_cls_count = cls_count + m; 493 494 double max_val = -1, total_weight = 0; 495 int max_k = -1; 496 497 for( k = 0; k < m; k++ ) 498 cls_count[k] = 0; 499 500 if( cv_n == 0 ) 501 { 502 for( i = 0; i < n; i++ ) 503 { 504 int si = _sidx[i]; 505 cls_count[w->cat_responses[si]] += w->sample_weights[si]; 506 } 507 } 508 else 509 { 510 for( j = 0; j < cv_n; j++ ) 511 for( k = 0; k < m; k++ ) 512 cv_cls_count[j*m + k] = 0; 513 514 for( i = 0; i < n; i++ ) 515 { 516 int si = _sidx[i]; 517 j = w->cv_labels[si]; k = w->cat_responses[si]; 518 cv_cls_count[j*m + k] += w->sample_weights[si]; 519 } 520 521 for( j = 0; j < cv_n; j++ ) 522 for( k = 0; k < m; k++ ) 523 cls_count[k] += cv_cls_count[j*m + k]; 524 } 525 526 for( k = 0; k < m; k++ ) 527 { 528 double val = cls_count[k]; 529 total_weight += val; 530 if( max_val < val ) 531 { 532 max_val = val; 533 max_k = k; 534 } 535 } 536 537 node->class_idx = max_k; 538 node->value = classLabels[max_k]; 539 node->node_risk = total_weight - max_val; 540 541 for( j = 0; j < cv_n; j++ ) 542 { 543 double sum_k = 0, sum = 0, max_val_k = 0; 544 max_val = -1; max_k = -1; 545 546 for( k = 0; k < m; k++ ) 547 { 548 double val_k = cv_cls_count[j*m + k]; 549 double val = cls_count[k] - val_k; 550 sum_k += val_k; 551 sum += val; 552 if( max_val < val ) 553 { 554 max_val = val; 555 max_val_k = val_k; 556 max_k = k; 557 } 558 } 559 560 w->cv_Tn[nidx*cv_n + j] = INT_MAX; 561 w->cv_node_risk[nidx*cv_n + j] = sum - max_val; 562 w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k; 563 } 564 } 565 else 566 { 567 // in case of regression tree: 568 // * node value is 1/n*sum_i(Y_i), where Y_i is i-th response, 569 // n is the number of samples in the node. 570 // * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2) 571 // * j-th cross-validation fold value and risk are calculated as above, 572 // but using the samples with cv_labels(*)!=j. 573 // * j-th cross-validation fold error is calculated 574 // using samples with cv_labels(*)==j as the test subset: 575 // error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2), 576 // where node_value_j is the node value calculated 577 // as described in the previous bullet, and summation is done 578 // over the samples with cv_labels(*)==j. 579 double sum = 0, sum2 = 0, sumw = 0; 580 581 if( cv_n == 0 ) 582 { 583 for( i = 0; i < n; i++ ) 584 { 585 int si = _sidx[i]; 586 double wval = w->sample_weights[si]; 587 double t = w->ord_responses[si]; 588 sum += t*wval; 589 sum2 += t*t*wval; 590 sumw += wval; 591 } 592 } 593 else 594 { 595 double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n; 596 double* cv_count = (double*)(cv_sum2 + cv_n); 597 598 for( j = 0; j < cv_n; j++ ) 599 { 600 cv_sum[j] = cv_sum2[j] = 0.; 601 cv_count[j] = 0; 602 } 603 604 for( i = 0; i < n; i++ ) 605 { 606 int si = _sidx[i]; 607 j = w->cv_labels[si]; 608 double wval = w->sample_weights[si]; 609 double t = w->ord_responses[si]; 610 cv_sum[j] += t*wval; 611 cv_sum2[j] += t*t*wval; 612 cv_count[j] += wval; 613 } 614 615 for( j = 0; j < cv_n; j++ ) 616 { 617 sum += cv_sum[j]; 618 sum2 += cv_sum2[j]; 619 sumw += cv_count[j]; 620 } 621 622 for( j = 0; j < cv_n; j++ ) 623 { 624 double s = sum - cv_sum[j], si = sum - s; 625 double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2; 626 double c = cv_count[j], ci = sumw - c; 627 double r = si/std::max(ci, DBL_EPSILON); 628 w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci; 629 w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r; 630 w->cv_Tn[nidx*cv_n + j] = INT_MAX; 631 } 632 } 633 634 node->node_risk = sum2 - (sum/sumw)*sum; 635 node->value = sum/sumw; 636 } 637 } 638 639 DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality ) 640 { 641 const double epsilon = FLT_EPSILON*2; 642 int n = (int)_sidx.size(); 643 int m = (int)classLabels.size(); 644 645 cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double)); 646 const int* sidx = &_sidx[0]; 647 const int* responses = &w->cat_responses[0]; 648 const double* weights = &w->sample_weights[0]; 649 double* lcw = (double*)(uchar*)buf; 650 double* rcw = lcw + m; 651 float* values = (float*)(rcw + m); 652 int* sorted_idx = (int*)(values + n); 653 int i, best_i = -1; 654 double best_val = initQuality; 655 656 for( i = 0; i < m; i++ ) 657 lcw[i] = rcw[i] = 0.; 658 659 w->data->getValues( vi, _sidx, values ); 660 661 for( i = 0; i < n; i++ ) 662 { 663 sorted_idx[i] = i; 664 int si = sidx[i]; 665 rcw[responses[si]] += weights[si]; 666 } 667 668 std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values)); 669 670 double L = 0, R = 0, lsum2 = 0, rsum2 = 0; 671 for( i = 0; i < m; i++ ) 672 { 673 double wval = rcw[i]; 674 R += wval; 675 rsum2 += wval*wval; 676 } 677 678 for( i = 0; i < n - 1; i++ ) 679 { 680 int curr = sorted_idx[i]; 681 int next = sorted_idx[i+1]; 682 int si = sidx[curr]; 683 double wval = weights[si], w2 = wval*wval; 684 L += wval; R -= wval; 685 int idx = responses[si]; 686 double lv = lcw[idx], rv = rcw[idx]; 687 lsum2 += 2*lv*wval + w2; 688 rsum2 -= 2*rv*wval - w2; 689 lcw[idx] = lv + wval; rcw[idx] = rv - wval; 690 691 if( values[curr] + epsilon < values[next] ) 692 { 693 double val = (lsum2*R + rsum2*L)/(L*R); 694 if( best_val < val ) 695 { 696 best_val = val; 697 best_i = i; 698 } 699 } 700 } 701 702 WSplit split; 703 if( best_i >= 0 ) 704 { 705 split.varIdx = vi; 706 split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f; 707 split.inversed = false; 708 split.quality = (float)best_val; 709 } 710 return split; 711 } 712 713 // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector. 714 void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels ) 715 { 716 int iters = 0, max_iters = 100; 717 int i, j, idx; 718 cv::AutoBuffer<double> buf(n + k); 719 double *v_weights = buf, *c_weights = buf + n; 720 bool modified = true; 721 RNG r((uint64)-1); 722 723 // assign labels randomly 724 for( i = 0; i < n; i++ ) 725 { 726 double sum = 0; 727 const double* v = vectors + i*m; 728 labels[i] = i < k ? i : r.uniform(0, k); 729 730 // compute weight of each vector 731 for( j = 0; j < m; j++ ) 732 sum += v[j]; 733 v_weights[i] = sum ? 1./sum : 0.; 734 } 735 736 for( i = 0; i < n; i++ ) 737 { 738 int i1 = r.uniform(0, n); 739 int i2 = r.uniform(0, n); 740 std::swap( labels[i1], labels[i2] ); 741 } 742 743 for( iters = 0; iters <= max_iters; iters++ ) 744 { 745 // calculate csums 746 for( i = 0; i < k; i++ ) 747 { 748 for( j = 0; j < m; j++ ) 749 csums[i*m + j] = 0; 750 } 751 752 for( i = 0; i < n; i++ ) 753 { 754 const double* v = vectors + i*m; 755 double* s = csums + labels[i]*m; 756 for( j = 0; j < m; j++ ) 757 s[j] += v[j]; 758 } 759 760 // exit the loop here, when we have up-to-date csums 761 if( iters == max_iters || !modified ) 762 break; 763 764 modified = false; 765 766 // calculate weight of each cluster 767 for( i = 0; i < k; i++ ) 768 { 769 const double* s = csums + i*m; 770 double sum = 0; 771 for( j = 0; j < m; j++ ) 772 sum += s[j]; 773 c_weights[i] = sum ? 1./sum : 0; 774 } 775 776 // now for each vector determine the closest cluster 777 for( i = 0; i < n; i++ ) 778 { 779 const double* v = vectors + i*m; 780 double alpha = v_weights[i]; 781 double min_dist2 = DBL_MAX; 782 int min_idx = -1; 783 784 for( idx = 0; idx < k; idx++ ) 785 { 786 const double* s = csums + idx*m; 787 double dist2 = 0., beta = c_weights[idx]; 788 for( j = 0; j < m; j++ ) 789 { 790 double t = v[j]*alpha - s[j]*beta; 791 dist2 += t*t; 792 } 793 if( min_dist2 > dist2 ) 794 { 795 min_dist2 = dist2; 796 min_idx = idx; 797 } 798 } 799 800 if( min_idx != labels[i] ) 801 modified = true; 802 labels[i] = min_idx; 803 } 804 } 805 } 806 807 DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx, 808 double initQuality, int* subset ) 809 { 810 int _mi = getCatCount(vi), mi = _mi; 811 int n = (int)_sidx.size(); 812 int m = (int)classLabels.size(); 813 814 int base_size = m*(3 + mi) + mi + 1; 815 if( m > 2 && mi > params.getMaxCategories() ) 816 base_size += m*std::min(params.getMaxCategories(), n) + mi; 817 else 818 base_size += mi; 819 AutoBuffer<double> buf(base_size + n); 820 821 double* lc = (double*)buf; 822 double* rc = lc + m; 823 double* _cjk = rc + m*2, *cjk = _cjk; 824 double* c_weights = cjk + m*mi; 825 826 int* labels = (int*)(buf + base_size); 827 w->data->getNormCatValues(vi, _sidx, labels); 828 const int* responses = &w->cat_responses[0]; 829 const double* weights = &w->sample_weights[0]; 830 831 int* cluster_labels = 0; 832 double** dbl_ptr = 0; 833 int i, j, k, si, idx; 834 double L = 0, R = 0; 835 double best_val = initQuality; 836 int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0; 837 838 // init array of counters: 839 // c_{jk} - number of samples that have vi-th input variable = j and response = k. 840 for( j = -1; j < mi; j++ ) 841 for( k = 0; k < m; k++ ) 842 cjk[j*m + k] = 0; 843 844 for( i = 0; i < n; i++ ) 845 { 846 si = _sidx[i]; 847 j = labels[i]; 848 k = responses[si]; 849 cjk[j*m + k] += weights[si]; 850 } 851 852 if( m > 2 ) 853 { 854 if( mi > params.getMaxCategories() ) 855 { 856 mi = std::min(params.getMaxCategories(), n); 857 cjk = c_weights + _mi; 858 cluster_labels = (int*)(cjk + m*mi); 859 clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels ); 860 } 861 subset_i = 1; 862 subset_n = 1 << mi; 863 } 864 else 865 { 866 assert( m == 2 ); 867 dbl_ptr = (double**)(c_weights + _mi); 868 for( j = 0; j < mi; j++ ) 869 dbl_ptr[j] = cjk + j*2 + 1; 870 std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>()); 871 subset_i = 0; 872 subset_n = mi; 873 } 874 875 for( k = 0; k < m; k++ ) 876 { 877 double sum = 0; 878 for( j = 0; j < mi; j++ ) 879 sum += cjk[j*m + k]; 880 CV_Assert(sum > 0); 881 rc[k] = sum; 882 lc[k] = 0; 883 } 884 885 for( j = 0; j < mi; j++ ) 886 { 887 double sum = 0; 888 for( k = 0; k < m; k++ ) 889 sum += cjk[j*m + k]; 890 c_weights[j] = sum; 891 R += c_weights[j]; 892 } 893 894 for( ; subset_i < subset_n; subset_i++ ) 895 { 896 double lsum2 = 0, rsum2 = 0; 897 898 if( m == 2 ) 899 idx = (int)(dbl_ptr[subset_i] - cjk)/2; 900 else 901 { 902 int graycode = (subset_i>>1)^subset_i; 903 int diff = graycode ^ prevcode; 904 905 // determine index of the changed bit. 906 Cv32suf u; 907 idx = diff >= (1 << 16) ? 16 : 0; 908 u.f = (float)(((diff >> 16) | diff) & 65535); 909 idx += (u.i >> 23) - 127; 910 subtract = graycode < prevcode; 911 prevcode = graycode; 912 } 913 914 double* crow = cjk + idx*m; 915 double weight = c_weights[idx]; 916 if( weight < FLT_EPSILON ) 917 continue; 918 919 if( !subtract ) 920 { 921 for( k = 0; k < m; k++ ) 922 { 923 double t = crow[k]; 924 double lval = lc[k] + t; 925 double rval = rc[k] - t; 926 lsum2 += lval*lval; 927 rsum2 += rval*rval; 928 lc[k] = lval; rc[k] = rval; 929 } 930 L += weight; 931 R -= weight; 932 } 933 else 934 { 935 for( k = 0; k < m; k++ ) 936 { 937 double t = crow[k]; 938 double lval = lc[k] - t; 939 double rval = rc[k] + t; 940 lsum2 += lval*lval; 941 rsum2 += rval*rval; 942 lc[k] = lval; rc[k] = rval; 943 } 944 L -= weight; 945 R += weight; 946 } 947 948 if( L > FLT_EPSILON && R > FLT_EPSILON ) 949 { 950 double val = (lsum2*R + rsum2*L)/(L*R); 951 if( best_val < val ) 952 { 953 best_val = val; 954 best_subset = subset_i; 955 } 956 } 957 } 958 959 WSplit split; 960 if( best_subset >= 0 ) 961 { 962 split.varIdx = vi; 963 split.quality = (float)best_val; 964 memset( subset, 0, getSubsetSize(vi) * sizeof(int) ); 965 if( m == 2 ) 966 { 967 for( i = 0; i <= best_subset; i++ ) 968 { 969 idx = (int)(dbl_ptr[i] - cjk) >> 1; 970 subset[idx >> 5] |= 1 << (idx & 31); 971 } 972 } 973 else 974 { 975 for( i = 0; i < _mi; i++ ) 976 { 977 idx = cluster_labels ? cluster_labels[i] : i; 978 if( best_subset & (1 << idx) ) 979 subset[i >> 5] |= 1 << (i & 31); 980 } 981 } 982 } 983 return split; 984 } 985 986 DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality ) 987 { 988 const float epsilon = FLT_EPSILON*2; 989 const double* weights = &w->sample_weights[0]; 990 int n = (int)_sidx.size(); 991 992 AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float))); 993 994 float* values = (float*)(uchar*)buf; 995 int* sorted_idx = (int*)(values + n); 996 w->data->getValues(vi, _sidx, values); 997 const double* responses = &w->ord_responses[0]; 998 999 int i, si, best_i = -1; 1000 double L = 0, R = 0; 1001 double best_val = initQuality, lsum = 0, rsum = 0; 1002 1003 for( i = 0; i < n; i++ ) 1004 { 1005 sorted_idx[i] = i; 1006 si = _sidx[i]; 1007 R += weights[si]; 1008 rsum += weights[si]*responses[si]; 1009 } 1010 1011 std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values)); 1012 1013 // find the optimal split 1014 for( i = 0; i < n - 1; i++ ) 1015 { 1016 int curr = sorted_idx[i]; 1017 int next = sorted_idx[i+1]; 1018 si = _sidx[curr]; 1019 double wval = weights[si]; 1020 double t = responses[si]*wval; 1021 L += wval; R -= wval; 1022 lsum += t; rsum -= t; 1023 1024 if( values[curr] + epsilon < values[next] ) 1025 { 1026 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R); 1027 if( best_val < val ) 1028 { 1029 best_val = val; 1030 best_i = i; 1031 } 1032 } 1033 } 1034 1035 WSplit split; 1036 if( best_i >= 0 ) 1037 { 1038 split.varIdx = vi; 1039 split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f; 1040 split.inversed = false; 1041 split.quality = (float)best_val; 1042 } 1043 return split; 1044 } 1045 1046 DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx, 1047 double initQuality, int* subset ) 1048 { 1049 const double* weights = &w->sample_weights[0]; 1050 const double* responses = &w->ord_responses[0]; 1051 int n = (int)_sidx.size(); 1052 int mi = getCatCount(vi); 1053 1054 AutoBuffer<double> buf(3*mi + 3 + n); 1055 double* sum = (double*)buf + 1; 1056 double* counts = sum + mi + 1; 1057 double** sum_ptr = (double**)(counts + mi); 1058 int* cat_labels = (int*)(sum_ptr + mi); 1059 1060 w->data->getNormCatValues(vi, _sidx, cat_labels); 1061 1062 double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0; 1063 int i, si, best_subset = -1, subset_i; 1064 1065 for( i = -1; i < mi; i++ ) 1066 sum[i] = counts[i] = 0; 1067 1068 // calculate sum response and weight of each category of the input var 1069 for( i = 0; i < n; i++ ) 1070 { 1071 int idx = cat_labels[i]; 1072 si = _sidx[i]; 1073 double wval = weights[si]; 1074 sum[idx] += responses[si]*wval; 1075 counts[idx] += wval; 1076 } 1077 1078 // calculate average response in each category 1079 for( i = 0; i < mi; i++ ) 1080 { 1081 R += counts[i]; 1082 rsum += sum[i]; 1083 sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0; 1084 sum_ptr[i] = sum + i; 1085 } 1086 1087 std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>()); 1088 1089 // revert back to unnormalized sums 1090 // (there should be a very little loss in accuracy) 1091 for( i = 0; i < mi; i++ ) 1092 sum[i] *= counts[i]; 1093 1094 for( subset_i = 0; subset_i < mi-1; subset_i++ ) 1095 { 1096 int idx = (int)(sum_ptr[subset_i] - sum); 1097 double ni = counts[idx]; 1098 1099 if( ni > FLT_EPSILON ) 1100 { 1101 double s = sum[idx]; 1102 lsum += s; L += ni; 1103 rsum -= s; R -= ni; 1104 1105 if( L > FLT_EPSILON && R > FLT_EPSILON ) 1106 { 1107 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R); 1108 if( best_val < val ) 1109 { 1110 best_val = val; 1111 best_subset = subset_i; 1112 } 1113 } 1114 } 1115 } 1116 1117 WSplit split; 1118 if( best_subset >= 0 ) 1119 { 1120 split.varIdx = vi; 1121 split.quality = (float)best_val; 1122 memset( subset, 0, getSubsetSize(vi) * sizeof(int)); 1123 for( i = 0; i <= best_subset; i++ ) 1124 { 1125 int idx = (int)(sum_ptr[i] - sum); 1126 subset[idx >> 5] |= 1 << (idx & 31); 1127 } 1128 } 1129 return split; 1130 } 1131 1132 int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx, 1133 vector<int>& _sleft, vector<int>& _sright ) 1134 { 1135 WSplit split = w->wsplits[splitidx]; 1136 int i, si, n = (int)_sidx.size(), vi = split.varIdx; 1137 _sleft.reserve(n); 1138 _sright.reserve(n); 1139 _sleft.clear(); 1140 _sright.clear(); 1141 1142 AutoBuffer<float> buf(n); 1143 int mi = getCatCount(vi); 1144 double wleft = 0, wright = 0; 1145 const double* weights = &w->sample_weights[0]; 1146 1147 if( mi <= 0 ) // split on an ordered variable 1148 { 1149 float c = split.c; 1150 float* values = buf; 1151 w->data->getValues(vi, _sidx, values); 1152 1153 for( i = 0; i < n; i++ ) 1154 { 1155 si = _sidx[i]; 1156 if( values[i] <= c ) 1157 { 1158 _sleft.push_back(si); 1159 wleft += weights[si]; 1160 } 1161 else 1162 { 1163 _sright.push_back(si); 1164 wright += weights[si]; 1165 } 1166 } 1167 } 1168 else 1169 { 1170 const int* subset = &w->wsubsets[split.subsetOfs]; 1171 int* cat_labels = (int*)(float*)buf; 1172 w->data->getNormCatValues(vi, _sidx, cat_labels); 1173 1174 for( i = 0; i < n; i++ ) 1175 { 1176 si = _sidx[i]; 1177 unsigned u = cat_labels[i]; 1178 if( CV_DTREE_CAT_DIR(u, subset) < 0 ) 1179 { 1180 _sleft.push_back(si); 1181 wleft += weights[si]; 1182 } 1183 else 1184 { 1185 _sright.push_back(si); 1186 wright += weights[si]; 1187 } 1188 } 1189 } 1190 CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n ); 1191 return wleft > wright ? -1 : 1; 1192 } 1193 1194 int DTreesImpl::pruneCV( int root ) 1195 { 1196 vector<double> ab; 1197 1198 // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}. 1199 // 2. choose the best tree index (if need, apply 1SE rule). 1200 // 3. store the best index and cut the branches. 1201 1202 int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count; 1203 // currently, 1SE for regression is not implemented 1204 bool use_1se = params.use1SERule != 0 && _isClassifier; 1205 double min_err = 0, min_err_se = 0; 1206 int min_idx = -1; 1207 1208 // build the main tree sequence, calculate alpha's 1209 for(;;tree_count++) 1210 { 1211 double min_alpha = updateTreeRNC(root, tree_count, -1); 1212 if( cutTree(root, tree_count, -1, min_alpha) ) 1213 break; 1214 1215 ab.push_back(min_alpha); 1216 } 1217 1218 if( tree_count > 0 ) 1219 { 1220 ab[0] = 0.; 1221 1222 for( ti = 1; ti < tree_count-1; ti++ ) 1223 ab[ti] = std::sqrt(ab[ti]*ab[ti+1]); 1224 ab[tree_count-1] = DBL_MAX*0.5; 1225 1226 Mat err_jk(cv_n, tree_count, CV_64F); 1227 1228 for( j = 0; j < cv_n; j++ ) 1229 { 1230 int tj = 0, tk = 0; 1231 for( ; tj < tree_count; tj++ ) 1232 { 1233 double min_alpha = updateTreeRNC(root, tj, j); 1234 if( cutTree(root, tj, j, min_alpha) ) 1235 min_alpha = DBL_MAX; 1236 1237 for( ; tk < tree_count; tk++ ) 1238 { 1239 if( ab[tk] > min_alpha ) 1240 break; 1241 err_jk.at<double>(j, tk) = w->wnodes[root].tree_error; 1242 } 1243 } 1244 } 1245 1246 for( ti = 0; ti < tree_count; ti++ ) 1247 { 1248 double sum_err = 0; 1249 for( j = 0; j < cv_n; j++ ) 1250 sum_err += err_jk.at<double>(j, ti); 1251 if( ti == 0 || sum_err < min_err ) 1252 { 1253 min_err = sum_err; 1254 min_idx = ti; 1255 if( use_1se ) 1256 min_err_se = sqrt( sum_err*(n - sum_err) ); 1257 } 1258 else if( sum_err < min_err + min_err_se ) 1259 min_idx = ti; 1260 } 1261 } 1262 1263 return min_idx; 1264 } 1265 1266 double DTreesImpl::updateTreeRNC( int root, double T, int fold ) 1267 { 1268 int nidx = root, pidx = -1, cv_n = params.getCVFolds(); 1269 double min_alpha = DBL_MAX; 1270 1271 for(;;) 1272 { 1273 WNode *node = 0, *parent = 0; 1274 1275 for(;;) 1276 { 1277 node = &w->wnodes[nidx]; 1278 double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn; 1279 if( t <= T || node->left < 0 ) 1280 { 1281 node->complexity = 1; 1282 node->tree_risk = node->node_risk; 1283 node->tree_error = 0.; 1284 if( fold >= 0 ) 1285 { 1286 node->tree_risk = w->cv_node_risk[nidx*cv_n + fold]; 1287 node->tree_error = w->cv_node_error[nidx*cv_n + fold]; 1288 } 1289 break; 1290 } 1291 nidx = node->left; 1292 } 1293 1294 for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx; 1295 nidx = pidx, pidx = w->wnodes[pidx].parent ) 1296 { 1297 node = &w->wnodes[nidx]; 1298 parent = &w->wnodes[pidx]; 1299 parent->complexity += node->complexity; 1300 parent->tree_risk += node->tree_risk; 1301 parent->tree_error += node->tree_error; 1302 1303 parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk) 1304 - parent->tree_risk)/(parent->complexity - 1); 1305 min_alpha = std::min( min_alpha, parent->alpha ); 1306 } 1307 1308 if( pidx < 0 ) 1309 break; 1310 1311 node = &w->wnodes[nidx]; 1312 parent = &w->wnodes[pidx]; 1313 parent->complexity = node->complexity; 1314 parent->tree_risk = node->tree_risk; 1315 parent->tree_error = node->tree_error; 1316 nidx = parent->right; 1317 } 1318 1319 return min_alpha; 1320 } 1321 1322 bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha ) 1323 { 1324 int cv_n = params.getCVFolds(), nidx = root, pidx = -1; 1325 WNode* node = &w->wnodes[root]; 1326 if( node->left < 0 ) 1327 return true; 1328 1329 for(;;) 1330 { 1331 for(;;) 1332 { 1333 node = &w->wnodes[nidx]; 1334 double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn; 1335 if( t <= T || node->left < 0 ) 1336 break; 1337 if( node->alpha <= min_alpha + FLT_EPSILON ) 1338 { 1339 if( fold >= 0 ) 1340 w->cv_Tn[nidx*cv_n + fold] = T; 1341 else 1342 node->Tn = T; 1343 if( nidx == root ) 1344 return true; 1345 break; 1346 } 1347 nidx = node->left; 1348 } 1349 1350 for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx; 1351 nidx = pidx, pidx = w->wnodes[pidx].parent ) 1352 ; 1353 1354 if( pidx < 0 ) 1355 break; 1356 1357 nidx = w->wnodes[pidx].right; 1358 } 1359 1360 return false; 1361 } 1362 1363 float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const 1364 { 1365 CV_Assert( sample.type() == CV_32F ); 1366 1367 int predictType = flags & PREDICT_MASK; 1368 int nvars = (int)varIdx.size(); 1369 if( nvars == 0 ) 1370 nvars = (int)varType.size(); 1371 int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size(); 1372 int catbufsize = ncats > 0 ? nvars : 0; 1373 AutoBuffer<int> buf(nclasses + catbufsize + 1); 1374 int* votes = buf; 1375 int* catbuf = votes + nclasses; 1376 const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0; 1377 const uchar* vtype = &varType[0]; 1378 const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0; 1379 const int* cmap = !catMap.empty() ? &catMap[0] : 0; 1380 const float* psample = sample.ptr<float>(); 1381 const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0; 1382 size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float); 1383 double sum = 0.; 1384 int lastClassIdx = -1; 1385 const float MISSED_VAL = TrainData::missingValue(); 1386 1387 for( i = 0; i < catbufsize; i++ ) 1388 catbuf[i] = -1; 1389 1390 if( predictType == PREDICT_AUTO ) 1391 { 1392 predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ? 1393 PREDICT_SUM : PREDICT_MAX_VOTE; 1394 } 1395 1396 if( predictType == PREDICT_MAX_VOTE ) 1397 { 1398 for( i = 0; i < nclasses; i++ ) 1399 votes[i] = 0; 1400 } 1401 1402 for( int ridx = range.start; ridx < range.end; ridx++ ) 1403 { 1404 int nidx = roots[ridx], prev = nidx, c = 0; 1405 1406 for(;;) 1407 { 1408 prev = nidx; 1409 const Node& node = nodes[nidx]; 1410 if( node.split < 0 ) 1411 break; 1412 const Split& split = splits[node.split]; 1413 int vi = split.varIdx; 1414 int ci = cvidx ? cvidx[vi] : vi; 1415 float val = psample[ci*sstep]; 1416 if( val == MISSED_VAL ) 1417 { 1418 if( !missingSubstPtr ) 1419 { 1420 nidx = node.defaultDir < 0 ? node.left : node.right; 1421 continue; 1422 } 1423 val = missingSubstPtr[vi]; 1424 } 1425 1426 if( vtype[vi] == VAR_ORDERED ) 1427 nidx = val <= split.c ? node.left : node.right; 1428 else 1429 { 1430 if( flags & PREPROCESSED_INPUT ) 1431 c = cvRound(val); 1432 else 1433 { 1434 c = catbuf[ci]; 1435 if( c < 0 ) 1436 { 1437 int a = c = cofs[vi][0]; 1438 int b = cofs[vi][1]; 1439 1440 int ival = cvRound(val); 1441 if( ival != val ) 1442 CV_Error( CV_StsBadArg, 1443 "one of input categorical variable is not an integer" ); 1444 1445 while( a < b ) 1446 { 1447 c = (a + b) >> 1; 1448 if( ival < cmap[c] ) 1449 b = c; 1450 else if( ival > cmap[c] ) 1451 a = c+1; 1452 else 1453 break; 1454 } 1455 1456 CV_Assert( c >= 0 && ival == cmap[c] ); 1457 1458 c -= cofs[vi][0]; 1459 catbuf[ci] = c; 1460 } 1461 const int* subset = &subsets[split.subsetOfs]; 1462 unsigned u = c; 1463 nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right; 1464 } 1465 } 1466 } 1467 1468 if( predictType == PREDICT_SUM ) 1469 sum += nodes[prev].value; 1470 else 1471 { 1472 lastClassIdx = nodes[prev].classIdx; 1473 votes[lastClassIdx]++; 1474 } 1475 } 1476 1477 if( predictType == PREDICT_MAX_VOTE ) 1478 { 1479 int best_idx = lastClassIdx; 1480 if( range.end - range.start > 1 ) 1481 { 1482 best_idx = 0; 1483 for( i = 1; i < nclasses; i++ ) 1484 if( votes[best_idx] < votes[i] ) 1485 best_idx = i; 1486 } 1487 sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx]; 1488 } 1489 1490 return (float)sum; 1491 } 1492 1493 1494 float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const 1495 { 1496 CV_Assert( !roots.empty() ); 1497 Mat samples = _samples.getMat(), results; 1498 int i, nsamples = samples.rows; 1499 int rtype = CV_32F; 1500 bool needresults = _results.needed(); 1501 float retval = 0.f; 1502 bool iscls = isClassifier(); 1503 float scale = !iscls ? 1.f/(int)roots.size() : 1.f; 1504 1505 if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE ) 1506 rtype = CV_32S; 1507 1508 if( needresults ) 1509 { 1510 _results.create(nsamples, 1, rtype); 1511 results = _results.getMat(); 1512 } 1513 else 1514 nsamples = std::min(nsamples, 1); 1515 1516 for( i = 0; i < nsamples; i++ ) 1517 { 1518 float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale; 1519 if( needresults ) 1520 { 1521 if( rtype == CV_32F ) 1522 results.at<float>(i) = val; 1523 else 1524 results.at<int>(i) = cvRound(val); 1525 } 1526 if( i == 0 ) 1527 retval = val; 1528 } 1529 return retval; 1530 } 1531 1532 void DTreesImpl::writeTrainingParams(FileStorage& fs) const 1533 { 1534 fs << "use_surrogates" << (params.useSurrogates ? 1 : 0); 1535 fs << "max_categories" << params.getMaxCategories(); 1536 fs << "regression_accuracy" << params.getRegressionAccuracy(); 1537 1538 fs << "max_depth" << params.getMaxDepth(); 1539 fs << "min_sample_count" << params.getMinSampleCount(); 1540 fs << "cross_validation_folds" << params.getCVFolds(); 1541 1542 if( params.getCVFolds() > 1 ) 1543 fs << "use_1se_rule" << (params.use1SERule ? 1 : 0); 1544 1545 if( !params.priors.empty() ) 1546 fs << "priors" << params.priors; 1547 } 1548 1549 void DTreesImpl::writeParams(FileStorage& fs) const 1550 { 1551 fs << "is_classifier" << isClassifier(); 1552 fs << "var_all" << (int)varType.size(); 1553 fs << "var_count" << getVarCount(); 1554 1555 int ord_var_count = 0, cat_var_count = 0; 1556 int i, n = (int)varType.size(); 1557 for( i = 0; i < n; i++ ) 1558 if( varType[i] == VAR_ORDERED ) 1559 ord_var_count++; 1560 else 1561 cat_var_count++; 1562 fs << "ord_var_count" << ord_var_count; 1563 fs << "cat_var_count" << cat_var_count; 1564 1565 fs << "training_params" << "{"; 1566 writeTrainingParams(fs); 1567 1568 fs << "}"; 1569 1570 if( !varIdx.empty() ) 1571 { 1572 fs << "global_var_idx" << 1; 1573 fs << "var_idx" << varIdx; 1574 } 1575 1576 fs << "var_type" << varType; 1577 1578 if( !catOfs.empty() ) 1579 fs << "cat_ofs" << catOfs; 1580 if( !catMap.empty() ) 1581 fs << "cat_map" << catMap; 1582 if( !classLabels.empty() ) 1583 fs << "class_labels" << classLabels; 1584 if( !missingSubst.empty() ) 1585 fs << "missing_subst" << missingSubst; 1586 } 1587 1588 void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const 1589 { 1590 const Split& split = splits[splitidx]; 1591 1592 fs << "{:"; 1593 1594 int vi = split.varIdx; 1595 fs << "var" << vi; 1596 fs << "quality" << split.quality; 1597 1598 if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var 1599 { 1600 int i, n = getCatCount(vi), to_right = 0; 1601 const int* subset = &subsets[split.subsetOfs]; 1602 for( i = 0; i < n; i++ ) 1603 to_right += CV_DTREE_CAT_DIR(i, subset) > 0; 1604 1605 // ad-hoc rule when to use inverse categorical split notation 1606 // to achieve more compact and clear representation 1607 int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1; 1608 1609 fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:"; 1610 1611 for( i = 0; i < n; i++ ) 1612 { 1613 int dir = CV_DTREE_CAT_DIR(i, subset); 1614 if( dir*default_dir < 0 ) 1615 fs << i; 1616 } 1617 1618 fs << "]"; 1619 } 1620 else 1621 fs << (!split.inversed ? "le" : "gt") << split.c; 1622 1623 fs << "}"; 1624 } 1625 1626 void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const 1627 { 1628 const Node& node = nodes[nidx]; 1629 fs << "{"; 1630 fs << "depth" << depth; 1631 fs << "value" << node.value; 1632 1633 if( _isClassifier ) 1634 fs << "norm_class_idx" << node.classIdx; 1635 1636 if( node.split >= 0 ) 1637 { 1638 fs << "splits" << "["; 1639 1640 for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next ) 1641 writeSplit( fs, splitidx ); 1642 1643 fs << "]"; 1644 } 1645 1646 fs << "}"; 1647 } 1648 1649 void DTreesImpl::writeTree( FileStorage& fs, int root ) const 1650 { 1651 fs << "nodes" << "["; 1652 1653 int nidx = root, pidx = 0, depth = 0; 1654 const Node *node = 0; 1655 1656 // traverse the tree and save all the nodes in depth-first order 1657 for(;;) 1658 { 1659 for(;;) 1660 { 1661 writeNode( fs, nidx, depth ); 1662 node = &nodes[nidx]; 1663 if( node->left < 0 ) 1664 break; 1665 nidx = node->left; 1666 depth++; 1667 } 1668 1669 for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx; 1670 nidx = pidx, pidx = nodes[pidx].parent ) 1671 depth--; 1672 1673 if( pidx < 0 ) 1674 break; 1675 1676 nidx = nodes[pidx].right; 1677 } 1678 1679 fs << "]"; 1680 } 1681 1682 void DTreesImpl::write( FileStorage& fs ) const 1683 { 1684 writeParams(fs); 1685 writeTree(fs, roots[0]); 1686 } 1687 1688 void DTreesImpl::readParams( const FileNode& fn ) 1689 { 1690 _isClassifier = (int)fn["is_classifier"] != 0; 1691 /*int var_all = (int)fn["var_all"]; 1692 int var_count = (int)fn["var_count"]; 1693 int cat_var_count = (int)fn["cat_var_count"]; 1694 int ord_var_count = (int)fn["ord_var_count"];*/ 1695 1696 FileNode tparams_node = fn["training_params"]; 1697 1698 TreeParams params0 = TreeParams(); 1699 1700 if( !tparams_node.empty() ) // training parameters are not necessary 1701 { 1702 params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0; 1703 params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"])); 1704 params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]); 1705 params0.setMaxDepth((int)tparams_node["max_depth"]); 1706 params0.setMinSampleCount((int)tparams_node["min_sample_count"]); 1707 params0.setCVFolds((int)tparams_node["cross_validation_folds"]); 1708 1709 if( params0.getCVFolds() > 1 ) 1710 { 1711 params.use1SERule = (int)tparams_node["use_1se_rule"] != 0; 1712 } 1713 1714 tparams_node["priors"] >> params0.priors; 1715 } 1716 1717 readVectorOrMat(fn["var_idx"], varIdx); 1718 fn["var_type"] >> varType; 1719 1720 int format = 0; 1721 fn["format"] >> format; 1722 bool isLegacy = format < 3; 1723 1724 int varAll = (int)fn["var_all"]; 1725 if (isLegacy && (int)varType.size() <= varAll) 1726 { 1727 std::vector<uchar> extendedTypes(varAll + 1, 0); 1728 1729 int i = 0, n; 1730 if (!varIdx.empty()) 1731 { 1732 n = (int)varIdx.size(); 1733 for (; i < n; ++i) 1734 { 1735 int var = varIdx[i]; 1736 extendedTypes[var] = varType[i]; 1737 } 1738 } 1739 else 1740 { 1741 n = (int)varType.size(); 1742 for (; i < n; ++i) 1743 { 1744 extendedTypes[i] = varType[i]; 1745 } 1746 } 1747 extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED); 1748 extendedTypes.swap(varType); 1749 } 1750 1751 readVectorOrMat(fn["cat_map"], catMap); 1752 1753 if (isLegacy) 1754 { 1755 // generating "catOfs" from "cat_count" 1756 catOfs.clear(); 1757 classLabels.clear(); 1758 std::vector<int> counts; 1759 readVectorOrMat(fn["cat_count"], counts); 1760 unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1; 1761 for (; i < size; ++i) 1762 { 1763 Vec2i newOffsets(0, 0); 1764 if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap 1765 { 1766 newOffsets[0] = curShift; 1767 curShift += counts[j]; 1768 newOffsets[1] = curShift; 1769 ++j; 1770 } 1771 catOfs.push_back(newOffsets); 1772 } 1773 // other elements in "catMap" are "classLabels" 1774 if (curShift < catMap.size()) 1775 { 1776 classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end()); 1777 catMap.erase(catMap.begin() + curShift, catMap.end()); 1778 } 1779 } 1780 else 1781 { 1782 fn["cat_ofs"] >> catOfs; 1783 fn["missing_subst"] >> missingSubst; 1784 fn["class_labels"] >> classLabels; 1785 } 1786 1787 // init var mapping for node reading (var indexes or varIdx indexes) 1788 bool globalVarIdx = false; 1789 fn["global_var_idx"] >> globalVarIdx; 1790 if (globalVarIdx || varIdx.empty()) 1791 setRangeVector(varMapping, (int)varType.size()); 1792 else 1793 varMapping = varIdx; 1794 1795 initCompVarIdx(); 1796 setDParams(params0); 1797 } 1798 1799 int DTreesImpl::readSplit( const FileNode& fn ) 1800 { 1801 Split split; 1802 1803 int vi = (int)fn["var"]; 1804 CV_Assert( 0 <= vi && vi <= (int)varType.size() ); 1805 vi = varMapping[vi]; // convert to varIdx if needed 1806 split.varIdx = vi; 1807 1808 if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var 1809 { 1810 int i, val, ssize = getSubsetSize(vi); 1811 split.subsetOfs = (int)subsets.size(); 1812 for( i = 0; i < ssize; i++ ) 1813 subsets.push_back(0); 1814 int* subset = &subsets[split.subsetOfs]; 1815 FileNode fns = fn["in"]; 1816 if( fns.empty() ) 1817 { 1818 fns = fn["not_in"]; 1819 split.inversed = true; 1820 } 1821 1822 if( fns.isInt() ) 1823 { 1824 val = (int)fns; 1825 subset[val >> 5] |= 1 << (val & 31); 1826 } 1827 else 1828 { 1829 FileNodeIterator it = fns.begin(); 1830 int n = (int)fns.size(); 1831 for( i = 0; i < n; i++, ++it ) 1832 { 1833 val = (int)*it; 1834 subset[val >> 5] |= 1 << (val & 31); 1835 } 1836 } 1837 1838 // for categorical splits we do not use inversed splits, 1839 // instead we inverse the variable set in the split 1840 if( split.inversed ) 1841 { 1842 for( i = 0; i < ssize; i++ ) 1843 subset[i] ^= -1; 1844 split.inversed = false; 1845 } 1846 } 1847 else 1848 { 1849 FileNode cmpNode = fn["le"]; 1850 if( cmpNode.empty() ) 1851 { 1852 cmpNode = fn["gt"]; 1853 split.inversed = true; 1854 } 1855 split.c = (float)cmpNode; 1856 } 1857 1858 split.quality = (float)fn["quality"]; 1859 splits.push_back(split); 1860 1861 return (int)(splits.size() - 1); 1862 } 1863 1864 int DTreesImpl::readNode( const FileNode& fn ) 1865 { 1866 Node node; 1867 node.value = (double)fn["value"]; 1868 1869 if( _isClassifier ) 1870 node.classIdx = (int)fn["norm_class_idx"]; 1871 1872 FileNode sfn = fn["splits"]; 1873 if( !sfn.empty() ) 1874 { 1875 int i, n = (int)sfn.size(), prevsplit = -1; 1876 FileNodeIterator it = sfn.begin(); 1877 1878 for( i = 0; i < n; i++, ++it ) 1879 { 1880 int splitidx = readSplit(*it); 1881 if( splitidx < 0 ) 1882 break; 1883 if( prevsplit < 0 ) 1884 node.split = splitidx; 1885 else 1886 splits[prevsplit].next = splitidx; 1887 prevsplit = splitidx; 1888 } 1889 } 1890 nodes.push_back(node); 1891 return (int)(nodes.size() - 1); 1892 } 1893 1894 int DTreesImpl::readTree( const FileNode& fn ) 1895 { 1896 int i, n = (int)fn.size(), root = -1, pidx = -1; 1897 FileNodeIterator it = fn.begin(); 1898 1899 for( i = 0; i < n; i++, ++it ) 1900 { 1901 int nidx = readNode(*it); 1902 if( nidx < 0 ) 1903 break; 1904 Node& node = nodes[nidx]; 1905 node.parent = pidx; 1906 if( pidx < 0 ) 1907 root = nidx; 1908 else 1909 { 1910 Node& parent = nodes[pidx]; 1911 if( parent.left < 0 ) 1912 parent.left = nidx; 1913 else 1914 parent.right = nidx; 1915 } 1916 if( node.split >= 0 ) 1917 pidx = nidx; 1918 else 1919 { 1920 while( pidx >= 0 && nodes[pidx].right >= 0 ) 1921 pidx = nodes[pidx].parent; 1922 } 1923 } 1924 roots.push_back(root); 1925 return root; 1926 } 1927 1928 void DTreesImpl::read( const FileNode& fn ) 1929 { 1930 clear(); 1931 readParams(fn); 1932 1933 FileNode fnodes = fn["nodes"]; 1934 CV_Assert( !fnodes.empty() ); 1935 readTree(fnodes); 1936 } 1937 1938 Ptr<DTrees> DTrees::create() 1939 { 1940 return makePtr<DTreesImpl>(); 1941 } 1942 1943 } 1944 } 1945 1946 /* End of file. */ 1947