Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "_ml.h"
     42 
     43 static const float ord_nan = FLT_MAX*0.5f;
     44 static const int min_block_size = 1 << 16;
     45 static const int block_size_delta = 1 << 10;
     46 
     47 CvDTreeTrainData::CvDTreeTrainData()
     48 {
     49     var_idx = var_type = cat_count = cat_ofs = cat_map =
     50         priors = priors_mult = counts = buf = direction = split_buf = 0;
     51     tree_storage = temp_storage = 0;
     52 
     53     clear();
     54 }
     55 
     56 
     57 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
     58                       const CvMat* _responses, const CvMat* _var_idx,
     59                       const CvMat* _sample_idx, const CvMat* _var_type,
     60                       const CvMat* _missing_mask, const CvDTreeParams& _params,
     61                       bool _shared, bool _add_labels )
     62 {
     63     var_idx = var_type = cat_count = cat_ofs = cat_map =
     64         priors = priors_mult = counts = buf = direction = split_buf = 0;
     65     tree_storage = temp_storage = 0;
     66 
     67     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
     68               _var_type, _missing_mask, _params, _shared, _add_labels );
     69 }
     70 
     71 
     72 CvDTreeTrainData::~CvDTreeTrainData()
     73 {
     74     clear();
     75 }
     76 
     77 
     78 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
     79 {
     80     bool ok = false;
     81 
     82     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
     83 
     84     __BEGIN__;
     85 
     86     // set parameters
     87     params = _params;
     88 
     89     if( params.max_categories < 2 )
     90         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
     91     params.max_categories = MIN( params.max_categories, 15 );
     92 
     93     if( params.max_depth < 0 )
     94         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
     95     params.max_depth = MIN( params.max_depth, 25 );
     96 
     97     params.min_sample_count = MAX(params.min_sample_count,1);
     98 
     99     if( params.cv_folds < 0 )
    100         CV_ERROR( CV_StsOutOfRange,
    101         "params.cv_folds should be =0 (the tree is not pruned) "
    102         "or n>0 (tree is pruned using n-fold cross-validation)" );
    103 
    104     if( params.cv_folds == 1 )
    105         params.cv_folds = 0;
    106 
    107     if( params.regression_accuracy < 0 )
    108         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
    109 
    110     ok = true;
    111 
    112     __END__;
    113 
    114     return ok;
    115 }
    116 
    117 
    118 #define CV_CMP_NUM_PTR(a,b) (*(a) < *(b))
    119 static CV_IMPLEMENT_QSORT_EX( icvSortIntPtr, int*, CV_CMP_NUM_PTR, int )
    120 static CV_IMPLEMENT_QSORT_EX( icvSortDblPtr, double*, CV_CMP_NUM_PTR, int )
    121 
    122 #define CV_CMP_PAIRS(a,b) ((a).val < (b).val)
    123 static CV_IMPLEMENT_QSORT_EX( icvSortPairs, CvPair32s32f, CV_CMP_PAIRS, int )
    124 
    125 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
    126     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
    127     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
    128     bool _shared, bool _add_labels, bool _update_data )
    129 {
    130     CvMat* sample_idx = 0;
    131     CvMat* var_type0 = 0;
    132     CvMat* tmp_map = 0;
    133     int** int_ptr = 0;
    134     CvDTreeTrainData* data = 0;
    135 
    136     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
    137 
    138     __BEGIN__;
    139 
    140     int sample_all = 0, r_type = 0, cv_n;
    141     int total_c_count = 0;
    142     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
    143     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
    144     int vi, i;
    145     char err[100];
    146     const int *sidx = 0, *vidx = 0;
    147 
    148     if( _update_data && data_root )
    149     {
    150         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
    151             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
    152 
    153         // compare new and old train data
    154         if( !(data->var_count == var_count &&
    155             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
    156             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
    157             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
    158             CV_ERROR( CV_StsBadArg,
    159             "The new training data must have the same types and the input and output variables "
    160             "and the same categories for categorical variables" );
    161 
    162         cvReleaseMat( &priors );
    163         cvReleaseMat( &priors_mult );
    164         cvReleaseMat( &buf );
    165         cvReleaseMat( &direction );
    166         cvReleaseMat( &split_buf );
    167         cvReleaseMemStorage( &temp_storage );
    168 
    169         priors = data->priors; data->priors = 0;
    170         priors_mult = data->priors_mult; data->priors_mult = 0;
    171         buf = data->buf; data->buf = 0;
    172         buf_count = data->buf_count; buf_size = data->buf_size;
    173         sample_count = data->sample_count;
    174 
    175         direction = data->direction; data->direction = 0;
    176         split_buf = data->split_buf; data->split_buf = 0;
    177         temp_storage = data->temp_storage; data->temp_storage = 0;
    178         nv_heap = data->nv_heap; cv_heap = data->cv_heap;
    179 
    180         data_root = new_node( 0, sample_count, 0, 0 );
    181         EXIT;
    182     }
    183 
    184     clear();
    185 
    186     var_all = 0;
    187     rng = cvRNG(-1);
    188 
    189     CV_CALL( set_params( _params ));
    190 
    191     // check parameter types and sizes
    192     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
    193     if( _tflag == CV_ROW_SAMPLE )
    194     {
    195         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
    196         dv_step = 1;
    197         if( _missing_mask )
    198             ms_step = _missing_mask->step, mv_step = 1;
    199     }
    200     else
    201     {
    202         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
    203         ds_step = 1;
    204         if( _missing_mask )
    205             mv_step = _missing_mask->step, ms_step = 1;
    206     }
    207 
    208     sample_count = sample_all;
    209     var_count = var_all;
    210 
    211     if( _sample_idx )
    212     {
    213         CV_CALL( sample_idx = cvPreprocessIndexArray( _sample_idx, sample_all ));
    214         sidx = sample_idx->data.i;
    215         sample_count = sample_idx->rows + sample_idx->cols - 1;
    216     }
    217 
    218     if( _var_idx )
    219     {
    220         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
    221         vidx = var_idx->data.i;
    222         var_count = var_idx->rows + var_idx->cols - 1;
    223     }
    224 
    225     if( !CV_IS_MAT(_responses) ||
    226         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
    227          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
    228         _responses->rows != 1 && _responses->cols != 1 ||
    229         _responses->rows + _responses->cols - 1 != sample_all )
    230         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
    231                   "floating-point vector containing as many elements as "
    232                   "the total number of samples in the training data matrix" );
    233 
    234     CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_all, &r_type ));
    235     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
    236 
    237     cat_var_count = 0;
    238     ord_var_count = -1;
    239 
    240     is_classifier = r_type == CV_VAR_CATEGORICAL;
    241 
    242     // step 0. calc the number of categorical vars
    243     for( vi = 0; vi < var_count; vi++ )
    244     {
    245         var_type->data.i[vi] = var_type0->data.ptr[vi] == CV_VAR_CATEGORICAL ?
    246             cat_var_count++ : ord_var_count--;
    247     }
    248 
    249     ord_var_count = ~ord_var_count;
    250     cv_n = params.cv_folds;
    251     // set the two last elements of var_type array to be able
    252     // to locate responses and cross-validation labels using
    253     // the corresponding get_* functions.
    254     var_type->data.i[var_count] = cat_var_count;
    255     var_type->data.i[var_count+1] = cat_var_count+1;
    256 
    257     // in case of single ordered predictor we need dummy cv_labels
    258     // for safe split_node_data() operation
    259     have_labels = cv_n > 0 || ord_var_count == 1 && cat_var_count == 0 || _add_labels;
    260 
    261     buf_size = (ord_var_count + get_work_var_count())*sample_count + 2;
    262     shared = _shared;
    263     buf_count = shared ? 3 : 2;
    264     CV_CALL( buf = cvCreateMat( buf_count, buf_size, CV_32SC1 ));
    265     CV_CALL( cat_count = cvCreateMat( 1, cat_var_count+1, CV_32SC1 ));
    266     CV_CALL( cat_ofs = cvCreateMat( 1, cat_count->cols+1, CV_32SC1 ));
    267     CV_CALL( cat_map = cvCreateMat( 1, cat_count->cols*10 + 128, CV_32SC1 ));
    268 
    269     // now calculate the maximum size of split,
    270     // create memory storage that will keep nodes and splits of the decision tree
    271     // allocate root node and the buffer for the whole training data
    272     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
    273         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
    274     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
    275     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
    276     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
    277     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
    278 
    279     nv_size = var_count*sizeof(int);
    280     nv_size = MAX( nv_size, (int)sizeof(CvSetElem) );
    281 
    282     temp_block_size = nv_size;
    283 
    284     if( cv_n )
    285     {
    286         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
    287             CV_ERROR( CV_StsOutOfRange,
    288                 "The many folds in cross-validation for such a small dataset" );
    289 
    290         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
    291         temp_block_size = MAX(temp_block_size, cv_size);
    292     }
    293 
    294     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
    295     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
    296     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
    297     if( cv_size )
    298         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
    299 
    300     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
    301     CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
    302 
    303     max_c_count = 1;
    304 
    305     // transform the training data to convenient representation
    306     for( vi = 0; vi <= var_count; vi++ )
    307     {
    308         int ci;
    309         const uchar* mask = 0;
    310         int m_step = 0, step;
    311         const int* idata = 0;
    312         const float* fdata = 0;
    313         int num_valid = 0;
    314 
    315         if( vi < var_count ) // analyze i-th input variable
    316         {
    317             int vi0 = vidx ? vidx[vi] : vi;
    318             ci = get_var_type(vi);
    319             step = ds_step; m_step = ms_step;
    320             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
    321                 idata = _train_data->data.i + vi0*dv_step;
    322             else
    323                 fdata = _train_data->data.fl + vi0*dv_step;
    324             if( _missing_mask )
    325                 mask = _missing_mask->data.ptr + vi0*mv_step;
    326         }
    327         else // analyze _responses
    328         {
    329             ci = cat_var_count;
    330             step = CV_IS_MAT_CONT(_responses->type) ?
    331                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
    332             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
    333                 idata = _responses->data.i;
    334             else
    335                 fdata = _responses->data.fl;
    336         }
    337 
    338         if( vi < var_count && ci >= 0 ||
    339             vi == var_count && is_classifier ) // process categorical variable or response
    340         {
    341             int c_count, prev_label;
    342             int* c_map, *dst = get_cat_var_data( data_root, vi );
    343 
    344             // copy data
    345             for( i = 0; i < sample_count; i++ )
    346             {
    347                 int val = INT_MAX, si = sidx ? sidx[i] : i;
    348                 if( !mask || !mask[si*m_step] )
    349                 {
    350                     if( idata )
    351                         val = idata[si*step];
    352                     else
    353                     {
    354                         float t = fdata[si*step];
    355                         val = cvRound(t);
    356                         if( val != t )
    357                         {
    358                             sprintf( err, "%d-th value of %d-th (categorical) "
    359                                 "variable is not an integer", i, vi );
    360                             CV_ERROR( CV_StsBadArg, err );
    361                         }
    362                     }
    363 
    364                     if( val == INT_MAX )
    365                     {
    366                         sprintf( err, "%d-th value of %d-th (categorical) "
    367                             "variable is too large", i, vi );
    368                         CV_ERROR( CV_StsBadArg, err );
    369                     }
    370                     num_valid++;
    371                 }
    372                 dst[i] = val;
    373                 int_ptr[i] = dst + i;
    374             }
    375 
    376             // sort all the values, including the missing measurements
    377             // that should all move to the end
    378             icvSortIntPtr( int_ptr, sample_count, 0 );
    379             //qsort( int_ptr, sample_count, sizeof(int_ptr[0]), icvCmpIntPtr );
    380 
    381             c_count = num_valid > 0;
    382 
    383             // count the categories
    384             for( i = 1; i < num_valid; i++ )
    385                 c_count += *int_ptr[i] != *int_ptr[i-1];
    386 
    387             if( vi > 0 )
    388                 max_c_count = MAX( max_c_count, c_count );
    389             cat_count->data.i[ci] = c_count;
    390             cat_ofs->data.i[ci] = total_c_count;
    391 
    392             // resize cat_map, if need
    393             if( cat_map->cols < total_c_count + c_count )
    394             {
    395                 tmp_map = cat_map;
    396                 CV_CALL( cat_map = cvCreateMat( 1,
    397                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
    398                 for( i = 0; i < total_c_count; i++ )
    399                     cat_map->data.i[i] = tmp_map->data.i[i];
    400                 cvReleaseMat( &tmp_map );
    401             }
    402 
    403             c_map = cat_map->data.i + total_c_count;
    404             total_c_count += c_count;
    405 
    406             // compact the class indices and build the map
    407             prev_label = ~*int_ptr[0];
    408             c_count = -1;
    409 
    410             for( i = 0; i < num_valid; i++ )
    411             {
    412                 int cur_label = *int_ptr[i];
    413                 if( cur_label != prev_label )
    414                     c_map[++c_count] = prev_label = cur_label;
    415                 *int_ptr[i] = c_count;
    416             }
    417 
    418             // replace labels for missing values with -1
    419             for( ; i < sample_count; i++ )
    420                 *int_ptr[i] = -1;
    421         }
    422         else if( ci < 0 ) // process ordered variable
    423         {
    424             CvPair32s32f* dst = get_ord_var_data( data_root, vi );
    425 
    426             for( i = 0; i < sample_count; i++ )
    427             {
    428                 float val = ord_nan;
    429                 int si = sidx ? sidx[i] : i;
    430                 if( !mask || !mask[si*m_step] )
    431                 {
    432                     if( idata )
    433                         val = (float)idata[si*step];
    434                     else
    435                         val = fdata[si*step];
    436 
    437                     if( fabs(val) >= ord_nan )
    438                     {
    439                         sprintf( err, "%d-th value of %d-th (ordered) "
    440                             "variable (=%g) is too large", i, vi, val );
    441                         CV_ERROR( CV_StsBadArg, err );
    442                     }
    443                     num_valid++;
    444                 }
    445                 dst[i].i = i;
    446                 dst[i].val = val;
    447             }
    448 
    449             icvSortPairs( dst, sample_count, 0 );
    450         }
    451         else // special case: process ordered response,
    452              // it will be stored similarly to categorical vars (i.e. no pairs)
    453         {
    454             float* dst = get_ord_responses( data_root );
    455 
    456             for( i = 0; i < sample_count; i++ )
    457             {
    458                 float val = ord_nan;
    459                 int si = sidx ? sidx[i] : i;
    460                 if( idata )
    461                     val = (float)idata[si*step];
    462                 else
    463                     val = fdata[si*step];
    464 
    465                 if( fabs(val) >= ord_nan )
    466                 {
    467                     sprintf( err, "%d-th value of %d-th (ordered) "
    468                         "variable (=%g) is out of range", i, vi, val );
    469                     CV_ERROR( CV_StsBadArg, err );
    470                 }
    471                 dst[i] = val;
    472             }
    473 
    474             cat_count->data.i[cat_var_count] = 0;
    475             cat_ofs->data.i[cat_var_count] = total_c_count;
    476             num_valid = sample_count;
    477         }
    478 
    479         if( vi < var_count )
    480             data_root->set_num_valid(vi, num_valid);
    481     }
    482 
    483     if( cv_n )
    484     {
    485         int* dst = get_labels(data_root);
    486         CvRNG* r = &rng;
    487 
    488         for( i = vi = 0; i < sample_count; i++ )
    489         {
    490             dst[i] = vi++;
    491             vi &= vi < cv_n ? -1 : 0;
    492         }
    493 
    494         for( i = 0; i < sample_count; i++ )
    495         {
    496             int a = cvRandInt(r) % sample_count;
    497             int b = cvRandInt(r) % sample_count;
    498             CV_SWAP( dst[a], dst[b], vi );
    499         }
    500     }
    501 
    502     cat_map->cols = MAX( total_c_count, 1 );
    503 
    504     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
    505         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
    506     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
    507 
    508     have_priors = is_classifier && params.priors;
    509     if( is_classifier )
    510     {
    511         int m = get_num_classes();
    512         double sum = 0;
    513         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
    514         for( i = 0; i < m; i++ )
    515         {
    516             double val = have_priors ? params.priors[i] : 1.;
    517             if( val <= 0 )
    518                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
    519             priors->data.db[i] = val;
    520             sum += val;
    521         }
    522 
    523         // normalize weights
    524         if( have_priors )
    525             cvScale( priors, priors, 1./sum );
    526 
    527         CV_CALL( priors_mult = cvCloneMat( priors ));
    528         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
    529     }
    530 
    531     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
    532     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
    533 
    534     __END__;
    535 
    536     if( data )
    537         delete data;
    538 
    539     cvFree( &int_ptr );
    540     cvReleaseMat( &sample_idx );
    541     cvReleaseMat( &var_type0 );
    542     cvReleaseMat( &tmp_map );
    543 }
    544 
    545 
    546 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
    547 {
    548     CvDTreeNode* root = 0;
    549     CvMat* isubsample_idx = 0;
    550     CvMat* subsample_co = 0;
    551 
    552     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
    553 
    554     __BEGIN__;
    555 
    556     if( !data_root )
    557         CV_ERROR( CV_StsError, "No training data has been set" );
    558 
    559     if( _subsample_idx )
    560         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
    561 
    562     if( !isubsample_idx )
    563     {
    564         // make a copy of the root node
    565         CvDTreeNode temp;
    566         int i;
    567         root = new_node( 0, 1, 0, 0 );
    568         temp = *root;
    569         *root = *data_root;
    570         root->num_valid = temp.num_valid;
    571         if( root->num_valid )
    572         {
    573             for( i = 0; i < var_count; i++ )
    574                 root->num_valid[i] = data_root->num_valid[i];
    575         }
    576         root->cv_Tn = temp.cv_Tn;
    577         root->cv_node_risk = temp.cv_node_risk;
    578         root->cv_node_error = temp.cv_node_error;
    579     }
    580     else
    581     {
    582         int* sidx = isubsample_idx->data.i;
    583         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
    584         int* co, cur_ofs = 0;
    585         int vi, i, total = data_root->sample_count;
    586         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
    587         int work_var_count = get_work_var_count();
    588         root = new_node( 0, count, 1, 0 );
    589 
    590         CV_CALL( subsample_co = cvCreateMat( 1, total*2, CV_32SC1 ));
    591         cvZero( subsample_co );
    592         co = subsample_co->data.i;
    593         for( i = 0; i < count; i++ )
    594             co[sidx[i]*2]++;
    595         for( i = 0; i < total; i++ )
    596         {
    597             if( co[i*2] )
    598             {
    599                 co[i*2+1] = cur_ofs;
    600                 cur_ofs += co[i*2];
    601             }
    602             else
    603                 co[i*2+1] = -1;
    604         }
    605 
    606         for( vi = 0; vi < work_var_count; vi++ )
    607         {
    608             int ci = get_var_type(vi);
    609 
    610             if( ci >= 0 || vi >= var_count )
    611             {
    612                 const int* src = get_cat_var_data( data_root, vi );
    613                 int* dst = get_cat_var_data( root, vi );
    614                 int num_valid = 0;
    615 
    616                 for( i = 0; i < count; i++ )
    617                 {
    618                     int val = src[sidx[i]];
    619                     dst[i] = val;
    620                     num_valid += val >= 0;
    621                 }
    622 
    623                 if( vi < var_count )
    624                     root->set_num_valid(vi, num_valid);
    625             }
    626             else
    627             {
    628                 const CvPair32s32f* src = get_ord_var_data( data_root, vi );
    629                 CvPair32s32f* dst = get_ord_var_data( root, vi );
    630                 int j = 0, idx, count_i;
    631                 int num_valid = data_root->get_num_valid(vi);
    632 
    633                 for( i = 0; i < num_valid; i++ )
    634                 {
    635                     idx = src[i].i;
    636                     count_i = co[idx*2];
    637                     if( count_i )
    638                     {
    639                         float val = src[i].val;
    640                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
    641                         {
    642                             dst[j].val = val;
    643                             dst[j].i = cur_ofs;
    644                         }
    645                     }
    646                 }
    647 
    648                 root->set_num_valid(vi, j);
    649 
    650                 for( ; i < total; i++ )
    651                 {
    652                     idx = src[i].i;
    653                     count_i = co[idx*2];
    654                     if( count_i )
    655                     {
    656                         float val = src[i].val;
    657                         for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
    658                         {
    659                             dst[j].val = val;
    660                             dst[j].i = cur_ofs;
    661                         }
    662                     }
    663                 }
    664             }
    665         }
    666     }
    667 
    668     __END__;
    669 
    670     cvReleaseMat( &isubsample_idx );
    671     cvReleaseMat( &subsample_co );
    672 
    673     return root;
    674 }
    675 
    676 
    677 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
    678                                     float* values, uchar* missing,
    679                                     float* responses, bool get_class_idx )
    680 {
    681     CvMat* subsample_idx = 0;
    682     CvMat* subsample_co = 0;
    683 
    684     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
    685 
    686     __BEGIN__;
    687 
    688     int i, vi, total = sample_count, count = total, cur_ofs = 0;
    689     int* sidx = 0;
    690     int* co = 0;
    691 
    692     if( _subsample_idx )
    693     {
    694         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
    695         sidx = subsample_idx->data.i;
    696         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
    697         co = subsample_co->data.i;
    698         cvZero( subsample_co );
    699         count = subsample_idx->cols + subsample_idx->rows - 1;
    700         for( i = 0; i < count; i++ )
    701             co[sidx[i]*2]++;
    702         for( i = 0; i < total; i++ )
    703         {
    704             int count_i = co[i*2];
    705             if( count_i )
    706             {
    707                 co[i*2+1] = cur_ofs*var_count;
    708                 cur_ofs += count_i;
    709             }
    710         }
    711     }
    712 
    713     if( missing )
    714         memset( missing, 1, count*var_count );
    715 
    716     for( vi = 0; vi < var_count; vi++ )
    717     {
    718         int ci = get_var_type(vi);
    719         if( ci >= 0 ) // categorical
    720         {
    721             float* dst = values + vi;
    722             uchar* m = missing ? missing + vi : 0;
    723             const int* src = get_cat_var_data(data_root, vi);
    724 
    725             for( i = 0; i < count; i++, dst += var_count )
    726             {
    727                 int idx = sidx ? sidx[i] : i;
    728                 int val = src[idx];
    729                 *dst = (float)val;
    730                 if( m )
    731                 {
    732                     *m = val < 0;
    733                     m += var_count;
    734                 }
    735             }
    736         }
    737         else // ordered
    738         {
    739             float* dst = values + vi;
    740             uchar* m = missing ? missing + vi : 0;
    741             const CvPair32s32f* src = get_ord_var_data(data_root, vi);
    742             int count1 = data_root->get_num_valid(vi);
    743 
    744             for( i = 0; i < count1; i++ )
    745             {
    746                 int idx = src[i].i;
    747                 int count_i = 1;
    748                 if( co )
    749                 {
    750                     count_i = co[idx*2];
    751                     cur_ofs = co[idx*2+1];
    752                 }
    753                 else
    754                     cur_ofs = idx*var_count;
    755                 if( count_i )
    756                 {
    757                     float val = src[i].val;
    758                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
    759                     {
    760                         dst[cur_ofs] = val;
    761                         if( m )
    762                             m[cur_ofs] = 0;
    763                     }
    764                 }
    765             }
    766         }
    767     }
    768 
    769     // copy responses
    770     if( responses )
    771     {
    772         if( is_classifier )
    773         {
    774             const int* src = get_class_labels(data_root);
    775             for( i = 0; i < count; i++ )
    776             {
    777                 int idx = sidx ? sidx[i] : i;
    778                 int val = get_class_idx ? src[idx] :
    779                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
    780                 responses[i] = (float)val;
    781             }
    782         }
    783         else
    784         {
    785             const float* src = get_ord_responses(data_root);
    786             for( i = 0; i < count; i++ )
    787             {
    788                 int idx = sidx ? sidx[i] : i;
    789                 responses[i] = src[idx];
    790             }
    791         }
    792     }
    793 
    794     __END__;
    795 
    796     cvReleaseMat( &subsample_idx );
    797     cvReleaseMat( &subsample_co );
    798 }
    799 
    800 
    801 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
    802                                          int storage_idx, int offset )
    803 {
    804     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
    805 
    806     node->sample_count = count;
    807     node->depth = parent ? parent->depth + 1 : 0;
    808     node->parent = parent;
    809     node->left = node->right = 0;
    810     node->split = 0;
    811     node->value = 0;
    812     node->class_idx = 0;
    813     node->maxlr = 0.;
    814 
    815     node->buf_idx = storage_idx;
    816     node->offset = offset;
    817     if( nv_heap )
    818         node->num_valid = (int*)cvSetNew( nv_heap );
    819     else
    820         node->num_valid = 0;
    821     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
    822     node->complexity = 0;
    823 
    824     if( params.cv_folds > 0 && cv_heap )
    825     {
    826         int cv_n = params.cv_folds;
    827         node->Tn = INT_MAX;
    828         node->cv_Tn = (int*)cvSetNew( cv_heap );
    829         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
    830         node->cv_node_error = node->cv_node_risk + cv_n;
    831     }
    832     else
    833     {
    834         node->Tn = 0;
    835         node->cv_Tn = 0;
    836         node->cv_node_risk = 0;
    837         node->cv_node_error = 0;
    838     }
    839 
    840     return node;
    841 }
    842 
    843 
    844 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
    845                 int split_point, int inversed, float quality )
    846 {
    847     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
    848     split->var_idx = vi;
    849     split->ord.c = cmp_val;
    850     split->ord.split_point = split_point;
    851     split->inversed = inversed;
    852     split->quality = quality;
    853     split->next = 0;
    854 
    855     return split;
    856 }
    857 
    858 
    859 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
    860 {
    861     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
    862     int i, n = (max_c_count + 31)/32;
    863 
    864     split->var_idx = vi;
    865     split->inversed = 0;
    866     split->quality = quality;
    867     for( i = 0; i < n; i++ )
    868         split->subset[i] = 0;
    869     split->next = 0;
    870 
    871     return split;
    872 }
    873 
    874 
    875 void CvDTreeTrainData::free_node( CvDTreeNode* node )
    876 {
    877     CvDTreeSplit* split = node->split;
    878     free_node_data( node );
    879     while( split )
    880     {
    881         CvDTreeSplit* next = split->next;
    882         cvSetRemoveByPtr( split_heap, split );
    883         split = next;
    884     }
    885     node->split = 0;
    886     cvSetRemoveByPtr( node_heap, node );
    887 }
    888 
    889 
    890 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
    891 {
    892     if( node->num_valid )
    893     {
    894         cvSetRemoveByPtr( nv_heap, node->num_valid );
    895         node->num_valid = 0;
    896     }
    897     // do not free cv_* fields, as all the cross-validation related data is released at once.
    898 }
    899 
    900 
    901 void CvDTreeTrainData::free_train_data()
    902 {
    903     cvReleaseMat( &counts );
    904     cvReleaseMat( &buf );
    905     cvReleaseMat( &direction );
    906     cvReleaseMat( &split_buf );
    907     cvReleaseMemStorage( &temp_storage );
    908     cv_heap = nv_heap = 0;
    909 }
    910 
    911 
    912 void CvDTreeTrainData::clear()
    913 {
    914     free_train_data();
    915 
    916     cvReleaseMemStorage( &tree_storage );
    917 
    918     cvReleaseMat( &var_idx );
    919     cvReleaseMat( &var_type );
    920     cvReleaseMat( &cat_count );
    921     cvReleaseMat( &cat_ofs );
    922     cvReleaseMat( &cat_map );
    923     cvReleaseMat( &priors );
    924     cvReleaseMat( &priors_mult );
    925 
    926     node_heap = split_heap = 0;
    927 
    928     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
    929     have_labels = have_priors = is_classifier = false;
    930 
    931     buf_count = buf_size = 0;
    932     shared = false;
    933 
    934     data_root = 0;
    935 
    936     rng = cvRNG(-1);
    937 }
    938 
    939 
    940 int CvDTreeTrainData::get_num_classes() const
    941 {
    942     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
    943 }
    944 
    945 
    946 int CvDTreeTrainData::get_var_type(int vi) const
    947 {
    948     return var_type->data.i[vi];
    949 }
    950 
    951 
    952 int CvDTreeTrainData::get_work_var_count() const
    953 {
    954     return var_count + 1 + (have_labels ? 1 : 0);
    955 }
    956 
    957 CvPair32s32f* CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi )
    958 {
    959     int oi = ~get_var_type(vi);
    960     assert( 0 <= oi && oi < ord_var_count );
    961     return (CvPair32s32f*)(buf->data.i + n->buf_idx*buf->cols +
    962                            n->offset + oi*n->sample_count*2);
    963 }
    964 
    965 
    966 int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n )
    967 {
    968     return get_cat_var_data( n, var_count );
    969 }
    970 
    971 
    972 float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n )
    973 {
    974     return (float*)get_cat_var_data( n, var_count );
    975 }
    976 
    977 
    978 int* CvDTreeTrainData::get_labels( CvDTreeNode* n )
    979 {
    980     return have_labels ? get_cat_var_data( n, var_count + 1 ) : 0;
    981 }
    982 
    983 
    984 int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi )
    985 {
    986     int ci = get_var_type(vi);
    987     assert( 0 <= ci && ci <= cat_var_count + 1 );
    988     return buf->data.i + n->buf_idx*buf->cols + n->offset +
    989            (ord_var_count*2 + ci)*n->sample_count;
    990 }
    991 
    992 
    993 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
    994 {
    995     int idx = n->buf_idx + 1;
    996     if( idx >= buf_count )
    997         idx = shared ? 1 : 0;
    998     return idx;
    999 }
   1000 
   1001 
   1002 void CvDTreeTrainData::write_params( CvFileStorage* fs )
   1003 {
   1004     CV_FUNCNAME( "CvDTreeTrainData::write_params" );
   1005 
   1006     __BEGIN__;
   1007 
   1008     int vi, vcount = var_count;
   1009 
   1010     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
   1011     cvWriteInt( fs, "var_all", var_all );
   1012     cvWriteInt( fs, "var_count", var_count );
   1013     cvWriteInt( fs, "ord_var_count", ord_var_count );
   1014     cvWriteInt( fs, "cat_var_count", cat_var_count );
   1015 
   1016     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
   1017     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
   1018 
   1019     if( is_classifier )
   1020     {
   1021         cvWriteInt( fs, "max_categories", params.max_categories );
   1022     }
   1023     else
   1024     {
   1025         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
   1026     }
   1027 
   1028     cvWriteInt( fs, "max_depth", params.max_depth );
   1029     cvWriteInt( fs, "min_sample_count", params.min_sample_count );
   1030     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
   1031 
   1032     if( params.cv_folds > 1 )
   1033     {
   1034         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
   1035         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
   1036     }
   1037 
   1038     if( priors )
   1039         cvWrite( fs, "priors", priors );
   1040 
   1041     cvEndWriteStruct( fs );
   1042 
   1043     if( var_idx )
   1044         cvWrite( fs, "var_idx", var_idx );
   1045 
   1046     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
   1047 
   1048     for( vi = 0; vi < vcount; vi++ )
   1049         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
   1050 
   1051     cvEndWriteStruct( fs );
   1052 
   1053     if( cat_count && (cat_var_count > 0 || is_classifier) )
   1054     {
   1055         CV_ASSERT( cat_count != 0 );
   1056         cvWrite( fs, "cat_count", cat_count );
   1057         cvWrite( fs, "cat_map", cat_map );
   1058     }
   1059 
   1060     __END__;
   1061 }
   1062 
   1063 
   1064 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
   1065 {
   1066     CV_FUNCNAME( "CvDTreeTrainData::read_params" );
   1067 
   1068     __BEGIN__;
   1069 
   1070     CvFileNode *tparams_node, *vartype_node;
   1071     CvSeqReader reader;
   1072     int vi, max_split_size, tree_block_size;
   1073 
   1074     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
   1075     var_all = cvReadIntByName( fs, node, "var_all" );
   1076     var_count = cvReadIntByName( fs, node, "var_count", var_all );
   1077     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
   1078     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
   1079 
   1080     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
   1081 
   1082     if( tparams_node ) // training parameters are not necessary
   1083     {
   1084         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
   1085 
   1086         if( is_classifier )
   1087         {
   1088             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
   1089         }
   1090         else
   1091         {
   1092             params.regression_accuracy =
   1093                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
   1094         }
   1095 
   1096         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
   1097         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
   1098         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
   1099 
   1100         if( params.cv_folds > 1 )
   1101         {
   1102             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
   1103             params.truncate_pruned_tree =
   1104                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
   1105         }
   1106 
   1107         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
   1108         if( priors )
   1109         {
   1110             if( !CV_IS_MAT(priors) )
   1111                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
   1112             priors_mult = cvCloneMat( priors );
   1113         }
   1114     }
   1115 
   1116     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
   1117     if( var_idx )
   1118     {
   1119         if( !CV_IS_MAT(var_idx) ||
   1120             var_idx->cols != 1 && var_idx->rows != 1 ||
   1121             var_idx->cols + var_idx->rows - 1 != var_count ||
   1122             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
   1123             CV_ERROR( CV_StsParseError,
   1124                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
   1125 
   1126         for( vi = 0; vi < var_count; vi++ )
   1127             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
   1128                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
   1129     }
   1130 
   1131     ////// read var type
   1132     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
   1133 
   1134     cat_var_count = 0;
   1135     ord_var_count = -1;
   1136     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
   1137 
   1138     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
   1139         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
   1140     else
   1141     {
   1142         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
   1143             vartype_node->data.seq->total != var_count )
   1144             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
   1145 
   1146         cvStartReadSeq( vartype_node->data.seq, &reader );
   1147 
   1148         for( vi = 0; vi < var_count; vi++ )
   1149         {
   1150             CvFileNode* n = (CvFileNode*)reader.ptr;
   1151             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
   1152                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
   1153             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
   1154             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   1155         }
   1156     }
   1157     var_type->data.i[var_count] = cat_var_count;
   1158 
   1159     ord_var_count = ~ord_var_count;
   1160     if( cat_var_count != cat_var_count || ord_var_count != ord_var_count )
   1161         CV_ERROR( CV_StsParseError, "var_type is inconsistent with cat_var_count and ord_var_count" );
   1162     //////
   1163 
   1164     if( cat_var_count > 0 || is_classifier )
   1165     {
   1166         int ccount, total_c_count = 0;
   1167         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
   1168         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
   1169 
   1170         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
   1171             cat_count->cols != 1 && cat_count->rows != 1 ||
   1172             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
   1173             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
   1174             cat_map->cols != 1 && cat_map->rows != 1 ||
   1175             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
   1176             CV_ERROR( CV_StsParseError,
   1177             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
   1178 
   1179         ccount = cat_var_count + is_classifier;
   1180 
   1181         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
   1182         cat_ofs->data.i[0] = 0;
   1183         max_c_count = 1;
   1184 
   1185         for( vi = 0; vi < ccount; vi++ )
   1186         {
   1187             int val = cat_count->data.i[vi];
   1188             if( val <= 0 )
   1189                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
   1190             max_c_count = MAX( max_c_count, val );
   1191             cat_ofs->data.i[vi+1] = total_c_count += val;
   1192         }
   1193 
   1194         if( cat_map->cols + cat_map->rows - 1 != total_c_count )
   1195             CV_ERROR( CV_StsBadSize,
   1196             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
   1197     }
   1198 
   1199     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
   1200         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
   1201 
   1202     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
   1203     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
   1204     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
   1205     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
   1206             sizeof(CvDTreeNode), tree_storage ));
   1207     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
   1208             max_split_size, tree_storage ));
   1209 
   1210     __END__;
   1211 }
   1212 
   1213 
   1214 /////////////////////// Decision Tree /////////////////////////
   1215 
   1216 CvDTree::CvDTree()
   1217 {
   1218     data = 0;
   1219     var_importance = 0;
   1220     default_model_name = "my_tree";
   1221 
   1222     clear();
   1223 }
   1224 
   1225 
   1226 void CvDTree::clear()
   1227 {
   1228     cvReleaseMat( &var_importance );
   1229     if( data )
   1230     {
   1231         if( !data->shared )
   1232             delete data;
   1233         else
   1234             free_tree();
   1235         data = 0;
   1236     }
   1237     root = 0;
   1238     pruned_tree_idx = -1;
   1239 }
   1240 
   1241 
   1242 CvDTree::~CvDTree()
   1243 {
   1244     clear();
   1245 }
   1246 
   1247 
   1248 const CvDTreeNode* CvDTree::get_root() const
   1249 {
   1250     return root;
   1251 }
   1252 
   1253 
   1254 int CvDTree::get_pruned_tree_idx() const
   1255 {
   1256     return pruned_tree_idx;
   1257 }
   1258 
   1259 
   1260 CvDTreeTrainData* CvDTree::get_data()
   1261 {
   1262     return data;
   1263 }
   1264 
   1265 
   1266 bool CvDTree::train( const CvMat* _train_data, int _tflag,
   1267                      const CvMat* _responses, const CvMat* _var_idx,
   1268                      const CvMat* _sample_idx, const CvMat* _var_type,
   1269                      const CvMat* _missing_mask, CvDTreeParams _params )
   1270 {
   1271     bool result = false;
   1272 
   1273     CV_FUNCNAME( "CvDTree::train" );
   1274 
   1275     __BEGIN__;
   1276 
   1277     clear();
   1278     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
   1279                                  _var_idx, _sample_idx, _var_type,
   1280                                  _missing_mask, _params, false );
   1281     CV_CALL( result = do_train(0));
   1282 
   1283     __END__;
   1284 
   1285     return result;
   1286 }
   1287 
   1288 
   1289 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
   1290 {
   1291     bool result = false;
   1292 
   1293     CV_FUNCNAME( "CvDTree::train" );
   1294 
   1295     __BEGIN__;
   1296 
   1297     clear();
   1298     data = _data;
   1299     data->shared = true;
   1300     CV_CALL( result = do_train(_subsample_idx));
   1301 
   1302     __END__;
   1303 
   1304     return result;
   1305 }
   1306 
   1307 
   1308 bool CvDTree::do_train( const CvMat* _subsample_idx )
   1309 {
   1310     bool result = false;
   1311 
   1312     CV_FUNCNAME( "CvDTree::do_train" );
   1313 
   1314     __BEGIN__;
   1315 
   1316     root = data->subsample_data( _subsample_idx );
   1317 
   1318     CV_CALL( try_split_node(root));
   1319 
   1320     if( data->params.cv_folds > 0 )
   1321         CV_CALL( prune_cv());
   1322 
   1323     if( !data->shared )
   1324         data->free_train_data();
   1325 
   1326     result = true;
   1327 
   1328     __END__;
   1329 
   1330     return result;
   1331 }
   1332 
   1333 
   1334 void CvDTree::try_split_node( CvDTreeNode* node )
   1335 {
   1336     CvDTreeSplit* best_split = 0;
   1337     int i, n = node->sample_count, vi;
   1338     bool can_split = true;
   1339     double quality_scale;
   1340 
   1341     calc_node_value( node );
   1342 
   1343     if( node->sample_count <= data->params.min_sample_count ||
   1344         node->depth >= data->params.max_depth )
   1345         can_split = false;
   1346 
   1347     if( can_split && data->is_classifier )
   1348     {
   1349         // check if we have a "pure" node,
   1350         // we assume that cls_count is filled by calc_node_value()
   1351         int* cls_count = data->counts->data.i;
   1352         int nz = 0, m = data->get_num_classes();
   1353         for( i = 0; i < m; i++ )
   1354             nz += cls_count[i] != 0;
   1355         if( nz == 1 ) // there is only one class
   1356             can_split = false;
   1357     }
   1358     else if( can_split )
   1359     {
   1360         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
   1361             can_split = false;
   1362     }
   1363 
   1364     if( can_split )
   1365     {
   1366         best_split = find_best_split(node);
   1367         // TODO: check the split quality ...
   1368         node->split = best_split;
   1369     }
   1370 
   1371     if( !can_split || !best_split )
   1372     {
   1373         data->free_node_data(node);
   1374         return;
   1375     }
   1376 
   1377     quality_scale = calc_node_dir( node );
   1378 
   1379     if( data->params.use_surrogates )
   1380     {
   1381         // find all the surrogate splits
   1382         // and sort them by their similarity to the primary one
   1383         for( vi = 0; vi < data->var_count; vi++ )
   1384         {
   1385             CvDTreeSplit* split;
   1386             int ci = data->get_var_type(vi);
   1387 
   1388             if( vi == best_split->var_idx )
   1389                 continue;
   1390 
   1391             if( ci >= 0 )
   1392                 split = find_surrogate_split_cat( node, vi );
   1393             else
   1394                 split = find_surrogate_split_ord( node, vi );
   1395 
   1396             if( split )
   1397             {
   1398                 // insert the split
   1399                 CvDTreeSplit* prev_split = node->split;
   1400                 split->quality = (float)(split->quality*quality_scale);
   1401 
   1402                 while( prev_split->next &&
   1403                        prev_split->next->quality > split->quality )
   1404                     prev_split = prev_split->next;
   1405                 split->next = prev_split->next;
   1406                 prev_split->next = split;
   1407             }
   1408         }
   1409     }
   1410 
   1411     split_node_data( node );
   1412     try_split_node( node->left );
   1413     try_split_node( node->right );
   1414 }
   1415 
   1416 
   1417 // calculate direction (left(-1),right(1),missing(0))
   1418 // for each sample using the best split
   1419 // the function returns scale coefficients for surrogate split quality factors.
   1420 // the scale is applied to normalize surrogate split quality relatively to the
   1421 // best (primary) split quality. That is, if a surrogate split is absolutely
   1422 // identical to the primary split, its quality will be set to the maximum value =
   1423 // quality of the primary split; otherwise, it will be lower.
   1424 // besides, the function compute node->maxlr,
   1425 // minimum possible quality (w/o considering the above mentioned scale)
   1426 // for a surrogate split. Surrogate splits with quality less than node->maxlr
   1427 // are not discarded.
   1428 double CvDTree::calc_node_dir( CvDTreeNode* node )
   1429 {
   1430     char* dir = (char*)data->direction->data.ptr;
   1431     int i, n = node->sample_count, vi = node->split->var_idx;
   1432     double L, R;
   1433 
   1434     assert( !node->split->inversed );
   1435 
   1436     if( data->get_var_type(vi) >= 0 ) // split on categorical var
   1437     {
   1438         const int* labels = data->get_cat_var_data(node,vi);
   1439         const int* subset = node->split->subset;
   1440 
   1441         if( !data->have_priors )
   1442         {
   1443             int sum = 0, sum_abs = 0;
   1444 
   1445             for( i = 0; i < n; i++ )
   1446             {
   1447                 int idx = labels[i];
   1448                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
   1449                 sum += d; sum_abs += d & 1;
   1450                 dir[i] = (char)d;
   1451             }
   1452 
   1453             R = (sum_abs + sum) >> 1;
   1454             L = (sum_abs - sum) >> 1;
   1455         }
   1456         else
   1457         {
   1458             const int* responses = data->get_class_labels(node);
   1459             const double* priors = data->priors_mult->data.db;
   1460             double sum = 0, sum_abs = 0;
   1461 
   1462             for( i = 0; i < n; i++ )
   1463             {
   1464                 int idx = labels[i];
   1465                 double w = priors[responses[i]];
   1466                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
   1467                 sum += d*w; sum_abs += (d & 1)*w;
   1468                 dir[i] = (char)d;
   1469             }
   1470 
   1471             R = (sum_abs + sum) * 0.5;
   1472             L = (sum_abs - sum) * 0.5;
   1473         }
   1474     }
   1475     else // split on ordered var
   1476     {
   1477         const CvPair32s32f* sorted = data->get_ord_var_data(node,vi);
   1478         int split_point = node->split->ord.split_point;
   1479         int n1 = node->get_num_valid(vi);
   1480 
   1481         assert( 0 <= split_point && split_point < n1-1 );
   1482 
   1483         if( !data->have_priors )
   1484         {
   1485             for( i = 0; i <= split_point; i++ )
   1486                 dir[sorted[i].i] = (char)-1;
   1487             for( ; i < n1; i++ )
   1488                 dir[sorted[i].i] = (char)1;
   1489             for( ; i < n; i++ )
   1490                 dir[sorted[i].i] = (char)0;
   1491 
   1492             L = split_point-1;
   1493             R = n1 - split_point + 1;
   1494         }
   1495         else
   1496         {
   1497             const int* responses = data->get_class_labels(node);
   1498             const double* priors = data->priors_mult->data.db;
   1499             L = R = 0;
   1500 
   1501             for( i = 0; i <= split_point; i++ )
   1502             {
   1503                 int idx = sorted[i].i;
   1504                 double w = priors[responses[idx]];
   1505                 dir[idx] = (char)-1;
   1506                 L += w;
   1507             }
   1508 
   1509             for( ; i < n1; i++ )
   1510             {
   1511                 int idx = sorted[i].i;
   1512                 double w = priors[responses[idx]];
   1513                 dir[idx] = (char)1;
   1514                 R += w;
   1515             }
   1516 
   1517             for( ; i < n; i++ )
   1518                 dir[sorted[i].i] = (char)0;
   1519         }
   1520     }
   1521 
   1522     node->maxlr = MAX( L, R );
   1523     return node->split->quality/(L + R);
   1524 }
   1525 
   1526 
   1527 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
   1528 {
   1529     int vi;
   1530     CvDTreeSplit *best_split = 0, *split = 0, *t;
   1531 
   1532     for( vi = 0; vi < data->var_count; vi++ )
   1533     {
   1534         int ci = data->get_var_type(vi);
   1535         if( node->get_num_valid(vi) <= 1 )
   1536             continue;
   1537 
   1538         if( data->is_classifier )
   1539         {
   1540             if( ci >= 0 )
   1541                 split = find_split_cat_class( node, vi );
   1542             else
   1543                 split = find_split_ord_class( node, vi );
   1544         }
   1545         else
   1546         {
   1547             if( ci >= 0 )
   1548                 split = find_split_cat_reg( node, vi );
   1549             else
   1550                 split = find_split_ord_reg( node, vi );
   1551         }
   1552 
   1553         if( split )
   1554         {
   1555             if( !best_split || best_split->quality < split->quality )
   1556                 CV_SWAP( best_split, split, t );
   1557             if( split )
   1558                 cvSetRemoveByPtr( data->split_heap, split );
   1559         }
   1560     }
   1561 
   1562     return best_split;
   1563 }
   1564 
   1565 
   1566 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi )
   1567 {
   1568     const float epsilon = FLT_EPSILON*2;
   1569     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
   1570     const int* responses = data->get_class_labels(node);
   1571     int n = node->sample_count;
   1572     int n1 = node->get_num_valid(vi);
   1573     int m = data->get_num_classes();
   1574     const int* rc0 = data->counts->data.i;
   1575     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
   1576     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
   1577     int i, best_i = -1;
   1578     double lsum2 = 0, rsum2 = 0, best_val = 0;
   1579     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
   1580 
   1581     // init arrays of class instance counters on both sides of the split
   1582     for( i = 0; i < m; i++ )
   1583     {
   1584         lc[i] = 0;
   1585         rc[i] = rc0[i];
   1586     }
   1587 
   1588     // compensate for missing values
   1589     for( i = n1; i < n; i++ )
   1590         rc[responses[sorted[i].i]]--;
   1591 
   1592     if( !priors )
   1593     {
   1594         int L = 0, R = n1;
   1595 
   1596         for( i = 0; i < m; i++ )
   1597             rsum2 += (double)rc[i]*rc[i];
   1598 
   1599         for( i = 0; i < n1 - 1; i++ )
   1600         {
   1601             int idx = responses[sorted[i].i];
   1602             int lv, rv;
   1603             L++; R--;
   1604             lv = lc[idx]; rv = rc[idx];
   1605             lsum2 += lv*2 + 1;
   1606             rsum2 -= rv*2 - 1;
   1607             lc[idx] = lv + 1; rc[idx] = rv - 1;
   1608 
   1609             if( sorted[i].val + epsilon < sorted[i+1].val )
   1610             {
   1611                 double val = (lsum2*R + rsum2*L)/((double)L*R);
   1612                 if( best_val < val )
   1613                 {
   1614                     best_val = val;
   1615                     best_i = i;
   1616                 }
   1617             }
   1618         }
   1619     }
   1620     else
   1621     {
   1622         double L = 0, R = 0;
   1623         for( i = 0; i < m; i++ )
   1624         {
   1625             double wv = rc[i]*priors[i];
   1626             R += wv;
   1627             rsum2 += wv*wv;
   1628         }
   1629 
   1630         for( i = 0; i < n1 - 1; i++ )
   1631         {
   1632             int idx = responses[sorted[i].i];
   1633             int lv, rv;
   1634             double p = priors[idx], p2 = p*p;
   1635             L += p; R -= p;
   1636             lv = lc[idx]; rv = rc[idx];
   1637             lsum2 += p2*(lv*2 + 1);
   1638             rsum2 -= p2*(rv*2 - 1);
   1639             lc[idx] = lv + 1; rc[idx] = rv - 1;
   1640 
   1641             if( sorted[i].val + epsilon < sorted[i+1].val )
   1642             {
   1643                 double val = (lsum2*R + rsum2*L)/((double)L*R);
   1644                 if( best_val < val )
   1645                 {
   1646                     best_val = val;
   1647                     best_i = i;
   1648                 }
   1649             }
   1650         }
   1651     }
   1652 
   1653     return best_i >= 0 ? data->new_split_ord( vi,
   1654         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
   1655         0, (float)best_val ) : 0;
   1656 }
   1657 
   1658 
   1659 void CvDTree::cluster_categories( const int* vectors, int n, int m,
   1660                                 int* csums, int k, int* labels )
   1661 {
   1662     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
   1663     int iters = 0, max_iters = 100;
   1664     int i, j, idx;
   1665     double* buf = (double*)cvStackAlloc( (n + k)*sizeof(buf[0]) );
   1666     double *v_weights = buf, *c_weights = buf + k;
   1667     bool modified = true;
   1668     CvRNG* r = &data->rng;
   1669 
   1670     // assign labels randomly
   1671     for( i = idx = 0; i < n; i++ )
   1672     {
   1673         int sum = 0;
   1674         const int* v = vectors + i*m;
   1675         labels[i] = idx++;
   1676         idx &= idx < k ? -1 : 0;
   1677 
   1678         // compute weight of each vector
   1679         for( j = 0; j < m; j++ )
   1680             sum += v[j];
   1681         v_weights[i] = sum ? 1./sum : 0.;
   1682     }
   1683 
   1684     for( i = 0; i < n; i++ )
   1685     {
   1686         int i1 = cvRandInt(r) % n;
   1687         int i2 = cvRandInt(r) % n;
   1688         CV_SWAP( labels[i1], labels[i2], j );
   1689     }
   1690 
   1691     for( iters = 0; iters <= max_iters; iters++ )
   1692     {
   1693         // calculate csums
   1694         for( i = 0; i < k; i++ )
   1695         {
   1696             for( j = 0; j < m; j++ )
   1697                 csums[i*m + j] = 0;
   1698         }
   1699 
   1700         for( i = 0; i < n; i++ )
   1701         {
   1702             const int* v = vectors + i*m;
   1703             int* s = csums + labels[i]*m;
   1704             for( j = 0; j < m; j++ )
   1705                 s[j] += v[j];
   1706         }
   1707 
   1708         // exit the loop here, when we have up-to-date csums
   1709         if( iters == max_iters || !modified )
   1710             break;
   1711 
   1712         modified = false;
   1713 
   1714         // calculate weight of each cluster
   1715         for( i = 0; i < k; i++ )
   1716         {
   1717             const int* s = csums + i*m;
   1718             int sum = 0;
   1719             for( j = 0; j < m; j++ )
   1720                 sum += s[j];
   1721             c_weights[i] = sum ? 1./sum : 0;
   1722         }
   1723 
   1724         // now for each vector determine the closest cluster
   1725         for( i = 0; i < n; i++ )
   1726         {
   1727             const int* v = vectors + i*m;
   1728             double alpha = v_weights[i];
   1729             double min_dist2 = DBL_MAX;
   1730             int min_idx = -1;
   1731 
   1732             for( idx = 0; idx < k; idx++ )
   1733             {
   1734                 const int* s = csums + idx*m;
   1735                 double dist2 = 0., beta = c_weights[idx];
   1736                 for( j = 0; j < m; j++ )
   1737                 {
   1738                     double t = v[j]*alpha - s[j]*beta;
   1739                     dist2 += t*t;
   1740                 }
   1741                 if( min_dist2 > dist2 )
   1742                 {
   1743                     min_dist2 = dist2;
   1744                     min_idx = idx;
   1745                 }
   1746             }
   1747 
   1748             if( min_idx != labels[i] )
   1749                 modified = true;
   1750             labels[i] = min_idx;
   1751         }
   1752     }
   1753 }
   1754 
   1755 
   1756 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi )
   1757 {
   1758     CvDTreeSplit* split;
   1759     const int* labels = data->get_cat_var_data(node, vi);
   1760     const int* responses = data->get_class_labels(node);
   1761     int ci = data->get_var_type(vi);
   1762     int n = node->sample_count;
   1763     int m = data->get_num_classes();
   1764     int _mi = data->cat_count->data.i[ci], mi = _mi;
   1765     int* lc = (int*)cvStackAlloc(m*sizeof(lc[0]));
   1766     int* rc = (int*)cvStackAlloc(m*sizeof(rc[0]));
   1767     int* _cjk = (int*)cvStackAlloc(m*(mi+1)*sizeof(_cjk[0]))+m, *cjk = _cjk;
   1768     double* c_weights = (double*)cvStackAlloc( mi*sizeof(c_weights[0]) );
   1769     int* cluster_labels = 0;
   1770     int** int_ptr = 0;
   1771     int i, j, k, idx;
   1772     double L = 0, R = 0;
   1773     double best_val = 0;
   1774     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
   1775     const double* priors = data->priors_mult->data.db;
   1776 
   1777     // init array of counters:
   1778     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
   1779     for( j = -1; j < mi; j++ )
   1780         for( k = 0; k < m; k++ )
   1781             cjk[j*m + k] = 0;
   1782 
   1783     for( i = 0; i < n; i++ )
   1784     {
   1785         j = labels[i];
   1786         k = responses[i];
   1787         cjk[j*m + k]++;
   1788     }
   1789 
   1790     if( m > 2 )
   1791     {
   1792         if( mi > data->params.max_categories )
   1793         {
   1794             mi = MIN(data->params.max_categories, n);
   1795             cjk += _mi*m;
   1796             cluster_labels = (int*)cvStackAlloc(mi*sizeof(cluster_labels[0]));
   1797             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
   1798         }
   1799         subset_i = 1;
   1800         subset_n = 1 << mi;
   1801     }
   1802     else
   1803     {
   1804         assert( m == 2 );
   1805         int_ptr = (int**)cvStackAlloc( mi*sizeof(int_ptr[0]) );
   1806         for( j = 0; j < mi; j++ )
   1807             int_ptr[j] = cjk + j*2 + 1;
   1808         icvSortIntPtr( int_ptr, mi, 0 );
   1809         subset_i = 0;
   1810         subset_n = mi;
   1811     }
   1812 
   1813     for( k = 0; k < m; k++ )
   1814     {
   1815         int sum = 0;
   1816         for( j = 0; j < mi; j++ )
   1817             sum += cjk[j*m + k];
   1818         rc[k] = sum;
   1819         lc[k] = 0;
   1820     }
   1821 
   1822     for( j = 0; j < mi; j++ )
   1823     {
   1824         double sum = 0;
   1825         for( k = 0; k < m; k++ )
   1826             sum += cjk[j*m + k]*priors[k];
   1827         c_weights[j] = sum;
   1828         R += c_weights[j];
   1829     }
   1830 
   1831     for( ; subset_i < subset_n; subset_i++ )
   1832     {
   1833         double weight;
   1834         int* crow;
   1835         double lsum2 = 0, rsum2 = 0;
   1836 
   1837         if( m == 2 )
   1838             idx = (int)(int_ptr[subset_i] - cjk)/2;
   1839         else
   1840         {
   1841             int graycode = (subset_i>>1)^subset_i;
   1842             int diff = graycode ^ prevcode;
   1843 
   1844             // determine index of the changed bit.
   1845             Cv32suf u;
   1846             idx = diff >= (1 << 16) ? 16 : 0;
   1847             u.f = (float)(((diff >> 16) | diff) & 65535);
   1848             idx += (u.i >> 23) - 127;
   1849             subtract = graycode < prevcode;
   1850             prevcode = graycode;
   1851         }
   1852 
   1853         crow = cjk + idx*m;
   1854         weight = c_weights[idx];
   1855         if( weight < FLT_EPSILON )
   1856             continue;
   1857 
   1858         if( !subtract )
   1859         {
   1860             for( k = 0; k < m; k++ )
   1861             {
   1862                 int t = crow[k];
   1863                 int lval = lc[k] + t;
   1864                 int rval = rc[k] - t;
   1865                 double p = priors[k], p2 = p*p;
   1866                 lsum2 += p2*lval*lval;
   1867                 rsum2 += p2*rval*rval;
   1868                 lc[k] = lval; rc[k] = rval;
   1869             }
   1870             L += weight;
   1871             R -= weight;
   1872         }
   1873         else
   1874         {
   1875             for( k = 0; k < m; k++ )
   1876             {
   1877                 int t = crow[k];
   1878                 int lval = lc[k] - t;
   1879                 int rval = rc[k] + t;
   1880                 double p = priors[k], p2 = p*p;
   1881                 lsum2 += p2*lval*lval;
   1882                 rsum2 += p2*rval*rval;
   1883                 lc[k] = lval; rc[k] = rval;
   1884             }
   1885             L -= weight;
   1886             R += weight;
   1887         }
   1888 
   1889         if( L > FLT_EPSILON && R > FLT_EPSILON )
   1890         {
   1891             double val = (lsum2*R + rsum2*L)/((double)L*R);
   1892             if( best_val < val )
   1893             {
   1894                 best_val = val;
   1895                 best_subset = subset_i;
   1896             }
   1897         }
   1898     }
   1899 
   1900     if( best_subset < 0 )
   1901         return 0;
   1902 
   1903     split = data->new_split_cat( vi, (float)best_val );
   1904 
   1905     if( m == 2 )
   1906     {
   1907         for( i = 0; i <= best_subset; i++ )
   1908         {
   1909             idx = (int)(int_ptr[i] - cjk) >> 1;
   1910             split->subset[idx >> 5] |= 1 << (idx & 31);
   1911         }
   1912     }
   1913     else
   1914     {
   1915         for( i = 0; i < _mi; i++ )
   1916         {
   1917             idx = cluster_labels ? cluster_labels[i] : i;
   1918             if( best_subset & (1 << idx) )
   1919                 split->subset[i >> 5] |= 1 << (i & 31);
   1920         }
   1921     }
   1922 
   1923     return split;
   1924 }
   1925 
   1926 
   1927 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi )
   1928 {
   1929     const float epsilon = FLT_EPSILON*2;
   1930     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
   1931     const float* responses = data->get_ord_responses(node);
   1932     int n = node->sample_count;
   1933     int n1 = node->get_num_valid(vi);
   1934     int i, best_i = -1;
   1935     double best_val = 0, lsum = 0, rsum = node->value*n;
   1936     int L = 0, R = n1;
   1937 
   1938     // compensate for missing values
   1939     for( i = n1; i < n; i++ )
   1940         rsum -= responses[sorted[i].i];
   1941 
   1942     // find the optimal split
   1943     for( i = 0; i < n1 - 1; i++ )
   1944     {
   1945         float t = responses[sorted[i].i];
   1946         L++; R--;
   1947         lsum += t;
   1948         rsum -= t;
   1949 
   1950         if( sorted[i].val + epsilon < sorted[i+1].val )
   1951         {
   1952             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
   1953             if( best_val < val )
   1954             {
   1955                 best_val = val;
   1956                 best_i = i;
   1957             }
   1958         }
   1959     }
   1960 
   1961     return best_i >= 0 ? data->new_split_ord( vi,
   1962         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
   1963         0, (float)best_val ) : 0;
   1964 }
   1965 
   1966 
   1967 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi )
   1968 {
   1969     CvDTreeSplit* split;
   1970     const int* labels = data->get_cat_var_data(node, vi);
   1971     const float* responses = data->get_ord_responses(node);
   1972     int ci = data->get_var_type(vi);
   1973     int n = node->sample_count;
   1974     int mi = data->cat_count->data.i[ci];
   1975     double* sum = (double*)cvStackAlloc( (mi+1)*sizeof(sum[0]) ) + 1;
   1976     int* counts = (int*)cvStackAlloc( (mi+1)*sizeof(counts[0]) ) + 1;
   1977     double** sum_ptr = 0;
   1978     int i, L = 0, R = 0;
   1979     double best_val = 0, lsum = 0, rsum = 0;
   1980     int best_subset = -1, subset_i;
   1981 
   1982     for( i = -1; i < mi; i++ )
   1983         sum[i] = counts[i] = 0;
   1984 
   1985     // calculate sum response and weight of each category of the input var
   1986     for( i = 0; i < n; i++ )
   1987     {
   1988         int idx = labels[i];
   1989         double s = sum[idx] + responses[i];
   1990         int nc = counts[idx] + 1;
   1991         sum[idx] = s;
   1992         counts[idx] = nc;
   1993     }
   1994 
   1995     // calculate average response in each category
   1996     for( i = 0; i < mi; i++ )
   1997     {
   1998         R += counts[i];
   1999         rsum += sum[i];
   2000         sum[i] /= MAX(counts[i],1);
   2001         sum_ptr[i] = sum + i;
   2002     }
   2003 
   2004     icvSortDblPtr( sum_ptr, mi, 0 );
   2005 
   2006     // revert back to unnormalized sums
   2007     // (there should be a very little loss of accuracy)
   2008     for( i = 0; i < mi; i++ )
   2009         sum[i] *= counts[i];
   2010 
   2011     for( subset_i = 0; subset_i < mi-1; subset_i++ )
   2012     {
   2013         int idx = (int)(sum_ptr[subset_i] - sum);
   2014         int ni = counts[idx];
   2015 
   2016         if( ni )
   2017         {
   2018             double s = sum[idx];
   2019             lsum += s; L += ni;
   2020             rsum -= s; R -= ni;
   2021 
   2022             if( L && R )
   2023             {
   2024                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
   2025                 if( best_val < val )
   2026                 {
   2027                     best_val = val;
   2028                     best_subset = subset_i;
   2029                 }
   2030             }
   2031         }
   2032     }
   2033 
   2034     if( best_subset < 0 )
   2035         return 0;
   2036 
   2037     split = data->new_split_cat( vi, (float)best_val );
   2038     for( i = 0; i <= best_subset; i++ )
   2039     {
   2040         int idx = (int)(sum_ptr[i] - sum);
   2041         split->subset[idx >> 5] |= 1 << (idx & 31);
   2042     }
   2043 
   2044     return split;
   2045 }
   2046 
   2047 
   2048 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi )
   2049 {
   2050     const float epsilon = FLT_EPSILON*2;
   2051     const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
   2052     const char* dir = (char*)data->direction->data.ptr;
   2053     int n1 = node->get_num_valid(vi);
   2054     // LL - number of samples that both the primary and the surrogate splits send to the left
   2055     // LR - ... primary split sends to the left and the surrogate split sends to the right
   2056     // RL - ... primary split sends to the right and the surrogate split sends to the left
   2057     // RR - ... both send to the right
   2058     int i, best_i = -1, best_inversed = 0;
   2059     double best_val;
   2060 
   2061     if( !data->have_priors )
   2062     {
   2063         int LL = 0, RL = 0, LR, RR;
   2064         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
   2065         int sum = 0, sum_abs = 0;
   2066 
   2067         for( i = 0; i < n1; i++ )
   2068         {
   2069             int d = dir[sorted[i].i];
   2070             sum += d; sum_abs += d & 1;
   2071         }
   2072 
   2073         // sum_abs = R + L; sum = R - L
   2074         RR = (sum_abs + sum) >> 1;
   2075         LR = (sum_abs - sum) >> 1;
   2076 
   2077         // initially all the samples are sent to the right by the surrogate split,
   2078         // LR of them are sent to the left by primary split, and RR - to the right.
   2079         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
   2080         for( i = 0; i < n1 - 1; i++ )
   2081         {
   2082             int d = dir[sorted[i].i];
   2083 
   2084             if( d < 0 )
   2085             {
   2086                 LL++; LR--;
   2087                 if( LL + RR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
   2088                 {
   2089                     best_val = LL + RR;
   2090                     best_i = i; best_inversed = 0;
   2091                 }
   2092             }
   2093             else if( d > 0 )
   2094             {
   2095                 RL++; RR--;
   2096                 if( RL + LR > _best_val && sorted[i].val + epsilon < sorted[i+1].val )
   2097                 {
   2098                     best_val = RL + LR;
   2099                     best_i = i; best_inversed = 1;
   2100                 }
   2101             }
   2102         }
   2103         best_val = _best_val;
   2104     }
   2105     else
   2106     {
   2107         double LL = 0, RL = 0, LR, RR;
   2108         double worst_val = node->maxlr;
   2109         double sum = 0, sum_abs = 0;
   2110         const double* priors = data->priors_mult->data.db;
   2111         const int* responses = data->get_class_labels(node);
   2112         best_val = worst_val;
   2113 
   2114         for( i = 0; i < n1; i++ )
   2115         {
   2116             int idx = sorted[i].i;
   2117             double w = priors[responses[idx]];
   2118             int d = dir[idx];
   2119             sum += d*w; sum_abs += (d & 1)*w;
   2120         }
   2121 
   2122         // sum_abs = R + L; sum = R - L
   2123         RR = (sum_abs + sum)*0.5;
   2124         LR = (sum_abs - sum)*0.5;
   2125 
   2126         // initially all the samples are sent to the right by the surrogate split,
   2127         // LR of them are sent to the left by primary split, and RR - to the right.
   2128         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
   2129         for( i = 0; i < n1 - 1; i++ )
   2130         {
   2131             int idx = sorted[i].i;
   2132             double w = priors[responses[idx]];
   2133             int d = dir[idx];
   2134 
   2135             if( d < 0 )
   2136             {
   2137                 LL += w; LR -= w;
   2138                 if( LL + RR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
   2139                 {
   2140                     best_val = LL + RR;
   2141                     best_i = i; best_inversed = 0;
   2142                 }
   2143             }
   2144             else if( d > 0 )
   2145             {
   2146                 RL += w; RR -= w;
   2147                 if( RL + LR > best_val && sorted[i].val + epsilon < sorted[i+1].val )
   2148                 {
   2149                     best_val = RL + LR;
   2150                     best_i = i; best_inversed = 1;
   2151                 }
   2152             }
   2153         }
   2154     }
   2155 
   2156     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
   2157         (sorted[best_i].val + sorted[best_i+1].val)*0.5f, best_i,
   2158         best_inversed, (float)best_val ) : 0;
   2159 }
   2160 
   2161 
   2162 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi )
   2163 {
   2164     const int* labels = data->get_cat_var_data(node, vi);
   2165     const char* dir = (char*)data->direction->data.ptr;
   2166     int n = node->sample_count;
   2167     // LL - number of samples that both the primary and the surrogate splits send to the left
   2168     // LR - ... primary split sends to the left and the surrogate split sends to the right
   2169     // RL - ... primary split sends to the right and the surrogate split sends to the left
   2170     // RR - ... both send to the right
   2171     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
   2172     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
   2173     double best_val = 0;
   2174     double* lc = (double*)cvStackAlloc( (mi+1)*2*sizeof(lc[0]) ) + 1;
   2175     double* rc = lc + mi + 1;
   2176 
   2177     for( i = -1; i < mi; i++ )
   2178         lc[i] = rc[i] = 0;
   2179 
   2180     // for each category calculate the weight of samples
   2181     // sent to the left (lc) and to the right (rc) by the primary split
   2182     if( !data->have_priors )
   2183     {
   2184         int* _lc = (int*)cvStackAlloc((mi+2)*2*sizeof(_lc[0])) + 1;
   2185         int* _rc = _lc + mi + 1;
   2186 
   2187         for( i = -1; i < mi; i++ )
   2188             _lc[i] = _rc[i] = 0;
   2189 
   2190         for( i = 0; i < n; i++ )
   2191         {
   2192             int idx = labels[i];
   2193             int d = dir[i];
   2194             int sum = _lc[idx] + d;
   2195             int sum_abs = _rc[idx] + (d & 1);
   2196             _lc[idx] = sum; _rc[idx] = sum_abs;
   2197         }
   2198 
   2199         for( i = 0; i < mi; i++ )
   2200         {
   2201             int sum = _lc[i];
   2202             int sum_abs = _rc[i];
   2203             lc[i] = (sum_abs - sum) >> 1;
   2204             rc[i] = (sum_abs + sum) >> 1;
   2205         }
   2206     }
   2207     else
   2208     {
   2209         const double* priors = data->priors_mult->data.db;
   2210         const int* responses = data->get_class_labels(node);
   2211 
   2212         for( i = 0; i < n; i++ )
   2213         {
   2214             int idx = labels[i];
   2215             double w = priors[responses[i]];
   2216             int d = dir[i];
   2217             double sum = lc[idx] + d*w;
   2218             double sum_abs = rc[idx] + (d & 1)*w;
   2219             lc[idx] = sum; rc[idx] = sum_abs;
   2220         }
   2221 
   2222         for( i = 0; i < mi; i++ )
   2223         {
   2224             double sum = lc[i];
   2225             double sum_abs = rc[i];
   2226             lc[i] = (sum_abs - sum) * 0.5;
   2227             rc[i] = (sum_abs + sum) * 0.5;
   2228         }
   2229     }
   2230 
   2231     // 2. now form the split.
   2232     // in each category send all the samples to the same direction as majority
   2233     for( i = 0; i < mi; i++ )
   2234     {
   2235         double lval = lc[i], rval = rc[i];
   2236         if( lval > rval )
   2237         {
   2238             split->subset[i >> 5] |= 1 << (i & 31);
   2239             best_val += lval;
   2240             l_win++;
   2241         }
   2242         else
   2243             best_val += rval;
   2244     }
   2245 
   2246     split->quality = (float)best_val;
   2247     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
   2248         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
   2249 
   2250     return split;
   2251 }
   2252 
   2253 
   2254 void CvDTree::calc_node_value( CvDTreeNode* node )
   2255 {
   2256     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
   2257     const int* cv_labels = data->get_labels(node);
   2258 
   2259     if( data->is_classifier )
   2260     {
   2261         // in case of classification tree:
   2262         //  * node value is the label of the class that has the largest weight in the node.
   2263         //  * node risk is the weighted number of misclassified samples,
   2264         //  * j-th cross-validation fold value and risk are calculated as above,
   2265         //    but using the samples with cv_labels(*)!=j.
   2266         //  * j-th cross-validation fold error is calculated as the weighted number of
   2267         //    misclassified samples with cv_labels(*)==j.
   2268 
   2269         // compute the number of instances of each class
   2270         int* cls_count = data->counts->data.i;
   2271         const int* responses = data->get_class_labels(node);
   2272         int m = data->get_num_classes();
   2273         int* cv_cls_count = (int*)cvStackAlloc(m*cv_n*sizeof(cv_cls_count[0]));
   2274         double max_val = -1, total_weight = 0;
   2275         int max_k = -1;
   2276         double* priors = data->priors_mult->data.db;
   2277 
   2278         for( k = 0; k < m; k++ )
   2279             cls_count[k] = 0;
   2280 
   2281         if( cv_n == 0 )
   2282         {
   2283             for( i = 0; i < n; i++ )
   2284                 cls_count[responses[i]]++;
   2285         }
   2286         else
   2287         {
   2288             for( j = 0; j < cv_n; j++ )
   2289                 for( k = 0; k < m; k++ )
   2290                     cv_cls_count[j*m + k] = 0;
   2291 
   2292             for( i = 0; i < n; i++ )
   2293             {
   2294                 j = cv_labels[i]; k = responses[i];
   2295                 cv_cls_count[j*m + k]++;
   2296             }
   2297 
   2298             for( j = 0; j < cv_n; j++ )
   2299                 for( k = 0; k < m; k++ )
   2300                     cls_count[k] += cv_cls_count[j*m + k];
   2301         }
   2302 
   2303         if( data->have_priors && node->parent == 0 )
   2304         {
   2305             // compute priors_mult from priors, take the sample ratio into account.
   2306             double sum = 0;
   2307             for( k = 0; k < m; k++ )
   2308             {
   2309                 int n_k = cls_count[k];
   2310                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
   2311                 sum += priors[k];
   2312             }
   2313             sum = 1./sum;
   2314             for( k = 0; k < m; k++ )
   2315                 priors[k] *= sum;
   2316         }
   2317 
   2318         for( k = 0; k < m; k++ )
   2319         {
   2320             double val = cls_count[k]*priors[k];
   2321             total_weight += val;
   2322             if( max_val < val )
   2323             {
   2324                 max_val = val;
   2325                 max_k = k;
   2326             }
   2327         }
   2328 
   2329         node->class_idx = max_k;
   2330         node->value = data->cat_map->data.i[
   2331             data->cat_ofs->data.i[data->cat_var_count] + max_k];
   2332         node->node_risk = total_weight - max_val;
   2333 
   2334         for( j = 0; j < cv_n; j++ )
   2335         {
   2336             double sum_k = 0, sum = 0, max_val_k = 0;
   2337             max_val = -1; max_k = -1;
   2338 
   2339             for( k = 0; k < m; k++ )
   2340             {
   2341                 double w = priors[k];
   2342                 double val_k = cv_cls_count[j*m + k]*w;
   2343                 double val = cls_count[k]*w - val_k;
   2344                 sum_k += val_k;
   2345                 sum += val;
   2346                 if( max_val < val )
   2347                 {
   2348                     max_val = val;
   2349                     max_val_k = val_k;
   2350                     max_k = k;
   2351                 }
   2352             }
   2353 
   2354             node->cv_Tn[j] = INT_MAX;
   2355             node->cv_node_risk[j] = sum - max_val;
   2356             node->cv_node_error[j] = sum_k - max_val_k;
   2357         }
   2358     }
   2359     else
   2360     {
   2361         // in case of regression tree:
   2362         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
   2363         //    n is the number of samples in the node.
   2364         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
   2365         //  * j-th cross-validation fold value and risk are calculated as above,
   2366         //    but using the samples with cv_labels(*)!=j.
   2367         //  * j-th cross-validation fold error is calculated
   2368         //    using samples with cv_labels(*)==j as the test subset:
   2369         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
   2370         //    where node_value_j is the node value calculated
   2371         //    as described in the previous bullet, and summation is done
   2372         //    over the samples with cv_labels(*)==j.
   2373 
   2374         double sum = 0, sum2 = 0;
   2375         const float* values = data->get_ord_responses(node);
   2376         double *cv_sum = 0, *cv_sum2 = 0;
   2377         int* cv_count = 0;
   2378 
   2379         if( cv_n == 0 )
   2380         {
   2381             for( i = 0; i < n; i++ )
   2382             {
   2383                 double t = values[i];
   2384                 sum += t;
   2385                 sum2 += t*t;
   2386             }
   2387         }
   2388         else
   2389         {
   2390             cv_sum = (double*)cvStackAlloc( cv_n*sizeof(cv_sum[0]) );
   2391             cv_sum2 = (double*)cvStackAlloc( cv_n*sizeof(cv_sum2[0]) );
   2392             cv_count = (int*)cvStackAlloc( cv_n*sizeof(cv_count[0]) );
   2393 
   2394             for( j = 0; j < cv_n; j++ )
   2395             {
   2396                 cv_sum[j] = cv_sum2[j] = 0.;
   2397                 cv_count[j] = 0;
   2398             }
   2399 
   2400             for( i = 0; i < n; i++ )
   2401             {
   2402                 j = cv_labels[i];
   2403                 double t = values[i];
   2404                 double s = cv_sum[j] + t;
   2405                 double s2 = cv_sum2[j] + t*t;
   2406                 int nc = cv_count[j] + 1;
   2407                 cv_sum[j] = s;
   2408                 cv_sum2[j] = s2;
   2409                 cv_count[j] = nc;
   2410             }
   2411 
   2412             for( j = 0; j < cv_n; j++ )
   2413             {
   2414                 sum += cv_sum[j];
   2415                 sum2 += cv_sum2[j];
   2416             }
   2417         }
   2418 
   2419         node->node_risk = sum2 - (sum/n)*sum;
   2420         node->value = sum/n;
   2421 
   2422         for( j = 0; j < cv_n; j++ )
   2423         {
   2424             double s = cv_sum[j], si = sum - s;
   2425             double s2 = cv_sum2[j], s2i = sum2 - s2;
   2426             int c = cv_count[j], ci = n - c;
   2427             double r = si/MAX(ci,1);
   2428             node->cv_node_risk[j] = s2i - r*r*ci;
   2429             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
   2430             node->cv_Tn[j] = INT_MAX;
   2431         }
   2432     }
   2433 }
   2434 
   2435 
   2436 void CvDTree::complete_node_dir( CvDTreeNode* node )
   2437 {
   2438     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
   2439     int nz = n - node->get_num_valid(node->split->var_idx);
   2440     char* dir = (char*)data->direction->data.ptr;
   2441 
   2442     // try to complete direction using surrogate splits
   2443     if( nz && data->params.use_surrogates )
   2444     {
   2445         CvDTreeSplit* split = node->split->next;
   2446         for( ; split != 0 && nz; split = split->next )
   2447         {
   2448             int inversed_mask = split->inversed ? -1 : 0;
   2449             vi = split->var_idx;
   2450 
   2451             if( data->get_var_type(vi) >= 0 ) // split on categorical var
   2452             {
   2453                 const int* labels = data->get_cat_var_data(node, vi);
   2454                 const int* subset = split->subset;
   2455 
   2456                 for( i = 0; i < n; i++ )
   2457                 {
   2458                     int idx;
   2459                     if( !dir[i] && (idx = labels[i]) >= 0 )
   2460                     {
   2461                         int d = CV_DTREE_CAT_DIR(idx,subset);
   2462                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
   2463                         if( --nz )
   2464                             break;
   2465                     }
   2466                 }
   2467             }
   2468             else // split on ordered var
   2469             {
   2470                 const CvPair32s32f* sorted = data->get_ord_var_data(node, vi);
   2471                 int split_point = split->ord.split_point;
   2472                 int n1 = node->get_num_valid(vi);
   2473 
   2474                 assert( 0 <= split_point && split_point < n-1 );
   2475 
   2476                 for( i = 0; i < n1; i++ )
   2477                 {
   2478                     int idx = sorted[i].i;
   2479                     if( !dir[idx] )
   2480                     {
   2481                         int d = i <= split_point ? -1 : 1;
   2482                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
   2483                         if( --nz )
   2484                             break;
   2485                     }
   2486                 }
   2487             }
   2488         }
   2489     }
   2490 
   2491     // find the default direction for the rest
   2492     if( nz )
   2493     {
   2494         for( i = nr = 0; i < n; i++ )
   2495             nr += dir[i] > 0;
   2496         nl = n - nr - nz;
   2497         d0 = nl > nr ? -1 : nr > nl;
   2498     }
   2499 
   2500     // make sure that every sample is directed either to the left or to the right
   2501     for( i = 0; i < n; i++ )
   2502     {
   2503         int d = dir[i];
   2504         if( !d )
   2505         {
   2506             d = d0;
   2507             if( !d )
   2508                 d = d1, d1 = -d1;
   2509         }
   2510         d = d > 0;
   2511         dir[i] = (char)d; // remap (-1,1) to (0,1)
   2512     }
   2513 }
   2514 
   2515 
   2516 void CvDTree::split_node_data( CvDTreeNode* node )
   2517 {
   2518     int vi, i, n = node->sample_count, nl, nr;
   2519     char* dir = (char*)data->direction->data.ptr;
   2520     CvDTreeNode *left = 0, *right = 0;
   2521     int* new_idx = data->split_buf->data.i;
   2522     int new_buf_idx = data->get_child_buf_idx( node );
   2523     int work_var_count = data->get_work_var_count();
   2524 
   2525     // speedup things a little, especially for tree ensembles with a lots of small trees:
   2526     //   do not physically split the input data between the left and right child nodes
   2527     //   when we are not going to split them further,
   2528     //   as calc_node_value() does not requires input features anyway.
   2529     bool split_input_data;
   2530 
   2531     complete_node_dir(node);
   2532 
   2533     for( i = nl = nr = 0; i < n; i++ )
   2534     {
   2535         int d = dir[i];
   2536         // initialize new indices for splitting ordered variables
   2537         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
   2538         nr += d;
   2539         nl += d^1;
   2540     }
   2541 
   2542     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
   2543     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset +
   2544         (data->ord_var_count + work_var_count)*nl );
   2545 
   2546     split_input_data = node->depth + 1 < data->params.max_depth &&
   2547         (node->left->sample_count > data->params.min_sample_count ||
   2548         node->right->sample_count > data->params.min_sample_count);
   2549 
   2550     // split ordered variables, keep both halves sorted.
   2551     for( vi = 0; vi < data->var_count; vi++ )
   2552     {
   2553         int ci = data->get_var_type(vi);
   2554         int n1 = node->get_num_valid(vi);
   2555         CvPair32s32f *src, *ldst0, *rdst0, *ldst, *rdst;
   2556         CvPair32s32f tl, tr;
   2557 
   2558         if( ci >= 0 || !split_input_data )
   2559             continue;
   2560 
   2561         src = data->get_ord_var_data(node, vi);
   2562         ldst0 = ldst = data->get_ord_var_data(left, vi);
   2563         rdst0 = rdst = data->get_ord_var_data(right, vi);
   2564         tl = ldst0[nl]; tr = rdst0[nr];
   2565 
   2566         // split sorted
   2567         for( i = 0; i < n1; i++ )
   2568         {
   2569             int idx = src[i].i;
   2570             float val = src[i].val;
   2571             int d = dir[idx];
   2572             idx = new_idx[idx];
   2573             ldst->i = rdst->i = idx;
   2574             ldst->val = rdst->val = val;
   2575             ldst += d^1;
   2576             rdst += d;
   2577         }
   2578 
   2579         left->set_num_valid(vi, (int)(ldst - ldst0));
   2580         right->set_num_valid(vi, (int)(rdst - rdst0));
   2581 
   2582         // split missing
   2583         for( ; i < n; i++ )
   2584         {
   2585             int idx = src[i].i;
   2586             int d = dir[idx];
   2587             idx = new_idx[idx];
   2588             ldst->i = rdst->i = idx;
   2589             ldst->val = rdst->val = ord_nan;
   2590             ldst += d^1;
   2591             rdst += d;
   2592         }
   2593 
   2594         ldst0[nl] = tl; rdst0[nr] = tr;
   2595     }
   2596 
   2597     // split categorical vars, responses and cv_labels using new_idx relocation table
   2598     for( vi = 0; vi < work_var_count; vi++ )
   2599     {
   2600         int ci = data->get_var_type(vi);
   2601         int n1 = node->get_num_valid(vi), nr1 = 0;
   2602         int *src, *ldst0, *rdst0, *ldst, *rdst;
   2603         int tl, tr;
   2604 
   2605         if( ci < 0 || (vi < data->var_count && !split_input_data) )
   2606             continue;
   2607 
   2608         src = data->get_cat_var_data(node, vi);
   2609         ldst0 = ldst = data->get_cat_var_data(left, vi);
   2610         rdst0 = rdst = data->get_cat_var_data(right, vi);
   2611         tl = ldst0[nl]; tr = rdst0[nr];
   2612 
   2613         for( i = 0; i < n; i++ )
   2614         {
   2615             int d = dir[i];
   2616             int val = src[i];
   2617             *ldst = *rdst = val;
   2618             ldst += d^1;
   2619             rdst += d;
   2620             nr1 += (val >= 0)&d;
   2621         }
   2622 
   2623         if( vi < data->var_count )
   2624         {
   2625             left->set_num_valid(vi, n1 - nr1);
   2626             right->set_num_valid(vi, nr1);
   2627         }
   2628 
   2629         ldst0[nl] = tl; rdst0[nr] = tr;
   2630     }
   2631 
   2632     // deallocate the parent node data that is not needed anymore
   2633     data->free_node_data(node);
   2634 }
   2635 
   2636 
   2637 void CvDTree::prune_cv()
   2638 {
   2639     CvMat* ab = 0;
   2640     CvMat* temp = 0;
   2641     CvMat* err_jk = 0;
   2642 
   2643     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
   2644     // 2. choose the best tree index (if need, apply 1SE rule).
   2645     // 3. store the best index and cut the branches.
   2646 
   2647     CV_FUNCNAME( "CvDTree::prune_cv" );
   2648 
   2649     __BEGIN__;
   2650 
   2651     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
   2652     // currently, 1SE for regression is not implemented
   2653     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
   2654     double* err;
   2655     double min_err = 0, min_err_se = 0;
   2656     int min_idx = -1;
   2657 
   2658     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
   2659 
   2660     // build the main tree sequence, calculate alpha's
   2661     for(;;tree_count++)
   2662     {
   2663         double min_alpha = update_tree_rnc(tree_count, -1);
   2664         if( cut_tree(tree_count, -1, min_alpha) )
   2665             break;
   2666 
   2667         if( ab->cols <= tree_count )
   2668         {
   2669             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
   2670             for( ti = 0; ti < ab->cols; ti++ )
   2671                 temp->data.db[ti] = ab->data.db[ti];
   2672             cvReleaseMat( &ab );
   2673             ab = temp;
   2674             temp = 0;
   2675         }
   2676 
   2677         ab->data.db[tree_count] = min_alpha;
   2678     }
   2679 
   2680     ab->data.db[0] = 0.;
   2681 
   2682     if( tree_count > 0 )
   2683     {
   2684         for( ti = 1; ti < tree_count-1; ti++ )
   2685             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
   2686         ab->data.db[tree_count-1] = DBL_MAX*0.5;
   2687 
   2688         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
   2689         err = err_jk->data.db;
   2690 
   2691         for( j = 0; j < cv_n; j++ )
   2692         {
   2693             int tj = 0, tk = 0;
   2694             for( ; tk < tree_count; tj++ )
   2695             {
   2696                 double min_alpha = update_tree_rnc(tj, j);
   2697                 if( cut_tree(tj, j, min_alpha) )
   2698                     min_alpha = DBL_MAX;
   2699 
   2700                 for( ; tk < tree_count; tk++ )
   2701                 {
   2702                     if( ab->data.db[tk] > min_alpha )
   2703                         break;
   2704                     err[j*tree_count + tk] = root->tree_error;
   2705                 }
   2706             }
   2707         }
   2708 
   2709         for( ti = 0; ti < tree_count; ti++ )
   2710         {
   2711             double sum_err = 0;
   2712             for( j = 0; j < cv_n; j++ )
   2713                 sum_err += err[j*tree_count + ti];
   2714             if( ti == 0 || sum_err < min_err )
   2715             {
   2716                 min_err = sum_err;
   2717                 min_idx = ti;
   2718                 if( use_1se )
   2719                     min_err_se = sqrt( sum_err*(n - sum_err) );
   2720             }
   2721             else if( sum_err < min_err + min_err_se )
   2722                 min_idx = ti;
   2723         }
   2724     }
   2725 
   2726     pruned_tree_idx = min_idx;
   2727     free_prune_data(data->params.truncate_pruned_tree != 0);
   2728 
   2729     __END__;
   2730 
   2731     cvReleaseMat( &err_jk );
   2732     cvReleaseMat( &ab );
   2733     cvReleaseMat( &temp );
   2734 }
   2735 
   2736 
   2737 double CvDTree::update_tree_rnc( int T, int fold )
   2738 {
   2739     CvDTreeNode* node = root;
   2740     double min_alpha = DBL_MAX;
   2741 
   2742     for(;;)
   2743     {
   2744         CvDTreeNode* parent;
   2745         for(;;)
   2746         {
   2747             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
   2748             if( t <= T || !node->left )
   2749             {
   2750                 node->complexity = 1;
   2751                 node->tree_risk = node->node_risk;
   2752                 node->tree_error = 0.;
   2753                 if( fold >= 0 )
   2754                 {
   2755                     node->tree_risk = node->cv_node_risk[fold];
   2756                     node->tree_error = node->cv_node_error[fold];
   2757                 }
   2758                 break;
   2759             }
   2760             node = node->left;
   2761         }
   2762 
   2763         for( parent = node->parent; parent && parent->right == node;
   2764             node = parent, parent = parent->parent )
   2765         {
   2766             parent->complexity += node->complexity;
   2767             parent->tree_risk += node->tree_risk;
   2768             parent->tree_error += node->tree_error;
   2769 
   2770             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
   2771                 - parent->tree_risk)/(parent->complexity - 1);
   2772             min_alpha = MIN( min_alpha, parent->alpha );
   2773         }
   2774 
   2775         if( !parent )
   2776             break;
   2777 
   2778         parent->complexity = node->complexity;
   2779         parent->tree_risk = node->tree_risk;
   2780         parent->tree_error = node->tree_error;
   2781         node = parent->right;
   2782     }
   2783 
   2784     return min_alpha;
   2785 }
   2786 
   2787 
   2788 int CvDTree::cut_tree( int T, int fold, double min_alpha )
   2789 {
   2790     CvDTreeNode* node = root;
   2791     if( !node->left )
   2792         return 1;
   2793 
   2794     for(;;)
   2795     {
   2796         CvDTreeNode* parent;
   2797         for(;;)
   2798         {
   2799             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
   2800             if( t <= T || !node->left )
   2801                 break;
   2802             if( node->alpha <= min_alpha + FLT_EPSILON )
   2803             {
   2804                 if( fold >= 0 )
   2805                     node->cv_Tn[fold] = T;
   2806                 else
   2807                     node->Tn = T;
   2808                 if( node == root )
   2809                     return 1;
   2810                 break;
   2811             }
   2812             node = node->left;
   2813         }
   2814 
   2815         for( parent = node->parent; parent && parent->right == node;
   2816             node = parent, parent = parent->parent )
   2817             ;
   2818 
   2819         if( !parent )
   2820             break;
   2821 
   2822         node = parent->right;
   2823     }
   2824 
   2825     return 0;
   2826 }
   2827 
   2828 
   2829 void CvDTree::free_prune_data(bool cut_tree)
   2830 {
   2831     CvDTreeNode* node = root;
   2832 
   2833     for(;;)
   2834     {
   2835         CvDTreeNode* parent;
   2836         for(;;)
   2837         {
   2838             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
   2839             // as we will clear the whole cross-validation heap at the end
   2840             node->cv_Tn = 0;
   2841             node->cv_node_error = node->cv_node_risk = 0;
   2842             if( !node->left )
   2843                 break;
   2844             node = node->left;
   2845         }
   2846 
   2847         for( parent = node->parent; parent && parent->right == node;
   2848             node = parent, parent = parent->parent )
   2849         {
   2850             if( cut_tree && parent->Tn <= pruned_tree_idx )
   2851             {
   2852                 data->free_node( parent->left );
   2853                 data->free_node( parent->right );
   2854                 parent->left = parent->right = 0;
   2855             }
   2856         }
   2857 
   2858         if( !parent )
   2859             break;
   2860 
   2861         node = parent->right;
   2862     }
   2863 
   2864     if( data->cv_heap )
   2865         cvClearSet( data->cv_heap );
   2866 }
   2867 
   2868 
   2869 void CvDTree::free_tree()
   2870 {
   2871     if( root && data && data->shared )
   2872     {
   2873         pruned_tree_idx = INT_MIN;
   2874         free_prune_data(true);
   2875         data->free_node(root);
   2876         root = 0;
   2877     }
   2878 }
   2879 
   2880 
   2881 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
   2882     const CvMat* _missing, bool preprocessed_input ) const
   2883 {
   2884     CvDTreeNode* result = 0;
   2885     int* catbuf = 0;
   2886 
   2887     CV_FUNCNAME( "CvDTree::predict" );
   2888 
   2889     __BEGIN__;
   2890 
   2891     int i, step, mstep = 0;
   2892     const float* sample;
   2893     const uchar* m = 0;
   2894     CvDTreeNode* node = root;
   2895     const int* vtype;
   2896     const int* vidx;
   2897     const int* cmap;
   2898     const int* cofs;
   2899 
   2900     if( !node )
   2901         CV_ERROR( CV_StsError, "The tree has not been trained yet" );
   2902 
   2903     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
   2904         _sample->cols != 1 && _sample->rows != 1 ||
   2905         _sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input ||
   2906         _sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input )
   2907             CV_ERROR( CV_StsBadArg,
   2908         "the input sample must be 1d floating-point vector with the same "
   2909         "number of elements as the total number of variables used for training" );
   2910 
   2911     sample = _sample->data.fl;
   2912     step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
   2913 
   2914     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
   2915     {
   2916         int n = data->cat_count->cols;
   2917         catbuf = (int*)cvStackAlloc(n*sizeof(catbuf[0]));
   2918         for( i = 0; i < n; i++ )
   2919             catbuf[i] = -1;
   2920     }
   2921 
   2922     if( _missing )
   2923     {
   2924         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
   2925         !CV_ARE_SIZES_EQ(_missing, _sample) )
   2926             CV_ERROR( CV_StsBadArg,
   2927         "the missing data mask must be 8-bit vector of the same size as input sample" );
   2928         m = _missing->data.ptr;
   2929         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
   2930     }
   2931 
   2932     vtype = data->var_type->data.i;
   2933     vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
   2934     cmap = data->cat_map ? data->cat_map->data.i : 0;
   2935     cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
   2936 
   2937     while( node->Tn > pruned_tree_idx && node->left )
   2938     {
   2939         CvDTreeSplit* split = node->split;
   2940         int dir = 0;
   2941         for( ; !dir && split != 0; split = split->next )
   2942         {
   2943             int vi = split->var_idx;
   2944             int ci = vtype[vi];
   2945             i = vidx ? vidx[vi] : vi;
   2946             float val = sample[i*step];
   2947             if( m && m[i*mstep] )
   2948                 continue;
   2949             if( ci < 0 ) // ordered
   2950                 dir = val <= split->ord.c ? -1 : 1;
   2951             else // categorical
   2952             {
   2953                 int c;
   2954                 if( preprocessed_input )
   2955                     c = cvRound(val);
   2956                 else
   2957                 {
   2958                     c = catbuf[ci];
   2959                     if( c < 0 )
   2960                     {
   2961                         int a = c = cofs[ci];
   2962                         int b = cofs[ci+1];
   2963                         int ival = cvRound(val);
   2964                         if( ival != val )
   2965                             CV_ERROR( CV_StsBadArg,
   2966                             "one of input categorical variable is not an integer" );
   2967 
   2968                         while( a < b )
   2969                         {
   2970                             c = (a + b) >> 1;
   2971                             if( ival < cmap[c] )
   2972                                 b = c;
   2973                             else if( ival > cmap[c] )
   2974                                 a = c+1;
   2975                             else
   2976                                 break;
   2977                         }
   2978 
   2979                         if( c < 0 || ival != cmap[c] )
   2980                             continue;
   2981 
   2982                         catbuf[ci] = c -= cofs[ci];
   2983                     }
   2984                 }
   2985                 dir = CV_DTREE_CAT_DIR(c, split->subset);
   2986             }
   2987 
   2988             if( split->inversed )
   2989                 dir = -dir;
   2990         }
   2991 
   2992         if( !dir )
   2993         {
   2994             double diff = node->right->sample_count - node->left->sample_count;
   2995             dir = diff < 0 ? -1 : 1;
   2996         }
   2997         node = dir < 0 ? node->left : node->right;
   2998     }
   2999 
   3000     result = node;
   3001 
   3002     __END__;
   3003 
   3004     return result;
   3005 }
   3006 
   3007 
   3008 const CvMat* CvDTree::get_var_importance()
   3009 {
   3010     if( !var_importance )
   3011     {
   3012         CvDTreeNode* node = root;
   3013         double* importance;
   3014         if( !node )
   3015             return 0;
   3016         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
   3017         cvZero( var_importance );
   3018         importance = var_importance->data.db;
   3019 
   3020         for(;;)
   3021         {
   3022             CvDTreeNode* parent;
   3023             for( ;; node = node->left )
   3024             {
   3025                 CvDTreeSplit* split = node->split;
   3026 
   3027                 if( !node->left || node->Tn <= pruned_tree_idx )
   3028                     break;
   3029 
   3030                 for( ; split != 0; split = split->next )
   3031                     importance[split->var_idx] += split->quality;
   3032             }
   3033 
   3034             for( parent = node->parent; parent && parent->right == node;
   3035                 node = parent, parent = parent->parent )
   3036                 ;
   3037 
   3038             if( !parent )
   3039                 break;
   3040 
   3041             node = parent->right;
   3042         }
   3043 
   3044         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
   3045     }
   3046 
   3047     return var_importance;
   3048 }
   3049 
   3050 
   3051 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split )
   3052 {
   3053     int ci;
   3054 
   3055     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
   3056     cvWriteInt( fs, "var", split->var_idx );
   3057     cvWriteReal( fs, "quality", split->quality );
   3058 
   3059     ci = data->get_var_type(split->var_idx);
   3060     if( ci >= 0 ) // split on a categorical var
   3061     {
   3062         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
   3063         for( i = 0; i < n; i++ )
   3064             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
   3065 
   3066         // ad-hoc rule when to use inverse categorical split notation
   3067         // to achieve more compact and clear representation
   3068         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
   3069 
   3070         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
   3071                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
   3072 
   3073         for( i = 0; i < n; i++ )
   3074         {
   3075             int dir = CV_DTREE_CAT_DIR(i,split->subset);
   3076             if( dir*default_dir < 0 )
   3077                 cvWriteInt( fs, 0, i );
   3078         }
   3079         cvEndWriteStruct( fs );
   3080     }
   3081     else
   3082         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
   3083 
   3084     cvEndWriteStruct( fs );
   3085 }
   3086 
   3087 
   3088 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node )
   3089 {
   3090     CvDTreeSplit* split;
   3091 
   3092     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
   3093 
   3094     cvWriteInt( fs, "depth", node->depth );
   3095     cvWriteInt( fs, "sample_count", node->sample_count );
   3096     cvWriteReal( fs, "value", node->value );
   3097 
   3098     if( data->is_classifier )
   3099         cvWriteInt( fs, "norm_class_idx", node->class_idx );
   3100 
   3101     cvWriteInt( fs, "Tn", node->Tn );
   3102     cvWriteInt( fs, "complexity", node->complexity );
   3103     cvWriteReal( fs, "alpha", node->alpha );
   3104     cvWriteReal( fs, "node_risk", node->node_risk );
   3105     cvWriteReal( fs, "tree_risk", node->tree_risk );
   3106     cvWriteReal( fs, "tree_error", node->tree_error );
   3107 
   3108     if( node->left )
   3109     {
   3110         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
   3111 
   3112         for( split = node->split; split != 0; split = split->next )
   3113             write_split( fs, split );
   3114 
   3115         cvEndWriteStruct( fs );
   3116     }
   3117 
   3118     cvEndWriteStruct( fs );
   3119 }
   3120 
   3121 
   3122 void CvDTree::write_tree_nodes( CvFileStorage* fs )
   3123 {
   3124     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
   3125 
   3126     __BEGIN__;
   3127 
   3128     CvDTreeNode* node = root;
   3129 
   3130     // traverse the tree and save all the nodes in depth-first order
   3131     for(;;)
   3132     {
   3133         CvDTreeNode* parent;
   3134         for(;;)
   3135         {
   3136             write_node( fs, node );
   3137             if( !node->left )
   3138                 break;
   3139             node = node->left;
   3140         }
   3141 
   3142         for( parent = node->parent; parent && parent->right == node;
   3143             node = parent, parent = parent->parent )
   3144             ;
   3145 
   3146         if( !parent )
   3147             break;
   3148 
   3149         node = parent->right;
   3150     }
   3151 
   3152     __END__;
   3153 }
   3154 
   3155 
   3156 void CvDTree::write( CvFileStorage* fs, const char* name )
   3157 {
   3158     //CV_FUNCNAME( "CvDTree::write" );
   3159 
   3160     __BEGIN__;
   3161 
   3162     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
   3163 
   3164     get_var_importance();
   3165     data->write_params( fs );
   3166     if( var_importance )
   3167         cvWrite( fs, "var_importance", var_importance );
   3168     write( fs );
   3169 
   3170     cvEndWriteStruct( fs );
   3171 
   3172     __END__;
   3173 }
   3174 
   3175 
   3176 void CvDTree::write( CvFileStorage* fs )
   3177 {
   3178     //CV_FUNCNAME( "CvDTree::write" );
   3179 
   3180     __BEGIN__;
   3181 
   3182     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
   3183 
   3184     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
   3185     write_tree_nodes( fs );
   3186     cvEndWriteStruct( fs );
   3187 
   3188     __END__;
   3189 }
   3190 
   3191 
   3192 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
   3193 {
   3194     CvDTreeSplit* split = 0;
   3195 
   3196     CV_FUNCNAME( "CvDTree::read_split" );
   3197 
   3198     __BEGIN__;
   3199 
   3200     int vi, ci;
   3201 
   3202     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
   3203         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
   3204 
   3205     vi = cvReadIntByName( fs, fnode, "var", -1 );
   3206     if( (unsigned)vi >= (unsigned)data->var_count )
   3207         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
   3208 
   3209     ci = data->get_var_type(vi);
   3210     if( ci >= 0 ) // split on categorical var
   3211     {
   3212         int i, n = data->cat_count->data.i[ci], inversed = 0, val;
   3213         CvSeqReader reader;
   3214         CvFileNode* inseq;
   3215         split = data->new_split_cat( vi, 0 );
   3216         inseq = cvGetFileNodeByName( fs, fnode, "in" );
   3217         if( !inseq )
   3218         {
   3219             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
   3220             inversed = 1;
   3221         }
   3222         if( !inseq ||
   3223             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
   3224             CV_ERROR( CV_StsParseError,
   3225             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
   3226 
   3227         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
   3228         {
   3229             val = inseq->data.i;
   3230             if( (unsigned)val >= (unsigned)n )
   3231                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
   3232 
   3233             split->subset[val >> 5] |= 1 << (val & 31);
   3234         }
   3235         else
   3236         {
   3237             cvStartReadSeq( inseq->data.seq, &reader );
   3238 
   3239             for( i = 0; i < reader.seq->total; i++ )
   3240             {
   3241                 CvFileNode* inode = (CvFileNode*)reader.ptr;
   3242                 val = inode->data.i;
   3243                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
   3244                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
   3245 
   3246                 split->subset[val >> 5] |= 1 << (val & 31);
   3247                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   3248             }
   3249         }
   3250 
   3251         // for categorical splits we do not use inversed splits,
   3252         // instead we inverse the variable set in the split
   3253         if( inversed )
   3254             for( i = 0; i < (n + 31) >> 5; i++ )
   3255                 split->subset[i] ^= -1;
   3256     }
   3257     else
   3258     {
   3259         CvFileNode* cmp_node;
   3260         split = data->new_split_ord( vi, 0, 0, 0, 0 );
   3261 
   3262         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
   3263         if( !cmp_node )
   3264         {
   3265             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
   3266             split->inversed = 1;
   3267         }
   3268 
   3269         split->ord.c = (float)cvReadReal( cmp_node );
   3270     }
   3271 
   3272     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
   3273 
   3274     __END__;
   3275 
   3276     return split;
   3277 }
   3278 
   3279 
   3280 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
   3281 {
   3282     CvDTreeNode* node = 0;
   3283 
   3284     CV_FUNCNAME( "CvDTree::read_node" );
   3285 
   3286     __BEGIN__;
   3287 
   3288     CvFileNode* splits;
   3289     int i, depth;
   3290 
   3291     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
   3292         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
   3293 
   3294     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
   3295     depth = cvReadIntByName( fs, fnode, "depth", -1 );
   3296     if( depth != node->depth )
   3297         CV_ERROR( CV_StsParseError, "incorrect node depth" );
   3298 
   3299     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
   3300     node->value = cvReadRealByName( fs, fnode, "value" );
   3301     if( data->is_classifier )
   3302         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
   3303 
   3304     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
   3305     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
   3306     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
   3307     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
   3308     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
   3309     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
   3310 
   3311     splits = cvGetFileNodeByName( fs, fnode, "splits" );
   3312     if( splits )
   3313     {
   3314         CvSeqReader reader;
   3315         CvDTreeSplit* last_split = 0;
   3316 
   3317         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
   3318             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
   3319 
   3320         cvStartReadSeq( splits->data.seq, &reader );
   3321         for( i = 0; i < reader.seq->total; i++ )
   3322         {
   3323             CvDTreeSplit* split;
   3324             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
   3325             if( !last_split )
   3326                 node->split = last_split = split;
   3327             else
   3328                 last_split = last_split->next = split;
   3329 
   3330             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   3331         }
   3332     }
   3333 
   3334     __END__;
   3335 
   3336     return node;
   3337 }
   3338 
   3339 
   3340 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
   3341 {
   3342     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
   3343 
   3344     __BEGIN__;
   3345 
   3346     CvSeqReader reader;
   3347     CvDTreeNode _root;
   3348     CvDTreeNode* parent = &_root;
   3349     int i;
   3350     parent->left = parent->right = parent->parent = 0;
   3351 
   3352     cvStartReadSeq( fnode->data.seq, &reader );
   3353 
   3354     for( i = 0; i < reader.seq->total; i++ )
   3355     {
   3356         CvDTreeNode* node;
   3357 
   3358         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
   3359         if( !parent->left )
   3360             parent->left = node;
   3361         else
   3362             parent->right = node;
   3363         if( node->split )
   3364             parent = node;
   3365         else
   3366         {
   3367             while( parent && parent->right )
   3368                 parent = parent->parent;
   3369         }
   3370 
   3371         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   3372     }
   3373 
   3374     root = _root.left;
   3375 
   3376     __END__;
   3377 }
   3378 
   3379 
   3380 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
   3381 {
   3382     CvDTreeTrainData* _data = new CvDTreeTrainData();
   3383     _data->read_params( fs, fnode );
   3384 
   3385     read( fs, fnode, _data );
   3386     get_var_importance();
   3387 }
   3388 
   3389 
   3390 // a special entry point for reading weak decision trees from the tree ensembles
   3391 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
   3392 {
   3393     CV_FUNCNAME( "CvDTree::read" );
   3394 
   3395     __BEGIN__;
   3396 
   3397     CvFileNode* tree_nodes;
   3398 
   3399     clear();
   3400     data = _data;
   3401 
   3402     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
   3403     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
   3404         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
   3405 
   3406     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
   3407     read_tree_nodes( fs, tree_nodes );
   3408 
   3409     __END__;
   3410 }
   3411 
   3412 /* End of file. */
   3413