Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "_ml.h"
     42 
     43 static inline double
     44 log_ratio( double val )
     45 {
     46     const double eps = 1e-5;
     47 
     48     val = MAX( val, eps );
     49     val = MIN( val, 1. - eps );
     50     return log( val/(1. - val) );
     51 }
     52 
     53 
     54 CvBoostParams::CvBoostParams()
     55 {
     56     boost_type = CvBoost::REAL;
     57     weak_count = 100;
     58     weight_trim_rate = 0.95;
     59     cv_folds = 0;
     60     max_depth = 1;
     61 }
     62 
     63 
     64 CvBoostParams::CvBoostParams( int _boost_type, int _weak_count,
     65                                         double _weight_trim_rate, int _max_depth,
     66                                         bool _use_surrogates, const float* _priors )
     67 {
     68     boost_type = _boost_type;
     69     weak_count = _weak_count;
     70     weight_trim_rate = _weight_trim_rate;
     71     split_criteria = CvBoost::DEFAULT;
     72     cv_folds = 0;
     73     max_depth = _max_depth;
     74     use_surrogates = _use_surrogates;
     75     priors = _priors;
     76 }
     77 
     78 
     79 
     80 ///////////////////////////////// CvBoostTree ///////////////////////////////////
     81 
     82 CvBoostTree::CvBoostTree()
     83 {
     84     ensemble = 0;
     85 }
     86 
     87 
     88 CvBoostTree::~CvBoostTree()
     89 {
     90     clear();
     91 }
     92 
     93 
     94 void
     95 CvBoostTree::clear()
     96 {
     97     CvDTree::clear();
     98     ensemble = 0;
     99 }
    100 
    101 
    102 bool
    103 CvBoostTree::train( CvDTreeTrainData* _train_data,
    104                     const CvMat* _subsample_idx, CvBoost* _ensemble )
    105 {
    106     clear();
    107     ensemble = _ensemble;
    108     data = _train_data;
    109     data->shared = true;
    110 
    111     return do_train( _subsample_idx );
    112 }
    113 
    114 
    115 bool
    116 CvBoostTree::train( const CvMat*, int, const CvMat*, const CvMat*,
    117                     const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
    118 {
    119     assert(0);
    120     return false;
    121 }
    122 
    123 
    124 bool
    125 CvBoostTree::train( CvDTreeTrainData*, const CvMat* )
    126 {
    127     assert(0);
    128     return false;
    129 }
    130 
    131 
    132 void
    133 CvBoostTree::scale( double scale )
    134 {
    135     CvDTreeNode* node = root;
    136 
    137     // traverse the tree and scale all the node values
    138     for(;;)
    139     {
    140         CvDTreeNode* parent;
    141         for(;;)
    142         {
    143             node->value *= scale;
    144             if( !node->left )
    145                 break;
    146             node = node->left;
    147         }
    148 
    149         for( parent = node->parent; parent && parent->right == node;
    150             node = parent, parent = parent->parent )
    151             ;
    152 
    153         if( !parent )
    154             break;
    155 
    156         node = parent->right;
    157     }
    158 }
    159 
    160 
    161 void
    162 CvBoostTree::try_split_node( CvDTreeNode* node )
    163 {
    164     CvDTree::try_split_node( node );
    165 
    166     if( !node->left )
    167     {
    168         // if the node has not been split,
    169         // store the responses for the corresponding training samples
    170         double* weak_eval = ensemble->get_weak_response()->data.db;
    171         int* labels = data->get_labels( node );
    172         int i, count = node->sample_count;
    173         double value = node->value;
    174 
    175         for( i = 0; i < count; i++ )
    176             weak_eval[labels[i]] = value;
    177     }
    178 }
    179 
    180 
    181 double
    182 CvBoostTree::calc_node_dir( CvDTreeNode* node )
    183 {
    184     char* dir = (char*)data->direction->data.ptr;
    185     const double* weights = ensemble->get_subtree_weights()->data.db;
    186     int i, n = node->sample_count, vi = node->split->var_idx;
    187     double L, R;
    188 
    189     assert( !node->split->inversed );
    190 
    191     if( data->get_var_type(vi) >= 0 ) // split on categorical var
    192     {
    193         const int* cat_labels = data->get_cat_var_data( node, vi );
    194         const int* subset = node->split->subset;
    195         double sum = 0, sum_abs = 0;
    196 
    197         for( i = 0; i < n; i++ )
    198         {
    199             int idx = cat_labels[i];
    200             double w = weights[i];
    201             int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
    202             sum += d*w; sum_abs += (d & 1)*w;
    203             dir[i] = (char)d;
    204         }
    205 
    206         R = (sum_abs + sum) * 0.5;
    207         L = (sum_abs - sum) * 0.5;
    208     }
    209     else // split on ordered var
    210     {
    211         const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
    212         int split_point = node->split->ord.split_point;
    213         int n1 = node->get_num_valid(vi);
    214 
    215         assert( 0 <= split_point && split_point < n1-1 );
    216         L = R = 0;
    217 
    218         for( i = 0; i <= split_point; i++ )
    219         {
    220             int idx = sorted[i].i;
    221             double w = weights[idx];
    222             dir[idx] = (char)-1;
    223             L += w;
    224         }
    225 
    226         for( ; i < n1; i++ )
    227         {
    228             int idx = sorted[i].i;
    229             double w = weights[idx];
    230             dir[idx] = (char)1;
    231             R += w;
    232         }
    233 
    234         for( ; i < n; i++ )
    235             dir[sorted[i].i] = (char)0;
    236     }
    237 
    238     node->maxlr = MAX( L, R );
    239     return node->split->quality/(L + R);
    240 }
    241 
    242 
    243 CvDTreeSplit*
    244 CvBoostTree::find_split_ord_class( CvDTreeNode* node, int vi )
    245 {
    246     const float epsilon = FLT_EPSILON*2;
    247     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
    248     const int* responses = data->get_class_labels(node);
    249     const double* weights = ensemble->get_subtree_weights()->data.db;
    250     int n = node->sample_count;
    251     int n1 = node->get_num_valid(vi);
    252     const double* rcw0 = weights + n;
    253     double lcw[2] = {0,0}, rcw[2];
    254     int i, best_i = -1;
    255     double best_val = 0;
    256     int boost_type = ensemble->get_params().boost_type;
    257     int split_criteria = ensemble->get_params().split_criteria;
    258 
    259     rcw[0] = rcw0[0]; rcw[1] = rcw0[1];
    260     for( i = n1; i < n; i++ )
    261     {
    262         int idx = sorted[i].i;
    263         double w = weights[idx];
    264         rcw[responses[idx]] -= w;
    265     }
    266 
    267     if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
    268         split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
    269 
    270     if( split_criteria == CvBoost::GINI )
    271     {
    272         double L = 0, R = rcw[0] + rcw[1];
    273         double lsum2 = 0, rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
    274 
    275         for( i = 0; i < n1 - 1; i++ )
    276         {
    277             int idx = sorted[i].i;
    278             double w = weights[idx], w2 = w*w;
    279             double lv, rv;
    280             idx = responses[idx];
    281             L += w; R -= w;
    282             lv = lcw[idx]; rv = rcw[idx];
    283             lsum2 += 2*lv*w + w2;
    284             rsum2 -= 2*rv*w - w2;
    285             lcw[idx] = lv + w; rcw[idx] = rv - w;
    286 
    287             if( sorted[i].val + epsilon < sorted[i+1].val )
    288             {
    289                 double val = (lsum2*R + rsum2*L)/(L*R);
    290                 if( best_val < val )
    291                 {
    292                     best_val = val;
    293                     best_i = i;
    294                 }
    295             }
    296         }
    297     }
    298     else
    299     {
    300         for( i = 0; i < n1 - 1; i++ )
    301         {
    302             int idx = sorted[i].i;
    303             double w = weights[idx];
    304             idx = responses[idx];
    305             lcw[idx] += w;
    306             rcw[idx] -= w;
    307 
    308             if( sorted[i].val + epsilon < sorted[i+1].val )
    309             {
    310                 double val = lcw[0] + rcw[1], val2 = lcw[1] + rcw[0];
    311                 val = MAX(val, val2);
    312                 if( best_val < val )
    313                 {
    314                     best_val = val;
    315                     best_i = i;
    316                 }
    317             }
    318         }
    319     }
    320 
    321     return best_i >= 0 ? data->new_split_ord( vi,
    322         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
    323         0, (float)best_val ) : 0;
    324 }
    325 
    326 
    327 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
    328 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
    329 
    330 CvDTreeSplit*
    331 CvBoostTree::find_split_cat_class( CvDTreeNode* node, int vi )
    332 {
    333     CvDTreeSplit* split;
    334     const int* cat_labels = data->get_cat_var_data(node, vi);
    335     const int* responses = data->get_class_labels(node);
    336     int ci = data->get_var_type(vi);
    337     int n = node->sample_count;
    338     int mi = data->cat_count->data.i[ci];
    339     double lcw[2]={0,0}, rcw[2]={0,0};
    340     double* cjk = (double*)cvStackAlloc(2*(mi+1)*sizeof(cjk[0]))+2;
    341     const double* weights = ensemble->get_subtree_weights()->data.db;
    342     double** dbl_ptr = (double**)cvStackAlloc( mi*sizeof(dbl_ptr[0]) );
    343     int i, j, k, idx;
    344     double L = 0, R;
    345     double best_val = 0;
    346     int best_subset = -1, subset_i;
    347     int boost_type = ensemble->get_params().boost_type;
    348     int split_criteria = ensemble->get_params().split_criteria;
    349 
    350     // init array of counters:
    351     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
    352     for( j = -1; j < mi; j++ )
    353         cjk[j*2] = cjk[j*2+1] = 0;
    354 
    355     for( i = 0; i < n; i++ )
    356     {
    357         double w = weights[i];
    358         j = cat_labels[i];
    359         k = responses[i];
    360         cjk[j*2 + k] += w;
    361     }
    362 
    363     for( j = 0; j < mi; j++ )
    364     {
    365         rcw[0] += cjk[j*2];
    366         rcw[1] += cjk[j*2+1];
    367         dbl_ptr[j] = cjk + j*2 + 1;
    368     }
    369 
    370     R = rcw[0] + rcw[1];
    371 
    372     if( split_criteria != CvBoost::GINI && split_criteria != CvBoost::MISCLASS )
    373         split_criteria = boost_type == CvBoost::DISCRETE ? CvBoost::MISCLASS : CvBoost::GINI;
    374 
    375     // sort rows of c_jk by increasing c_j,1
    376     // (i.e. by the weight of samples in j-th category that belong to class 1)
    377     icvSortDblPtr( dbl_ptr, mi, 0 );
    378 
    379     for( subset_i = 0; subset_i < mi-1; subset_i++ )
    380     {
    381         idx = (int)(dbl_ptr[subset_i] - cjk)/2;
    382         const double* crow = cjk + idx*2;
    383         double w0 = crow[0], w1 = crow[1];
    384         double weight = w0 + w1;
    385 
    386         if( weight < FLT_EPSILON )
    387             continue;
    388 
    389         lcw[0] += w0; rcw[0] -= w0;
    390         lcw[1] += w1; rcw[1] -= w1;
    391 
    392         if( split_criteria == CvBoost::GINI )
    393         {
    394             double lsum2 = lcw[0]*lcw[0] + lcw[1]*lcw[1];
    395             double rsum2 = rcw[0]*rcw[0] + rcw[1]*rcw[1];
    396 
    397             L += weight;
    398             R -= weight;
    399 
    400             if( L > FLT_EPSILON && R > FLT_EPSILON )
    401             {
    402                 double val = (lsum2*R + rsum2*L)/(L*R);
    403                 if( best_val < val )
    404                 {
    405                     best_val = val;
    406                     best_subset = subset_i;
    407                 }
    408             }
    409         }
    410         else
    411         {
    412             double val = lcw[0] + rcw[1];
    413             double val2 = lcw[1] + rcw[0];
    414 
    415             val = MAX(val, val2);
    416             if( best_val < val )
    417             {
    418                 best_val = val;
    419                 best_subset = subset_i;
    420             }
    421         }
    422     }
    423 
    424     if( best_subset < 0 )
    425         return 0;
    426 
    427     split = data->new_split_cat( vi, (float)best_val );
    428 
    429     for( i = 0; i <= best_subset; i++ )
    430     {
    431         idx = (int)(dbl_ptr[i] - cjk) >> 1;
    432         split->subset[idx >> 5] |= 1 << (idx & 31);
    433     }
    434 
    435     return split;
    436 }
    437 
    438 
    439 CvDTreeSplit*
    440 CvBoostTree::find_split_ord_reg( CvDTreeNode* node, int vi )
    441 {
    442     const float epsilon = FLT_EPSILON*2;
    443     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
    444     const float* responses = data->get_ord_responses(node);
    445     const double* weights = ensemble->get_subtree_weights()->data.db;
    446     int n = node->sample_count;
    447     int n1 = node->get_num_valid(vi);
    448     int i, best_i = -1;
    449     double best_val = 0, lsum = 0, rsum = node->value*n;
    450     double L = 0, R = weights[n];
    451 
    452     // compensate for missing values
    453     for( i = n1; i < n; i++ )
    454     {
    455         int idx = sorted[i].i;
    456         double w = weights[idx];
    457         rsum -= responses[idx]*w;
    458         R -= w;
    459     }
    460 
    461     // find the optimal split
    462     for( i = 0; i < n1 - 1; i++ )
    463     {
    464         int idx = sorted[i].i;
    465         double w = weights[idx];
    466         double t = responses[idx]*w;
    467         L += w; R -= w;
    468         lsum += t; rsum -= t;
    469 
    470         if( sorted[i].val + epsilon < sorted[i+1].val )
    471         {
    472             double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
    473             if( best_val < val )
    474             {
    475                 best_val = val;
    476                 best_i = i;
    477             }
    478         }
    479     }
    480 
    481     return best_i >= 0 ? data->new_split_ord( vi,
    482         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
    483         0, (float)best_val ) : 0;
    484 }
    485 
    486 
    487 CvDTreeSplit*
    488 CvBoostTree::find_split_cat_reg( CvDTreeNode* node, int vi )
    489 {
    490     CvDTreeSplit* split;
    491     const int* cat_labels = data->get_cat_var_data(node, vi);
    492     const float* responses = data->get_ord_responses(node);
    493     const double* weights = ensemble->get_subtree_weights()->data.db;
    494     int ci = data->get_var_type(vi);
    495     int n = node->sample_count;
    496     int mi = data->cat_count->data.i[ci];
    497     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
    498     double* counts = (double*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
    499     double** sum_ptr = (double**)cvStackAlloc( mi*sizeof(sum_ptr[0]) );
    500     double L = 0, R = 0, best_val = 0, lsum = 0, rsum = 0;
    501     int i, best_subset = -1, subset_i;
    502 
    503     for( i = -1; i < mi; i++ )
    504         sum[i] = counts[i] = 0;
    505 
    506     // calculate sum response and weight of each category of the input var
    507     for( i = 0; i < n; i++ )
    508     {
    509         int idx = cat_labels[i];
    510         double w = weights[i];
    511         double s = sum[idx] + responses[i]*w;
    512         double nc = counts[idx] + w;
    513         sum[idx] = s;
    514         counts[idx] = nc;
    515     }
    516 
    517     // calculate average response in each category
    518     for( i = 0; i < mi; i++ )
    519     {
    520         R += counts[i];
    521         rsum += sum[i];
    522         sum[i] /= counts[i];
    523         sum_ptr[i] = sum + i;
    524     }
    525 
    526     icvSortDblPtr( sum_ptr, mi, 0 );
    527 
    528     // revert back to unnormalized sums
    529     // (there should be a very little loss in accuracy)
    530     for( i = 0; i < mi; i++ )
    531         sum[i] *= counts[i];
    532 
    533     for( subset_i = 0; subset_i < mi-1; subset_i++ )
    534     {
    535         int idx = (int)(sum_ptr[subset_i] - sum);
    536         double ni = counts[idx];
    537 
    538         if( ni > FLT_EPSILON )
    539         {
    540             double s = sum[idx];
    541             lsum += s; L += ni;
    542             rsum -= s; R -= ni;
    543 
    544             if( L > FLT_EPSILON && R > FLT_EPSILON )
    545             {
    546                 double val = (lsum*lsum*R + rsum*rsum*L)/(L*R);
    547                 if( best_val < val )
    548                 {
    549                     best_val = val;
    550                     best_subset = subset_i;
    551                 }
    552             }
    553         }
    554     }
    555 
    556     if( best_subset < 0 )
    557         return 0;
    558 
    559     split = data->new_split_cat( vi, (float)best_val );
    560     for( i = 0; i <= best_subset; i++ )
    561     {
    562         int idx = (int)(sum_ptr[i] - sum);
    563         split->subset[idx >> 5] |= 1 << (idx & 31);
    564     }
    565 
    566     return split;
    567 }
    568 
    569 
    570 CvDTreeSplit*
    571 CvBoostTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
    572 {
    573     const float epsilon = FLT_EPSILON*2;
    574     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
    575     const double* weights = ensemble->get_subtree_weights()->data.db;
    576     const char* dir = (char*)data->direction->data.ptr;
    577     int n1 = node->get_num_valid(vi);
    578     // LL - number of samples that both the primary and the surrogate splits send to the left
    579     // LR - ... primary split sends to the left and the surrogate split sends to the right
    580     // RL - ... primary split sends to the right and the surrogate split sends to the left
    581     // RR - ... both send to the right
    582     int i, best_i = -1, best_inversed = 0;
    583     double best_val;
    584     double LL = 0, RL = 0, LR, RR;
    585     double worst_val = node->maxlr;
    586     double sum = 0, sum_abs = 0;
    587     best_val = worst_val;
    588 
    589     for( i = 0; i < n1; i++ )
    590     {
    591         int idx = sorted[i].i;
    592         double w = weights[idx];
    593         int d = dir[idx];
    594         sum += d*w; sum_abs += (d & 1)*w;
    595     }
    596 
    597     // sum_abs = R + L; sum = R - L
    598     RR = (sum_abs + sum)*0.5;
    599     LR = (sum_abs - sum)*0.5;
    600 
    601     // initially all the samples are sent to the right by the surrogate split,
    602     // LR of them are sent to the left by primary split, and RR - to the right.
    603     // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
    604     for( i = 0; i < n1 - 1; i++ )
    605     {
    606         int idx = sorted[i].i;
    607         double w = weights[idx];
    608         int d = dir[idx];
    609 
    610         if( d < 0 )
    611         {
    612             LL += w; LR -= w;
    613             if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
    614             {
    615                 best_val = LL + RR;
    616                 best_i = i; best_inversed = 0;
    617             }
    618         }
    619         else if( d > 0 )
    620         {
    621             RL += w; RR -= w;
    622             if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
    623             {
    624                 best_val = RL + LR;
    625                 best_i = i; best_inversed = 1;
    626             }
    627         }
    628     }
    629 
    630     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
    631         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
    632         best_inversed, (float)best_val ) : 0;
    633 }
    634 
    635 
    636 CvDTreeSplit*
    637 CvBoostTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
    638 {
    639     const int* cat_labels = data->get_cat_var_data(node, vi);
    640     const char* dir = (char*)data->direction->data.ptr;
    641     const double* weights = ensemble->get_subtree_weights()->data.db;
    642     int n = node->sample_count;
    643     // LL - number of samples that both the primary and the surrogate splits send to the left
    644     // LR - ... primary split sends to the left and the surrogate split sends to the right
    645     // RL - ... primary split sends to the right and the surrogate split sends to the left
    646     // RR - ... both send to the right
    647     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
    648     int i, mi = data->cat_count->data.i[data->get_var_type(vi)];
    649     double best_val = 0;
    650     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
    651     double* rc = lc + mi + 1;
    652 
    653     for( i = -1; i < mi; i++ )
    654         lc[i] = rc[i] = 0;
    655 
    656     // 1. for each category calculate the weight of samples
    657     // sent to the left (lc) and to the right (rc) by the primary split
    658     for( i = 0; i < n; i++ )
    659     {
    660         int idx = cat_labels[i];
    661         double w = weights[i];
    662         int d = dir[i];
    663         double sum = lc[idx] + d*w;
    664         double sum_abs = rc[idx] + (d & 1)*w;
    665         lc[idx] = sum; rc[idx] = sum_abs;
    666     }
    667 
    668     for( i = 0; i < mi; i++ )
    669     {
    670         double sum = lc[i];
    671         double sum_abs = rc[i];
    672         lc[i] = (sum_abs - sum) * 0.5;
    673         rc[i] = (sum_abs + sum) * 0.5;
    674     }
    675 
    676     // 2. now form the split.
    677     // in each category send all the samples to the same direction as majority
    678     for( i = 0; i < mi; i++ )
    679     {
    680         double lval = lc[i], rval = rc[i];
    681         if( lval > rval )
    682         {
    683             split->subset[i >> 5] |= 1 << (i & 31);
    684             best_val += lval;
    685         }
    686         else
    687             best_val += rval;
    688     }
    689 
    690     split->quality = (float)best_val;
    691     if( split->quality <= node->maxlr )
    692         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
    693 
    694     return split;
    695 }
    696 
    697 
    698 void
    699 CvBoostTree::calc_node_value( CvDTreeNode* node )
    700 {
    701     int i, count = node->sample_count;
    702     const double* weights = ensemble->get_weights()->data.db;
    703     const int* labels = data->get_labels(node);
    704     double* subtree_weights = ensemble->get_subtree_weights()->data.db;
    705     double rcw[2] = {0,0};
    706     int boost_type = ensemble->get_params().boost_type;
    707     //const double* priors = data->priors->data.db;
    708 
    709     if( data->is_classifier )
    710     {
    711         const int* responses = data->get_class_labels(node);
    712 
    713         for( i = 0; i < count; i++ )
    714         {
    715             int idx = labels[i];
    716             double w = weights[idx]/*priors[responses[i]]*/;
    717             rcw[responses[i]] += w;
    718             subtree_weights[i] = w;
    719         }
    720 
    721         node->class_idx = rcw[1] > rcw[0];
    722 
    723         if( boost_type == CvBoost::DISCRETE )
    724         {
    725             // ignore cat_map for responses, and use {-1,1},
    726             // as the whole ensemble response is computes as sign(sum_i(weak_response_i)
    727             node->value = node->class_idx*2 - 1;
    728         }
    729         else
    730         {
    731             double p = rcw[1]/(rcw[0] + rcw[1]);
    732             assert( boost_type == CvBoost::REAL );
    733 
    734             // store log-ratio of the probability
    735             node->value = 0.5*log_ratio(p);
    736         }
    737     }
    738     else
    739     {
    740         // in case of regression tree:
    741         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
    742         //    n is the number of samples in the node.
    743         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
    744         double sum = 0, sum2 = 0, iw;
    745         const float* values = data->get_ord_responses(node);
    746 
    747         for( i = 0; i < count; i++ )
    748         {
    749             int idx = labels[i];
    750             double w = weights[idx]/*priors[values[i] > 0]*/;
    751             double t = values[i];
    752             rcw[0] += w;
    753             subtree_weights[i] = w;
    754             sum += t*w;
    755             sum2 += t*t*w;
    756         }
    757 
    758         iw = 1./rcw[0];
    759         node->value = sum*iw;
    760         node->node_risk = sum2 - (sum*iw)*sum;
    761 
    762         // renormalize the risk, as in try_split_node the unweighted formula
    763         // sqrt(risk)/n is used, rather than sqrt(risk)/sum(weights_i)
    764         node->node_risk *= count*iw*count*iw;
    765     }
    766 
    767     // store summary weights
    768     subtree_weights[count] = rcw[0];
    769     subtree_weights[count+1] = rcw[1];
    770 }
    771 
    772 
    773 void CvBoostTree::read( CvFileStorage* fs, CvFileNode* fnode, CvBoost* _ensemble, CvDTreeTrainData* _data )
    774 {
    775     CvDTree::read( fs, fnode, _data );
    776     ensemble = _ensemble;
    777 }
    778 
    779 
    780 void CvBoostTree::read( CvFileStorage*, CvFileNode* )
    781 {
    782     assert(0);
    783 }
    784 
    785 void CvBoostTree::read( CvFileStorage* _fs, CvFileNode* _node,
    786                         CvDTreeTrainData* _data )
    787 {
    788     CvDTree::read( _fs, _node, _data );
    789 }
    790 
    791 
    792 /////////////////////////////////// CvBoost /////////////////////////////////////
    793 
    794 CvBoost::CvBoost()
    795 {
    796     data = 0;
    797     weak = 0;
    798     default_model_name = "my_boost_tree";
    799     orig_response = sum_response = weak_eval = subsample_mask =
    800         weights = subtree_weights = 0;
    801 
    802     clear();
    803 }
    804 
    805 
    806 void CvBoost::prune( CvSlice slice )
    807 {
    808     if( weak )
    809     {
    810         CvSeqReader reader;
    811         int i, count = cvSliceLength( slice, weak );
    812 
    813         cvStartReadSeq( weak, &reader );
    814         cvSetSeqReaderPos( &reader, slice.start_index );
    815 
    816         for( i = 0; i < count; i++ )
    817         {
    818             CvBoostTree* w;
    819             CV_READ_SEQ_ELEM( w, reader );
    820             delete w;
    821         }
    822 
    823         cvSeqRemoveSlice( weak, slice );
    824     }
    825 }
    826 
    827 
    828 void CvBoost::clear()
    829 {
    830     if( weak )
    831     {
    832         prune( CV_WHOLE_SEQ );
    833         cvReleaseMemStorage( &weak->storage );
    834     }
    835     if( data )
    836         delete data;
    837     weak = 0;
    838     data = 0;
    839     cvReleaseMat( &orig_response );
    840     cvReleaseMat( &sum_response );
    841     cvReleaseMat( &weak_eval );
    842     cvReleaseMat( &subsample_mask );
    843     cvReleaseMat( &weights );
    844     have_subsample = false;
    845 }
    846 
    847 
    848 CvBoost::~CvBoost()
    849 {
    850     clear();
    851 }
    852 
    853 
    854 CvBoost::CvBoost( const CvMat* _train_data, int _tflag,
    855                   const CvMat* _responses, const CvMat* _var_idx,
    856                   const CvMat* _sample_idx, const CvMat* _var_type,
    857                   const CvMat* _missing_mask, CvBoostParams _params )
    858 {
    859     weak = 0;
    860     data = 0;
    861     default_model_name = "my_boost_tree";
    862     orig_response = sum_response = weak_eval = subsample_mask = weights = 0;
    863 
    864     train( _train_data, _tflag, _responses, _var_idx, _sample_idx,
    865            _var_type, _missing_mask, _params );
    866 }
    867 
    868 
    869 bool
    870 CvBoost::set_params( const CvBoostParams& _params )
    871 {
    872     bool ok = false;
    873 
    874     CV_FUNCNAME( "CvBoost::set_params" );
    875 
    876     __BEGIN__;
    877 
    878     params = _params;
    879     if( params.boost_type != DISCRETE && params.boost_type != REAL &&
    880         params.boost_type != LOGIT && params.boost_type != GENTLE )
    881         CV_ERROR( CV_StsBadArg, "Unknown/unsupported boosting type" );
    882 
    883     params.weak_count = MAX( params.weak_count, 1 );
    884     params.weight_trim_rate = MAX( params.weight_trim_rate, 0. );
    885     params.weight_trim_rate = MIN( params.weight_trim_rate, 1. );
    886     if( params.weight_trim_rate < FLT_EPSILON )
    887         params.weight_trim_rate = 1.f;
    888 
    889     if( params.boost_type == DISCRETE &&
    890         params.split_criteria != GINI && params.split_criteria != MISCLASS )
    891         params.split_criteria = MISCLASS;
    892     if( params.boost_type == REAL &&
    893         params.split_criteria != GINI && params.split_criteria != MISCLASS )
    894         params.split_criteria = GINI;
    895     if( (params.boost_type == LOGIT || params.boost_type == GENTLE) &&
    896         params.split_criteria != SQERR )
    897         params.split_criteria = SQERR;
    898 
    899     ok = true;
    900 
    901     __END__;
    902 
    903     return ok;
    904 }
    905 
    906 
    907 bool
    908 CvBoost::train( const CvMat* _train_data, int _tflag,
    909               const CvMat* _responses, const CvMat* _var_idx,
    910               const CvMat* _sample_idx, const CvMat* _var_type,
    911               const CvMat* _missing_mask,
    912               CvBoostParams _params, bool _update )
    913 {
    914     bool ok = false;
    915     CvMemStorage* storage = 0;
    916 
    917     CV_FUNCNAME( "CvBoost::train" );
    918 
    919     __BEGIN__;
    920 
    921     int i;
    922 
    923     set_params( _params );
    924 
    925     if( !_update || !data )
    926     {
    927         clear();
    928         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
    929             _sample_idx, _var_type, _missing_mask, _params, true, true );
    930 
    931         if( data->get_num_classes() != 2 )
    932             CV_ERROR( CV_StsNotImplemented,
    933             "Boosted trees can only be used for 2-class classification." );
    934         CV_CALL( storage = cvCreateMemStorage() );
    935         weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
    936         storage = 0;
    937     }
    938     else
    939     {
    940         data->set_data( _train_data, _tflag, _responses, _var_idx,
    941             _sample_idx, _var_type, _missing_mask, _params, true, true, true );
    942     }
    943 
    944     update_weights( 0 );
    945 
    946     for( i = 0; i < params.weak_count; i++ )
    947     {
    948         CvBoostTree* tree = new CvBoostTree;
    949         if( !tree->train( data, subsample_mask, this ) )
    950         {
    951             delete tree;
    952             continue;
    953         }
    954         //cvCheckArr( get_weak_response());
    955         cvSeqPush( weak, &tree );
    956         update_weights( tree );
    957         trim_weights();
    958     }
    959 
    960     data->is_classifier = true;
    961     ok = true;
    962 
    963     __END__;
    964 
    965     return ok;
    966 }
    967 
    968 
    969 void
    970 CvBoost::update_weights( CvBoostTree* tree )
    971 {
    972     CV_FUNCNAME( "CvBoost::update_weights" );
    973 
    974     __BEGIN__;
    975 
    976     int i, count = data->sample_count;
    977     double sumw = 0.;
    978 
    979     if( !tree ) // before training the first tree, initialize weights and other parameters
    980     {
    981         const int* class_labels = data->get_class_labels(data->data_root);
    982         // in case of logitboost and gentle adaboost each weak tree is a regression tree,
    983         // so we need to convert class labels to floating-point values
    984         float* responses = data->get_ord_responses(data->data_root);
    985         int* labels = data->get_labels(data->data_root);
    986         double w0 = 1./count;
    987         double p[2] = { 1, 1 };
    988 
    989         cvReleaseMat( &orig_response );
    990         cvReleaseMat( &sum_response );
    991         cvReleaseMat( &weak_eval );
    992         cvReleaseMat( &subsample_mask );
    993         cvReleaseMat( &weights );
    994 
    995         CV_CALL( orig_response = cvCreateMat( 1, count, CV_32S ));
    996         CV_CALL( weak_eval = cvCreateMat( 1, count, CV_64F ));
    997         CV_CALL( subsample_mask = cvCreateMat( 1, count, CV_8U ));
    998         CV_CALL( weights = cvCreateMat( 1, count, CV_64F ));
    999         CV_CALL( subtree_weights = cvCreateMat( 1, count + 2, CV_64F ));
   1000 
   1001         if( data->have_priors )
   1002         {
   1003             // compute weight scale for each class from their prior probabilities
   1004             int c1 = 0;
   1005             for( i = 0; i < count; i++ )
   1006                 c1 += class_labels[i];
   1007             p[0] = data->priors->data.db[0]*(c1 < count ? 1./(count - c1) : 0.);
   1008             p[1] = data->priors->data.db[1]*(c1 > 0 ? 1./c1 : 0.);
   1009             p[0] /= p[0] + p[1];
   1010             p[1] = 1. - p[0];
   1011         }
   1012 
   1013         for( i = 0; i < count; i++ )
   1014         {
   1015             // save original categorical responses {0,1}, convert them to {-1,1}
   1016             orig_response->data.i[i] = class_labels[i]*2 - 1;
   1017             // make all the samples active at start.
   1018             // later, in trim_weights() deactivate/reactive again some, if need
   1019             subsample_mask->data.ptr[i] = (uchar)1;
   1020             // make all the initial weights the same.
   1021             weights->data.db[i] = w0*p[class_labels[i]];
   1022             // set the labels to find (from within weak tree learning proc)
   1023             // the particular sample weight, and where to store the response.
   1024             labels[i] = i;
   1025         }
   1026 
   1027         if( params.boost_type == LOGIT )
   1028         {
   1029             CV_CALL( sum_response = cvCreateMat( 1, count, CV_64F ));
   1030 
   1031             for( i = 0; i < count; i++ )
   1032             {
   1033                 sum_response->data.db[i] = 0;
   1034                 responses[i] = orig_response->data.i[i] > 0 ? 2.f : -2.f;
   1035             }
   1036 
   1037             // in case of logitboost each weak tree is a regression tree.
   1038             // the target function values are recalculated for each of the trees
   1039             data->is_classifier = false;
   1040         }
   1041         else if( params.boost_type == GENTLE )
   1042         {
   1043             for( i = 0; i < count; i++ )
   1044                 responses[i] = (float)orig_response->data.i[i];
   1045 
   1046             data->is_classifier = false;
   1047         }
   1048     }
   1049     else
   1050     {
   1051         // at this moment, for all the samples that participated in the training of the most
   1052         // recent weak classifier we know the responses. For other samples we need to compute them
   1053         if( have_subsample )
   1054         {
   1055             float* values = (float*)(data->buf->data.ptr + data->buf->step);
   1056             uchar* missing = data->buf->data.ptr + data->buf->step*2;
   1057             CvMat _sample, _mask;
   1058 
   1059             // invert the subsample mask
   1060             cvXorS( subsample_mask, cvScalar(1.), subsample_mask );
   1061             data->get_vectors( subsample_mask, values, missing, 0 );
   1062             //data->get_vectors( 0, values, missing, 0 );
   1063 
   1064             _sample = cvMat( 1, data->var_count, CV_32F );
   1065             _mask = cvMat( 1, data->var_count, CV_8U );
   1066 
   1067             // run tree through all the non-processed samples
   1068             for( i = 0; i < count; i++ )
   1069                 if( subsample_mask->data.ptr[i] )
   1070                 {
   1071                     _sample.data.fl = values;
   1072                     _mask.data.ptr = missing;
   1073                     values += _sample.cols;
   1074                     missing += _mask.cols;
   1075                     weak_eval->data.db[i] = tree->predict( &_sample, &_mask, true )->value;
   1076                 }
   1077         }
   1078 
   1079         // now update weights and other parameters for each type of boosting
   1080         if( params.boost_type == DISCRETE )
   1081         {
   1082             // Discrete AdaBoost:
   1083             //   weak_eval[i] (=f(x_i)) is in {-1,1}
   1084             //   err = sum(w_i*(f(x_i) != y_i))/sum(w_i)
   1085             //   C = log((1-err)/err)
   1086             //   w_i *= exp(C*(f(x_i) != y_i))
   1087 
   1088             double C, err = 0.;
   1089             double scale[] = { 1., 0. };
   1090 
   1091             for( i = 0; i < count; i++ )
   1092             {
   1093                 double w = weights->data.db[i];
   1094                 sumw += w;
   1095                 err += w*(weak_eval->data.db[i] != orig_response->data.i[i]);
   1096             }
   1097 
   1098             if( sumw != 0 )
   1099                 err /= sumw;
   1100             C = err = -log_ratio( err );
   1101             scale[1] = exp(err);
   1102 
   1103             sumw = 0;
   1104             for( i = 0; i < count; i++ )
   1105             {
   1106                 double w = weights->data.db[i]*
   1107                     scale[weak_eval->data.db[i] != orig_response->data.i[i]];
   1108                 sumw += w;
   1109                 weights->data.db[i] = w;
   1110             }
   1111 
   1112             tree->scale( C );
   1113         }
   1114         else if( params.boost_type == REAL )
   1115         {
   1116             // Real AdaBoost:
   1117             //   weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i)
   1118             //   w_i *= exp(-y_i*f(x_i))
   1119 
   1120             for( i = 0; i < count; i++ )
   1121                 weak_eval->data.db[i] *= -orig_response->data.i[i];
   1122 
   1123             cvExp( weak_eval, weak_eval );
   1124 
   1125             for( i = 0; i < count; i++ )
   1126             {
   1127                 double w = weights->data.db[i]*weak_eval->data.db[i];
   1128                 sumw += w;
   1129                 weights->data.db[i] = w;
   1130             }
   1131         }
   1132         else if( params.boost_type == LOGIT )
   1133         {
   1134             // LogitBoost:
   1135             //   weak_eval[i] = f(x_i) in [-z_max,z_max]
   1136             //   sum_response = F(x_i).
   1137             //   F(x_i) += 0.5*f(x_i)
   1138             //   p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i)))
   1139             //   reuse weak_eval: weak_eval[i] <- p(x_i)
   1140             //   w_i = p(x_i)*1(1 - p(x_i))
   1141             //   z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i)))
   1142             //   store z_i to the data->data_root as the new target responses
   1143 
   1144             const double lb_weight_thresh = FLT_EPSILON;
   1145             const double lb_z_max = 10.;
   1146             float* responses = data->get_ord_responses(data->data_root);
   1147 
   1148             /*if( weak->total == 7 )
   1149                 putchar('*');*/
   1150 
   1151             for( i = 0; i < count; i++ )
   1152             {
   1153                 double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i];
   1154                 sum_response->data.db[i] = s;
   1155                 weak_eval->data.db[i] = -2*s;
   1156             }
   1157 
   1158             cvExp( weak_eval, weak_eval );
   1159 
   1160             for( i = 0; i < count; i++ )
   1161             {
   1162                 double p = 1./(1. + weak_eval->data.db[i]);
   1163                 double w = p*(1 - p), z;
   1164                 w = MAX( w, lb_weight_thresh );
   1165                 weights->data.db[i] = w;
   1166                 sumw += w;
   1167                 if( orig_response->data.i[i] > 0 )
   1168                 {
   1169                     z = 1./p;
   1170                     responses[i] = (float)MIN(z, lb_z_max);
   1171                 }
   1172                 else
   1173                 {
   1174                     z = 1./(1-p);
   1175                     responses[i] = (float)-MIN(z, lb_z_max);
   1176                 }
   1177             }
   1178         }
   1179         else
   1180         {
   1181             // Gentle AdaBoost:
   1182             //   weak_eval[i] = f(x_i) in [-1,1]
   1183             //   w_i *= exp(-y_i*f(x_i))
   1184             assert( params.boost_type == GENTLE );
   1185 
   1186             for( i = 0; i < count; i++ )
   1187                 weak_eval->data.db[i] *= -orig_response->data.i[i];
   1188 
   1189             cvExp( weak_eval, weak_eval );
   1190 
   1191             for( i = 0; i < count; i++ )
   1192             {
   1193                 double w = weights->data.db[i] * weak_eval->data.db[i];
   1194                 weights->data.db[i] = w;
   1195                 sumw += w;
   1196             }
   1197         }
   1198     }
   1199 
   1200     // renormalize weights
   1201     if( sumw > FLT_EPSILON )
   1202     {
   1203         sumw = 1./sumw;
   1204         for( i = 0; i < count; ++i )
   1205             weights->data.db[i] *= sumw;
   1206     }
   1207 
   1208     __END__;
   1209 }
   1210 
   1211 
   1212 static CV_IMPLEMENT_QSORT_EX( icvSort_64f, double, CV_LT, int )
   1213 
   1214 
   1215 void
   1216 CvBoost::trim_weights()
   1217 {
   1218     CV_FUNCNAME( "CvBoost::trim_weights" );
   1219 
   1220     __BEGIN__;
   1221 
   1222     int i, count = data->sample_count, nz_count = 0;
   1223     double sum, threshold;
   1224 
   1225     if( params.weight_trim_rate <= 0. || params.weight_trim_rate >= 1. )
   1226         EXIT;
   1227 
   1228     // use weak_eval as temporary buffer for sorted weights
   1229     cvCopy( weights, weak_eval );
   1230 
   1231     icvSort_64f( weak_eval->data.db, count, 0 );
   1232 
   1233     // as weight trimming occurs immediately after updating the weights,
   1234     // where they are renormalized, we assume that the weight sum = 1.
   1235     sum = 1. - params.weight_trim_rate;
   1236 
   1237     for( i = 0; i < count; i++ )
   1238     {
   1239         double w = weak_eval->data.db[i];
   1240         if( sum > w )
   1241             break;
   1242         sum -= w;
   1243     }
   1244 
   1245     threshold = i < count ? weak_eval->data.db[i] : DBL_MAX;
   1246 
   1247     for( i = 0; i < count; i++ )
   1248     {
   1249         double w = weights->data.db[i];
   1250         int f = w > threshold;
   1251         subsample_mask->data.ptr[i] = (uchar)f;
   1252         nz_count += f;
   1253     }
   1254 
   1255     have_subsample = nz_count < count;
   1256 
   1257     __END__;
   1258 }
   1259 
   1260 
   1261 float
   1262 CvBoost::predict( const CvMat* _sample, const CvMat* _missing,
   1263                   CvMat* weak_responses, CvSlice slice,
   1264                   bool raw_mode ) const
   1265 {
   1266     float* buf = 0;
   1267     bool allocated = false;
   1268     float value = -FLT_MAX;
   1269 
   1270     CV_FUNCNAME( "CvBoost::predict" );
   1271 
   1272     __BEGIN__;
   1273 
   1274     int i, weak_count, var_count;
   1275     CvMat sample, missing;
   1276     CvSeqReader reader;
   1277     double sum = 0;
   1278     int cls_idx;
   1279     int wstep = 0;
   1280     const int* vtype;
   1281     const int* cmap;
   1282     const int* cofs;
   1283 
   1284     if( !weak )
   1285         CV_ERROR( CV_StsError, "The boosted tree ensemble has not been trained yet" );
   1286 
   1287     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
   1288         _sample->cols != 1 && _sample->rows != 1 ||
   1289         _sample->cols + _sample->rows - 1 != data->var_all && !raw_mode ||
   1290         _sample->cols + _sample->rows - 1 != data->var_count && raw_mode )
   1291             CV_ERROR( CV_StsBadArg,
   1292         "the input sample must be 1d floating-point vector with the same "
   1293         "number of elements as the total number of variables used for training" );
   1294 
   1295     if( _missing )
   1296     {
   1297         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
   1298             !CV_ARE_SIZES_EQ(_missing, _sample) )
   1299             CV_ERROR( CV_StsBadArg,
   1300             "the missing data mask must be 8-bit vector of the same size as input sample" );
   1301     }
   1302 
   1303     weak_count = cvSliceLength( slice, weak );
   1304     if( weak_count >= weak->total )
   1305     {
   1306         weak_count = weak->total;
   1307         slice.start_index = 0;
   1308     }
   1309 
   1310     if( weak_responses )
   1311     {
   1312         if( !CV_IS_MAT(weak_responses) ||
   1313             CV_MAT_TYPE(weak_responses->type) != CV_32FC1 ||
   1314             weak_responses->cols != 1 && weak_responses->rows != 1 ||
   1315             weak_responses->cols + weak_responses->rows - 1 != weak_count )
   1316             CV_ERROR( CV_StsBadArg,
   1317             "The output matrix of weak classifier responses must be valid "
   1318             "floating-point vector of the same number of components as the length of input slice" );
   1319         wstep = CV_IS_MAT_CONT(weak_responses->type) ? 1 : weak_responses->step/sizeof(float);
   1320     }
   1321 
   1322     var_count = data->var_count;
   1323     vtype = data->var_type->data.i;
   1324     cmap = data->cat_map->data.i;
   1325     cofs = data->cat_ofs->data.i;
   1326 
   1327     // if need, preprocess the input vector
   1328     if( !raw_mode && (data->cat_var_count > 0 || data->var_idx) )
   1329     {
   1330         int bufsize;
   1331         int step, mstep = 0;
   1332         const float* src_sample;
   1333         const uchar* src_mask = 0;
   1334         float* dst_sample;
   1335         uchar* dst_mask;
   1336         const int* vidx = data->var_idx && !raw_mode ? data->var_idx->data.i : 0;
   1337         bool have_mask = _missing != 0;
   1338 
   1339         bufsize = var_count*(sizeof(float) + sizeof(uchar));
   1340         if( bufsize <= CV_MAX_LOCAL_SIZE )
   1341             buf = (float*)cvStackAlloc( bufsize );
   1342         else
   1343         {
   1344             CV_CALL( buf = (float*)cvAlloc( bufsize ));
   1345             allocated = true;
   1346         }
   1347         dst_sample = buf;
   1348         dst_mask = (uchar*)(buf + var_count);
   1349 
   1350         src_sample = _sample->data.fl;
   1351         step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(src_sample[0]);
   1352 
   1353         if( _missing )
   1354         {
   1355             src_mask = _missing->data.ptr;
   1356             mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step;
   1357         }
   1358 
   1359         for( i = 0; i < var_count; i++ )
   1360         {
   1361             int idx = vidx ? vidx[i] : i;
   1362             float val = src_sample[idx*step];
   1363             int ci = vtype[i];
   1364             uchar m = src_mask ? src_mask[i] : (uchar)0;
   1365 
   1366             if( ci >= 0 )
   1367             {
   1368                 int a = cofs[ci], b = cofs[ci+1], c = a;
   1369                 int ival = cvRound(val);
   1370                 if( ival != val )
   1371                     CV_ERROR( CV_StsBadArg,
   1372                     "one of input categorical variable is not an integer" );
   1373 
   1374                 while( a < b )
   1375                 {
   1376                     c = (a + b) >> 1;
   1377                     if( ival < cmap[c] )
   1378                         b = c;
   1379                     else if( ival > cmap[c] )
   1380                         a = c+1;
   1381                     else
   1382                         break;
   1383                 }
   1384 
   1385                 if( c < 0 || ival != cmap[c] )
   1386                 {
   1387                     m = 1;
   1388                     have_mask = true;
   1389                 }
   1390                 else
   1391                 {
   1392                     val = (float)(c - cofs[ci]);
   1393                 }
   1394             }
   1395 
   1396             dst_sample[i] = val;
   1397             dst_mask[i] = m;
   1398         }
   1399 
   1400         sample = cvMat( 1, var_count, CV_32F, dst_sample );
   1401         _sample = &sample;
   1402 
   1403         if( have_mask )
   1404         {
   1405             missing = cvMat( 1, var_count, CV_8UC1, dst_mask );
   1406             _missing = &missing;
   1407         }
   1408     }
   1409 
   1410     cvStartReadSeq( weak, &reader );
   1411     cvSetSeqReaderPos( &reader, slice.start_index );
   1412 
   1413     for( i = 0; i < weak_count; i++ )
   1414     {
   1415         CvBoostTree* wtree;
   1416         double val;
   1417 
   1418         CV_READ_SEQ_ELEM( wtree, reader );
   1419 
   1420         val = wtree->predict( _sample, _missing, true )->value;
   1421         if( weak_responses )
   1422             weak_responses->data.fl[i*wstep] = (float)val;
   1423 
   1424         sum += val;
   1425     }
   1426 
   1427     cls_idx = sum >= 0;
   1428     if( raw_mode )
   1429         value = (float)cls_idx;
   1430     else
   1431         value = (float)cmap[cofs[vtype[var_count]] + cls_idx];
   1432 
   1433     __END__;
   1434 
   1435     if( allocated )
   1436         cvFree( &buf );
   1437 
   1438     return value;
   1439 }
   1440 
   1441 
   1442 
   1443 void CvBoost::write_params( CvFileStorage* fs )
   1444 {
   1445     CV_FUNCNAME( "CvBoost::write_params" );
   1446 
   1447     __BEGIN__;
   1448 
   1449     const char* boost_type_str =
   1450         params.boost_type == DISCRETE ? "DiscreteAdaboost" :
   1451         params.boost_type == REAL ? "RealAdaboost" :
   1452         params.boost_type == LOGIT ? "LogitBoost" :
   1453         params.boost_type == GENTLE ? "GentleAdaboost" : 0;
   1454 
   1455     const char* split_crit_str =
   1456         params.split_criteria == DEFAULT ? "Default" :
   1457         params.split_criteria == GINI ? "Gini" :
   1458         params.boost_type == MISCLASS ? "Misclassification" :
   1459         params.boost_type == SQERR ? "SquaredErr" : 0;
   1460 
   1461     if( boost_type_str )
   1462         cvWriteString( fs, "boosting_type", boost_type_str );
   1463     else
   1464         cvWriteInt( fs, "boosting_type", params.boost_type );
   1465 
   1466     if( split_crit_str )
   1467         cvWriteString( fs, "splitting_criteria", split_crit_str );
   1468     else
   1469         cvWriteInt( fs, "splitting_criteria", params.split_criteria );
   1470 
   1471     cvWriteInt( fs, "ntrees", params.weak_count );
   1472     cvWriteReal( fs, "weight_trimming_rate", params.weight_trim_rate );
   1473 
   1474     data->write_params( fs );
   1475 
   1476     __END__;
   1477 }
   1478 
   1479 
   1480 void CvBoost::read_params( CvFileStorage* fs, CvFileNode* fnode )
   1481 {
   1482     CV_FUNCNAME( "CvBoost::read_params" );
   1483 
   1484     __BEGIN__;
   1485 
   1486     CvFileNode* temp;
   1487 
   1488     if( !fnode || !CV_NODE_IS_MAP(fnode->tag) )
   1489         return;
   1490 
   1491     data = new CvDTreeTrainData();
   1492     CV_CALL( data->read_params(fs, fnode));
   1493     data->shared = true;
   1494 
   1495     params.max_depth = data->params.max_depth;
   1496     params.min_sample_count = data->params.min_sample_count;
   1497     params.max_categories = data->params.max_categories;
   1498     params.priors = data->params.priors;
   1499     params.regression_accuracy = data->params.regression_accuracy;
   1500     params.use_surrogates = data->params.use_surrogates;
   1501 
   1502     temp = cvGetFileNodeByName( fs, fnode, "boosting_type" );
   1503     if( !temp )
   1504         return;
   1505 
   1506     if( temp && CV_NODE_IS_STRING(temp->tag) )
   1507     {
   1508         const char* boost_type_str = cvReadString( temp, "" );
   1509         params.boost_type = strcmp( boost_type_str, "DiscreteAdaboost" ) == 0 ? DISCRETE :
   1510                             strcmp( boost_type_str, "RealAdaboost" ) == 0 ? REAL :
   1511                             strcmp( boost_type_str, "LogitBoost" ) == 0 ? LOGIT :
   1512                             strcmp( boost_type_str, "GentleAdaboost" ) == 0 ? GENTLE : -1;
   1513     }
   1514     else
   1515         params.boost_type = cvReadInt( temp, -1 );
   1516 
   1517     if( params.boost_type < DISCRETE || params.boost_type > GENTLE )
   1518         CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
   1519 
   1520     temp = cvGetFileNodeByName( fs, fnode, "splitting_criteria" );
   1521     if( temp && CV_NODE_IS_STRING(temp->tag) )
   1522     {
   1523         const char* split_crit_str = cvReadString( temp, "" );
   1524         params.split_criteria = strcmp( split_crit_str, "Default" ) == 0 ? DEFAULT :
   1525                                 strcmp( split_crit_str, "Gini" ) == 0 ? GINI :
   1526                                 strcmp( split_crit_str, "Misclassification" ) == 0 ? MISCLASS :
   1527                                 strcmp( split_crit_str, "SquaredErr" ) == 0 ? SQERR : -1;
   1528     }
   1529     else
   1530         params.split_criteria = cvReadInt( temp, -1 );
   1531 
   1532     if( params.split_criteria < DEFAULT || params.boost_type > SQERR )
   1533         CV_ERROR( CV_StsBadArg, "Unknown boosting type" );
   1534 
   1535     params.weak_count = cvReadIntByName( fs, fnode, "ntrees" );
   1536     params.weight_trim_rate = cvReadRealByName( fs, fnode, "weight_trimming_rate", 0. );
   1537 
   1538     __END__;
   1539 }
   1540 
   1541 
   1542 
   1543 void
   1544 CvBoost::read( CvFileStorage* fs, CvFileNode* node )
   1545 {
   1546     CV_FUNCNAME( "CvRTrees::read" );
   1547 
   1548     __BEGIN__;
   1549 
   1550     CvSeqReader reader;
   1551     CvFileNode* trees_fnode;
   1552     CvMemStorage* storage;
   1553     int i, ntrees;
   1554 
   1555     clear();
   1556     read_params( fs, node );
   1557 
   1558     if( !data )
   1559         EXIT;
   1560 
   1561     trees_fnode = cvGetFileNodeByName( fs, node, "trees" );
   1562     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
   1563         CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
   1564 
   1565     cvStartReadSeq( trees_fnode->data.seq, &reader );
   1566     ntrees = trees_fnode->data.seq->total;
   1567 
   1568     if( ntrees != params.weak_count )
   1569         CV_ERROR( CV_StsUnmatchedSizes,
   1570         "The number of trees stored does not match <ntrees> tag value" );
   1571 
   1572     CV_CALL( storage = cvCreateMemStorage() );
   1573     weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage );
   1574 
   1575     for( i = 0; i < ntrees; i++ )
   1576     {
   1577         CvBoostTree* tree = new CvBoostTree();
   1578         CV_CALL(tree->read( fs, (CvFileNode*)reader.ptr, this, data ));
   1579         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   1580         cvSeqPush( weak, &tree );
   1581     }
   1582 
   1583     __END__;
   1584 }
   1585 
   1586 
   1587 void
   1588 CvBoost::write( CvFileStorage* fs, const char* name )
   1589 {
   1590     CV_FUNCNAME( "CvBoost::write" );
   1591 
   1592     __BEGIN__;
   1593 
   1594     CvSeqReader reader;
   1595     int i;
   1596 
   1597     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_BOOSTING );
   1598 
   1599     if( !weak )
   1600         CV_ERROR( CV_StsBadArg, "The classifier has not been trained yet" );
   1601 
   1602     write_params( fs );
   1603     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
   1604 
   1605     cvStartReadSeq( weak, &reader );
   1606 
   1607     for( i = 0; i < weak->total; i++ )
   1608     {
   1609         CvBoostTree* tree;
   1610         CV_READ_SEQ_ELEM( tree, reader );
   1611         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
   1612         tree->write( fs );
   1613         cvEndWriteStruct( fs );
   1614     }
   1615 
   1616     cvEndWriteStruct( fs );
   1617     cvEndWriteStruct( fs );
   1618 
   1619     __END__;
   1620 }
   1621 
   1622 
   1623 CvMat*
   1624 CvBoost::get_weights()
   1625 {
   1626     return weights;
   1627 }
   1628 
   1629 
   1630 CvMat*
   1631 CvBoost::get_subtree_weights()
   1632 {
   1633     return subtree_weights;
   1634 }
   1635 
   1636 
   1637 CvMat*
   1638 CvBoost::get_weak_response()
   1639 {
   1640     return weak_eval;
   1641 }
   1642 
   1643 
   1644 const CvBoostParams&
   1645 CvBoost::get_params() const
   1646 {
   1647     return params;
   1648 }
   1649 
   1650 CvSeq* CvBoost::get_weak_predictors()
   1651 {
   1652     return weak;
   1653 }
   1654 
   1655 /* End of file. */
   1656