Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                           License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Copyright (C) 2014, Itseez Inc, all rights reserved.
     15 // Third party copyrights are property of their respective owners.
     16 //
     17 // Redistribution and use in source and binary forms, with or without modification,
     18 // are permitted provided that the following conditions are met:
     19 //
     20 //   * Redistribution's of source code must retain the above copyright notice,
     21 //     this list of conditions and the following disclaimer.
     22 //
     23 //   * Redistribution's in binary form must reproduce the above copyright notice,
     24 //     this list of conditions and the following disclaimer in the documentation
     25 //     and/or other materials provided with the distribution.
     26 //
     27 //   * The name of the copyright holders may not be used to endorse or promote products
     28 //     derived from this software without specific prior written permission.
     29 //
     30 // This software is provided by the copyright holders and contributors "as is" and
     31 // any express or implied warranties, including, but not limited to, the implied
     32 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     33 // In no event shall the Intel Corporation or contributors be liable for any direct,
     34 // indirect, incidental, special, exemplary, or consequential damages
     35 // (including, but not limited to, procurement of substitute goods or services;
     36 // loss of use, data, or profits; or business interruption) however caused
     37 // and on any theory of liability, whether in contract, strict liability,
     38 // or tort (including negligence or otherwise) arising in any way out of
     39 // the use of this software, even if advised of the possibility of such damage.
     40 //
     41 //M*/
     42 
     43 #include "precomp.hpp"
     44 #include <ctype.h>
     45 
     46 namespace cv {
     47 namespace ml {
     48 
     49 using std::vector;
     50 
     51 TreeParams::TreeParams()
     52 {
     53     maxDepth = INT_MAX;
     54     minSampleCount = 10;
     55     regressionAccuracy = 0.01f;
     56     useSurrogates = false;
     57     maxCategories = 10;
     58     CVFolds = 10;
     59     use1SERule = true;
     60     truncatePrunedTree = true;
     61     priors = Mat();
     62 }
     63 
     64 TreeParams::TreeParams(int _maxDepth, int _minSampleCount,
     65                        double _regressionAccuracy, bool _useSurrogates,
     66                        int _maxCategories, int _CVFolds,
     67                        bool _use1SERule, bool _truncatePrunedTree,
     68                        const Mat& _priors)
     69 {
     70     maxDepth = _maxDepth;
     71     minSampleCount = _minSampleCount;
     72     regressionAccuracy = (float)_regressionAccuracy;
     73     useSurrogates = _useSurrogates;
     74     maxCategories = _maxCategories;
     75     CVFolds = _CVFolds;
     76     use1SERule = _use1SERule;
     77     truncatePrunedTree = _truncatePrunedTree;
     78     priors = _priors;
     79 }
     80 
     81 DTrees::Node::Node()
     82 {
     83     classIdx = 0;
     84     value = 0;
     85     parent = left = right = split = defaultDir = -1;
     86 }
     87 
     88 DTrees::Split::Split()
     89 {
     90     varIdx = 0;
     91     inversed = false;
     92     quality = 0.f;
     93     next = -1;
     94     c = 0.f;
     95     subsetOfs = 0;
     96 }
     97 
     98 
     99 DTreesImpl::WorkData::WorkData(const Ptr<TrainData>& _data)
    100 {
    101     data = _data;
    102     vector<int> subsampleIdx;
    103     Mat sidx0 = _data->getTrainSampleIdx();
    104     if( !sidx0.empty() )
    105     {
    106         sidx0.copyTo(sidx);
    107         std::sort(sidx.begin(), sidx.end());
    108     }
    109     else
    110     {
    111         int n = _data->getNSamples();
    112         setRangeVector(sidx, n);
    113     }
    114 
    115     maxSubsetSize = 0;
    116 }
    117 
    118 DTreesImpl::DTreesImpl() {}
    119 DTreesImpl::~DTreesImpl() {}
    120 void DTreesImpl::clear()
    121 {
    122     varIdx.clear();
    123     compVarIdx.clear();
    124     varType.clear();
    125     catOfs.clear();
    126     catMap.clear();
    127     roots.clear();
    128     nodes.clear();
    129     splits.clear();
    130     subsets.clear();
    131     classLabels.clear();
    132 
    133     w.release();
    134     _isClassifier = false;
    135 }
    136 
    137 void DTreesImpl::startTraining( const Ptr<TrainData>& data, int )
    138 {
    139     clear();
    140     w = makePtr<WorkData>(data);
    141 
    142     Mat vtype = data->getVarType();
    143     vtype.copyTo(varType);
    144 
    145     data->getCatOfs().copyTo(catOfs);
    146     data->getCatMap().copyTo(catMap);
    147     data->getDefaultSubstValues().copyTo(missingSubst);
    148 
    149     int nallvars = data->getNAllVars();
    150 
    151     Mat vidx0 = data->getVarIdx();
    152     if( !vidx0.empty() )
    153         vidx0.copyTo(varIdx);
    154     else
    155         setRangeVector(varIdx, nallvars);
    156 
    157     initCompVarIdx();
    158 
    159     w->maxSubsetSize = 0;
    160 
    161     int i, nvars = (int)varIdx.size();
    162     for( i = 0; i < nvars; i++ )
    163         w->maxSubsetSize = std::max(w->maxSubsetSize, getCatCount(varIdx[i]));
    164 
    165     w->maxSubsetSize = std::max((w->maxSubsetSize + 31)/32, 1);
    166 
    167     data->getSampleWeights().copyTo(w->sample_weights);
    168 
    169     _isClassifier = data->getResponseType() == VAR_CATEGORICAL;
    170 
    171     if( _isClassifier )
    172     {
    173         data->getNormCatResponses().copyTo(w->cat_responses);
    174         data->getClassLabels().copyTo(classLabels);
    175         int nclasses = (int)classLabels.size();
    176 
    177         Mat class_weights = params.priors;
    178         if( !class_weights.empty() )
    179         {
    180             if( class_weights.type() != CV_64F || !class_weights.isContinuous() )
    181             {
    182                 Mat temp;
    183                 class_weights.convertTo(temp, CV_64F);
    184                 class_weights = temp;
    185             }
    186             CV_Assert( class_weights.checkVector(1, CV_64F) == nclasses );
    187 
    188             int nsamples = (int)w->cat_responses.size();
    189             const double* cw = class_weights.ptr<double>();
    190             CV_Assert( (int)w->sample_weights.size() == nsamples );
    191 
    192             for( i = 0; i < nsamples; i++ )
    193             {
    194                 int ci = w->cat_responses[i];
    195                 CV_Assert( 0 <= ci && ci < nclasses );
    196                 w->sample_weights[i] *= cw[ci];
    197             }
    198         }
    199     }
    200     else
    201         data->getResponses().copyTo(w->ord_responses);
    202 }
    203 
    204 
    205 void DTreesImpl::initCompVarIdx()
    206 {
    207     int nallvars = (int)varType.size();
    208     compVarIdx.assign(nallvars, -1);
    209     int i, nvars = (int)varIdx.size(), prevIdx = -1;
    210     for( i = 0; i < nvars; i++ )
    211     {
    212         int vi = varIdx[i];
    213         CV_Assert( 0 <= vi && vi < nallvars && vi > prevIdx );
    214         prevIdx = vi;
    215         compVarIdx[vi] = i;
    216     }
    217 }
    218 
    219 void DTreesImpl::endTraining()
    220 {
    221     w.release();
    222 }
    223 
    224 bool DTreesImpl::train( const Ptr<TrainData>& trainData, int flags )
    225 {
    226     startTraining(trainData, flags);
    227     bool ok = addTree( w->sidx ) >= 0;
    228     w.release();
    229     endTraining();
    230     return ok;
    231 }
    232 
    233 const vector<int>& DTreesImpl::getActiveVars()
    234 {
    235     return varIdx;
    236 }
    237 
    238 int DTreesImpl::addTree(const vector<int>& sidx )
    239 {
    240     size_t n = (params.getMaxDepth() > 0 ? (1 << params.getMaxDepth()) : 1024) + w->wnodes.size();
    241 
    242     w->wnodes.reserve(n);
    243     w->wsplits.reserve(n);
    244     w->wsubsets.reserve(n*w->maxSubsetSize);
    245     w->wnodes.clear();
    246     w->wsplits.clear();
    247     w->wsubsets.clear();
    248 
    249     int cv_n = params.getCVFolds();
    250 
    251     if( cv_n > 0 )
    252     {
    253         w->cv_Tn.resize(n*cv_n);
    254         w->cv_node_error.resize(n*cv_n);
    255         w->cv_node_risk.resize(n*cv_n);
    256     }
    257 
    258     // build the tree recursively
    259     int w_root = addNodeAndTrySplit(-1, sidx);
    260     int maxdepth = INT_MAX;//pruneCV(root);
    261 
    262     int w_nidx = w_root, pidx = -1, depth = 0;
    263     int root = (int)nodes.size();
    264 
    265     for(;;)
    266     {
    267         const WNode& wnode = w->wnodes[w_nidx];
    268         Node node;
    269         node.parent = pidx;
    270         node.classIdx = wnode.class_idx;
    271         node.value = wnode.value;
    272         node.defaultDir = wnode.defaultDir;
    273 
    274         int wsplit_idx = wnode.split;
    275         if( wsplit_idx >= 0 )
    276         {
    277             const WSplit& wsplit = w->wsplits[wsplit_idx];
    278             Split split;
    279             split.c = wsplit.c;
    280             split.quality = wsplit.quality;
    281             split.inversed = wsplit.inversed;
    282             split.varIdx = wsplit.varIdx;
    283             split.subsetOfs = -1;
    284             if( wsplit.subsetOfs >= 0 )
    285             {
    286                 int ssize = getSubsetSize(split.varIdx);
    287                 split.subsetOfs = (int)subsets.size();
    288                 subsets.resize(split.subsetOfs + ssize);
    289                 // This check verifies that subsets index is in the correct range
    290                 // as in case ssize == 0 no real resize performed.
    291                 // Thus memory kept safe.
    292                 // Also this skips useless memcpy call when size parameter is zero
    293                 if(ssize > 0)
    294                 {
    295                     memcpy(&subsets[split.subsetOfs], &w->wsubsets[wsplit.subsetOfs], ssize*sizeof(int));
    296                 }
    297             }
    298             node.split = (int)splits.size();
    299             splits.push_back(split);
    300         }
    301         int nidx = (int)nodes.size();
    302         nodes.push_back(node);
    303         if( pidx >= 0 )
    304         {
    305             int w_pidx = w->wnodes[w_nidx].parent;
    306             if( w->wnodes[w_pidx].left == w_nidx )
    307             {
    308                 nodes[pidx].left = nidx;
    309             }
    310             else
    311             {
    312                 CV_Assert(w->wnodes[w_pidx].right == w_nidx);
    313                 nodes[pidx].right = nidx;
    314             }
    315         }
    316 
    317         if( wnode.left >= 0 && depth+1 < maxdepth )
    318         {
    319             w_nidx = wnode.left;
    320             pidx = nidx;
    321             depth++;
    322         }
    323         else
    324         {
    325             int w_pidx = wnode.parent;
    326             while( w_pidx >= 0 && w->wnodes[w_pidx].right == w_nidx )
    327             {
    328                 w_nidx = w_pidx;
    329                 w_pidx = w->wnodes[w_pidx].parent;
    330                 nidx = pidx;
    331                 pidx = nodes[pidx].parent;
    332                 depth--;
    333             }
    334 
    335             if( w_pidx < 0 )
    336                 break;
    337 
    338             w_nidx = w->wnodes[w_pidx].right;
    339             CV_Assert( w_nidx >= 0 );
    340         }
    341     }
    342     roots.push_back(root);
    343     return root;
    344 }
    345 
    346 void DTreesImpl::setDParams(const TreeParams& _params)
    347 {
    348     params = _params;
    349 }
    350 
    351 int DTreesImpl::addNodeAndTrySplit( int parent, const vector<int>& sidx )
    352 {
    353     w->wnodes.push_back(WNode());
    354     int nidx = (int)(w->wnodes.size() - 1);
    355     WNode& node = w->wnodes.back();
    356 
    357     node.parent = parent;
    358     node.depth = parent >= 0 ? w->wnodes[parent].depth + 1 : 0;
    359     int nfolds = params.getCVFolds();
    360 
    361     if( nfolds > 0 )
    362     {
    363         w->cv_Tn.resize((nidx+1)*nfolds);
    364         w->cv_node_error.resize((nidx+1)*nfolds);
    365         w->cv_node_risk.resize((nidx+1)*nfolds);
    366     }
    367 
    368     int i, n = node.sample_count = (int)sidx.size();
    369     bool can_split = true;
    370     vector<int> sleft, sright;
    371 
    372     calcValue( nidx, sidx );
    373 
    374     if( n <= params.getMinSampleCount() || node.depth >= params.getMaxDepth() )
    375         can_split = false;
    376     else if( _isClassifier )
    377     {
    378         const int* responses = &w->cat_responses[0];
    379         const int* s = &sidx[0];
    380         int first = responses[s[0]];
    381         for( i = 1; i < n; i++ )
    382             if( responses[s[i]] != first )
    383                 break;
    384         if( i == n )
    385             can_split = false;
    386     }
    387     else
    388     {
    389         if( sqrt(node.node_risk) < params.getRegressionAccuracy() )
    390             can_split = false;
    391     }
    392 
    393     if( can_split )
    394         node.split = findBestSplit( sidx );
    395 
    396     //printf("depth=%d, nidx=%d, parent=%d, n=%d, %s, value=%.1f, risk=%.1f\n", node.depth, nidx, node.parent, n, (node.split < 0 ? "leaf" : varType[w->wsplits[node.split].varIdx] == VAR_CATEGORICAL ? "cat" : "ord"), node.value, node.node_risk);
    397 
    398     if( node.split >= 0 )
    399     {
    400         node.defaultDir = calcDir( node.split, sidx, sleft, sright );
    401         if( params.useSurrogates )
    402             CV_Error( CV_StsNotImplemented, "surrogate splits are not implemented yet");
    403 
    404         int left = addNodeAndTrySplit( nidx, sleft );
    405         int right = addNodeAndTrySplit( nidx, sright );
    406         w->wnodes[nidx].left = left;
    407         w->wnodes[nidx].right = right;
    408         CV_Assert( w->wnodes[nidx].left > 0 && w->wnodes[nidx].right > 0 );
    409     }
    410 
    411     return nidx;
    412 }
    413 
    414 int DTreesImpl::findBestSplit( const vector<int>& _sidx )
    415 {
    416     const vector<int>& activeVars = getActiveVars();
    417     int splitidx = -1;
    418     int vi_, nv = (int)activeVars.size();
    419     AutoBuffer<int> buf(w->maxSubsetSize*2);
    420     int *subset = buf, *best_subset = subset + w->maxSubsetSize;
    421     WSplit split, best_split;
    422     best_split.quality = 0.;
    423 
    424     for( vi_ = 0; vi_ < nv; vi_++ )
    425     {
    426         int vi = activeVars[vi_];
    427         if( varType[vi] == VAR_CATEGORICAL )
    428         {
    429             if( _isClassifier )
    430                 split = findSplitCatClass(vi, _sidx, 0, subset);
    431             else
    432                 split = findSplitCatReg(vi, _sidx, 0, subset);
    433         }
    434         else
    435         {
    436             if( _isClassifier )
    437                 split = findSplitOrdClass(vi, _sidx, 0);
    438             else
    439                 split = findSplitOrdReg(vi, _sidx, 0);
    440         }
    441         if( split.quality > best_split.quality )
    442         {
    443             best_split = split;
    444             std::swap(subset, best_subset);
    445         }
    446     }
    447 
    448     if( best_split.quality > 0 )
    449     {
    450         int best_vi = best_split.varIdx;
    451         CV_Assert( compVarIdx[best_split.varIdx] >= 0 && best_vi >= 0 );
    452         int i, prevsz = (int)w->wsubsets.size(), ssize = getSubsetSize(best_vi);
    453         w->wsubsets.resize(prevsz + ssize);
    454         for( i = 0; i < ssize; i++ )
    455             w->wsubsets[prevsz + i] = best_subset[i];
    456         best_split.subsetOfs = prevsz;
    457         w->wsplits.push_back(best_split);
    458         splitidx = (int)(w->wsplits.size()-1);
    459     }
    460 
    461     return splitidx;
    462 }
    463 
    464 void DTreesImpl::calcValue( int nidx, const vector<int>& _sidx )
    465 {
    466     WNode* node = &w->wnodes[nidx];
    467     int i, j, k, n = (int)_sidx.size(), cv_n = params.getCVFolds();
    468     int m = (int)classLabels.size();
    469 
    470     cv::AutoBuffer<double> buf(std::max(m, 3)*(cv_n+1));
    471 
    472     if( cv_n > 0 )
    473     {
    474         size_t sz = w->cv_Tn.size();
    475         w->cv_Tn.resize(sz + cv_n);
    476         w->cv_node_risk.resize(sz + cv_n);
    477         w->cv_node_error.resize(sz + cv_n);
    478     }
    479 
    480     if( _isClassifier )
    481     {
    482         // in case of classification tree:
    483         //  * node value is the label of the class that has the largest weight in the node.
    484         //  * node risk is the weighted number of misclassified samples,
    485         //  * j-th cross-validation fold value and risk are calculated as above,
    486         //    but using the samples with cv_labels(*)!=j.
    487         //  * j-th cross-validation fold error is calculated as the weighted number of
    488         //    misclassified samples with cv_labels(*)==j.
    489 
    490         // compute the number of instances of each class
    491         double* cls_count = buf;
    492         double* cv_cls_count = cls_count + m;
    493 
    494         double max_val = -1, total_weight = 0;
    495         int max_k = -1;
    496 
    497         for( k = 0; k < m; k++ )
    498             cls_count[k] = 0;
    499 
    500         if( cv_n == 0 )
    501         {
    502             for( i = 0; i < n; i++ )
    503             {
    504                 int si = _sidx[i];
    505                 cls_count[w->cat_responses[si]] += w->sample_weights[si];
    506             }
    507         }
    508         else
    509         {
    510             for( j = 0; j < cv_n; j++ )
    511                 for( k = 0; k < m; k++ )
    512                     cv_cls_count[j*m + k] = 0;
    513 
    514             for( i = 0; i < n; i++ )
    515             {
    516                 int si = _sidx[i];
    517                 j = w->cv_labels[si]; k = w->cat_responses[si];
    518                 cv_cls_count[j*m + k] += w->sample_weights[si];
    519             }
    520 
    521             for( j = 0; j < cv_n; j++ )
    522                 for( k = 0; k < m; k++ )
    523                     cls_count[k] += cv_cls_count[j*m + k];
    524         }
    525 
    526         for( k = 0; k < m; k++ )
    527         {
    528             double val = cls_count[k];
    529             total_weight += val;
    530             if( max_val < val )
    531             {
    532                 max_val = val;
    533                 max_k = k;
    534             }
    535         }
    536 
    537         node->class_idx = max_k;
    538         node->value = classLabels[max_k];
    539         node->node_risk = total_weight - max_val;
    540 
    541         for( j = 0; j < cv_n; j++ )
    542         {
    543             double sum_k = 0, sum = 0, max_val_k = 0;
    544             max_val = -1; max_k = -1;
    545 
    546             for( k = 0; k < m; k++ )
    547             {
    548                 double val_k = cv_cls_count[j*m + k];
    549                 double val = cls_count[k] - val_k;
    550                 sum_k += val_k;
    551                 sum += val;
    552                 if( max_val < val )
    553                 {
    554                     max_val = val;
    555                     max_val_k = val_k;
    556                     max_k = k;
    557                 }
    558             }
    559 
    560             w->cv_Tn[nidx*cv_n + j] = INT_MAX;
    561             w->cv_node_risk[nidx*cv_n + j] = sum - max_val;
    562             w->cv_node_error[nidx*cv_n + j] = sum_k - max_val_k;
    563         }
    564     }
    565     else
    566     {
    567         // in case of regression tree:
    568         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
    569         //    n is the number of samples in the node.
    570         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
    571         //  * j-th cross-validation fold value and risk are calculated as above,
    572         //    but using the samples with cv_labels(*)!=j.
    573         //  * j-th cross-validation fold error is calculated
    574         //    using samples with cv_labels(*)==j as the test subset:
    575         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
    576         //    where node_value_j is the node value calculated
    577         //    as described in the previous bullet, and summation is done
    578         //    over the samples with cv_labels(*)==j.
    579         double sum = 0, sum2 = 0, sumw = 0;
    580 
    581         if( cv_n == 0 )
    582         {
    583             for( i = 0; i < n; i++ )
    584             {
    585                 int si = _sidx[i];
    586                 double wval = w->sample_weights[si];
    587                 double t = w->ord_responses[si];
    588                 sum += t*wval;
    589                 sum2 += t*t*wval;
    590                 sumw += wval;
    591             }
    592         }
    593         else
    594         {
    595             double *cv_sum = buf, *cv_sum2 = cv_sum + cv_n;
    596             double* cv_count = (double*)(cv_sum2 + cv_n);
    597 
    598             for( j = 0; j < cv_n; j++ )
    599             {
    600                 cv_sum[j] = cv_sum2[j] = 0.;
    601                 cv_count[j] = 0;
    602             }
    603 
    604             for( i = 0; i < n; i++ )
    605             {
    606                 int si = _sidx[i];
    607                 j = w->cv_labels[si];
    608                 double wval = w->sample_weights[si];
    609                 double t = w->ord_responses[si];
    610                 cv_sum[j] += t*wval;
    611                 cv_sum2[j] += t*t*wval;
    612                 cv_count[j] += wval;
    613             }
    614 
    615             for( j = 0; j < cv_n; j++ )
    616             {
    617                 sum += cv_sum[j];
    618                 sum2 += cv_sum2[j];
    619                 sumw += cv_count[j];
    620             }
    621 
    622             for( j = 0; j < cv_n; j++ )
    623             {
    624                 double s = sum - cv_sum[j], si = sum - s;
    625                 double s2 = sum2 - cv_sum2[j], s2i = sum2 - s2;
    626                 double c = cv_count[j], ci = sumw - c;
    627                 double r = si/std::max(ci, DBL_EPSILON);
    628                 w->cv_node_risk[nidx*cv_n + j] = s2i - r*r*ci;
    629                 w->cv_node_error[nidx*cv_n + j] = s2 - 2*r*s + c*r*r;
    630                 w->cv_Tn[nidx*cv_n + j] = INT_MAX;
    631             }
    632         }
    633 
    634         node->node_risk = sum2 - (sum/sumw)*sum;
    635         node->value = sum/sumw;
    636     }
    637 }
    638 
    639 DTreesImpl::WSplit DTreesImpl::findSplitOrdClass( int vi, const vector<int>& _sidx, double initQuality )
    640 {
    641     const double epsilon = FLT_EPSILON*2;
    642     int n = (int)_sidx.size();
    643     int m = (int)classLabels.size();
    644 
    645     cv::AutoBuffer<uchar> buf(n*(sizeof(float) + sizeof(int)) + m*2*sizeof(double));
    646     const int* sidx = &_sidx[0];
    647     const int* responses = &w->cat_responses[0];
    648     const double* weights = &w->sample_weights[0];
    649     double* lcw = (double*)(uchar*)buf;
    650     double* rcw = lcw + m;
    651     float* values = (float*)(rcw + m);
    652     int* sorted_idx = (int*)(values + n);
    653     int i, best_i = -1;
    654     double best_val = initQuality;
    655 
    656     for( i = 0; i < m; i++ )
    657         lcw[i] = rcw[i] = 0.;
    658 
    659     w->data->getValues( vi, _sidx, values );
    660 
    661     for( i = 0; i < n; i++ )
    662     {
    663         sorted_idx[i] = i;
    664         int si = sidx[i];
    665         rcw[responses[si]] += weights[si];
    666     }
    667 
    668     std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
    669 
    670     double L = 0, R = 0, lsum2 = 0, rsum2 = 0;
    671     for( i = 0; i < m; i++ )
    672     {
    673         double wval = rcw[i];
    674         R += wval;
    675         rsum2 += wval*wval;
    676     }
    677 
    678     for( i = 0; i < n - 1; i++ )
    679     {
    680         int curr = sorted_idx[i];
    681         int next = sorted_idx[i+1];
    682         int si = sidx[curr];
    683         double wval = weights[si], w2 = wval*wval;
    684         L += wval; R -= wval;
    685         int idx = responses[si];
    686         double lv = lcw[idx], rv = rcw[idx];
    687         lsum2 += 2*lv*wval + w2;
    688         rsum2 -= 2*rv*wval - w2;
    689         lcw[idx] = lv + wval; rcw[idx] = rv - wval;
    690 
    691         if( values[curr] + epsilon < values[next] )
    692         {
    693             double val = (lsum2*R + rsum2*L)/(L*R);
    694             if( best_val < val )
    695             {
    696                 best_val = val;
    697                 best_i = i;
    698             }
    699         }
    700     }
    701 
    702     WSplit split;
    703     if( best_i >= 0 )
    704     {
    705         split.varIdx = vi;
    706         split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
    707         split.inversed = false;
    708         split.quality = (float)best_val;
    709     }
    710     return split;
    711 }
    712 
    713 // simple k-means, slightly modified to take into account the "weight" (L1-norm) of each vector.
    714 void DTreesImpl::clusterCategories( const double* vectors, int n, int m, double* csums, int k, int* labels )
    715 {
    716     int iters = 0, max_iters = 100;
    717     int i, j, idx;
    718     cv::AutoBuffer<double> buf(n + k);
    719     double *v_weights = buf, *c_weights = buf + n;
    720     bool modified = true;
    721     RNG r((uint64)-1);
    722 
    723     // assign labels randomly
    724     for( i = 0; i < n; i++ )
    725     {
    726         double sum = 0;
    727         const double* v = vectors + i*m;
    728         labels[i] = i < k ? i : r.uniform(0, k);
    729 
    730         // compute weight of each vector
    731         for( j = 0; j < m; j++ )
    732             sum += v[j];
    733         v_weights[i] = sum ? 1./sum : 0.;
    734     }
    735 
    736     for( i = 0; i < n; i++ )
    737     {
    738         int i1 = r.uniform(0, n);
    739         int i2 = r.uniform(0, n);
    740         std::swap( labels[i1], labels[i2] );
    741     }
    742 
    743     for( iters = 0; iters <= max_iters; iters++ )
    744     {
    745         // calculate csums
    746         for( i = 0; i < k; i++ )
    747         {
    748             for( j = 0; j < m; j++ )
    749                 csums[i*m + j] = 0;
    750         }
    751 
    752         for( i = 0; i < n; i++ )
    753         {
    754             const double* v = vectors + i*m;
    755             double* s = csums + labels[i]*m;
    756             for( j = 0; j < m; j++ )
    757                 s[j] += v[j];
    758         }
    759 
    760         // exit the loop here, when we have up-to-date csums
    761         if( iters == max_iters || !modified )
    762             break;
    763 
    764         modified = false;
    765 
    766         // calculate weight of each cluster
    767         for( i = 0; i < k; i++ )
    768         {
    769             const double* s = csums + i*m;
    770             double sum = 0;
    771             for( j = 0; j < m; j++ )
    772                 sum += s[j];
    773             c_weights[i] = sum ? 1./sum : 0;
    774         }
    775 
    776         // now for each vector determine the closest cluster
    777         for( i = 0; i < n; i++ )
    778         {
    779             const double* v = vectors + i*m;
    780             double alpha = v_weights[i];
    781             double min_dist2 = DBL_MAX;
    782             int min_idx = -1;
    783 
    784             for( idx = 0; idx < k; idx++ )
    785             {
    786                 const double* s = csums + idx*m;
    787                 double dist2 = 0., beta = c_weights[idx];
    788                 for( j = 0; j < m; j++ )
    789                 {
    790                     double t = v[j]*alpha - s[j]*beta;
    791                     dist2 += t*t;
    792                 }
    793                 if( min_dist2 > dist2 )
    794                 {
    795                     min_dist2 = dist2;
    796                     min_idx = idx;
    797                 }
    798             }
    799 
    800             if( min_idx != labels[i] )
    801                 modified = true;
    802             labels[i] = min_idx;
    803         }
    804     }
    805 }
    806 
    807 DTreesImpl::WSplit DTreesImpl::findSplitCatClass( int vi, const vector<int>& _sidx,
    808                                                   double initQuality, int* subset )
    809 {
    810     int _mi = getCatCount(vi), mi = _mi;
    811     int n = (int)_sidx.size();
    812     int m = (int)classLabels.size();
    813 
    814     int base_size = m*(3 + mi) + mi + 1;
    815     if( m > 2 && mi > params.getMaxCategories() )
    816         base_size += m*std::min(params.getMaxCategories(), n) + mi;
    817     else
    818         base_size += mi;
    819     AutoBuffer<double> buf(base_size + n);
    820 
    821     double* lc = (double*)buf;
    822     double* rc = lc + m;
    823     double* _cjk = rc + m*2, *cjk = _cjk;
    824     double* c_weights = cjk + m*mi;
    825 
    826     int* labels = (int*)(buf + base_size);
    827     w->data->getNormCatValues(vi, _sidx, labels);
    828     const int* responses = &w->cat_responses[0];
    829     const double* weights = &w->sample_weights[0];
    830 
    831     int* cluster_labels = 0;
    832     double** dbl_ptr = 0;
    833     int i, j, k, si, idx;
    834     double L = 0, R = 0;
    835     double best_val = initQuality;
    836     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
    837 
    838     // init array of counters:
    839     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
    840     for( j = -1; j < mi; j++ )
    841         for( k = 0; k < m; k++ )
    842             cjk[j*m + k] = 0;
    843 
    844     for( i = 0; i < n; i++ )
    845     {
    846         si = _sidx[i];
    847         j = labels[i];
    848         k = responses[si];
    849         cjk[j*m + k] += weights[si];
    850     }
    851 
    852     if( m > 2 )
    853     {
    854         if( mi > params.getMaxCategories() )
    855         {
    856             mi = std::min(params.getMaxCategories(), n);
    857             cjk = c_weights + _mi;
    858             cluster_labels = (int*)(cjk + m*mi);
    859             clusterCategories( _cjk, _mi, m, cjk, mi, cluster_labels );
    860         }
    861         subset_i = 1;
    862         subset_n = 1 << mi;
    863     }
    864     else
    865     {
    866         assert( m == 2 );
    867         dbl_ptr = (double**)(c_weights + _mi);
    868         for( j = 0; j < mi; j++ )
    869             dbl_ptr[j] = cjk + j*2 + 1;
    870         std::sort(dbl_ptr, dbl_ptr + mi, cmp_lt_ptr<double>());
    871         subset_i = 0;
    872         subset_n = mi;
    873     }
    874 
    875     for( k = 0; k < m; k++ )
    876     {
    877         double sum = 0;
    878         for( j = 0; j < mi; j++ )
    879             sum += cjk[j*m + k];
    880         CV_Assert(sum > 0);
    881         rc[k] = sum;
    882         lc[k] = 0;
    883     }
    884 
    885     for( j = 0; j < mi; j++ )
    886     {
    887         double sum = 0;
    888         for( k = 0; k < m; k++ )
    889             sum += cjk[j*m + k];
    890         c_weights[j] = sum;
    891         R += c_weights[j];
    892     }
    893 
    894     for( ; subset_i < subset_n; subset_i++ )
    895     {
    896         double lsum2 = 0, rsum2 = 0;
    897 
    898         if( m == 2 )
    899             idx = (int)(dbl_ptr[subset_i] - cjk)/2;
    900         else
    901         {
    902             int graycode = (subset_i>>1)^subset_i;
    903             int diff = graycode ^ prevcode;
    904 
    905             // determine index of the changed bit.
    906             Cv32suf u;
    907             idx = diff >= (1 << 16) ? 16 : 0;
    908             u.f = (float)(((diff >> 16) | diff) & 65535);
    909             idx += (u.i >> 23) - 127;
    910             subtract = graycode < prevcode;
    911             prevcode = graycode;
    912         }
    913 
    914         double* crow = cjk + idx*m;
    915         double weight = c_weights[idx];
    916         if( weight < FLT_EPSILON )
    917             continue;
    918 
    919         if( !subtract )
    920         {
    921             for( k = 0; k < m; k++ )
    922             {
    923                 double t = crow[k];
    924                 double lval = lc[k] + t;
    925                 double rval = rc[k] - t;
    926                 lsum2 += lval*lval;
    927                 rsum2 += rval*rval;
    928                 lc[k] = lval; rc[k] = rval;
    929             }
    930             L += weight;
    931             R -= weight;
    932         }
    933         else
    934         {
    935             for( k = 0; k < m; k++ )
    936             {
    937                 double t = crow[k];
    938                 double lval = lc[k] - t;
    939                 double rval = rc[k] + t;
    940                 lsum2 += lval*lval;
    941                 rsum2 += rval*rval;
    942                 lc[k] = lval; rc[k] = rval;
    943             }
    944             L -= weight;
    945             R += weight;
    946         }
    947 
    948         if( L > FLT_EPSILON && R > FLT_EPSILON )
    949         {
    950             double val = (lsum2*R + rsum2*L)/(L*R);
    951             if( best_val < val )
    952             {
    953                 best_val = val;
    954                 best_subset = subset_i;
    955             }
    956         }
    957     }
    958 
    959     WSplit split;
    960     if( best_subset >= 0 )
    961     {
    962         split.varIdx = vi;
    963         split.quality = (float)best_val;
    964         memset( subset, 0, getSubsetSize(vi) * sizeof(int) );
    965         if( m == 2 )
    966         {
    967             for( i = 0; i <= best_subset; i++ )
    968             {
    969                 idx = (int)(dbl_ptr[i] - cjk) >> 1;
    970                 subset[idx >> 5] |= 1 << (idx & 31);
    971             }
    972         }
    973         else
    974         {
    975             for( i = 0; i < _mi; i++ )
    976             {
    977                 idx = cluster_labels ? cluster_labels[i] : i;
    978                 if( best_subset & (1 << idx) )
    979                     subset[i >> 5] |= 1 << (i & 31);
    980             }
    981         }
    982     }
    983     return split;
    984 }
    985 
    986 DTreesImpl::WSplit DTreesImpl::findSplitOrdReg( int vi, const vector<int>& _sidx, double initQuality )
    987 {
    988     const float epsilon = FLT_EPSILON*2;
    989     const double* weights = &w->sample_weights[0];
    990     int n = (int)_sidx.size();
    991 
    992     AutoBuffer<uchar> buf(n*(sizeof(int) + sizeof(float)));
    993 
    994     float* values = (float*)(uchar*)buf;
    995     int* sorted_idx = (int*)(values + n);
    996     w->data->getValues(vi, _sidx, values);
    997     const double* responses = &w->ord_responses[0];
    998 
    999     int i, si, best_i = -1;
   1000     double L = 0, R = 0;
   1001     double best_val = initQuality, lsum = 0, rsum = 0;
   1002 
   1003     for( i = 0; i < n; i++ )
   1004     {
   1005         sorted_idx[i] = i;
   1006         si = _sidx[i];
   1007         R += weights[si];
   1008         rsum += weights[si]*responses[si];
   1009     }
   1010 
   1011     std::sort(sorted_idx, sorted_idx + n, cmp_lt_idx<float>(values));
   1012 
   1013     // find the optimal split
   1014     for( i = 0; i < n - 1; i++ )
   1015     {
   1016         int curr = sorted_idx[i];
   1017         int next = sorted_idx[i+1];
   1018         si = _sidx[curr];
   1019         double wval = weights[si];
   1020         double t = responses[si]*wval;
   1021         L += wval; R -= wval;
   1022         lsum += t; rsum -= t;
   1023 
   1024         if( values[curr] + epsilon < values[next] )
   1025         {
   1026             double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
   1027             if( best_val < val )
   1028             {
   1029                 best_val = val;
   1030                 best_i = i;
   1031             }
   1032         }
   1033     }
   1034 
   1035     WSplit split;
   1036     if( best_i >= 0 )
   1037     {
   1038         split.varIdx = vi;
   1039         split.c = (values[sorted_idx[best_i]] + values[sorted_idx[best_i+1]])*0.5f;
   1040         split.inversed = false;
   1041         split.quality = (float)best_val;
   1042     }
   1043     return split;
   1044 }
   1045 
   1046 DTreesImpl::WSplit DTreesImpl::findSplitCatReg( int vi, const vector<int>& _sidx,
   1047                                                 double initQuality, int* subset )
   1048 {
   1049     const double* weights = &w->sample_weights[0];
   1050     const double* responses = &w->ord_responses[0];
   1051     int n = (int)_sidx.size();
   1052     int mi = getCatCount(vi);
   1053 
   1054     AutoBuffer<double> buf(3*mi + 3 + n);
   1055     double* sum = (double*)buf + 1;
   1056     double* counts = sum + mi + 1;
   1057     double** sum_ptr = (double**)(counts + mi);
   1058     int* cat_labels = (int*)(sum_ptr + mi);
   1059 
   1060     w->data->getNormCatValues(vi, _sidx, cat_labels);
   1061 
   1062     double L = 0, R = 0, best_val = initQuality, lsum = 0, rsum = 0;
   1063     int i, si, best_subset = -1, subset_i;
   1064 
   1065     for( i = -1; i < mi; i++ )
   1066         sum[i] = counts[i] = 0;
   1067 
   1068     // calculate sum response and weight of each category of the input var
   1069     for( i = 0; i < n; i++ )
   1070     {
   1071         int idx = cat_labels[i];
   1072         si = _sidx[i];
   1073         double wval = weights[si];
   1074         sum[idx] += responses[si]*wval;
   1075         counts[idx] += wval;
   1076     }
   1077 
   1078     // calculate average response in each category
   1079     for( i = 0; i < mi; i++ )
   1080     {
   1081         R += counts[i];
   1082         rsum += sum[i];
   1083         sum[i] = fabs(counts[i]) > DBL_EPSILON ? sum[i]/counts[i] : 0;
   1084         sum_ptr[i] = sum + i;
   1085     }
   1086 
   1087     std::sort(sum_ptr, sum_ptr + mi, cmp_lt_ptr<double>());
   1088 
   1089     // revert back to unnormalized sums
   1090     // (there should be a very little loss in accuracy)
   1091     for( i = 0; i < mi; i++ )
   1092         sum[i] *= counts[i];
   1093 
   1094     for( subset_i = 0; subset_i < mi-1; subset_i++ )
   1095     {
   1096         int idx = (int)(sum_ptr[subset_i] - sum);
   1097         double ni = counts[idx];
   1098 
   1099         if( ni > FLT_EPSILON )
   1100         {
   1101             double s = sum[idx];
   1102             lsum += s; L += ni;
   1103             rsum -= s; R -= ni;
   1104 
   1105             if( L > FLT_EPSILON && R > FLT_EPSILON )
   1106             {
   1107                 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
   1108                 if( best_val < val )
   1109                 {
   1110                     best_val = val;
   1111                     best_subset = subset_i;
   1112                 }
   1113             }
   1114         }
   1115     }
   1116 
   1117     WSplit split;
   1118     if( best_subset >= 0 )
   1119     {
   1120         split.varIdx = vi;
   1121         split.quality = (float)best_val;
   1122         memset( subset, 0, getSubsetSize(vi) * sizeof(int));
   1123         for( i = 0; i <= best_subset; i++ )
   1124         {
   1125             int idx = (int)(sum_ptr[i] - sum);
   1126             subset[idx >> 5] |= 1 << (idx & 31);
   1127         }
   1128     }
   1129     return split;
   1130 }
   1131 
   1132 int DTreesImpl::calcDir( int splitidx, const vector<int>& _sidx,
   1133                          vector<int>& _sleft, vector<int>& _sright )
   1134 {
   1135     WSplit split = w->wsplits[splitidx];
   1136     int i, si, n = (int)_sidx.size(), vi = split.varIdx;
   1137     _sleft.reserve(n);
   1138     _sright.reserve(n);
   1139     _sleft.clear();
   1140     _sright.clear();
   1141 
   1142     AutoBuffer<float> buf(n);
   1143     int mi = getCatCount(vi);
   1144     double wleft = 0, wright = 0;
   1145     const double* weights = &w->sample_weights[0];
   1146 
   1147     if( mi <= 0 ) // split on an ordered variable
   1148     {
   1149         float c = split.c;
   1150         float* values = buf;
   1151         w->data->getValues(vi, _sidx, values);
   1152 
   1153         for( i = 0; i < n; i++ )
   1154         {
   1155             si = _sidx[i];
   1156             if( values[i] <= c )
   1157             {
   1158                 _sleft.push_back(si);
   1159                 wleft += weights[si];
   1160             }
   1161             else
   1162             {
   1163                 _sright.push_back(si);
   1164                 wright += weights[si];
   1165             }
   1166         }
   1167     }
   1168     else
   1169     {
   1170         const int* subset = &w->wsubsets[split.subsetOfs];
   1171         int* cat_labels = (int*)(float*)buf;
   1172         w->data->getNormCatValues(vi, _sidx, cat_labels);
   1173 
   1174         for( i = 0; i < n; i++ )
   1175         {
   1176             si = _sidx[i];
   1177             unsigned u = cat_labels[i];
   1178             if( CV_DTREE_CAT_DIR(u, subset) < 0 )
   1179             {
   1180                 _sleft.push_back(si);
   1181                 wleft += weights[si];
   1182             }
   1183             else
   1184             {
   1185                 _sright.push_back(si);
   1186                 wright += weights[si];
   1187             }
   1188         }
   1189     }
   1190     CV_Assert( (int)_sleft.size() < n && (int)_sright.size() < n );
   1191     return wleft > wright ? -1 : 1;
   1192 }
   1193 
   1194 int DTreesImpl::pruneCV( int root )
   1195 {
   1196     vector<double> ab;
   1197 
   1198     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
   1199     // 2. choose the best tree index (if need, apply 1SE rule).
   1200     // 3. store the best index and cut the branches.
   1201 
   1202     int ti, tree_count = 0, j, cv_n = params.getCVFolds(), n = w->wnodes[root].sample_count;
   1203     // currently, 1SE for regression is not implemented
   1204     bool use_1se = params.use1SERule != 0 && _isClassifier;
   1205     double min_err = 0, min_err_se = 0;
   1206     int min_idx = -1;
   1207 
   1208     // build the main tree sequence, calculate alpha's
   1209     for(;;tree_count++)
   1210     {
   1211         double min_alpha = updateTreeRNC(root, tree_count, -1);
   1212         if( cutTree(root, tree_count, -1, min_alpha) )
   1213             break;
   1214 
   1215         ab.push_back(min_alpha);
   1216     }
   1217 
   1218     if( tree_count > 0 )
   1219     {
   1220         ab[0] = 0.;
   1221 
   1222         for( ti = 1; ti < tree_count-1; ti++ )
   1223             ab[ti] = std::sqrt(ab[ti]*ab[ti+1]);
   1224         ab[tree_count-1] = DBL_MAX*0.5;
   1225 
   1226         Mat err_jk(cv_n, tree_count, CV_64F);
   1227 
   1228         for( j = 0; j < cv_n; j++ )
   1229         {
   1230             int tj = 0, tk = 0;
   1231             for( ; tj < tree_count; tj++ )
   1232             {
   1233                 double min_alpha = updateTreeRNC(root, tj, j);
   1234                 if( cutTree(root, tj, j, min_alpha) )
   1235                     min_alpha = DBL_MAX;
   1236 
   1237                 for( ; tk < tree_count; tk++ )
   1238                 {
   1239                     if( ab[tk] > min_alpha )
   1240                         break;
   1241                     err_jk.at<double>(j, tk) = w->wnodes[root].tree_error;
   1242                 }
   1243             }
   1244         }
   1245 
   1246         for( ti = 0; ti < tree_count; ti++ )
   1247         {
   1248             double sum_err = 0;
   1249             for( j = 0; j < cv_n; j++ )
   1250                 sum_err += err_jk.at<double>(j, ti);
   1251             if( ti == 0 || sum_err < min_err )
   1252             {
   1253                 min_err = sum_err;
   1254                 min_idx = ti;
   1255                 if( use_1se )
   1256                     min_err_se = sqrt( sum_err*(n - sum_err) );
   1257             }
   1258             else if( sum_err < min_err + min_err_se )
   1259                 min_idx = ti;
   1260         }
   1261     }
   1262 
   1263     return min_idx;
   1264 }
   1265 
   1266 double DTreesImpl::updateTreeRNC( int root, double T, int fold )
   1267 {
   1268     int nidx = root, pidx = -1, cv_n = params.getCVFolds();
   1269     double min_alpha = DBL_MAX;
   1270 
   1271     for(;;)
   1272     {
   1273         WNode *node = 0, *parent = 0;
   1274 
   1275         for(;;)
   1276         {
   1277             node = &w->wnodes[nidx];
   1278             double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
   1279             if( t <= T || node->left < 0 )
   1280             {
   1281                 node->complexity = 1;
   1282                 node->tree_risk = node->node_risk;
   1283                 node->tree_error = 0.;
   1284                 if( fold >= 0 )
   1285                 {
   1286                     node->tree_risk = w->cv_node_risk[nidx*cv_n + fold];
   1287                     node->tree_error = w->cv_node_error[nidx*cv_n + fold];
   1288                 }
   1289                 break;
   1290             }
   1291             nidx = node->left;
   1292         }
   1293 
   1294         for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
   1295              nidx = pidx, pidx = w->wnodes[pidx].parent )
   1296         {
   1297             node = &w->wnodes[nidx];
   1298             parent = &w->wnodes[pidx];
   1299             parent->complexity += node->complexity;
   1300             parent->tree_risk += node->tree_risk;
   1301             parent->tree_error += node->tree_error;
   1302 
   1303             parent->alpha = ((fold >= 0 ? w->cv_node_risk[pidx*cv_n + fold] : parent->node_risk)
   1304                              - parent->tree_risk)/(parent->complexity - 1);
   1305             min_alpha = std::min( min_alpha, parent->alpha );
   1306         }
   1307 
   1308         if( pidx < 0 )
   1309             break;
   1310 
   1311         node = &w->wnodes[nidx];
   1312         parent = &w->wnodes[pidx];
   1313         parent->complexity = node->complexity;
   1314         parent->tree_risk = node->tree_risk;
   1315         parent->tree_error = node->tree_error;
   1316         nidx = parent->right;
   1317     }
   1318 
   1319     return min_alpha;
   1320 }
   1321 
   1322 bool DTreesImpl::cutTree( int root, double T, int fold, double min_alpha )
   1323 {
   1324     int cv_n = params.getCVFolds(), nidx = root, pidx = -1;
   1325     WNode* node = &w->wnodes[root];
   1326     if( node->left < 0 )
   1327         return true;
   1328 
   1329     for(;;)
   1330     {
   1331         for(;;)
   1332         {
   1333             node = &w->wnodes[nidx];
   1334             double t = fold >= 0 ? w->cv_Tn[nidx*cv_n + fold] : node->Tn;
   1335             if( t <= T || node->left < 0 )
   1336                 break;
   1337             if( node->alpha <= min_alpha + FLT_EPSILON )
   1338             {
   1339                 if( fold >= 0 )
   1340                     w->cv_Tn[nidx*cv_n + fold] = T;
   1341                 else
   1342                     node->Tn = T;
   1343                 if( nidx == root )
   1344                     return true;
   1345                 break;
   1346             }
   1347             nidx = node->left;
   1348         }
   1349 
   1350         for( pidx = node->parent; pidx >= 0 && w->wnodes[pidx].right == nidx;
   1351              nidx = pidx, pidx = w->wnodes[pidx].parent )
   1352             ;
   1353 
   1354         if( pidx < 0 )
   1355             break;
   1356 
   1357         nidx = w->wnodes[pidx].right;
   1358     }
   1359 
   1360     return false;
   1361 }
   1362 
   1363 float DTreesImpl::predictTrees( const Range& range, const Mat& sample, int flags ) const
   1364 {
   1365     CV_Assert( sample.type() == CV_32F );
   1366 
   1367     int predictType = flags & PREDICT_MASK;
   1368     int nvars = (int)varIdx.size();
   1369     if( nvars == 0 )
   1370         nvars = (int)varType.size();
   1371     int i, ncats = (int)catOfs.size(), nclasses = (int)classLabels.size();
   1372     int catbufsize = ncats > 0 ? nvars : 0;
   1373     AutoBuffer<int> buf(nclasses + catbufsize + 1);
   1374     int* votes = buf;
   1375     int* catbuf = votes + nclasses;
   1376     const int* cvidx = (flags & (COMPRESSED_INPUT|PREPROCESSED_INPUT)) == 0 && !varIdx.empty() ? &compVarIdx[0] : 0;
   1377     const uchar* vtype = &varType[0];
   1378     const Vec2i* cofs = !catOfs.empty() ? &catOfs[0] : 0;
   1379     const int* cmap = !catMap.empty() ? &catMap[0] : 0;
   1380     const float* psample = sample.ptr<float>();
   1381     const float* missingSubstPtr = !missingSubst.empty() ? &missingSubst[0] : 0;
   1382     size_t sstep = sample.isContinuous() ? 1 : sample.step/sizeof(float);
   1383     double sum = 0.;
   1384     int lastClassIdx = -1;
   1385     const float MISSED_VAL = TrainData::missingValue();
   1386 
   1387     for( i = 0; i < catbufsize; i++ )
   1388         catbuf[i] = -1;
   1389 
   1390     if( predictType == PREDICT_AUTO )
   1391     {
   1392         predictType = !_isClassifier || (classLabels.size() == 2 && (flags & RAW_OUTPUT) != 0) ?
   1393             PREDICT_SUM : PREDICT_MAX_VOTE;
   1394     }
   1395 
   1396     if( predictType == PREDICT_MAX_VOTE )
   1397     {
   1398         for( i = 0; i < nclasses; i++ )
   1399             votes[i] = 0;
   1400     }
   1401 
   1402     for( int ridx = range.start; ridx < range.end; ridx++ )
   1403     {
   1404         int nidx = roots[ridx], prev = nidx, c = 0;
   1405 
   1406         for(;;)
   1407         {
   1408             prev = nidx;
   1409             const Node& node = nodes[nidx];
   1410             if( node.split < 0 )
   1411                 break;
   1412             const Split& split = splits[node.split];
   1413             int vi = split.varIdx;
   1414             int ci = cvidx ? cvidx[vi] : vi;
   1415             float val = psample[ci*sstep];
   1416             if( val == MISSED_VAL )
   1417             {
   1418                 if( !missingSubstPtr )
   1419                 {
   1420                     nidx = node.defaultDir < 0 ? node.left : node.right;
   1421                     continue;
   1422                 }
   1423                 val = missingSubstPtr[vi];
   1424             }
   1425 
   1426             if( vtype[vi] == VAR_ORDERED )
   1427                 nidx = val <= split.c ? node.left : node.right;
   1428             else
   1429             {
   1430                 if( flags & PREPROCESSED_INPUT )
   1431                     c = cvRound(val);
   1432                 else
   1433                 {
   1434                     c = catbuf[ci];
   1435                     if( c < 0 )
   1436                     {
   1437                         int a = c = cofs[vi][0];
   1438                         int b = cofs[vi][1];
   1439 
   1440                         int ival = cvRound(val);
   1441                         if( ival != val )
   1442                             CV_Error( CV_StsBadArg,
   1443                                      "one of input categorical variable is not an integer" );
   1444 
   1445                         while( a < b )
   1446                         {
   1447                             c = (a + b) >> 1;
   1448                             if( ival < cmap[c] )
   1449                                 b = c;
   1450                             else if( ival > cmap[c] )
   1451                                 a = c+1;
   1452                             else
   1453                                 break;
   1454                         }
   1455 
   1456                         CV_Assert( c >= 0 && ival == cmap[c] );
   1457 
   1458                         c -= cofs[vi][0];
   1459                         catbuf[ci] = c;
   1460                     }
   1461                     const int* subset = &subsets[split.subsetOfs];
   1462                     unsigned u = c;
   1463                     nidx = CV_DTREE_CAT_DIR(u, subset) < 0 ? node.left : node.right;
   1464                 }
   1465             }
   1466         }
   1467 
   1468         if( predictType == PREDICT_SUM )
   1469             sum += nodes[prev].value;
   1470         else
   1471         {
   1472             lastClassIdx = nodes[prev].classIdx;
   1473             votes[lastClassIdx]++;
   1474         }
   1475     }
   1476 
   1477     if( predictType == PREDICT_MAX_VOTE )
   1478     {
   1479         int best_idx = lastClassIdx;
   1480         if( range.end - range.start > 1 )
   1481         {
   1482             best_idx = 0;
   1483             for( i = 1; i < nclasses; i++ )
   1484                 if( votes[best_idx] < votes[i] )
   1485                     best_idx = i;
   1486         }
   1487         sum = (flags & RAW_OUTPUT) ? (float)best_idx : classLabels[best_idx];
   1488     }
   1489 
   1490     return (float)sum;
   1491 }
   1492 
   1493 
   1494 float DTreesImpl::predict( InputArray _samples, OutputArray _results, int flags ) const
   1495 {
   1496     CV_Assert( !roots.empty() );
   1497     Mat samples = _samples.getMat(), results;
   1498     int i, nsamples = samples.rows;
   1499     int rtype = CV_32F;
   1500     bool needresults = _results.needed();
   1501     float retval = 0.f;
   1502     bool iscls = isClassifier();
   1503     float scale = !iscls ? 1.f/(int)roots.size() : 1.f;
   1504 
   1505     if( iscls && (flags & PREDICT_MASK) == PREDICT_MAX_VOTE )
   1506         rtype = CV_32S;
   1507 
   1508     if( needresults )
   1509     {
   1510         _results.create(nsamples, 1, rtype);
   1511         results = _results.getMat();
   1512     }
   1513     else
   1514         nsamples = std::min(nsamples, 1);
   1515 
   1516     for( i = 0; i < nsamples; i++ )
   1517     {
   1518         float val = predictTrees( Range(0, (int)roots.size()), samples.row(i), flags )*scale;
   1519         if( needresults )
   1520         {
   1521             if( rtype == CV_32F )
   1522                 results.at<float>(i) = val;
   1523             else
   1524                 results.at<int>(i) = cvRound(val);
   1525         }
   1526         if( i == 0 )
   1527             retval = val;
   1528     }
   1529     return retval;
   1530 }
   1531 
   1532 void DTreesImpl::writeTrainingParams(FileStorage& fs) const
   1533 {
   1534     fs << "use_surrogates" << (params.useSurrogates ? 1 : 0);
   1535     fs << "max_categories" << params.getMaxCategories();
   1536     fs << "regression_accuracy" << params.getRegressionAccuracy();
   1537 
   1538     fs << "max_depth" << params.getMaxDepth();
   1539     fs << "min_sample_count" << params.getMinSampleCount();
   1540     fs << "cross_validation_folds" << params.getCVFolds();
   1541 
   1542     if( params.getCVFolds() > 1 )
   1543         fs << "use_1se_rule" << (params.use1SERule ? 1 : 0);
   1544 
   1545     if( !params.priors.empty() )
   1546         fs << "priors" << params.priors;
   1547 }
   1548 
   1549 void DTreesImpl::writeParams(FileStorage& fs) const
   1550 {
   1551     fs << "is_classifier" << isClassifier();
   1552     fs << "var_all" << (int)varType.size();
   1553     fs << "var_count" << getVarCount();
   1554 
   1555     int ord_var_count = 0, cat_var_count = 0;
   1556     int i, n = (int)varType.size();
   1557     for( i = 0; i < n; i++ )
   1558         if( varType[i] == VAR_ORDERED )
   1559             ord_var_count++;
   1560         else
   1561             cat_var_count++;
   1562     fs << "ord_var_count" << ord_var_count;
   1563     fs << "cat_var_count" << cat_var_count;
   1564 
   1565     fs << "training_params" << "{";
   1566     writeTrainingParams(fs);
   1567 
   1568     fs << "}";
   1569 
   1570     if( !varIdx.empty() )
   1571     {
   1572         fs << "global_var_idx" << 1;
   1573         fs << "var_idx" << varIdx;
   1574     }
   1575 
   1576     fs << "var_type" << varType;
   1577 
   1578     if( !catOfs.empty() )
   1579         fs << "cat_ofs" << catOfs;
   1580     if( !catMap.empty() )
   1581         fs << "cat_map" << catMap;
   1582     if( !classLabels.empty() )
   1583         fs << "class_labels" << classLabels;
   1584     if( !missingSubst.empty() )
   1585         fs << "missing_subst" << missingSubst;
   1586 }
   1587 
   1588 void DTreesImpl::writeSplit( FileStorage& fs, int splitidx ) const
   1589 {
   1590     const Split& split = splits[splitidx];
   1591 
   1592     fs << "{:";
   1593 
   1594     int vi = split.varIdx;
   1595     fs << "var" << vi;
   1596     fs << "quality" << split.quality;
   1597 
   1598     if( varType[vi] == VAR_CATEGORICAL ) // split on a categorical var
   1599     {
   1600         int i, n = getCatCount(vi), to_right = 0;
   1601         const int* subset = &subsets[split.subsetOfs];
   1602         for( i = 0; i < n; i++ )
   1603             to_right += CV_DTREE_CAT_DIR(i, subset) > 0;
   1604 
   1605         // ad-hoc rule when to use inverse categorical split notation
   1606         // to achieve more compact and clear representation
   1607         int default_dir = to_right <= 1 || to_right <= std::min(3, n/2) || to_right <= n/3 ? -1 : 1;
   1608 
   1609         fs << (default_dir*(split.inversed ? -1 : 1) > 0 ? "in" : "not_in") << "[:";
   1610 
   1611         for( i = 0; i < n; i++ )
   1612         {
   1613             int dir = CV_DTREE_CAT_DIR(i, subset);
   1614             if( dir*default_dir < 0 )
   1615                 fs << i;
   1616         }
   1617 
   1618         fs << "]";
   1619     }
   1620     else
   1621         fs << (!split.inversed ? "le" : "gt") << split.c;
   1622 
   1623     fs << "}";
   1624 }
   1625 
   1626 void DTreesImpl::writeNode( FileStorage& fs, int nidx, int depth ) const
   1627 {
   1628     const Node& node = nodes[nidx];
   1629     fs << "{";
   1630     fs << "depth" << depth;
   1631     fs << "value" << node.value;
   1632 
   1633     if( _isClassifier )
   1634         fs << "norm_class_idx" << node.classIdx;
   1635 
   1636     if( node.split >= 0 )
   1637     {
   1638         fs << "splits" << "[";
   1639 
   1640         for( int splitidx = node.split; splitidx >= 0; splitidx = splits[splitidx].next )
   1641             writeSplit( fs, splitidx );
   1642 
   1643         fs << "]";
   1644     }
   1645 
   1646     fs << "}";
   1647 }
   1648 
   1649 void DTreesImpl::writeTree( FileStorage& fs, int root ) const
   1650 {
   1651     fs << "nodes" << "[";
   1652 
   1653     int nidx = root, pidx = 0, depth = 0;
   1654     const Node *node = 0;
   1655 
   1656     // traverse the tree and save all the nodes in depth-first order
   1657     for(;;)
   1658     {
   1659         for(;;)
   1660         {
   1661             writeNode( fs, nidx, depth );
   1662             node = &nodes[nidx];
   1663             if( node->left < 0 )
   1664                 break;
   1665             nidx = node->left;
   1666             depth++;
   1667         }
   1668 
   1669         for( pidx = node->parent; pidx >= 0 && nodes[pidx].right == nidx;
   1670              nidx = pidx, pidx = nodes[pidx].parent )
   1671             depth--;
   1672 
   1673         if( pidx < 0 )
   1674             break;
   1675 
   1676         nidx = nodes[pidx].right;
   1677     }
   1678 
   1679     fs << "]";
   1680 }
   1681 
   1682 void DTreesImpl::write( FileStorage& fs ) const
   1683 {
   1684     writeParams(fs);
   1685     writeTree(fs, roots[0]);
   1686 }
   1687 
   1688 void DTreesImpl::readParams( const FileNode& fn )
   1689 {
   1690     _isClassifier = (int)fn["is_classifier"] != 0;
   1691     /*int var_all = (int)fn["var_all"];
   1692     int var_count = (int)fn["var_count"];
   1693     int cat_var_count = (int)fn["cat_var_count"];
   1694     int ord_var_count = (int)fn["ord_var_count"];*/
   1695 
   1696     FileNode tparams_node = fn["training_params"];
   1697 
   1698     TreeParams params0 = TreeParams();
   1699 
   1700     if( !tparams_node.empty() ) // training parameters are not necessary
   1701     {
   1702         params0.useSurrogates = (int)tparams_node["use_surrogates"] != 0;
   1703         params0.setMaxCategories((int)(tparams_node["max_categories"].empty() ? 16 : tparams_node["max_categories"]));
   1704         params0.setRegressionAccuracy((float)tparams_node["regression_accuracy"]);
   1705         params0.setMaxDepth((int)tparams_node["max_depth"]);
   1706         params0.setMinSampleCount((int)tparams_node["min_sample_count"]);
   1707         params0.setCVFolds((int)tparams_node["cross_validation_folds"]);
   1708 
   1709         if( params0.getCVFolds() > 1 )
   1710         {
   1711             params.use1SERule = (int)tparams_node["use_1se_rule"] != 0;
   1712         }
   1713 
   1714         tparams_node["priors"] >> params0.priors;
   1715     }
   1716 
   1717     readVectorOrMat(fn["var_idx"], varIdx);
   1718     fn["var_type"] >> varType;
   1719 
   1720     int format = 0;
   1721     fn["format"] >> format;
   1722     bool isLegacy = format < 3;
   1723 
   1724     int varAll = (int)fn["var_all"];
   1725     if (isLegacy && (int)varType.size() <= varAll)
   1726     {
   1727         std::vector<uchar> extendedTypes(varAll + 1, 0);
   1728 
   1729         int i = 0, n;
   1730         if (!varIdx.empty())
   1731         {
   1732             n = (int)varIdx.size();
   1733             for (; i < n; ++i)
   1734             {
   1735                 int var = varIdx[i];
   1736                 extendedTypes[var] = varType[i];
   1737             }
   1738         }
   1739         else
   1740         {
   1741             n = (int)varType.size();
   1742             for (; i < n; ++i)
   1743             {
   1744                 extendedTypes[i] = varType[i];
   1745             }
   1746         }
   1747         extendedTypes[varAll] = (uchar)(_isClassifier ? VAR_CATEGORICAL : VAR_ORDERED);
   1748         extendedTypes.swap(varType);
   1749     }
   1750 
   1751     readVectorOrMat(fn["cat_map"], catMap);
   1752 
   1753     if (isLegacy)
   1754     {
   1755         // generating "catOfs" from "cat_count"
   1756         catOfs.clear();
   1757         classLabels.clear();
   1758         std::vector<int> counts;
   1759         readVectorOrMat(fn["cat_count"], counts);
   1760         unsigned int i = 0, j = 0, curShift = 0, size = (int)varType.size() - 1;
   1761         for (; i < size; ++i)
   1762         {
   1763             Vec2i newOffsets(0, 0);
   1764             if (varType[i] == VAR_CATEGORICAL) // only categorical vars are represented in catMap
   1765             {
   1766                 newOffsets[0] = curShift;
   1767                 curShift += counts[j];
   1768                 newOffsets[1] = curShift;
   1769                 ++j;
   1770             }
   1771             catOfs.push_back(newOffsets);
   1772         }
   1773         // other elements in "catMap" are "classLabels"
   1774         if (curShift < catMap.size())
   1775         {
   1776             classLabels.insert(classLabels.end(), catMap.begin() + curShift, catMap.end());
   1777             catMap.erase(catMap.begin() + curShift, catMap.end());
   1778         }
   1779     }
   1780     else
   1781     {
   1782         fn["cat_ofs"] >> catOfs;
   1783         fn["missing_subst"] >> missingSubst;
   1784         fn["class_labels"] >> classLabels;
   1785     }
   1786 
   1787     // init var mapping for node reading (var indexes or varIdx indexes)
   1788     bool globalVarIdx = false;
   1789     fn["global_var_idx"] >> globalVarIdx;
   1790     if (globalVarIdx || varIdx.empty())
   1791         setRangeVector(varMapping, (int)varType.size());
   1792     else
   1793         varMapping = varIdx;
   1794 
   1795     initCompVarIdx();
   1796     setDParams(params0);
   1797 }
   1798 
   1799 int DTreesImpl::readSplit( const FileNode& fn )
   1800 {
   1801     Split split;
   1802 
   1803     int vi = (int)fn["var"];
   1804     CV_Assert( 0 <= vi && vi <= (int)varType.size() );
   1805     vi = varMapping[vi]; // convert to varIdx if needed
   1806     split.varIdx = vi;
   1807 
   1808     if( varType[vi] == VAR_CATEGORICAL ) // split on categorical var
   1809     {
   1810         int i, val, ssize = getSubsetSize(vi);
   1811         split.subsetOfs = (int)subsets.size();
   1812         for( i = 0; i < ssize; i++ )
   1813             subsets.push_back(0);
   1814         int* subset = &subsets[split.subsetOfs];
   1815         FileNode fns = fn["in"];
   1816         if( fns.empty() )
   1817         {
   1818             fns = fn["not_in"];
   1819             split.inversed = true;
   1820         }
   1821 
   1822         if( fns.isInt() )
   1823         {
   1824             val = (int)fns;
   1825             subset[val >> 5] |= 1 << (val & 31);
   1826         }
   1827         else
   1828         {
   1829             FileNodeIterator it = fns.begin();
   1830             int n = (int)fns.size();
   1831             for( i = 0; i < n; i++, ++it )
   1832             {
   1833                 val = (int)*it;
   1834                 subset[val >> 5] |= 1 << (val & 31);
   1835             }
   1836         }
   1837 
   1838         // for categorical splits we do not use inversed splits,
   1839         // instead we inverse the variable set in the split
   1840         if( split.inversed )
   1841         {
   1842             for( i = 0; i < ssize; i++ )
   1843                 subset[i] ^= -1;
   1844             split.inversed = false;
   1845         }
   1846     }
   1847     else
   1848     {
   1849         FileNode cmpNode = fn["le"];
   1850         if( cmpNode.empty() )
   1851         {
   1852             cmpNode = fn["gt"];
   1853             split.inversed = true;
   1854         }
   1855         split.c = (float)cmpNode;
   1856     }
   1857 
   1858     split.quality = (float)fn["quality"];
   1859     splits.push_back(split);
   1860 
   1861     return (int)(splits.size() - 1);
   1862 }
   1863 
   1864 int DTreesImpl::readNode( const FileNode& fn )
   1865 {
   1866     Node node;
   1867     node.value = (double)fn["value"];
   1868 
   1869     if( _isClassifier )
   1870         node.classIdx = (int)fn["norm_class_idx"];
   1871 
   1872     FileNode sfn = fn["splits"];
   1873     if( !sfn.empty() )
   1874     {
   1875         int i, n = (int)sfn.size(), prevsplit = -1;
   1876         FileNodeIterator it = sfn.begin();
   1877 
   1878         for( i = 0; i < n; i++, ++it )
   1879         {
   1880             int splitidx = readSplit(*it);
   1881             if( splitidx < 0 )
   1882                 break;
   1883             if( prevsplit < 0 )
   1884                 node.split = splitidx;
   1885             else
   1886                 splits[prevsplit].next = splitidx;
   1887             prevsplit = splitidx;
   1888         }
   1889     }
   1890     nodes.push_back(node);
   1891     return (int)(nodes.size() - 1);
   1892 }
   1893 
   1894 int DTreesImpl::readTree( const FileNode& fn )
   1895 {
   1896     int i, n = (int)fn.size(), root = -1, pidx = -1;
   1897     FileNodeIterator it = fn.begin();
   1898 
   1899     for( i = 0; i < n; i++, ++it )
   1900     {
   1901         int nidx = readNode(*it);
   1902         if( nidx < 0 )
   1903             break;
   1904         Node& node = nodes[nidx];
   1905         node.parent = pidx;
   1906         if( pidx < 0 )
   1907             root = nidx;
   1908         else
   1909         {
   1910             Node& parent = nodes[pidx];
   1911             if( parent.left < 0 )
   1912                 parent.left = nidx;
   1913             else
   1914                 parent.right = nidx;
   1915         }
   1916         if( node.split >= 0 )
   1917             pidx = nidx;
   1918         else
   1919         {
   1920             while( pidx >= 0 && nodes[pidx].right >= 0 )
   1921                 pidx = nodes[pidx].parent;
   1922         }
   1923     }
   1924     roots.push_back(root);
   1925     return root;
   1926 }
   1927 
   1928 void DTreesImpl::read( const FileNode& fn )
   1929 {
   1930     clear();
   1931     readParams(fn);
   1932 
   1933     FileNode fnodes = fn["nodes"];
   1934     CV_Assert( !fnodes.empty() );
   1935     readTree(fnodes);
   1936 }
   1937 
   1938 Ptr<DTrees> DTrees::create()
   1939 {
   1940     return makePtr<DTreesImpl>();
   1941 }
   1942 
   1943 }
   1944 }
   1945 
   1946 /* End of file. */
   1947