Home | History | Annotate | Download | only in traincascade
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "old_ml_precomp.hpp"
     42 #include <ctype.h>
     43 
     44 using namespace cv;
     45 
     46 static const float ord_nan = FLT_MAX*0.5f;
     47 static const int min_block_size = 1 << 16;
     48 static const int block_size_delta = 1 << 10;
     49 
     50 CvDTreeTrainData::CvDTreeTrainData()
     51 {
     52     var_idx = var_type = cat_count = cat_ofs = cat_map =
     53         priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
     54     buf = 0;
     55     tree_storage = temp_storage = 0;
     56 
     57     clear();
     58 }
     59 
     60 
     61 CvDTreeTrainData::CvDTreeTrainData( const CvMat* _train_data, int _tflag,
     62                       const CvMat* _responses, const CvMat* _var_idx,
     63                       const CvMat* _sample_idx, const CvMat* _var_type,
     64                       const CvMat* _missing_mask, const CvDTreeParams& _params,
     65                       bool _shared, bool _add_labels )
     66 {
     67     var_idx = var_type = cat_count = cat_ofs = cat_map =
     68         priors = priors_mult = counts = direction = split_buf = responses_copy = 0;
     69     buf = 0;
     70 
     71     tree_storage = temp_storage = 0;
     72 
     73     set_data( _train_data, _tflag, _responses, _var_idx, _sample_idx,
     74               _var_type, _missing_mask, _params, _shared, _add_labels );
     75 }
     76 
     77 
     78 CvDTreeTrainData::~CvDTreeTrainData()
     79 {
     80     clear();
     81 }
     82 
     83 
     84 bool CvDTreeTrainData::set_params( const CvDTreeParams& _params )
     85 {
     86     bool ok = false;
     87 
     88     CV_FUNCNAME( "CvDTreeTrainData::set_params" );
     89 
     90     __BEGIN__;
     91 
     92     // set parameters
     93     params = _params;
     94 
     95     if( params.max_categories < 2 )
     96         CV_ERROR( CV_StsOutOfRange, "params.max_categories should be >= 2" );
     97     params.max_categories = MIN( params.max_categories, 15 );
     98 
     99     if( params.max_depth < 0 )
    100         CV_ERROR( CV_StsOutOfRange, "params.max_depth should be >= 0" );
    101     params.max_depth = MIN( params.max_depth, 25 );
    102 
    103     params.min_sample_count = MAX(params.min_sample_count,1);
    104 
    105     if( params.cv_folds < 0 )
    106         CV_ERROR( CV_StsOutOfRange,
    107         "params.cv_folds should be =0 (the tree is not pruned) "
    108         "or n>0 (tree is pruned using n-fold cross-validation)" );
    109 
    110     if( params.cv_folds == 1 )
    111         params.cv_folds = 0;
    112 
    113     if( params.regression_accuracy < 0 )
    114         CV_ERROR( CV_StsOutOfRange, "params.regression_accuracy should be >= 0" );
    115 
    116     ok = true;
    117 
    118     __END__;
    119 
    120     return ok;
    121 }
    122 
    123 template<typename T>
    124 class LessThanPtr
    125 {
    126 public:
    127     bool operator()(T* a, T* b) const { return *a < *b; }
    128 };
    129 
    130 template<typename T, typename Idx>
    131 class LessThanIdx
    132 {
    133 public:
    134     LessThanIdx( const T* _arr ) : arr(_arr) {}
    135     bool operator()(Idx a, Idx b) const { return arr[a] < arr[b]; }
    136     const T* arr;
    137 };
    138 
    139 class LessThanPairs
    140 {
    141 public:
    142     bool operator()(const CvPair16u32s& a, const CvPair16u32s& b) const { return *a.i < *b.i; }
    143 };
    144 
    145 void CvDTreeTrainData::set_data( const CvMat* _train_data, int _tflag,
    146     const CvMat* _responses, const CvMat* _var_idx, const CvMat* _sample_idx,
    147     const CvMat* _var_type, const CvMat* _missing_mask, const CvDTreeParams& _params,
    148     bool _shared, bool _add_labels, bool _update_data )
    149 {
    150     CvMat* sample_indices = 0;
    151     CvMat* var_type0 = 0;
    152     CvMat* tmp_map = 0;
    153     int** int_ptr = 0;
    154     CvPair16u32s* pair16u32s_ptr = 0;
    155     CvDTreeTrainData* data = 0;
    156     float *_fdst = 0;
    157     int *_idst = 0;
    158     unsigned short* udst = 0;
    159     int* idst = 0;
    160 
    161     CV_FUNCNAME( "CvDTreeTrainData::set_data" );
    162 
    163     __BEGIN__;
    164 
    165     int sample_all = 0, r_type, cv_n;
    166     int total_c_count = 0;
    167     int tree_block_size, temp_block_size, max_split_size, nv_size, cv_size = 0;
    168     int ds_step, dv_step, ms_step = 0, mv_step = 0; // {data|mask}{sample|var}_step
    169     int vi, i, size;
    170     char err[100];
    171     const int *sidx = 0, *vidx = 0;
    172 
    173     uint64 effective_buf_size = 0;
    174     int effective_buf_height = 0, effective_buf_width = 0;
    175 
    176     if( _update_data && data_root )
    177     {
    178         data = new CvDTreeTrainData( _train_data, _tflag, _responses, _var_idx,
    179             _sample_idx, _var_type, _missing_mask, _params, _shared, _add_labels );
    180 
    181         // compare new and old train data
    182         if( !(data->var_count == var_count &&
    183             cvNorm( data->var_type, var_type, CV_C ) < FLT_EPSILON &&
    184             cvNorm( data->cat_count, cat_count, CV_C ) < FLT_EPSILON &&
    185             cvNorm( data->cat_map, cat_map, CV_C ) < FLT_EPSILON) )
    186             CV_ERROR( CV_StsBadArg,
    187             "The new training data must have the same types and the input and output variables "
    188             "and the same categories for categorical variables" );
    189 
    190         cvReleaseMat( &priors );
    191         cvReleaseMat( &priors_mult );
    192         cvReleaseMat( &buf );
    193         cvReleaseMat( &direction );
    194         cvReleaseMat( &split_buf );
    195         cvReleaseMemStorage( &temp_storage );
    196 
    197         priors = data->priors; data->priors = 0;
    198         priors_mult = data->priors_mult; data->priors_mult = 0;
    199         buf = data->buf; data->buf = 0;
    200         buf_count = data->buf_count; buf_size = data->buf_size;
    201         sample_count = data->sample_count;
    202 
    203         direction = data->direction; data->direction = 0;
    204         split_buf = data->split_buf; data->split_buf = 0;
    205         temp_storage = data->temp_storage; data->temp_storage = 0;
    206         nv_heap = data->nv_heap; cv_heap = data->cv_heap;
    207 
    208         data_root = new_node( 0, sample_count, 0, 0 );
    209         EXIT;
    210     }
    211 
    212     clear();
    213 
    214     var_all = 0;
    215     rng = &cv::theRNG();
    216 
    217     CV_CALL( set_params( _params ));
    218 
    219     // check parameter types and sizes
    220     CV_CALL( cvCheckTrainData( _train_data, _tflag, _missing_mask, &var_all, &sample_all ));
    221 
    222     train_data = _train_data;
    223     responses = _responses;
    224 
    225     if( _tflag == CV_ROW_SAMPLE )
    226     {
    227         ds_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
    228         dv_step = 1;
    229         if( _missing_mask )
    230             ms_step = _missing_mask->step, mv_step = 1;
    231     }
    232     else
    233     {
    234         dv_step = _train_data->step/CV_ELEM_SIZE(_train_data->type);
    235         ds_step = 1;
    236         if( _missing_mask )
    237             mv_step = _missing_mask->step, ms_step = 1;
    238     }
    239     tflag = _tflag;
    240 
    241     sample_count = sample_all;
    242     var_count = var_all;
    243 
    244     if( _sample_idx )
    245     {
    246         CV_CALL( sample_indices = cvPreprocessIndexArray( _sample_idx, sample_all ));
    247         sidx = sample_indices->data.i;
    248         sample_count = sample_indices->rows + sample_indices->cols - 1;
    249     }
    250 
    251     if( _var_idx )
    252     {
    253         CV_CALL( var_idx = cvPreprocessIndexArray( _var_idx, var_all ));
    254         vidx = var_idx->data.i;
    255         var_count = var_idx->rows + var_idx->cols - 1;
    256     }
    257 
    258     is_buf_16u = false;
    259     if ( sample_count < 65536 )
    260         is_buf_16u = true;
    261 
    262     if( !CV_IS_MAT(_responses) ||
    263         (CV_MAT_TYPE(_responses->type) != CV_32SC1 &&
    264          CV_MAT_TYPE(_responses->type) != CV_32FC1) ||
    265         (_responses->rows != 1 && _responses->cols != 1) ||
    266         _responses->rows + _responses->cols - 1 != sample_all )
    267         CV_ERROR( CV_StsBadArg, "The array of _responses must be an integer or "
    268                   "floating-point vector containing as many elements as "
    269                   "the total number of samples in the training data matrix" );
    270 
    271     r_type = CV_VAR_CATEGORICAL;
    272     if( _var_type )
    273         CV_CALL( var_type0 = cvPreprocessVarType( _var_type, var_idx, var_count, &r_type ));
    274 
    275     CV_CALL( var_type = cvCreateMat( 1, var_count+2, CV_32SC1 ));
    276 
    277     cat_var_count = 0;
    278     ord_var_count = -1;
    279 
    280     is_classifier = r_type == CV_VAR_CATEGORICAL;
    281 
    282     // step 0. calc the number of categorical vars
    283     for( vi = 0; vi < var_count; vi++ )
    284     {
    285         char vt = var_type0 ? var_type0->data.ptr[vi] : CV_VAR_ORDERED;
    286         var_type->data.i[vi] = vt == CV_VAR_CATEGORICAL ? cat_var_count++ : ord_var_count--;
    287     }
    288 
    289     ord_var_count = ~ord_var_count;
    290     cv_n = params.cv_folds;
    291     // set the two last elements of var_type array to be able
    292     // to locate responses and cross-validation labels using
    293     // the corresponding get_* functions.
    294     var_type->data.i[var_count] = cat_var_count;
    295     var_type->data.i[var_count+1] = cat_var_count+1;
    296 
    297     // in case of single ordered predictor we need dummy cv_labels
    298     // for safe split_node_data() operation
    299     have_labels = cv_n > 0 || (ord_var_count == 1 && cat_var_count == 0) || _add_labels;
    300 
    301     work_var_count = var_count + (is_classifier ? 1 : 0) // for responses class_labels
    302                                + (have_labels ? 1 : 0); // for cv_labels
    303 
    304     shared = _shared;
    305     buf_count = shared ? 2 : 1;
    306 
    307     buf_size = -1; // the member buf_size is obsolete
    308 
    309     effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated
    310     effective_buf_width = sample_count;
    311     effective_buf_height = work_var_count+1;
    312 
    313     if (effective_buf_width >= effective_buf_height)
    314         effective_buf_height *= buf_count;
    315     else
    316         effective_buf_width *= buf_count;
    317 
    318     if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size)
    319     {
    320         CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit");
    321     }
    322 
    323 
    324 
    325     if ( is_buf_16u )
    326     {
    327         CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ));
    328         CV_CALL( pair16u32s_ptr = (CvPair16u32s*)cvAlloc( sample_count*sizeof(pair16u32s_ptr[0]) ));
    329     }
    330     else
    331     {
    332         CV_CALL( buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ));
    333         CV_CALL( int_ptr = (int**)cvAlloc( sample_count*sizeof(int_ptr[0]) ));
    334     }
    335 
    336     size = is_classifier ? (cat_var_count+1) : cat_var_count;
    337     size = !size ? 1 : size;
    338     CV_CALL( cat_count = cvCreateMat( 1, size, CV_32SC1 ));
    339     CV_CALL( cat_ofs = cvCreateMat( 1, size, CV_32SC1 ));
    340 
    341     size = is_classifier ? (cat_var_count + 1)*params.max_categories : cat_var_count*params.max_categories;
    342     size = !size ? 1 : size;
    343     CV_CALL( cat_map = cvCreateMat( 1, size, CV_32SC1 ));
    344 
    345     // now calculate the maximum size of split,
    346     // create memory storage that will keep nodes and splits of the decision tree
    347     // allocate root node and the buffer for the whole training data
    348     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
    349         (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*));
    350     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
    351     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
    352     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
    353     CV_CALL( node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ));
    354 
    355     nv_size = var_count*sizeof(int);
    356     nv_size = cvAlign(MAX( nv_size, (int)sizeof(CvSetElem) ), sizeof(void*));
    357 
    358     temp_block_size = nv_size;
    359 
    360     if( cv_n )
    361     {
    362         if( sample_count < cv_n*MAX(params.min_sample_count,10) )
    363             CV_ERROR( CV_StsOutOfRange,
    364                 "The many folds in cross-validation for such a small dataset" );
    365 
    366         cv_size = cvAlign( cv_n*(sizeof(int) + sizeof(double)*2), sizeof(double) );
    367         temp_block_size = MAX(temp_block_size, cv_size);
    368     }
    369 
    370     temp_block_size = MAX( temp_block_size + block_size_delta, min_block_size );
    371     CV_CALL( temp_storage = cvCreateMemStorage( temp_block_size ));
    372     CV_CALL( nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nv_size, temp_storage ));
    373     if( cv_size )
    374         CV_CALL( cv_heap = cvCreateSet( 0, sizeof(*cv_heap), cv_size, temp_storage ));
    375 
    376     CV_CALL( data_root = new_node( 0, sample_count, 0, 0 ));
    377 
    378     max_c_count = 1;
    379 
    380     _fdst = 0;
    381     _idst = 0;
    382     if (ord_var_count)
    383         _fdst = (float*)cvAlloc(sample_count*sizeof(_fdst[0]));
    384     if (is_buf_16u && (cat_var_count || is_classifier))
    385         _idst = (int*)cvAlloc(sample_count*sizeof(_idst[0]));
    386 
    387     // transform the training data to convenient representation
    388     for( vi = 0; vi <= var_count; vi++ )
    389     {
    390         int ci;
    391         const uchar* mask = 0;
    392         int64 m_step = 0, step;
    393         const int* idata = 0;
    394         const float* fdata = 0;
    395         int num_valid = 0;
    396 
    397         if( vi < var_count ) // analyze i-th input variable
    398         {
    399             int vi0 = vidx ? vidx[vi] : vi;
    400             ci = get_var_type(vi);
    401             step = ds_step; m_step = ms_step;
    402             if( CV_MAT_TYPE(_train_data->type) == CV_32SC1 )
    403                 idata = _train_data->data.i + vi0*dv_step;
    404             else
    405                 fdata = _train_data->data.fl + vi0*dv_step;
    406             if( _missing_mask )
    407                 mask = _missing_mask->data.ptr + vi0*mv_step;
    408         }
    409         else // analyze _responses
    410         {
    411             ci = cat_var_count;
    412             step = CV_IS_MAT_CONT(_responses->type) ?
    413                 1 : _responses->step / CV_ELEM_SIZE(_responses->type);
    414             if( CV_MAT_TYPE(_responses->type) == CV_32SC1 )
    415                 idata = _responses->data.i;
    416             else
    417                 fdata = _responses->data.fl;
    418         }
    419 
    420         if( (vi < var_count && ci>=0) ||
    421             (vi == var_count && is_classifier) ) // process categorical variable or response
    422         {
    423             int c_count, prev_label;
    424             int* c_map;
    425 
    426             if (is_buf_16u)
    427                 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
    428             else
    429                 idst = buf->data.i + (size_t)vi*sample_count;
    430 
    431             // copy data
    432             for( i = 0; i < sample_count; i++ )
    433             {
    434                 int val = INT_MAX, si = sidx ? sidx[i] : i;
    435                 if( !mask || !mask[(size_t)si*m_step] )
    436                 {
    437                     if( idata )
    438                         val = idata[(size_t)si*step];
    439                     else
    440                     {
    441                         float t = fdata[(size_t)si*step];
    442                         val = cvRound(t);
    443                         if( fabs(t - val) > FLT_EPSILON )
    444                         {
    445                             sprintf( err, "%d-th value of %d-th (categorical) "
    446                                 "variable is not an integer", i, vi );
    447                             CV_ERROR( CV_StsBadArg, err );
    448                         }
    449                     }
    450 
    451                     if( val == INT_MAX )
    452                     {
    453                         sprintf( err, "%d-th value of %d-th (categorical) "
    454                             "variable is too large", i, vi );
    455                         CV_ERROR( CV_StsBadArg, err );
    456                     }
    457                     num_valid++;
    458                 }
    459                 if (is_buf_16u)
    460                 {
    461                     _idst[i] = val;
    462                     pair16u32s_ptr[i].u = udst + i;
    463                     pair16u32s_ptr[i].i = _idst + i;
    464                 }
    465                 else
    466                 {
    467                     idst[i] = val;
    468                     int_ptr[i] = idst + i;
    469                 }
    470             }
    471 
    472             c_count = num_valid > 0;
    473             if (is_buf_16u)
    474             {
    475                 std::sort(pair16u32s_ptr, pair16u32s_ptr + sample_count, LessThanPairs());
    476                 // count the categories
    477                 for( i = 1; i < num_valid; i++ )
    478                     if (*pair16u32s_ptr[i].i != *pair16u32s_ptr[i-1].i)
    479                         c_count ++ ;
    480             }
    481             else
    482             {
    483                 std::sort(int_ptr, int_ptr + sample_count, LessThanPtr<int>());
    484                 // count the categories
    485                 for( i = 1; i < num_valid; i++ )
    486                     c_count += *int_ptr[i] != *int_ptr[i-1];
    487             }
    488 
    489             if( vi > 0 )
    490                 max_c_count = MAX( max_c_count, c_count );
    491             cat_count->data.i[ci] = c_count;
    492             cat_ofs->data.i[ci] = total_c_count;
    493 
    494             // resize cat_map, if need
    495             if( cat_map->cols < total_c_count + c_count )
    496             {
    497                 tmp_map = cat_map;
    498                 CV_CALL( cat_map = cvCreateMat( 1,
    499                     MAX(cat_map->cols*3/2,total_c_count+c_count), CV_32SC1 ));
    500                 for( i = 0; i < total_c_count; i++ )
    501                     cat_map->data.i[i] = tmp_map->data.i[i];
    502                 cvReleaseMat( &tmp_map );
    503             }
    504 
    505             c_map = cat_map->data.i + total_c_count;
    506             total_c_count += c_count;
    507 
    508             c_count = -1;
    509             if (is_buf_16u)
    510             {
    511                 // compact the class indices and build the map
    512                 prev_label = ~*pair16u32s_ptr[0].i;
    513                 for( i = 0; i < num_valid; i++ )
    514                 {
    515                     int cur_label = *pair16u32s_ptr[i].i;
    516                     if( cur_label != prev_label )
    517                         c_map[++c_count] = prev_label = cur_label;
    518                     *pair16u32s_ptr[i].u = (unsigned short)c_count;
    519                 }
    520                 // replace labels for missing values with -1
    521                 for( ; i < sample_count; i++ )
    522                     *pair16u32s_ptr[i].u = 65535;
    523             }
    524             else
    525             {
    526                 // compact the class indices and build the map
    527                 prev_label = ~*int_ptr[0];
    528                 for( i = 0; i < num_valid; i++ )
    529                 {
    530                     int cur_label = *int_ptr[i];
    531                     if( cur_label != prev_label )
    532                         c_map[++c_count] = prev_label = cur_label;
    533                     *int_ptr[i] = c_count;
    534                 }
    535                 // replace labels for missing values with -1
    536                 for( ; i < sample_count; i++ )
    537                     *int_ptr[i] = -1;
    538             }
    539         }
    540         else if( ci < 0 ) // process ordered variable
    541         {
    542             if (is_buf_16u)
    543                 udst = (unsigned short*)(buf->data.s + (size_t)vi*sample_count);
    544             else
    545                 idst = buf->data.i + (size_t)vi*sample_count;
    546 
    547             for( i = 0; i < sample_count; i++ )
    548             {
    549                 float val = ord_nan;
    550                 int si = sidx ? sidx[i] : i;
    551                 if( !mask || !mask[(size_t)si*m_step] )
    552                 {
    553                     if( idata )
    554                         val = (float)idata[(size_t)si*step];
    555                     else
    556                         val = fdata[(size_t)si*step];
    557 
    558                     if( fabs(val) >= ord_nan )
    559                     {
    560                         sprintf( err, "%d-th value of %d-th (ordered) "
    561                             "variable (=%g) is too large", i, vi, val );
    562                         CV_ERROR( CV_StsBadArg, err );
    563                     }
    564                     num_valid++;
    565                 }
    566 
    567                 if (is_buf_16u)
    568                     udst[i] = (unsigned short)i; // TODO: memory corruption may be here
    569                 else
    570                     idst[i] = i;
    571                 _fdst[i] = val;
    572 
    573             }
    574             if (is_buf_16u)
    575                 std::sort(udst, udst + sample_count, LessThanIdx<float, unsigned short>(_fdst));
    576             else
    577                 std::sort(idst, idst + sample_count, LessThanIdx<float, int>(_fdst));
    578         }
    579 
    580         if( vi < var_count )
    581             data_root->set_num_valid(vi, num_valid);
    582     }
    583 
    584     // set sample labels
    585     if (is_buf_16u)
    586         udst = (unsigned short*)(buf->data.s + (size_t)work_var_count*sample_count);
    587     else
    588         idst = buf->data.i + (size_t)work_var_count*sample_count;
    589 
    590     for (i = 0; i < sample_count; i++)
    591     {
    592         if (udst)
    593             udst[i] = sidx ? (unsigned short)sidx[i] : (unsigned short)i;
    594         else
    595             idst[i] = sidx ? sidx[i] : i;
    596     }
    597 
    598     if( cv_n )
    599     {
    600         unsigned short* usdst = 0;
    601         int* idst2 = 0;
    602 
    603         if (is_buf_16u)
    604         {
    605             usdst = (unsigned short*)(buf->data.s + (size_t)(get_work_var_count()-1)*sample_count);
    606             for( i = vi = 0; i < sample_count; i++ )
    607             {
    608                 usdst[i] = (unsigned short)vi++;
    609                 vi &= vi < cv_n ? -1 : 0;
    610             }
    611 
    612             for( i = 0; i < sample_count; i++ )
    613             {
    614                 int a = (*rng)(sample_count);
    615                 int b = (*rng)(sample_count);
    616                 unsigned short unsh = (unsigned short)vi;
    617                 CV_SWAP( usdst[a], usdst[b], unsh );
    618             }
    619         }
    620         else
    621         {
    622             idst2 = buf->data.i + (size_t)(get_work_var_count()-1)*sample_count;
    623             for( i = vi = 0; i < sample_count; i++ )
    624             {
    625                 idst2[i] = vi++;
    626                 vi &= vi < cv_n ? -1 : 0;
    627             }
    628 
    629             for( i = 0; i < sample_count; i++ )
    630             {
    631                 int a = (*rng)(sample_count);
    632                 int b = (*rng)(sample_count);
    633                 CV_SWAP( idst2[a], idst2[b], vi );
    634             }
    635         }
    636     }
    637 
    638     if ( cat_map )
    639         cat_map->cols = MAX( total_c_count, 1 );
    640 
    641     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
    642         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
    643     CV_CALL( split_heap = cvCreateSet( 0, sizeof(*split_heap), max_split_size, tree_storage ));
    644 
    645     have_priors = is_classifier && params.priors;
    646     if( is_classifier )
    647     {
    648         int m = get_num_classes();
    649         double sum = 0;
    650         CV_CALL( priors = cvCreateMat( 1, m, CV_64F ));
    651         for( i = 0; i < m; i++ )
    652         {
    653             double val = have_priors ? params.priors[i] : 1.;
    654             if( val <= 0 )
    655                 CV_ERROR( CV_StsOutOfRange, "Every class weight should be positive" );
    656             priors->data.db[i] = val;
    657             sum += val;
    658         }
    659 
    660         // normalize weights
    661         if( have_priors )
    662             cvScale( priors, priors, 1./sum );
    663 
    664         CV_CALL( priors_mult = cvCloneMat( priors ));
    665         CV_CALL( counts = cvCreateMat( 1, m, CV_32SC1 ));
    666     }
    667 
    668 
    669     CV_CALL( direction = cvCreateMat( 1, sample_count, CV_8UC1 ));
    670     CV_CALL( split_buf = cvCreateMat( 1, sample_count, CV_32SC1 ));
    671 
    672     __END__;
    673 
    674     if( data )
    675         delete data;
    676 
    677     if (_fdst)
    678         cvFree( &_fdst );
    679     if (_idst)
    680         cvFree( &_idst );
    681     cvFree( &int_ptr );
    682     cvFree( &pair16u32s_ptr);
    683     cvReleaseMat( &var_type0 );
    684     cvReleaseMat( &sample_indices );
    685     cvReleaseMat( &tmp_map );
    686 }
    687 
    688 void CvDTreeTrainData::do_responses_copy()
    689 {
    690     responses_copy = cvCreateMat( responses->rows, responses->cols, responses->type );
    691     cvCopy( responses, responses_copy);
    692     responses = responses_copy;
    693 }
    694 
    695 CvDTreeNode* CvDTreeTrainData::subsample_data( const CvMat* _subsample_idx )
    696 {
    697     CvDTreeNode* root = 0;
    698     CvMat* isubsample_idx = 0;
    699     CvMat* subsample_co = 0;
    700 
    701     bool isMakeRootCopy = true;
    702 
    703     CV_FUNCNAME( "CvDTreeTrainData::subsample_data" );
    704 
    705     __BEGIN__;
    706 
    707     if( !data_root )
    708         CV_ERROR( CV_StsError, "No training data has been set" );
    709 
    710     if( _subsample_idx )
    711     {
    712         CV_CALL( isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
    713 
    714         if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count )
    715         {
    716             const int* sidx = isubsample_idx->data.i;
    717             for( int i = 0; i < sample_count; i++ )
    718             {
    719                 if( sidx[i] != i )
    720                 {
    721                     isMakeRootCopy = false;
    722                     break;
    723                 }
    724             }
    725         }
    726         else
    727             isMakeRootCopy = false;
    728     }
    729 
    730     if( isMakeRootCopy )
    731     {
    732         // make a copy of the root node
    733         CvDTreeNode temp;
    734         int i;
    735         root = new_node( 0, 1, 0, 0 );
    736         temp = *root;
    737         *root = *data_root;
    738         root->num_valid = temp.num_valid;
    739         if( root->num_valid )
    740         {
    741             for( i = 0; i < var_count; i++ )
    742                 root->num_valid[i] = data_root->num_valid[i];
    743         }
    744         root->cv_Tn = temp.cv_Tn;
    745         root->cv_node_risk = temp.cv_node_risk;
    746         root->cv_node_error = temp.cv_node_error;
    747     }
    748     else
    749     {
    750         int* sidx = isubsample_idx->data.i;
    751         // co - array of count/offset pairs (to handle duplicated values in _subsample_idx)
    752         int* co, cur_ofs = 0;
    753         int vi, i;
    754         int workVarCount = get_work_var_count();
    755         int count = isubsample_idx->rows + isubsample_idx->cols - 1;
    756 
    757         root = new_node( 0, count, 1, 0 );
    758 
    759         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
    760         cvZero( subsample_co );
    761         co = subsample_co->data.i;
    762         for( i = 0; i < count; i++ )
    763             co[sidx[i]*2]++;
    764         for( i = 0; i < sample_count; i++ )
    765         {
    766             if( co[i*2] )
    767             {
    768                 co[i*2+1] = cur_ofs;
    769                 cur_ofs += co[i*2];
    770             }
    771             else
    772                 co[i*2+1] = -1;
    773         }
    774 
    775         cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
    776         for( vi = 0; vi < workVarCount; vi++ )
    777         {
    778             int ci = get_var_type(vi);
    779 
    780             if( ci >= 0 || vi >= var_count )
    781             {
    782                 int num_valid = 0;
    783                 const int* src = CvDTreeTrainData::get_cat_var_data( data_root, vi, (int*)(uchar*)inn_buf );
    784 
    785                 if (is_buf_16u)
    786                 {
    787                     unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
    788                         (size_t)vi*sample_count + root->offset);
    789                     for( i = 0; i < count; i++ )
    790                     {
    791                         int val = src[sidx[i]];
    792                         udst[i] = (unsigned short)val;
    793                         num_valid += val >= 0;
    794                     }
    795                 }
    796                 else
    797                 {
    798                     int* idst = buf->data.i + root->buf_idx*get_length_subbuf() +
    799                         (size_t)vi*sample_count + root->offset;
    800                     for( i = 0; i < count; i++ )
    801                     {
    802                         int val = src[sidx[i]];
    803                         idst[i] = val;
    804                         num_valid += val >= 0;
    805                     }
    806                 }
    807 
    808                 if( vi < var_count )
    809                     root->set_num_valid(vi, num_valid);
    810             }
    811             else
    812             {
    813                 int *src_idx_buf = (int*)(uchar*)inn_buf;
    814                 float *src_val_buf = (float*)(src_idx_buf + sample_count);
    815                 int* sample_indices_buf = (int*)(src_val_buf + sample_count);
    816                 const int* src_idx = 0;
    817                 const float* src_val = 0;
    818                 get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf );
    819                 int j = 0, idx, count_i;
    820                 int num_valid = data_root->get_num_valid(vi);
    821 
    822                 if (is_buf_16u)
    823                 {
    824                     unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
    825                         (size_t)vi*sample_count + data_root->offset);
    826                     for( i = 0; i < num_valid; i++ )
    827                     {
    828                         idx = src_idx[i];
    829                         count_i = co[idx*2];
    830                         if( count_i )
    831                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
    832                                 udst_idx[j] = (unsigned short)cur_ofs;
    833                     }
    834 
    835                     root->set_num_valid(vi, j);
    836 
    837                     for( ; i < sample_count; i++ )
    838                     {
    839                         idx = src_idx[i];
    840                         count_i = co[idx*2];
    841                         if( count_i )
    842                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
    843                                 udst_idx[j] = (unsigned short)cur_ofs;
    844                     }
    845                 }
    846                 else
    847                 {
    848                     int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() +
    849                         (size_t)vi*sample_count + root->offset;
    850                     for( i = 0; i < num_valid; i++ )
    851                     {
    852                         idx = src_idx[i];
    853                         count_i = co[idx*2];
    854                         if( count_i )
    855                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
    856                                 idst_idx[j] = cur_ofs;
    857                     }
    858 
    859                     root->set_num_valid(vi, j);
    860 
    861                     for( ; i < sample_count; i++ )
    862                     {
    863                         idx = src_idx[i];
    864                         count_i = co[idx*2];
    865                         if( count_i )
    866                             for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ )
    867                                 idst_idx[j] = cur_ofs;
    868                     }
    869                 }
    870             }
    871         }
    872         // sample indices subsampling
    873         const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf);
    874         if (is_buf_16u)
    875         {
    876             unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() +
    877                 (size_t)workVarCount*sample_count + root->offset);
    878             for (i = 0; i < count; i++)
    879                 sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]];
    880         }
    881         else
    882         {
    883             int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() +
    884                 (size_t)workVarCount*sample_count + root->offset;
    885             for (i = 0; i < count; i++)
    886                 sample_idx_dst[i] = sample_idx_src[sidx[i]];
    887         }
    888     }
    889 
    890     __END__;
    891 
    892     cvReleaseMat( &isubsample_idx );
    893     cvReleaseMat( &subsample_co );
    894 
    895     return root;
    896 }
    897 
    898 
    899 void CvDTreeTrainData::get_vectors( const CvMat* _subsample_idx,
    900                                     float* values, uchar* missing,
    901                                     float* _responses, bool get_class_idx )
    902 {
    903     CvMat* subsample_idx = 0;
    904     CvMat* subsample_co = 0;
    905 
    906     CV_FUNCNAME( "CvDTreeTrainData::get_vectors" );
    907 
    908     __BEGIN__;
    909 
    910     int i, vi, total = sample_count, count = total, cur_ofs = 0;
    911     int* sidx = 0;
    912     int* co = 0;
    913 
    914     cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float)));
    915     if( _subsample_idx )
    916     {
    917         CV_CALL( subsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count ));
    918         sidx = subsample_idx->data.i;
    919         CV_CALL( subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 ));
    920         co = subsample_co->data.i;
    921         cvZero( subsample_co );
    922         count = subsample_idx->cols + subsample_idx->rows - 1;
    923         for( i = 0; i < count; i++ )
    924             co[sidx[i]*2]++;
    925         for( i = 0; i < total; i++ )
    926         {
    927             int count_i = co[i*2];
    928             if( count_i )
    929             {
    930                 co[i*2+1] = cur_ofs*var_count;
    931                 cur_ofs += count_i;
    932             }
    933         }
    934     }
    935 
    936     if( missing )
    937         memset( missing, 1, count*var_count );
    938 
    939     for( vi = 0; vi < var_count; vi++ )
    940     {
    941         int ci = get_var_type(vi);
    942         if( ci >= 0 ) // categorical
    943         {
    944             float* dst = values + vi;
    945             uchar* m = missing ? missing + vi : 0;
    946             const int* src = get_cat_var_data(data_root, vi, (int*)(uchar*)inn_buf);
    947 
    948             for( i = 0; i < count; i++, dst += var_count )
    949             {
    950                 int idx = sidx ? sidx[i] : i;
    951                 int val = src[idx];
    952                 *dst = (float)val;
    953                 if( m )
    954                 {
    955                     *m = (!is_buf_16u && val < 0) || (is_buf_16u && (val == 65535));
    956                     m += var_count;
    957                 }
    958             }
    959         }
    960         else // ordered
    961         {
    962             float* dst = values + vi;
    963             uchar* m = missing ? missing + vi : 0;
    964             int count1 = data_root->get_num_valid(vi);
    965             float *src_val_buf = (float*)(uchar*)inn_buf;
    966             int* src_idx_buf = (int*)(src_val_buf + sample_count);
    967             int* sample_indices_buf = src_idx_buf + sample_count;
    968             const float *src_val = 0;
    969             const int* src_idx = 0;
    970             get_ord_var_data(data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf);
    971 
    972             for( i = 0; i < count1; i++ )
    973             {
    974                 int idx = src_idx[i];
    975                 int count_i = 1;
    976                 if( co )
    977                 {
    978                     count_i = co[idx*2];
    979                     cur_ofs = co[idx*2+1];
    980                 }
    981                 else
    982                     cur_ofs = idx*var_count;
    983                 if( count_i )
    984                 {
    985                     float val = src_val[i];
    986                     for( ; count_i > 0; count_i--, cur_ofs += var_count )
    987                     {
    988                         dst[cur_ofs] = val;
    989                         if( m )
    990                             m[cur_ofs] = 0;
    991                     }
    992                 }
    993             }
    994         }
    995     }
    996 
    997     // copy responses
    998     if( _responses )
    999     {
   1000         if( is_classifier )
   1001         {
   1002             const int* src = get_class_labels(data_root, (int*)(uchar*)inn_buf);
   1003             for( i = 0; i < count; i++ )
   1004             {
   1005                 int idx = sidx ? sidx[i] : i;
   1006                 int val = get_class_idx ? src[idx] :
   1007                     cat_map->data.i[cat_ofs->data.i[cat_var_count]+src[idx]];
   1008                 _responses[i] = (float)val;
   1009             }
   1010         }
   1011         else
   1012         {
   1013             float* val_buf = (float*)(uchar*)inn_buf;
   1014             int* sample_idx_buf = (int*)(val_buf + sample_count);
   1015             const float* _values = get_ord_responses(data_root, val_buf, sample_idx_buf);
   1016             for( i = 0; i < count; i++ )
   1017             {
   1018                 int idx = sidx ? sidx[i] : i;
   1019                 _responses[i] = _values[idx];
   1020             }
   1021         }
   1022     }
   1023 
   1024     __END__;
   1025 
   1026     cvReleaseMat( &subsample_idx );
   1027     cvReleaseMat( &subsample_co );
   1028 }
   1029 
   1030 
   1031 CvDTreeNode* CvDTreeTrainData::new_node( CvDTreeNode* parent, int count,
   1032                                          int storage_idx, int offset )
   1033 {
   1034     CvDTreeNode* node = (CvDTreeNode*)cvSetNew( node_heap );
   1035 
   1036     node->sample_count = count;
   1037     node->depth = parent ? parent->depth + 1 : 0;
   1038     node->parent = parent;
   1039     node->left = node->right = 0;
   1040     node->split = 0;
   1041     node->value = 0;
   1042     node->class_idx = 0;
   1043     node->maxlr = 0.;
   1044 
   1045     node->buf_idx = storage_idx;
   1046     node->offset = offset;
   1047     if( nv_heap )
   1048         node->num_valid = (int*)cvSetNew( nv_heap );
   1049     else
   1050         node->num_valid = 0;
   1051     node->alpha = node->node_risk = node->tree_risk = node->tree_error = 0.;
   1052     node->complexity = 0;
   1053 
   1054     if( params.cv_folds > 0 && cv_heap )
   1055     {
   1056         int cv_n = params.cv_folds;
   1057         node->Tn = INT_MAX;
   1058         node->cv_Tn = (int*)cvSetNew( cv_heap );
   1059         node->cv_node_risk = (double*)cvAlignPtr(node->cv_Tn + cv_n, sizeof(double));
   1060         node->cv_node_error = node->cv_node_risk + cv_n;
   1061     }
   1062     else
   1063     {
   1064         node->Tn = 0;
   1065         node->cv_Tn = 0;
   1066         node->cv_node_risk = 0;
   1067         node->cv_node_error = 0;
   1068     }
   1069 
   1070     return node;
   1071 }
   1072 
   1073 
   1074 CvDTreeSplit* CvDTreeTrainData::new_split_ord( int vi, float cmp_val,
   1075                 int split_point, int inversed, float quality )
   1076 {
   1077     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
   1078     split->var_idx = vi;
   1079     split->condensed_idx = INT_MIN;
   1080     split->ord.c = cmp_val;
   1081     split->ord.split_point = split_point;
   1082     split->inversed = inversed;
   1083     split->quality = quality;
   1084     split->next = 0;
   1085 
   1086     return split;
   1087 }
   1088 
   1089 
   1090 CvDTreeSplit* CvDTreeTrainData::new_split_cat( int vi, float quality )
   1091 {
   1092     CvDTreeSplit* split = (CvDTreeSplit*)cvSetNew( split_heap );
   1093     int i, n = (max_c_count + 31)/32;
   1094 
   1095     split->var_idx = vi;
   1096     split->condensed_idx = INT_MIN;
   1097     split->inversed = 0;
   1098     split->quality = quality;
   1099     for( i = 0; i < n; i++ )
   1100         split->subset[i] = 0;
   1101     split->next = 0;
   1102 
   1103     return split;
   1104 }
   1105 
   1106 
   1107 void CvDTreeTrainData::free_node( CvDTreeNode* node )
   1108 {
   1109     CvDTreeSplit* split = node->split;
   1110     free_node_data( node );
   1111     while( split )
   1112     {
   1113         CvDTreeSplit* next = split->next;
   1114         cvSetRemoveByPtr( split_heap, split );
   1115         split = next;
   1116     }
   1117     node->split = 0;
   1118     cvSetRemoveByPtr( node_heap, node );
   1119 }
   1120 
   1121 
   1122 void CvDTreeTrainData::free_node_data( CvDTreeNode* node )
   1123 {
   1124     if( node->num_valid )
   1125     {
   1126         cvSetRemoveByPtr( nv_heap, node->num_valid );
   1127         node->num_valid = 0;
   1128     }
   1129     // do not free cv_* fields, as all the cross-validation related data is released at once.
   1130 }
   1131 
   1132 
   1133 void CvDTreeTrainData::free_train_data()
   1134 {
   1135     cvReleaseMat( &counts );
   1136     cvReleaseMat( &buf );
   1137     cvReleaseMat( &direction );
   1138     cvReleaseMat( &split_buf );
   1139     cvReleaseMemStorage( &temp_storage );
   1140     cvReleaseMat( &responses_copy );
   1141     cv_heap = nv_heap = 0;
   1142 }
   1143 
   1144 
   1145 void CvDTreeTrainData::clear()
   1146 {
   1147     free_train_data();
   1148 
   1149     cvReleaseMemStorage( &tree_storage );
   1150 
   1151     cvReleaseMat( &var_idx );
   1152     cvReleaseMat( &var_type );
   1153     cvReleaseMat( &cat_count );
   1154     cvReleaseMat( &cat_ofs );
   1155     cvReleaseMat( &cat_map );
   1156     cvReleaseMat( &priors );
   1157     cvReleaseMat( &priors_mult );
   1158 
   1159     node_heap = split_heap = 0;
   1160 
   1161     sample_count = var_all = var_count = max_c_count = ord_var_count = cat_var_count = 0;
   1162     have_labels = have_priors = is_classifier = false;
   1163 
   1164     buf_count = buf_size = 0;
   1165     shared = false;
   1166 
   1167     data_root = 0;
   1168 
   1169     rng = &cv::theRNG();
   1170 }
   1171 
   1172 
   1173 int CvDTreeTrainData::get_num_classes() const
   1174 {
   1175     return is_classifier ? cat_count->data.i[cat_var_count] : 0;
   1176 }
   1177 
   1178 
   1179 int CvDTreeTrainData::get_var_type(int vi) const
   1180 {
   1181     return var_type->data.i[vi];
   1182 }
   1183 
   1184 void CvDTreeTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
   1185                                          const float** ord_values, const int** sorted_indices, int* sample_indices_buf )
   1186 {
   1187     int vidx = var_idx ? var_idx->data.i[vi] : vi;
   1188     int node_sample_count = n->sample_count;
   1189     int td_step = train_data->step/CV_ELEM_SIZE(train_data->type);
   1190 
   1191     const int* sample_indices = get_sample_indices(n, sample_indices_buf);
   1192 
   1193     if( !is_buf_16u )
   1194         *sorted_indices = buf->data.i + n->buf_idx*get_length_subbuf() +
   1195         (size_t)vi*sample_count + n->offset;
   1196     else {
   1197         const unsigned short* short_indices = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
   1198             (size_t)vi*sample_count + n->offset );
   1199         for( int i = 0; i < node_sample_count; i++ )
   1200             sorted_indices_buf[i] = short_indices[i];
   1201         *sorted_indices = sorted_indices_buf;
   1202     }
   1203 
   1204     if( tflag == CV_ROW_SAMPLE )
   1205     {
   1206         for( int i = 0; i < node_sample_count &&
   1207             ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
   1208         {
   1209             int idx = (*sorted_indices)[i];
   1210             idx = sample_indices[idx];
   1211             ord_values_buf[i] = *(train_data->data.fl + idx * td_step + vidx);
   1212         }
   1213     }
   1214     else
   1215         for( int i = 0; i < node_sample_count &&
   1216             ((((*sorted_indices)[i] >= 0) && !is_buf_16u) || (((*sorted_indices)[i] != 65535) && is_buf_16u)); i++ )
   1217         {
   1218             int idx = (*sorted_indices)[i];
   1219             idx = sample_indices[idx];
   1220             ord_values_buf[i] = *(train_data->data.fl + vidx* td_step + idx);
   1221         }
   1222 
   1223     *ord_values = ord_values_buf;
   1224 }
   1225 
   1226 
   1227 const int* CvDTreeTrainData::get_class_labels( CvDTreeNode* n, int* labels_buf )
   1228 {
   1229     if (is_classifier)
   1230         return get_cat_var_data( n, var_count, labels_buf);
   1231     return 0;
   1232 }
   1233 
   1234 const int* CvDTreeTrainData::get_sample_indices( CvDTreeNode* n, int* indices_buf )
   1235 {
   1236     return get_cat_var_data( n, get_work_var_count(), indices_buf );
   1237 }
   1238 
   1239 const float* CvDTreeTrainData::get_ord_responses( CvDTreeNode* n, float* values_buf, int*sample_indices_buf )
   1240 {
   1241     int _sample_count = n->sample_count;
   1242     int r_step = CV_IS_MAT_CONT(responses->type) ? 1 : responses->step/CV_ELEM_SIZE(responses->type);
   1243     const int* indices = get_sample_indices(n, sample_indices_buf);
   1244 
   1245     for( int i = 0; i < _sample_count &&
   1246         (((indices[i] >= 0) && !is_buf_16u) || ((indices[i] != 65535) && is_buf_16u)); i++ )
   1247     {
   1248         int idx = indices[i];
   1249         values_buf[i] = *(responses->data.fl + idx * r_step);
   1250     }
   1251 
   1252     return values_buf;
   1253 }
   1254 
   1255 
   1256 const int* CvDTreeTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf )
   1257 {
   1258     if (have_labels)
   1259         return get_cat_var_data( n, get_work_var_count()- 1, labels_buf);
   1260     return 0;
   1261 }
   1262 
   1263 
   1264 const int* CvDTreeTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf)
   1265 {
   1266     const int* cat_values = 0;
   1267     if( !is_buf_16u )
   1268         cat_values = buf->data.i + n->buf_idx*get_length_subbuf() +
   1269             (size_t)vi*sample_count + n->offset;
   1270     else {
   1271         const unsigned short* short_values = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() +
   1272             (size_t)vi*sample_count + n->offset);
   1273         for( int i = 0; i < n->sample_count; i++ )
   1274             cat_values_buf[i] = short_values[i];
   1275         cat_values = cat_values_buf;
   1276     }
   1277     return cat_values;
   1278 }
   1279 
   1280 
   1281 int CvDTreeTrainData::get_child_buf_idx( CvDTreeNode* n )
   1282 {
   1283     int idx = n->buf_idx + 1;
   1284     if( idx >= buf_count )
   1285         idx = shared ? 1 : 0;
   1286     return idx;
   1287 }
   1288 
   1289 
   1290 void CvDTreeTrainData::write_params( CvFileStorage* fs ) const
   1291 {
   1292     CV_FUNCNAME( "CvDTreeTrainData::write_params" );
   1293 
   1294     __BEGIN__;
   1295 
   1296     int vi, vcount = var_count;
   1297 
   1298     cvWriteInt( fs, "is_classifier", is_classifier ? 1 : 0 );
   1299     cvWriteInt( fs, "var_all", var_all );
   1300     cvWriteInt( fs, "var_count", var_count );
   1301     cvWriteInt( fs, "ord_var_count", ord_var_count );
   1302     cvWriteInt( fs, "cat_var_count", cat_var_count );
   1303 
   1304     cvStartWriteStruct( fs, "training_params", CV_NODE_MAP );
   1305     cvWriteInt( fs, "use_surrogates", params.use_surrogates ? 1 : 0 );
   1306 
   1307     if( is_classifier )
   1308     {
   1309         cvWriteInt( fs, "max_categories", params.max_categories );
   1310     }
   1311     else
   1312     {
   1313         cvWriteReal( fs, "regression_accuracy", params.regression_accuracy );
   1314     }
   1315 
   1316     cvWriteInt( fs, "max_depth", params.max_depth );
   1317     cvWriteInt( fs, "min_sample_count", params.min_sample_count );
   1318     cvWriteInt( fs, "cross_validation_folds", params.cv_folds );
   1319 
   1320     if( params.cv_folds > 1 )
   1321     {
   1322         cvWriteInt( fs, "use_1se_rule", params.use_1se_rule ? 1 : 0 );
   1323         cvWriteInt( fs, "truncate_pruned_tree", params.truncate_pruned_tree ? 1 : 0 );
   1324     }
   1325 
   1326     if( priors )
   1327         cvWrite( fs, "priors", priors );
   1328 
   1329     cvEndWriteStruct( fs );
   1330 
   1331     if( var_idx )
   1332         cvWrite( fs, "var_idx", var_idx );
   1333 
   1334     cvStartWriteStruct( fs, "var_type", CV_NODE_SEQ+CV_NODE_FLOW );
   1335 
   1336     for( vi = 0; vi < vcount; vi++ )
   1337         cvWriteInt( fs, 0, var_type->data.i[vi] >= 0 );
   1338 
   1339     cvEndWriteStruct( fs );
   1340 
   1341     if( cat_count && (cat_var_count > 0 || is_classifier) )
   1342     {
   1343         CV_ASSERT( cat_count != 0 );
   1344         cvWrite( fs, "cat_count", cat_count );
   1345         cvWrite( fs, "cat_map", cat_map );
   1346     }
   1347 
   1348     __END__;
   1349 }
   1350 
   1351 
   1352 void CvDTreeTrainData::read_params( CvFileStorage* fs, CvFileNode* node )
   1353 {
   1354     CV_FUNCNAME( "CvDTreeTrainData::read_params" );
   1355 
   1356     __BEGIN__;
   1357 
   1358     CvFileNode *tparams_node, *vartype_node;
   1359     CvSeqReader reader;
   1360     int vi, max_split_size, tree_block_size;
   1361 
   1362     is_classifier = (cvReadIntByName( fs, node, "is_classifier" ) != 0);
   1363     var_all = cvReadIntByName( fs, node, "var_all" );
   1364     var_count = cvReadIntByName( fs, node, "var_count", var_all );
   1365     cat_var_count = cvReadIntByName( fs, node, "cat_var_count" );
   1366     ord_var_count = cvReadIntByName( fs, node, "ord_var_count" );
   1367 
   1368     tparams_node = cvGetFileNodeByName( fs, node, "training_params" );
   1369 
   1370     if( tparams_node ) // training parameters are not necessary
   1371     {
   1372         params.use_surrogates = cvReadIntByName( fs, tparams_node, "use_surrogates", 1 ) != 0;
   1373 
   1374         if( is_classifier )
   1375         {
   1376             params.max_categories = cvReadIntByName( fs, tparams_node, "max_categories" );
   1377         }
   1378         else
   1379         {
   1380             params.regression_accuracy =
   1381                 (float)cvReadRealByName( fs, tparams_node, "regression_accuracy" );
   1382         }
   1383 
   1384         params.max_depth = cvReadIntByName( fs, tparams_node, "max_depth" );
   1385         params.min_sample_count = cvReadIntByName( fs, tparams_node, "min_sample_count" );
   1386         params.cv_folds = cvReadIntByName( fs, tparams_node, "cross_validation_folds" );
   1387 
   1388         if( params.cv_folds > 1 )
   1389         {
   1390             params.use_1se_rule = cvReadIntByName( fs, tparams_node, "use_1se_rule" ) != 0;
   1391             params.truncate_pruned_tree =
   1392                 cvReadIntByName( fs, tparams_node, "truncate_pruned_tree" ) != 0;
   1393         }
   1394 
   1395         priors = (CvMat*)cvReadByName( fs, tparams_node, "priors" );
   1396         if( priors )
   1397         {
   1398             if( !CV_IS_MAT(priors) )
   1399                 CV_ERROR( CV_StsParseError, "priors must stored as a matrix" );
   1400             priors_mult = cvCloneMat( priors );
   1401         }
   1402     }
   1403 
   1404     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, node, "var_idx" ));
   1405     if( var_idx )
   1406     {
   1407         if( !CV_IS_MAT(var_idx) ||
   1408             (var_idx->cols != 1 && var_idx->rows != 1) ||
   1409             var_idx->cols + var_idx->rows - 1 != var_count ||
   1410             CV_MAT_TYPE(var_idx->type) != CV_32SC1 )
   1411             CV_ERROR( CV_StsParseError,
   1412                 "var_idx (if exist) must be valid 1d integer vector containing <var_count> elements" );
   1413 
   1414         for( vi = 0; vi < var_count; vi++ )
   1415             if( (unsigned)var_idx->data.i[vi] >= (unsigned)var_all )
   1416                 CV_ERROR( CV_StsOutOfRange, "some of var_idx elements are out of range" );
   1417     }
   1418 
   1419     ////// read var type
   1420     CV_CALL( var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ));
   1421 
   1422     cat_var_count = 0;
   1423     ord_var_count = -1;
   1424     vartype_node = cvGetFileNodeByName( fs, node, "var_type" );
   1425 
   1426     if( vartype_node && CV_NODE_TYPE(vartype_node->tag) == CV_NODE_INT && var_count == 1 )
   1427         var_type->data.i[0] = vartype_node->data.i ? cat_var_count++ : ord_var_count--;
   1428     else
   1429     {
   1430         if( !vartype_node || CV_NODE_TYPE(vartype_node->tag) != CV_NODE_SEQ ||
   1431             vartype_node->data.seq->total != var_count )
   1432             CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
   1433 
   1434         cvStartReadSeq( vartype_node->data.seq, &reader );
   1435 
   1436         for( vi = 0; vi < var_count; vi++ )
   1437         {
   1438             CvFileNode* n = (CvFileNode*)reader.ptr;
   1439             if( CV_NODE_TYPE(n->tag) != CV_NODE_INT || (n->data.i & ~1) )
   1440                 CV_ERROR( CV_StsParseError, "var_type must exist and be a sequence of 0's and 1's" );
   1441             var_type->data.i[vi] = n->data.i ? cat_var_count++ : ord_var_count--;
   1442             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   1443         }
   1444     }
   1445     var_type->data.i[var_count] = cat_var_count;
   1446 
   1447     ord_var_count = ~ord_var_count;
   1448     //////
   1449 
   1450     if( cat_var_count > 0 || is_classifier )
   1451     {
   1452         int ccount, total_c_count = 0;
   1453         CV_CALL( cat_count = (CvMat*)cvReadByName( fs, node, "cat_count" ));
   1454         CV_CALL( cat_map = (CvMat*)cvReadByName( fs, node, "cat_map" ));
   1455 
   1456         if( !CV_IS_MAT(cat_count) || !CV_IS_MAT(cat_map) ||
   1457             (cat_count->cols != 1 && cat_count->rows != 1) ||
   1458             CV_MAT_TYPE(cat_count->type) != CV_32SC1 ||
   1459             cat_count->cols + cat_count->rows - 1 != cat_var_count + is_classifier ||
   1460             (cat_map->cols != 1 && cat_map->rows != 1) ||
   1461             CV_MAT_TYPE(cat_map->type) != CV_32SC1 )
   1462             CV_ERROR( CV_StsParseError,
   1463             "Both cat_count and cat_map must exist and be valid 1d integer vectors of an appropriate size" );
   1464 
   1465         ccount = cat_var_count + is_classifier;
   1466 
   1467         CV_CALL( cat_ofs = cvCreateMat( 1, ccount + 1, CV_32SC1 ));
   1468         cat_ofs->data.i[0] = 0;
   1469         max_c_count = 1;
   1470 
   1471         for( vi = 0; vi < ccount; vi++ )
   1472         {
   1473             int val = cat_count->data.i[vi];
   1474             if( val <= 0 )
   1475                 CV_ERROR( CV_StsOutOfRange, "some of cat_count elements are out of range" );
   1476             max_c_count = MAX( max_c_count, val );
   1477             cat_ofs->data.i[vi+1] = total_c_count += val;
   1478         }
   1479 
   1480         if( cat_map->cols + cat_map->rows - 1 != total_c_count )
   1481             CV_ERROR( CV_StsBadSize,
   1482             "cat_map vector length is not equal to the total number of categories in all categorical vars" );
   1483     }
   1484 
   1485     max_split_size = cvAlign(sizeof(CvDTreeSplit) +
   1486         (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*));
   1487 
   1488     tree_block_size = MAX((int)sizeof(CvDTreeNode)*8, max_split_size);
   1489     tree_block_size = MAX(tree_block_size + block_size_delta, min_block_size);
   1490     CV_CALL( tree_storage = cvCreateMemStorage( tree_block_size ));
   1491     CV_CALL( node_heap = cvCreateSet( 0, sizeof(node_heap[0]),
   1492             sizeof(CvDTreeNode), tree_storage ));
   1493     CV_CALL( split_heap = cvCreateSet( 0, sizeof(split_heap[0]),
   1494             max_split_size, tree_storage ));
   1495 
   1496     __END__;
   1497 }
   1498 
   1499 /////////////////////// Decision Tree /////////////////////////
   1500 CvDTreeParams::CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
   1501     cv_folds(10), use_surrogates(true), use_1se_rule(true),
   1502     truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
   1503 {}
   1504 
   1505 CvDTreeParams::CvDTreeParams( int _max_depth, int _min_sample_count,
   1506                               float _regression_accuracy, bool _use_surrogates,
   1507                               int _max_categories, int _cv_folds,
   1508                               bool _use_1se_rule, bool _truncate_pruned_tree,
   1509                               const float* _priors ) :
   1510     max_categories(_max_categories), max_depth(_max_depth),
   1511     min_sample_count(_min_sample_count), cv_folds (_cv_folds),
   1512     use_surrogates(_use_surrogates), use_1se_rule(_use_1se_rule),
   1513     truncate_pruned_tree(_truncate_pruned_tree),
   1514     regression_accuracy(_regression_accuracy),
   1515     priors(_priors)
   1516 {}
   1517 
   1518 CvDTree::CvDTree()
   1519 {
   1520     data = 0;
   1521     var_importance = 0;
   1522     default_model_name = "my_tree";
   1523 
   1524     clear();
   1525 }
   1526 
   1527 
   1528 void CvDTree::clear()
   1529 {
   1530     cvReleaseMat( &var_importance );
   1531     if( data )
   1532     {
   1533         if( !data->shared )
   1534             delete data;
   1535         else
   1536             free_tree();
   1537         data = 0;
   1538     }
   1539     root = 0;
   1540     pruned_tree_idx = -1;
   1541 }
   1542 
   1543 
   1544 CvDTree::~CvDTree()
   1545 {
   1546     clear();
   1547 }
   1548 
   1549 
   1550 const CvDTreeNode* CvDTree::get_root() const
   1551 {
   1552     return root;
   1553 }
   1554 
   1555 
   1556 int CvDTree::get_pruned_tree_idx() const
   1557 {
   1558     return pruned_tree_idx;
   1559 }
   1560 
   1561 
   1562 CvDTreeTrainData* CvDTree::get_data()
   1563 {
   1564     return data;
   1565 }
   1566 
   1567 
   1568 bool CvDTree::train( const CvMat* _train_data, int _tflag,
   1569                      const CvMat* _responses, const CvMat* _var_idx,
   1570                      const CvMat* _sample_idx, const CvMat* _var_type,
   1571                      const CvMat* _missing_mask, CvDTreeParams _params )
   1572 {
   1573     bool result = false;
   1574 
   1575     CV_FUNCNAME( "CvDTree::train" );
   1576 
   1577     __BEGIN__;
   1578 
   1579     clear();
   1580     data = new CvDTreeTrainData( _train_data, _tflag, _responses,
   1581                                  _var_idx, _sample_idx, _var_type,
   1582                                  _missing_mask, _params, false );
   1583     CV_CALL( result = do_train(0) );
   1584 
   1585     __END__;
   1586 
   1587     return result;
   1588 }
   1589 
   1590 bool CvDTree::train( const Mat& _train_data, int _tflag,
   1591                     const Mat& _responses, const Mat& _var_idx,
   1592                     const Mat& _sample_idx, const Mat& _var_type,
   1593                     const Mat& _missing_mask, CvDTreeParams _params )
   1594 {
   1595     train_data_hdr = _train_data;
   1596     train_data_mat = _train_data;
   1597     responses_hdr = _responses;
   1598     responses_mat = _responses;
   1599 
   1600     CvMat vidx=_var_idx, sidx=_sample_idx, vtype=_var_type, mmask=_missing_mask;
   1601 
   1602     return train(&train_data_hdr, _tflag, &responses_hdr, vidx.data.ptr ? &vidx : 0, sidx.data.ptr ? &sidx : 0,
   1603                  vtype.data.ptr ? &vtype : 0, mmask.data.ptr ? &mmask : 0, _params);
   1604 }
   1605 
   1606 
   1607 bool CvDTree::train( CvMLData* _data, CvDTreeParams _params )
   1608 {
   1609    bool result = false;
   1610 
   1611     CV_FUNCNAME( "CvDTree::train" );
   1612 
   1613     __BEGIN__;
   1614 
   1615     const CvMat* values = _data->get_values();
   1616     const CvMat* response = _data->get_responses();
   1617     const CvMat* missing = _data->get_missing();
   1618     const CvMat* var_types = _data->get_var_types();
   1619     const CvMat* train_sidx = _data->get_train_sample_idx();
   1620     const CvMat* var_idx = _data->get_var_idx();
   1621 
   1622     CV_CALL( result = train( values, CV_ROW_SAMPLE, response, var_idx,
   1623         train_sidx, var_types, missing, _params ) );
   1624 
   1625     __END__;
   1626 
   1627     return result;
   1628 }
   1629 
   1630 bool CvDTree::train( CvDTreeTrainData* _data, const CvMat* _subsample_idx )
   1631 {
   1632     bool result = false;
   1633 
   1634     CV_FUNCNAME( "CvDTree::train" );
   1635 
   1636     __BEGIN__;
   1637 
   1638     clear();
   1639     data = _data;
   1640     data->shared = true;
   1641     CV_CALL( result = do_train(_subsample_idx));
   1642 
   1643     __END__;
   1644 
   1645     return result;
   1646 }
   1647 
   1648 
   1649 bool CvDTree::do_train( const CvMat* _subsample_idx )
   1650 {
   1651     bool result = false;
   1652 
   1653     CV_FUNCNAME( "CvDTree::do_train" );
   1654 
   1655     __BEGIN__;
   1656 
   1657     root = data->subsample_data( _subsample_idx );
   1658 
   1659     CV_CALL( try_split_node(root));
   1660 
   1661     if( root->split )
   1662     {
   1663         CV_Assert( root->left );
   1664         CV_Assert( root->right );
   1665 
   1666         if( data->params.cv_folds > 0 )
   1667             CV_CALL( prune_cv() );
   1668 
   1669         if( !data->shared )
   1670             data->free_train_data();
   1671 
   1672         result = true;
   1673     }
   1674 
   1675     __END__;
   1676 
   1677     return result;
   1678 }
   1679 
   1680 
   1681 void CvDTree::try_split_node( CvDTreeNode* node )
   1682 {
   1683     CvDTreeSplit* best_split = 0;
   1684     int i, n = node->sample_count, vi;
   1685     bool can_split = true;
   1686     double quality_scale;
   1687 
   1688     calc_node_value( node );
   1689 
   1690     if( node->sample_count <= data->params.min_sample_count ||
   1691         node->depth >= data->params.max_depth )
   1692         can_split = false;
   1693 
   1694     if( can_split && data->is_classifier )
   1695     {
   1696         // check if we have a "pure" node,
   1697         // we assume that cls_count is filled by calc_node_value()
   1698         int* cls_count = data->counts->data.i;
   1699         int nz = 0, m = data->get_num_classes();
   1700         for( i = 0; i < m; i++ )
   1701             nz += cls_count[i] != 0;
   1702         if( nz == 1 ) // there is only one class
   1703             can_split = false;
   1704     }
   1705     else if( can_split )
   1706     {
   1707         if( sqrt(node->node_risk)/n < data->params.regression_accuracy )
   1708             can_split = false;
   1709     }
   1710 
   1711     if( can_split )
   1712     {
   1713         best_split = find_best_split(node);
   1714         // TODO: check the split quality ...
   1715         node->split = best_split;
   1716     }
   1717     if( !can_split || !best_split )
   1718     {
   1719         data->free_node_data(node);
   1720         return;
   1721     }
   1722 
   1723     quality_scale = calc_node_dir( node );
   1724     if( data->params.use_surrogates )
   1725     {
   1726         // find all the surrogate splits
   1727         // and sort them by their similarity to the primary one
   1728         for( vi = 0; vi < data->var_count; vi++ )
   1729         {
   1730             CvDTreeSplit* split;
   1731             int ci = data->get_var_type(vi);
   1732 
   1733             if( vi == best_split->var_idx )
   1734                 continue;
   1735 
   1736             if( ci >= 0 )
   1737                 split = find_surrogate_split_cat( node, vi );
   1738             else
   1739                 split = find_surrogate_split_ord( node, vi );
   1740 
   1741             if( split )
   1742             {
   1743                 // insert the split
   1744                 CvDTreeSplit* prev_split = node->split;
   1745                 split->quality = (float)(split->quality*quality_scale);
   1746 
   1747                 while( prev_split->next &&
   1748                        prev_split->next->quality > split->quality )
   1749                     prev_split = prev_split->next;
   1750                 split->next = prev_split->next;
   1751                 prev_split->next = split;
   1752             }
   1753         }
   1754     }
   1755     split_node_data( node );
   1756     try_split_node( node->left );
   1757     try_split_node( node->right );
   1758 }
   1759 
   1760 
   1761 // calculate direction (left(-1),right(1),missing(0))
   1762 // for each sample using the best split
   1763 // the function returns scale coefficients for surrogate split quality factors.
   1764 // the scale is applied to normalize surrogate split quality relatively to the
   1765 // best (primary) split quality. That is, if a surrogate split is absolutely
   1766 // identical to the primary split, its quality will be set to the maximum value =
   1767 // quality of the primary split; otherwise, it will be lower.
   1768 // besides, the function compute node->maxlr,
   1769 // minimum possible quality (w/o considering the above mentioned scale)
   1770 // for a surrogate split. Surrogate splits with quality less than node->maxlr
   1771 // are not discarded.
   1772 double CvDTree::calc_node_dir( CvDTreeNode* node )
   1773 {
   1774     char* dir = (char*)data->direction->data.ptr;
   1775     int i, n = node->sample_count, vi = node->split->var_idx;
   1776     double L, R;
   1777 
   1778     assert( !node->split->inversed );
   1779 
   1780     if( data->get_var_type(vi) >= 0 ) // split on categorical var
   1781     {
   1782         cv::AutoBuffer<int> inn_buf(n*(!data->have_priors ? 1 : 2));
   1783         int* labels_buf = (int*)inn_buf;
   1784         const int* labels = data->get_cat_var_data( node, vi, labels_buf );
   1785         const int* subset = node->split->subset;
   1786         if( !data->have_priors )
   1787         {
   1788             int sum = 0, sum_abs = 0;
   1789 
   1790             for( i = 0; i < n; i++ )
   1791             {
   1792                 int idx = labels[i];
   1793                 int d = ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ) ?
   1794                     CV_DTREE_CAT_DIR(idx,subset) : 0;
   1795                 sum += d; sum_abs += d & 1;
   1796                 dir[i] = (char)d;
   1797             }
   1798 
   1799             R = (sum_abs + sum) >> 1;
   1800             L = (sum_abs - sum) >> 1;
   1801         }
   1802         else
   1803         {
   1804             const double* priors = data->priors_mult->data.db;
   1805             double sum = 0, sum_abs = 0;
   1806             int* responses_buf = labels_buf + n;
   1807             const int* responses = data->get_class_labels(node, responses_buf);
   1808 
   1809             for( i = 0; i < n; i++ )
   1810             {
   1811                 int idx = labels[i];
   1812                 double w = priors[responses[i]];
   1813                 int d = idx >= 0 ? CV_DTREE_CAT_DIR(idx,subset) : 0;
   1814                 sum += d*w; sum_abs += (d & 1)*w;
   1815                 dir[i] = (char)d;
   1816             }
   1817 
   1818             R = (sum_abs + sum) * 0.5;
   1819             L = (sum_abs - sum) * 0.5;
   1820         }
   1821     }
   1822     else // split on ordered var
   1823     {
   1824         int split_point = node->split->ord.split_point;
   1825         int n1 = node->get_num_valid(vi);
   1826         cv::AutoBuffer<uchar> inn_buf(n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)));
   1827         float* val_buf = (float*)(uchar*)inn_buf;
   1828         int* sorted_buf = (int*)(val_buf + n);
   1829         int* sample_idx_buf = sorted_buf + n;
   1830         const float* val = 0;
   1831         const int* sorted = 0;
   1832         data->get_ord_var_data( node, vi, val_buf, sorted_buf, &val, &sorted, sample_idx_buf);
   1833 
   1834         assert( 0 <= split_point && split_point < n1-1 );
   1835 
   1836         if( !data->have_priors )
   1837         {
   1838             for( i = 0; i <= split_point; i++ )
   1839                 dir[sorted[i]] = (char)-1;
   1840             for( ; i < n1; i++ )
   1841                 dir[sorted[i]] = (char)1;
   1842             for( ; i < n; i++ )
   1843                 dir[sorted[i]] = (char)0;
   1844 
   1845             L = split_point-1;
   1846             R = n1 - split_point + 1;
   1847         }
   1848         else
   1849         {
   1850             const double* priors = data->priors_mult->data.db;
   1851             int* responses_buf = sample_idx_buf + n;
   1852             const int* responses = data->get_class_labels(node, responses_buf);
   1853             L = R = 0;
   1854 
   1855             for( i = 0; i <= split_point; i++ )
   1856             {
   1857                 int idx = sorted[i];
   1858                 double w = priors[responses[idx]];
   1859                 dir[idx] = (char)-1;
   1860                 L += w;
   1861             }
   1862 
   1863             for( ; i < n1; i++ )
   1864             {
   1865                 int idx = sorted[i];
   1866                 double w = priors[responses[idx]];
   1867                 dir[idx] = (char)1;
   1868                 R += w;
   1869             }
   1870 
   1871             for( ; i < n; i++ )
   1872                 dir[sorted[i]] = (char)0;
   1873         }
   1874     }
   1875     node->maxlr = MAX( L, R );
   1876     return node->split->quality/(L + R);
   1877 }
   1878 
   1879 
   1880 namespace cv
   1881 {
   1882 
   1883 template<> CV_EXPORTS void DefaultDeleter<CvDTreeSplit>::operator ()(CvDTreeSplit* obj) const
   1884 {
   1885     fastFree(obj);
   1886 }
   1887 
   1888 DTreeBestSplitFinder::DTreeBestSplitFinder( CvDTree* _tree, CvDTreeNode* _node)
   1889 {
   1890     tree = _tree;
   1891     node = _node;
   1892     splitSize = tree->get_data()->split_heap->elem_size;
   1893 
   1894     bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
   1895     memset(bestSplit.get(), 0, splitSize);
   1896     bestSplit->quality = -1;
   1897     bestSplit->condensed_idx = INT_MIN;
   1898     split.reset((CvDTreeSplit*)fastMalloc(splitSize));
   1899     memset(split.get(), 0, splitSize);
   1900     //haveSplit = false;
   1901 }
   1902 
   1903 DTreeBestSplitFinder::DTreeBestSplitFinder( const DTreeBestSplitFinder& finder, Split )
   1904 {
   1905     tree = finder.tree;
   1906     node = finder.node;
   1907     splitSize = tree->get_data()->split_heap->elem_size;
   1908 
   1909     bestSplit.reset((CvDTreeSplit*)fastMalloc(splitSize));
   1910     memcpy(bestSplit.get(), finder.bestSplit.get(), splitSize);
   1911     split.reset((CvDTreeSplit*)fastMalloc(splitSize));
   1912     memset(split.get(), 0, splitSize);
   1913 }
   1914 
   1915 void DTreeBestSplitFinder::operator()(const BlockedRange& range)
   1916 {
   1917     int vi, vi1 = range.begin(), vi2 = range.end();
   1918     int n = node->sample_count;
   1919     CvDTreeTrainData* data = tree->get_data();
   1920     AutoBuffer<uchar> inn_buf(2*n*(sizeof(int) + sizeof(float)));
   1921 
   1922     for( vi = vi1; vi < vi2; vi++ )
   1923     {
   1924         CvDTreeSplit *res;
   1925         int ci = data->get_var_type(vi);
   1926         if( node->get_num_valid(vi) <= 1 )
   1927             continue;
   1928 
   1929         if( data->is_classifier )
   1930         {
   1931             if( ci >= 0 )
   1932                 res = tree->find_split_cat_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
   1933             else
   1934                 res = tree->find_split_ord_class( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
   1935         }
   1936         else
   1937         {
   1938             if( ci >= 0 )
   1939                 res = tree->find_split_cat_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
   1940             else
   1941                 res = tree->find_split_ord_reg( node, vi, bestSplit->quality, split, (uchar*)inn_buf );
   1942         }
   1943 
   1944         if( res && bestSplit->quality < split->quality )
   1945                 memcpy( bestSplit.get(), split.get(), splitSize );
   1946     }
   1947 }
   1948 
   1949 void DTreeBestSplitFinder::join( DTreeBestSplitFinder& rhs )
   1950 {
   1951     if( bestSplit->quality < rhs.bestSplit->quality )
   1952         memcpy( bestSplit.get(), rhs.bestSplit.get(), splitSize );
   1953 }
   1954 }
   1955 
   1956 
   1957 CvDTreeSplit* CvDTree::find_best_split( CvDTreeNode* node )
   1958 {
   1959     DTreeBestSplitFinder finder( this, node );
   1960 
   1961     cv::parallel_reduce(cv::BlockedRange(0, data->var_count), finder);
   1962 
   1963     CvDTreeSplit *bestSplit = 0;
   1964     if( finder.bestSplit->quality > 0 )
   1965     {
   1966         bestSplit = data->new_split_cat( 0, -1.0f );
   1967         memcpy( bestSplit, finder.bestSplit, finder.splitSize );
   1968     }
   1969 
   1970     return bestSplit;
   1971 }
   1972 
   1973 CvDTreeSplit* CvDTree::find_split_ord_class( CvDTreeNode* node, int vi,
   1974                                              float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
   1975 {
   1976     const float epsilon = FLT_EPSILON*2;
   1977     int n = node->sample_count;
   1978     int n1 = node->get_num_valid(vi);
   1979     int m = data->get_num_classes();
   1980 
   1981     int base_size = 2*m*sizeof(int);
   1982     cv::AutoBuffer<uchar> inn_buf(base_size);
   1983     if( !_ext_buf )
   1984       inn_buf.allocate(base_size + n*(3*sizeof(int)+sizeof(float)));
   1985     uchar* base_buf = (uchar*)inn_buf;
   1986     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
   1987     float* values_buf = (float*)ext_buf;
   1988     int* sorted_indices_buf = (int*)(values_buf + n);
   1989     int* sample_indices_buf = sorted_indices_buf + n;
   1990     const float* values = 0;
   1991     const int* sorted_indices = 0;
   1992     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values,
   1993                             &sorted_indices, sample_indices_buf );
   1994     int* responses_buf =  sample_indices_buf + n;
   1995     const int* responses = data->get_class_labels( node, responses_buf );
   1996 
   1997     const int* rc0 = data->counts->data.i;
   1998     int* lc = (int*)base_buf;
   1999     int* rc = lc + m;
   2000     int i, best_i = -1;
   2001     double lsum2 = 0, rsum2 = 0, best_val = init_quality;
   2002     const double* priors = data->have_priors ? data->priors_mult->data.db : 0;
   2003 
   2004     // init arrays of class instance counters on both sides of the split
   2005     for( i = 0; i < m; i++ )
   2006     {
   2007         lc[i] = 0;
   2008         rc[i] = rc0[i];
   2009     }
   2010 
   2011     // compensate for missing values
   2012     for( i = n1; i < n; i++ )
   2013     {
   2014         rc[responses[sorted_indices[i]]]--;
   2015     }
   2016 
   2017     if( !priors )
   2018     {
   2019         int L = 0, R = n1;
   2020 
   2021         for( i = 0; i < m; i++ )
   2022             rsum2 += (double)rc[i]*rc[i];
   2023 
   2024         for( i = 0; i < n1 - 1; i++ )
   2025         {
   2026             int idx = responses[sorted_indices[i]];
   2027             int lv, rv;
   2028             L++; R--;
   2029             lv = lc[idx]; rv = rc[idx];
   2030             lsum2 += lv*2 + 1;
   2031             rsum2 -= rv*2 - 1;
   2032             lc[idx] = lv + 1; rc[idx] = rv - 1;
   2033 
   2034             if( values[i] + epsilon < values[i+1] )
   2035             {
   2036                 double val = (lsum2*R + rsum2*L)/((double)L*R);
   2037                 if( best_val < val )
   2038                 {
   2039                     best_val = val;
   2040                     best_i = i;
   2041                 }
   2042             }
   2043         }
   2044     }
   2045     else
   2046     {
   2047         double L = 0, R = 0;
   2048         for( i = 0; i < m; i++ )
   2049         {
   2050             double wv = rc[i]*priors[i];
   2051             R += wv;
   2052             rsum2 += wv*wv;
   2053         }
   2054 
   2055         for( i = 0; i < n1 - 1; i++ )
   2056         {
   2057             int idx = responses[sorted_indices[i]];
   2058             int lv, rv;
   2059             double p = priors[idx], p2 = p*p;
   2060             L += p; R -= p;
   2061             lv = lc[idx]; rv = rc[idx];
   2062             lsum2 += p2*(lv*2 + 1);
   2063             rsum2 -= p2*(rv*2 - 1);
   2064             lc[idx] = lv + 1; rc[idx] = rv - 1;
   2065 
   2066             if( values[i] + epsilon < values[i+1] )
   2067             {
   2068                 double val = (lsum2*R + rsum2*L)/((double)L*R);
   2069                 if( best_val < val )
   2070                 {
   2071                     best_val = val;
   2072                     best_i = i;
   2073                 }
   2074             }
   2075         }
   2076     }
   2077 
   2078     CvDTreeSplit* split = 0;
   2079     if( best_i >= 0 )
   2080     {
   2081         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
   2082         split->var_idx = vi;
   2083         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
   2084         split->ord.split_point = best_i;
   2085         split->inversed = 0;
   2086         split->quality = (float)best_val;
   2087     }
   2088     return split;
   2089 }
   2090 
   2091 
   2092 void CvDTree::cluster_categories( const int* vectors, int n, int m,
   2093                                 int* csums, int k, int* labels )
   2094 {
   2095     // TODO: consider adding priors (class weights) and sample weights to the clustering algorithm
   2096     int iters = 0, max_iters = 100;
   2097     int i, j, idx;
   2098     cv::AutoBuffer<double> buf(n + k);
   2099     double *v_weights = buf, *c_weights = buf + n;
   2100     bool modified = true;
   2101     RNG* r = data->rng;
   2102 
   2103     // assign labels randomly
   2104     for( i = 0; i < n; i++ )
   2105     {
   2106         int sum = 0;
   2107         const int* v = vectors + i*m;
   2108         labels[i] = i < k ? i : r->uniform(0, k);
   2109 
   2110         // compute weight of each vector
   2111         for( j = 0; j < m; j++ )
   2112             sum += v[j];
   2113         v_weights[i] = sum ? 1./sum : 0.;
   2114     }
   2115 
   2116     for( i = 0; i < n; i++ )
   2117     {
   2118         int i1 = (*r)(n);
   2119         int i2 = (*r)(n);
   2120         CV_SWAP( labels[i1], labels[i2], j );
   2121     }
   2122 
   2123     for( iters = 0; iters <= max_iters; iters++ )
   2124     {
   2125         // calculate csums
   2126         for( i = 0; i < k; i++ )
   2127         {
   2128             for( j = 0; j < m; j++ )
   2129                 csums[i*m + j] = 0;
   2130         }
   2131 
   2132         for( i = 0; i < n; i++ )
   2133         {
   2134             const int* v = vectors + i*m;
   2135             int* s = csums + labels[i]*m;
   2136             for( j = 0; j < m; j++ )
   2137                 s[j] += v[j];
   2138         }
   2139 
   2140         // exit the loop here, when we have up-to-date csums
   2141         if( iters == max_iters || !modified )
   2142             break;
   2143 
   2144         modified = false;
   2145 
   2146         // calculate weight of each cluster
   2147         for( i = 0; i < k; i++ )
   2148         {
   2149             const int* s = csums + i*m;
   2150             int sum = 0;
   2151             for( j = 0; j < m; j++ )
   2152                 sum += s[j];
   2153             c_weights[i] = sum ? 1./sum : 0;
   2154         }
   2155 
   2156         // now for each vector determine the closest cluster
   2157         for( i = 0; i < n; i++ )
   2158         {
   2159             const int* v = vectors + i*m;
   2160             double alpha = v_weights[i];
   2161             double min_dist2 = DBL_MAX;
   2162             int min_idx = -1;
   2163 
   2164             for( idx = 0; idx < k; idx++ )
   2165             {
   2166                 const int* s = csums + idx*m;
   2167                 double dist2 = 0., beta = c_weights[idx];
   2168                 for( j = 0; j < m; j++ )
   2169                 {
   2170                     double t = v[j]*alpha - s[j]*beta;
   2171                     dist2 += t*t;
   2172                 }
   2173                 if( min_dist2 > dist2 )
   2174                 {
   2175                     min_dist2 = dist2;
   2176                     min_idx = idx;
   2177                 }
   2178             }
   2179 
   2180             if( min_idx != labels[i] )
   2181                 modified = true;
   2182             labels[i] = min_idx;
   2183         }
   2184     }
   2185 }
   2186 
   2187 
   2188 CvDTreeSplit* CvDTree::find_split_cat_class( CvDTreeNode* node, int vi, float init_quality,
   2189                                              CvDTreeSplit* _split, uchar* _ext_buf )
   2190 {
   2191     int ci = data->get_var_type(vi);
   2192     int n = node->sample_count;
   2193     int m = data->get_num_classes();
   2194     int _mi = data->cat_count->data.i[ci], mi = _mi;
   2195 
   2196     int base_size = m*(3 + mi)*sizeof(int) + (mi+1)*sizeof(double);
   2197     if( m > 2 && mi > data->params.max_categories )
   2198         base_size += (m*std::min(data->params.max_categories, n) + mi)*sizeof(int);
   2199     else
   2200         base_size += mi*sizeof(int*);
   2201     cv::AutoBuffer<uchar> inn_buf(base_size);
   2202     if( !_ext_buf )
   2203         inn_buf.allocate(base_size + 2*n*sizeof(int));
   2204     uchar* base_buf = (uchar*)inn_buf;
   2205     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
   2206 
   2207     int* lc = (int*)base_buf;
   2208     int* rc = lc + m;
   2209     int* _cjk = rc + m*2, *cjk = _cjk;
   2210     double* c_weights = (double*)alignPtr(cjk + m*mi, sizeof(double));
   2211 
   2212     int* labels_buf = (int*)ext_buf;
   2213     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
   2214     int* responses_buf = labels_buf + n;
   2215     const int* responses = data->get_class_labels(node, responses_buf);
   2216 
   2217     int* cluster_labels = 0;
   2218     int** int_ptr = 0;
   2219     int i, j, k, idx;
   2220     double L = 0, R = 0;
   2221     double best_val = init_quality;
   2222     int prevcode = 0, best_subset = -1, subset_i, subset_n, subtract = 0;
   2223     const double* priors = data->priors_mult->data.db;
   2224 
   2225     // init array of counters:
   2226     // c_{jk} - number of samples that have vi-th input variable = j and response = k.
   2227     for( j = -1; j < mi; j++ )
   2228         for( k = 0; k < m; k++ )
   2229             cjk[j*m + k] = 0;
   2230 
   2231     for( i = 0; i < n; i++ )
   2232     {
   2233        j = ( labels[i] == 65535 && data->is_buf_16u) ? -1 : labels[i];
   2234        k = responses[i];
   2235        cjk[j*m + k]++;
   2236     }
   2237 
   2238     if( m > 2 )
   2239     {
   2240         if( mi > data->params.max_categories )
   2241         {
   2242             mi = MIN(data->params.max_categories, n);
   2243             cjk = (int*)(c_weights + _mi);
   2244             cluster_labels = cjk + m*mi;
   2245             cluster_categories( _cjk, _mi, m, cjk, mi, cluster_labels );
   2246         }
   2247         subset_i = 1;
   2248         subset_n = 1 << mi;
   2249     }
   2250     else
   2251     {
   2252         assert( m == 2 );
   2253         int_ptr = (int**)(c_weights + _mi);
   2254         for( j = 0; j < mi; j++ )
   2255             int_ptr[j] = cjk + j*2 + 1;
   2256         std::sort(int_ptr, int_ptr + mi, LessThanPtr<int>());
   2257         subset_i = 0;
   2258         subset_n = mi;
   2259     }
   2260 
   2261     for( k = 0; k < m; k++ )
   2262     {
   2263         int sum = 0;
   2264         for( j = 0; j < mi; j++ )
   2265             sum += cjk[j*m + k];
   2266         rc[k] = sum;
   2267         lc[k] = 0;
   2268     }
   2269 
   2270     for( j = 0; j < mi; j++ )
   2271     {
   2272         double sum = 0;
   2273         for( k = 0; k < m; k++ )
   2274             sum += cjk[j*m + k]*priors[k];
   2275         c_weights[j] = sum;
   2276         R += c_weights[j];
   2277     }
   2278 
   2279     for( ; subset_i < subset_n; subset_i++ )
   2280     {
   2281         double weight;
   2282         int* crow;
   2283         double lsum2 = 0, rsum2 = 0;
   2284 
   2285         if( m == 2 )
   2286             idx = (int)(int_ptr[subset_i] - cjk)/2;
   2287         else
   2288         {
   2289             int graycode = (subset_i>>1)^subset_i;
   2290             int diff = graycode ^ prevcode;
   2291 
   2292             // determine index of the changed bit.
   2293             Cv32suf u;
   2294             idx = diff >= (1 << 16) ? 16 : 0;
   2295             u.f = (float)(((diff >> 16) | diff) & 65535);
   2296             idx += (u.i >> 23) - 127;
   2297             subtract = graycode < prevcode;
   2298             prevcode = graycode;
   2299         }
   2300 
   2301         crow = cjk + idx*m;
   2302         weight = c_weights[idx];
   2303         if( weight < FLT_EPSILON )
   2304             continue;
   2305 
   2306         if( !subtract )
   2307         {
   2308             for( k = 0; k < m; k++ )
   2309             {
   2310                 int t = crow[k];
   2311                 int lval = lc[k] + t;
   2312                 int rval = rc[k] - t;
   2313                 double p = priors[k], p2 = p*p;
   2314                 lsum2 += p2*lval*lval;
   2315                 rsum2 += p2*rval*rval;
   2316                 lc[k] = lval; rc[k] = rval;
   2317             }
   2318             L += weight;
   2319             R -= weight;
   2320         }
   2321         else
   2322         {
   2323             for( k = 0; k < m; k++ )
   2324             {
   2325                 int t = crow[k];
   2326                 int lval = lc[k] - t;
   2327                 int rval = rc[k] + t;
   2328                 double p = priors[k], p2 = p*p;
   2329                 lsum2 += p2*lval*lval;
   2330                 rsum2 += p2*rval*rval;
   2331                 lc[k] = lval; rc[k] = rval;
   2332             }
   2333             L -= weight;
   2334             R += weight;
   2335         }
   2336 
   2337         if( L > FLT_EPSILON && R > FLT_EPSILON )
   2338         {
   2339             double val = (lsum2*R + rsum2*L)/((double)L*R);
   2340             if( best_val < val )
   2341             {
   2342                 best_val = val;
   2343                 best_subset = subset_i;
   2344             }
   2345         }
   2346     }
   2347 
   2348     CvDTreeSplit* split = 0;
   2349     if( best_subset >= 0 )
   2350     {
   2351         split = _split ? _split : data->new_split_cat( 0, -1.0f );
   2352         split->var_idx = vi;
   2353         split->quality = (float)best_val;
   2354         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
   2355         if( m == 2 )
   2356         {
   2357             for( i = 0; i <= best_subset; i++ )
   2358             {
   2359                 idx = (int)(int_ptr[i] - cjk) >> 1;
   2360                 split->subset[idx >> 5] |= 1 << (idx & 31);
   2361             }
   2362         }
   2363         else
   2364         {
   2365             for( i = 0; i < _mi; i++ )
   2366             {
   2367                 idx = cluster_labels ? cluster_labels[i] : i;
   2368                 if( best_subset & (1 << idx) )
   2369                     split->subset[i >> 5] |= 1 << (i & 31);
   2370             }
   2371         }
   2372     }
   2373     return split;
   2374 }
   2375 
   2376 
   2377 CvDTreeSplit* CvDTree::find_split_ord_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
   2378 {
   2379     const float epsilon = FLT_EPSILON*2;
   2380     int n = node->sample_count;
   2381     int n1 = node->get_num_valid(vi);
   2382 
   2383     cv::AutoBuffer<uchar> inn_buf;
   2384     if( !_ext_buf )
   2385         inn_buf.allocate(2*n*(sizeof(int) + sizeof(float)));
   2386     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
   2387     float* values_buf = (float*)ext_buf;
   2388     int* sorted_indices_buf = (int*)(values_buf + n);
   2389     int* sample_indices_buf = sorted_indices_buf + n;
   2390     const float* values = 0;
   2391     const int* sorted_indices = 0;
   2392     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
   2393     float* responses_buf =  (float*)(sample_indices_buf + n);
   2394     const float* responses = data->get_ord_responses( node, responses_buf, sample_indices_buf );
   2395 
   2396     int i, best_i = -1;
   2397     double best_val = init_quality, lsum = 0, rsum = node->value*n;
   2398     int L = 0, R = n1;
   2399 
   2400     // compensate for missing values
   2401     for( i = n1; i < n; i++ )
   2402         rsum -= responses[sorted_indices[i]];
   2403 
   2404     // find the optimal split
   2405     for( i = 0; i < n1 - 1; i++ )
   2406     {
   2407         float t = responses[sorted_indices[i]];
   2408         L++; R--;
   2409         lsum += t;
   2410         rsum -= t;
   2411 
   2412         if( values[i] + epsilon < values[i+1] )
   2413         {
   2414             double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
   2415             if( best_val < val )
   2416             {
   2417                 best_val = val;
   2418                 best_i = i;
   2419             }
   2420         }
   2421     }
   2422 
   2423     CvDTreeSplit* split = 0;
   2424     if( best_i >= 0 )
   2425     {
   2426         split = _split ? _split : data->new_split_ord( 0, 0.0f, 0, 0, 0.0f );
   2427         split->var_idx = vi;
   2428         split->ord.c = (values[best_i] + values[best_i+1])*0.5f;
   2429         split->ord.split_point = best_i;
   2430         split->inversed = 0;
   2431         split->quality = (float)best_val;
   2432     }
   2433     return split;
   2434 }
   2435 
   2436 CvDTreeSplit* CvDTree::find_split_cat_reg( CvDTreeNode* node, int vi, float init_quality, CvDTreeSplit* _split, uchar* _ext_buf )
   2437 {
   2438     int ci = data->get_var_type(vi);
   2439     int n = node->sample_count;
   2440     int mi = data->cat_count->data.i[ci];
   2441 
   2442     int base_size = (mi+2)*sizeof(double) + (mi+1)*(sizeof(int) + sizeof(double*));
   2443     cv::AutoBuffer<uchar> inn_buf(base_size);
   2444     if( !_ext_buf )
   2445         inn_buf.allocate(base_size + n*(2*sizeof(int) + sizeof(float)));
   2446     uchar* base_buf = (uchar*)inn_buf;
   2447     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
   2448     int* labels_buf = (int*)ext_buf;
   2449     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
   2450     float* responses_buf = (float*)(labels_buf + n);
   2451     int* sample_indices_buf = (int*)(responses_buf + n);
   2452     const float* responses = data->get_ord_responses(node, responses_buf, sample_indices_buf);
   2453 
   2454     double* sum = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
   2455     int* counts = (int*)(sum + mi) + 1;
   2456     double** sum_ptr = (double**)(counts + mi);
   2457     int i, L = 0, R = 0;
   2458     double best_val = init_quality, lsum = 0, rsum = 0;
   2459     int best_subset = -1, subset_i;
   2460 
   2461     for( i = -1; i < mi; i++ )
   2462         sum[i] = counts[i] = 0;
   2463 
   2464     // calculate sum response and weight of each category of the input var
   2465     for( i = 0; i < n; i++ )
   2466     {
   2467         int idx = ( (labels[i] == 65535) && data->is_buf_16u ) ? -1 : labels[i];
   2468         double s = sum[idx] + responses[i];
   2469         int nc = counts[idx] + 1;
   2470         sum[idx] = s;
   2471         counts[idx] = nc;
   2472     }
   2473 
   2474     // calculate average response in each category
   2475     for( i = 0; i < mi; i++ )
   2476     {
   2477         R += counts[i];
   2478         rsum += sum[i];
   2479         sum[i] /= MAX(counts[i],1);
   2480         sum_ptr[i] = sum + i;
   2481     }
   2482 
   2483     std::sort(sum_ptr, sum_ptr + mi, LessThanPtr<double>());
   2484 
   2485     // revert back to unnormalized sums
   2486     // (there should be a very little loss of accuracy)
   2487     for( i = 0; i < mi; i++ )
   2488         sum[i] *= counts[i];
   2489 
   2490     for( subset_i = 0; subset_i < mi-1; subset_i++ )
   2491     {
   2492         int idx = (int)(sum_ptr[subset_i] - sum);
   2493         int ni = counts[idx];
   2494 
   2495         if( ni )
   2496         {
   2497             double s = sum[idx];
   2498             lsum += s; L += ni;
   2499             rsum -= s; R -= ni;
   2500 
   2501             if( L && R )
   2502             {
   2503                 double val = (lsum*lsum*R + rsum*rsum*L)/((double)L*R);
   2504                 if( best_val < val )
   2505                 {
   2506                     best_val = val;
   2507                     best_subset = subset_i;
   2508                 }
   2509             }
   2510         }
   2511     }
   2512 
   2513     CvDTreeSplit* split = 0;
   2514     if( best_subset >= 0 )
   2515     {
   2516         split = _split ? _split : data->new_split_cat( 0, -1.0f);
   2517         split->var_idx = vi;
   2518         split->quality = (float)best_val;
   2519         memset( split->subset, 0, (data->max_c_count + 31)/32 * sizeof(int));
   2520         for( i = 0; i <= best_subset; i++ )
   2521         {
   2522             int idx = (int)(sum_ptr[i] - sum);
   2523             split->subset[idx >> 5] |= 1 << (idx & 31);
   2524         }
   2525     }
   2526     return split;
   2527 }
   2528 
   2529 CvDTreeSplit* CvDTree::find_surrogate_split_ord( CvDTreeNode* node, int vi, uchar* _ext_buf )
   2530 {
   2531     const float epsilon = FLT_EPSILON*2;
   2532     const char* dir = (char*)data->direction->data.ptr;
   2533     int n = node->sample_count, n1 = node->get_num_valid(vi);
   2534     cv::AutoBuffer<uchar> inn_buf;
   2535     if( !_ext_buf )
   2536         inn_buf.allocate( n*(sizeof(int)*(data->have_priors ? 3 : 2) + sizeof(float)) );
   2537     uchar* ext_buf = _ext_buf ? _ext_buf : (uchar*)inn_buf;
   2538     float* values_buf = (float*)ext_buf;
   2539     int* sorted_indices_buf = (int*)(values_buf + n);
   2540     int* sample_indices_buf = sorted_indices_buf + n;
   2541     const float* values = 0;
   2542     const int* sorted_indices = 0;
   2543     data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
   2544     // LL - number of samples that both the primary and the surrogate splits send to the left
   2545     // LR - ... primary split sends to the left and the surrogate split sends to the right
   2546     // RL - ... primary split sends to the right and the surrogate split sends to the left
   2547     // RR - ... both send to the right
   2548     int i, best_i = -1, best_inversed = 0;
   2549     double best_val;
   2550 
   2551     if( !data->have_priors )
   2552     {
   2553         int LL = 0, RL = 0, LR, RR;
   2554         int worst_val = cvFloor(node->maxlr), _best_val = worst_val;
   2555         int sum = 0, sum_abs = 0;
   2556 
   2557         for( i = 0; i < n1; i++ )
   2558         {
   2559             int d = dir[sorted_indices[i]];
   2560             sum += d; sum_abs += d & 1;
   2561         }
   2562 
   2563         // sum_abs = R + L; sum = R - L
   2564         RR = (sum_abs + sum) >> 1;
   2565         LR = (sum_abs - sum) >> 1;
   2566 
   2567         // initially all the samples are sent to the right by the surrogate split,
   2568         // LR of them are sent to the left by primary split, and RR - to the right.
   2569         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
   2570         for( i = 0; i < n1 - 1; i++ )
   2571         {
   2572             int d = dir[sorted_indices[i]];
   2573 
   2574             if( d < 0 )
   2575             {
   2576                 LL++; LR--;
   2577                 if( LL + RR > _best_val && values[i] + epsilon < values[i+1] )
   2578                 {
   2579                     best_val = LL + RR;
   2580                     best_i = i; best_inversed = 0;
   2581                 }
   2582             }
   2583             else if( d > 0 )
   2584             {
   2585                 RL++; RR--;
   2586                 if( RL + LR > _best_val && values[i] + epsilon < values[i+1] )
   2587                 {
   2588                     best_val = RL + LR;
   2589                     best_i = i; best_inversed = 1;
   2590                 }
   2591             }
   2592         }
   2593         best_val = _best_val;
   2594     }
   2595     else
   2596     {
   2597         double LL = 0, RL = 0, LR, RR;
   2598         double worst_val = node->maxlr;
   2599         double sum = 0, sum_abs = 0;
   2600         const double* priors = data->priors_mult->data.db;
   2601         int* responses_buf = sample_indices_buf + n;
   2602         const int* responses = data->get_class_labels(node, responses_buf);
   2603         best_val = worst_val;
   2604 
   2605         for( i = 0; i < n1; i++ )
   2606         {
   2607             int idx = sorted_indices[i];
   2608             double w = priors[responses[idx]];
   2609             int d = dir[idx];
   2610             sum += d*w; sum_abs += (d & 1)*w;
   2611         }
   2612 
   2613         // sum_abs = R + L; sum = R - L
   2614         RR = (sum_abs + sum)*0.5;
   2615         LR = (sum_abs - sum)*0.5;
   2616 
   2617         // initially all the samples are sent to the right by the surrogate split,
   2618         // LR of them are sent to the left by primary split, and RR - to the right.
   2619         // now iteratively compute LL, LR, RL and RR for every possible surrogate split value.
   2620         for( i = 0; i < n1 - 1; i++ )
   2621         {
   2622             int idx = sorted_indices[i];
   2623             double w = priors[responses[idx]];
   2624             int d = dir[idx];
   2625 
   2626             if( d < 0 )
   2627             {
   2628                 LL += w; LR -= w;
   2629                 if( LL + RR > best_val && values[i] + epsilon < values[i+1] )
   2630                 {
   2631                     best_val = LL + RR;
   2632                     best_i = i; best_inversed = 0;
   2633                 }
   2634             }
   2635             else if( d > 0 )
   2636             {
   2637                 RL += w; RR -= w;
   2638                 if( RL + LR > best_val && values[i] + epsilon < values[i+1] )
   2639                 {
   2640                     best_val = RL + LR;
   2641                     best_i = i; best_inversed = 1;
   2642                 }
   2643             }
   2644         }
   2645     }
   2646     return best_i >= 0 && best_val > node->maxlr ? data->new_split_ord( vi,
   2647         (values[best_i] + values[best_i+1])*0.5f, best_i, best_inversed, (float)best_val ) : 0;
   2648 }
   2649 
   2650 
   2651 CvDTreeSplit* CvDTree::find_surrogate_split_cat( CvDTreeNode* node, int vi, uchar* _ext_buf )
   2652 {
   2653     const char* dir = (char*)data->direction->data.ptr;
   2654     int n = node->sample_count;
   2655     int i, mi = data->cat_count->data.i[data->get_var_type(vi)], l_win = 0;
   2656 
   2657     int base_size = (2*(mi+1)+1)*sizeof(double) + (!data->have_priors ? 2*(mi+1)*sizeof(int) : 0);
   2658     cv::AutoBuffer<uchar> inn_buf(base_size);
   2659     if( !_ext_buf )
   2660         inn_buf.allocate(base_size + n*(sizeof(int) + (data->have_priors ? sizeof(int) : 0)));
   2661     uchar* base_buf = (uchar*)inn_buf;
   2662     uchar* ext_buf = _ext_buf ? _ext_buf : base_buf + base_size;
   2663 
   2664     int* labels_buf = (int*)ext_buf;
   2665     const int* labels = data->get_cat_var_data(node, vi, labels_buf);
   2666     // LL - number of samples that both the primary and the surrogate splits send to the left
   2667     // LR - ... primary split sends to the left and the surrogate split sends to the right
   2668     // RL - ... primary split sends to the right and the surrogate split sends to the left
   2669     // RR - ... both send to the right
   2670     CvDTreeSplit* split = data->new_split_cat( vi, 0 );
   2671     double best_val = 0;
   2672     double* lc = (double*)cv::alignPtr(base_buf,sizeof(double)) + 1;
   2673     double* rc = lc + mi + 1;
   2674 
   2675     for( i = -1; i < mi; i++ )
   2676         lc[i] = rc[i] = 0;
   2677 
   2678     // for each category calculate the weight of samples
   2679     // sent to the left (lc) and to the right (rc) by the primary split
   2680     if( !data->have_priors )
   2681     {
   2682         int* _lc = (int*)rc + 1;
   2683         int* _rc = _lc + mi + 1;
   2684 
   2685         for( i = -1; i < mi; i++ )
   2686             _lc[i] = _rc[i] = 0;
   2687 
   2688         for( i = 0; i < n; i++ )
   2689         {
   2690             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
   2691             int d = dir[i];
   2692             int sum = _lc[idx] + d;
   2693             int sum_abs = _rc[idx] + (d & 1);
   2694             _lc[idx] = sum; _rc[idx] = sum_abs;
   2695         }
   2696 
   2697         for( i = 0; i < mi; i++ )
   2698         {
   2699             int sum = _lc[i];
   2700             int sum_abs = _rc[i];
   2701             lc[i] = (sum_abs - sum) >> 1;
   2702             rc[i] = (sum_abs + sum) >> 1;
   2703         }
   2704     }
   2705     else
   2706     {
   2707         const double* priors = data->priors_mult->data.db;
   2708         int* responses_buf = labels_buf + n;
   2709         const int* responses = data->get_class_labels(node, responses_buf);
   2710 
   2711         for( i = 0; i < n; i++ )
   2712         {
   2713             int idx = ( (labels[i] == 65535) && (data->is_buf_16u) ) ? -1 : labels[i];
   2714             double w = priors[responses[i]];
   2715             int d = dir[i];
   2716             double sum = lc[idx] + d*w;
   2717             double sum_abs = rc[idx] + (d & 1)*w;
   2718             lc[idx] = sum; rc[idx] = sum_abs;
   2719         }
   2720 
   2721         for( i = 0; i < mi; i++ )
   2722         {
   2723             double sum = lc[i];
   2724             double sum_abs = rc[i];
   2725             lc[i] = (sum_abs - sum) * 0.5;
   2726             rc[i] = (sum_abs + sum) * 0.5;
   2727         }
   2728     }
   2729 
   2730     // 2. now form the split.
   2731     // in each category send all the samples to the same direction as majority
   2732     for( i = 0; i < mi; i++ )
   2733     {
   2734         double lval = lc[i], rval = rc[i];
   2735         if( lval > rval )
   2736         {
   2737             split->subset[i >> 5] |= 1 << (i & 31);
   2738             best_val += lval;
   2739             l_win++;
   2740         }
   2741         else
   2742             best_val += rval;
   2743     }
   2744 
   2745     split->quality = (float)best_val;
   2746     if( split->quality <= node->maxlr || l_win == 0 || l_win == mi )
   2747         cvSetRemoveByPtr( data->split_heap, split ), split = 0;
   2748 
   2749     return split;
   2750 }
   2751 
   2752 
   2753 void CvDTree::calc_node_value( CvDTreeNode* node )
   2754 {
   2755     int i, j, k, n = node->sample_count, cv_n = data->params.cv_folds;
   2756     int m = data->get_num_classes();
   2757 
   2758     int base_size = data->is_classifier ? m*cv_n*sizeof(int) : 2*cv_n*sizeof(double)+cv_n*sizeof(int);
   2759     int ext_size = n*(sizeof(int) + (data->is_classifier ? sizeof(int) : sizeof(int)+sizeof(float)));
   2760     cv::AutoBuffer<uchar> inn_buf(base_size + ext_size);
   2761     uchar* base_buf = (uchar*)inn_buf;
   2762     uchar* ext_buf = base_buf + base_size;
   2763 
   2764     int* cv_labels_buf = (int*)ext_buf;
   2765     const int* cv_labels = data->get_cv_labels(node, cv_labels_buf);
   2766 
   2767     if( data->is_classifier )
   2768     {
   2769         // in case of classification tree:
   2770         //  * node value is the label of the class that has the largest weight in the node.
   2771         //  * node risk is the weighted number of misclassified samples,
   2772         //  * j-th cross-validation fold value and risk are calculated as above,
   2773         //    but using the samples with cv_labels(*)!=j.
   2774         //  * j-th cross-validation fold error is calculated as the weighted number of
   2775         //    misclassified samples with cv_labels(*)==j.
   2776 
   2777         // compute the number of instances of each class
   2778         int* cls_count = data->counts->data.i;
   2779         int* responses_buf = cv_labels_buf + n;
   2780         const int* responses = data->get_class_labels(node, responses_buf);
   2781         int* cv_cls_count = (int*)base_buf;
   2782         double max_val = -1, total_weight = 0;
   2783         int max_k = -1;
   2784         double* priors = data->priors_mult->data.db;
   2785 
   2786         for( k = 0; k < m; k++ )
   2787             cls_count[k] = 0;
   2788 
   2789         if( cv_n == 0 )
   2790         {
   2791             for( i = 0; i < n; i++ )
   2792                 cls_count[responses[i]]++;
   2793         }
   2794         else
   2795         {
   2796             for( j = 0; j < cv_n; j++ )
   2797                 for( k = 0; k < m; k++ )
   2798                     cv_cls_count[j*m + k] = 0;
   2799 
   2800             for( i = 0; i < n; i++ )
   2801             {
   2802                 j = cv_labels[i]; k = responses[i];
   2803                 cv_cls_count[j*m + k]++;
   2804             }
   2805 
   2806             for( j = 0; j < cv_n; j++ )
   2807                 for( k = 0; k < m; k++ )
   2808                     cls_count[k] += cv_cls_count[j*m + k];
   2809         }
   2810 
   2811         if( data->have_priors && node->parent == 0 )
   2812         {
   2813             // compute priors_mult from priors, take the sample ratio into account.
   2814             double sum = 0;
   2815             for( k = 0; k < m; k++ )
   2816             {
   2817                 int n_k = cls_count[k];
   2818                 priors[k] = data->priors->data.db[k]*(n_k ? 1./n_k : 0.);
   2819                 sum += priors[k];
   2820             }
   2821             sum = 1./sum;
   2822             for( k = 0; k < m; k++ )
   2823                 priors[k] *= sum;
   2824         }
   2825 
   2826         for( k = 0; k < m; k++ )
   2827         {
   2828             double val = cls_count[k]*priors[k];
   2829             total_weight += val;
   2830             if( max_val < val )
   2831             {
   2832                 max_val = val;
   2833                 max_k = k;
   2834             }
   2835         }
   2836 
   2837         node->class_idx = max_k;
   2838         node->value = data->cat_map->data.i[
   2839             data->cat_ofs->data.i[data->cat_var_count] + max_k];
   2840         node->node_risk = total_weight - max_val;
   2841 
   2842         for( j = 0; j < cv_n; j++ )
   2843         {
   2844             double sum_k = 0, sum = 0, max_val_k = 0;
   2845             max_val = -1; max_k = -1;
   2846 
   2847             for( k = 0; k < m; k++ )
   2848             {
   2849                 double w = priors[k];
   2850                 double val_k = cv_cls_count[j*m + k]*w;
   2851                 double val = cls_count[k]*w - val_k;
   2852                 sum_k += val_k;
   2853                 sum += val;
   2854                 if( max_val < val )
   2855                 {
   2856                     max_val = val;
   2857                     max_val_k = val_k;
   2858                     max_k = k;
   2859                 }
   2860             }
   2861 
   2862             node->cv_Tn[j] = INT_MAX;
   2863             node->cv_node_risk[j] = sum - max_val;
   2864             node->cv_node_error[j] = sum_k - max_val_k;
   2865         }
   2866     }
   2867     else
   2868     {
   2869         // in case of regression tree:
   2870         //  * node value is 1/n*sum_i(Y_i), where Y_i is i-th response,
   2871         //    n is the number of samples in the node.
   2872         //  * node risk is the sum of squared errors: sum_i((Y_i - <node_value>)^2)
   2873         //  * j-th cross-validation fold value and risk are calculated as above,
   2874         //    but using the samples with cv_labels(*)!=j.
   2875         //  * j-th cross-validation fold error is calculated
   2876         //    using samples with cv_labels(*)==j as the test subset:
   2877         //    error_j = sum_(i,cv_labels(i)==j)((Y_i - <node_value_j>)^2),
   2878         //    where node_value_j is the node value calculated
   2879         //    as described in the previous bullet, and summation is done
   2880         //    over the samples with cv_labels(*)==j.
   2881 
   2882         double sum = 0, sum2 = 0;
   2883         float* values_buf = (float*)(cv_labels_buf + n);
   2884         int* sample_indices_buf = (int*)(values_buf + n);
   2885         const float* values = data->get_ord_responses(node, values_buf, sample_indices_buf);
   2886         double *cv_sum = 0, *cv_sum2 = 0;
   2887         int* cv_count = 0;
   2888 
   2889         if( cv_n == 0 )
   2890         {
   2891             for( i = 0; i < n; i++ )
   2892             {
   2893                 double t = values[i];
   2894                 sum += t;
   2895                 sum2 += t*t;
   2896             }
   2897         }
   2898         else
   2899         {
   2900             cv_sum = (double*)base_buf;
   2901             cv_sum2 = cv_sum + cv_n;
   2902             cv_count = (int*)(cv_sum2 + cv_n);
   2903 
   2904             for( j = 0; j < cv_n; j++ )
   2905             {
   2906                 cv_sum[j] = cv_sum2[j] = 0.;
   2907                 cv_count[j] = 0;
   2908             }
   2909 
   2910             for( i = 0; i < n; i++ )
   2911             {
   2912                 j = cv_labels[i];
   2913                 double t = values[i];
   2914                 double s = cv_sum[j] + t;
   2915                 double s2 = cv_sum2[j] + t*t;
   2916                 int nc = cv_count[j] + 1;
   2917                 cv_sum[j] = s;
   2918                 cv_sum2[j] = s2;
   2919                 cv_count[j] = nc;
   2920             }
   2921 
   2922             for( j = 0; j < cv_n; j++ )
   2923             {
   2924                 sum += cv_sum[j];
   2925                 sum2 += cv_sum2[j];
   2926             }
   2927         }
   2928 
   2929         node->node_risk = sum2 - (sum/n)*sum;
   2930         node->value = sum/n;
   2931 
   2932         for( j = 0; j < cv_n; j++ )
   2933         {
   2934             double s = cv_sum[j], si = sum - s;
   2935             double s2 = cv_sum2[j], s2i = sum2 - s2;
   2936             int c = cv_count[j], ci = n - c;
   2937             double r = si/MAX(ci,1);
   2938             node->cv_node_risk[j] = s2i - r*r*ci;
   2939             node->cv_node_error[j] = s2 - 2*r*s + c*r*r;
   2940             node->cv_Tn[j] = INT_MAX;
   2941         }
   2942     }
   2943 }
   2944 
   2945 
   2946 void CvDTree::complete_node_dir( CvDTreeNode* node )
   2947 {
   2948     int vi, i, n = node->sample_count, nl, nr, d0 = 0, d1 = -1;
   2949     int nz = n - node->get_num_valid(node->split->var_idx);
   2950     char* dir = (char*)data->direction->data.ptr;
   2951 
   2952     // try to complete direction using surrogate splits
   2953     if( nz && data->params.use_surrogates )
   2954     {
   2955         cv::AutoBuffer<uchar> inn_buf(n*(2*sizeof(int)+sizeof(float)));
   2956         CvDTreeSplit* split = node->split->next;
   2957         for( ; split != 0 && nz; split = split->next )
   2958         {
   2959             int inversed_mask = split->inversed ? -1 : 0;
   2960             vi = split->var_idx;
   2961 
   2962             if( data->get_var_type(vi) >= 0 ) // split on categorical var
   2963             {
   2964                 int* labels_buf = (int*)(uchar*)inn_buf;
   2965                 const int* labels = data->get_cat_var_data(node, vi, labels_buf);
   2966                 const int* subset = split->subset;
   2967 
   2968                 for( i = 0; i < n; i++ )
   2969                 {
   2970                     int idx = labels[i];
   2971                     if( !dir[i] && ( ((idx >= 0)&&(!data->is_buf_16u)) || ((idx != 65535)&&(data->is_buf_16u)) ))
   2972 
   2973                     {
   2974                         int d = CV_DTREE_CAT_DIR(idx,subset);
   2975                         dir[i] = (char)((d ^ inversed_mask) - inversed_mask);
   2976                         if( --nz )
   2977                             break;
   2978                     }
   2979                 }
   2980             }
   2981             else // split on ordered var
   2982             {
   2983                 float* values_buf = (float*)(uchar*)inn_buf;
   2984                 int* sorted_indices_buf = (int*)(values_buf + n);
   2985                 int* sample_indices_buf = sorted_indices_buf + n;
   2986                 const float* values = 0;
   2987                 const int* sorted_indices = 0;
   2988                 data->get_ord_var_data( node, vi, values_buf, sorted_indices_buf, &values, &sorted_indices, sample_indices_buf );
   2989                 int split_point = split->ord.split_point;
   2990                 int n1 = node->get_num_valid(vi);
   2991 
   2992                 assert( 0 <= split_point && split_point < n-1 );
   2993 
   2994                 for( i = 0; i < n1; i++ )
   2995                 {
   2996                     int idx = sorted_indices[i];
   2997                     if( !dir[idx] )
   2998                     {
   2999                         int d = i <= split_point ? -1 : 1;
   3000                         dir[idx] = (char)((d ^ inversed_mask) - inversed_mask);
   3001                         if( --nz )
   3002                             break;
   3003                     }
   3004                 }
   3005             }
   3006         }
   3007     }
   3008 
   3009     // find the default direction for the rest
   3010     if( nz )
   3011     {
   3012         for( i = nr = 0; i < n; i++ )
   3013             nr += dir[i] > 0;
   3014         nl = n - nr - nz;
   3015         d0 = nl > nr ? -1 : nr > nl;
   3016     }
   3017 
   3018     // make sure that every sample is directed either to the left or to the right
   3019     for( i = 0; i < n; i++ )
   3020     {
   3021         int d = dir[i];
   3022         if( !d )
   3023         {
   3024             d = d0;
   3025             if( !d )
   3026                 d = d1, d1 = -d1;
   3027         }
   3028         d = d > 0;
   3029         dir[i] = (char)d; // remap (-1,1) to (0,1)
   3030     }
   3031 }
   3032 
   3033 
   3034 void CvDTree::split_node_data( CvDTreeNode* node )
   3035 {
   3036     int vi, i, n = node->sample_count, nl, nr, scount = data->sample_count;
   3037     char* dir = (char*)data->direction->data.ptr;
   3038     CvDTreeNode *left = 0, *right = 0;
   3039     int* new_idx = data->split_buf->data.i;
   3040     int new_buf_idx = data->get_child_buf_idx( node );
   3041     int work_var_count = data->get_work_var_count();
   3042     CvMat* buf = data->buf;
   3043     size_t length_buf_row = data->get_length_subbuf();
   3044     cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int) + sizeof(float)));
   3045     int* temp_buf = (int*)(uchar*)inn_buf;
   3046 
   3047     complete_node_dir(node);
   3048 
   3049     for( i = nl = nr = 0; i < n; i++ )
   3050     {
   3051         int d = dir[i];
   3052         // initialize new indices for splitting ordered variables
   3053         new_idx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li
   3054         nr += d;
   3055         nl += d^1;
   3056     }
   3057 
   3058     bool split_input_data;
   3059     node->left = left = data->new_node( node, nl, new_buf_idx, node->offset );
   3060     node->right = right = data->new_node( node, nr, new_buf_idx, node->offset + nl );
   3061 
   3062     split_input_data = node->depth + 1 < data->params.max_depth &&
   3063         (node->left->sample_count > data->params.min_sample_count ||
   3064         node->right->sample_count > data->params.min_sample_count);
   3065 
   3066     // split ordered variables, keep both halves sorted.
   3067     for( vi = 0; vi < data->var_count; vi++ )
   3068     {
   3069         int ci = data->get_var_type(vi);
   3070 
   3071         if( ci >= 0 || !split_input_data )
   3072             continue;
   3073 
   3074         int n1 = node->get_num_valid(vi);
   3075         float* src_val_buf = (float*)(uchar*)(temp_buf + n);
   3076         int* src_sorted_idx_buf = (int*)(src_val_buf + n);
   3077         int* src_sample_idx_buf = src_sorted_idx_buf + n;
   3078         const float* src_val = 0;
   3079         const int* src_sorted_idx = 0;
   3080         data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf);
   3081 
   3082         for(i = 0; i < n; i++)
   3083             temp_buf[i] = src_sorted_idx[i];
   3084 
   3085         if (data->is_buf_16u)
   3086         {
   3087             unsigned short *ldst, *rdst, *ldst0, *rdst0;
   3088             //unsigned short tl, tr;
   3089             ldst0 = ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
   3090                 vi*scount + left->offset);
   3091             rdst0 = rdst = (unsigned short*)(ldst + nl);
   3092 
   3093             // split sorted
   3094             for( i = 0; i < n1; i++ )
   3095             {
   3096                 int idx = temp_buf[i];
   3097                 int d = dir[idx];
   3098                 idx = new_idx[idx];
   3099                 if (d)
   3100                 {
   3101                     *rdst = (unsigned short)idx;
   3102                     rdst++;
   3103                 }
   3104                 else
   3105                 {
   3106                     *ldst = (unsigned short)idx;
   3107                     ldst++;
   3108                 }
   3109             }
   3110 
   3111             left->set_num_valid(vi, (int)(ldst - ldst0));
   3112             right->set_num_valid(vi, (int)(rdst - rdst0));
   3113 
   3114             // split missing
   3115             for( ; i < n; i++ )
   3116             {
   3117                 int idx = temp_buf[i];
   3118                 int d = dir[idx];
   3119                 idx = new_idx[idx];
   3120                 if (d)
   3121                 {
   3122                     *rdst = (unsigned short)idx;
   3123                     rdst++;
   3124                 }
   3125                 else
   3126                 {
   3127                     *ldst = (unsigned short)idx;
   3128                     ldst++;
   3129                 }
   3130             }
   3131         }
   3132         else
   3133         {
   3134             int *ldst0, *ldst, *rdst0, *rdst;
   3135             ldst0 = ldst = buf->data.i + left->buf_idx*length_buf_row +
   3136                 vi*scount + left->offset;
   3137             rdst0 = rdst = buf->data.i + right->buf_idx*length_buf_row +
   3138                 vi*scount + right->offset;
   3139 
   3140             // split sorted
   3141             for( i = 0; i < n1; i++ )
   3142             {
   3143                 int idx = temp_buf[i];
   3144                 int d = dir[idx];
   3145                 idx = new_idx[idx];
   3146                 if (d)
   3147                 {
   3148                     *rdst = idx;
   3149                     rdst++;
   3150                 }
   3151                 else
   3152                 {
   3153                     *ldst = idx;
   3154                     ldst++;
   3155                 }
   3156             }
   3157 
   3158             left->set_num_valid(vi, (int)(ldst - ldst0));
   3159             right->set_num_valid(vi, (int)(rdst - rdst0));
   3160 
   3161             // split missing
   3162             for( ; i < n; i++ )
   3163             {
   3164                 int idx = temp_buf[i];
   3165                 int d = dir[idx];
   3166                 idx = new_idx[idx];
   3167                 if (d)
   3168                 {
   3169                     *rdst = idx;
   3170                     rdst++;
   3171                 }
   3172                 else
   3173                 {
   3174                     *ldst = idx;
   3175                     ldst++;
   3176                 }
   3177             }
   3178         }
   3179     }
   3180 
   3181     // split categorical vars, responses and cv_labels using new_idx relocation table
   3182     for( vi = 0; vi < work_var_count; vi++ )
   3183     {
   3184         int ci = data->get_var_type(vi);
   3185         int n1 = node->get_num_valid(vi), nr1 = 0;
   3186 
   3187         if( ci < 0 || (vi < data->var_count && !split_input_data) )
   3188             continue;
   3189 
   3190         int *src_lbls_buf = temp_buf + n;
   3191         const int* src_lbls = data->get_cat_var_data(node, vi, src_lbls_buf);
   3192 
   3193         for(i = 0; i < n; i++)
   3194             temp_buf[i] = src_lbls[i];
   3195 
   3196         if (data->is_buf_16u)
   3197         {
   3198             unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row +
   3199                 vi*scount + left->offset);
   3200             unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row +
   3201                 vi*scount + right->offset);
   3202 
   3203             for( i = 0; i < n; i++ )
   3204             {
   3205                 int d = dir[i];
   3206                 int idx = temp_buf[i];
   3207                 if (d)
   3208                 {
   3209                     *rdst = (unsigned short)idx;
   3210                     rdst++;
   3211                     nr1 += (idx != 65535 )&d;
   3212                 }
   3213                 else
   3214                 {
   3215                     *ldst = (unsigned short)idx;
   3216                     ldst++;
   3217                 }
   3218             }
   3219 
   3220             if( vi < data->var_count )
   3221             {
   3222                 left->set_num_valid(vi, n1 - nr1);
   3223                 right->set_num_valid(vi, nr1);
   3224             }
   3225         }
   3226         else
   3227         {
   3228             int *ldst = buf->data.i + left->buf_idx*length_buf_row +
   3229                 vi*scount + left->offset;
   3230             int *rdst = buf->data.i + right->buf_idx*length_buf_row +
   3231                 vi*scount + right->offset;
   3232 
   3233             for( i = 0; i < n; i++ )
   3234             {
   3235                 int d = dir[i];
   3236                 int idx = temp_buf[i];
   3237                 if (d)
   3238                 {
   3239                     *rdst = idx;
   3240                     rdst++;
   3241                     nr1 += (idx >= 0)&d;
   3242                 }
   3243                 else
   3244                 {
   3245                     *ldst = idx;
   3246                     ldst++;
   3247                 }
   3248 
   3249             }
   3250 
   3251             if( vi < data->var_count )
   3252             {
   3253                 left->set_num_valid(vi, n1 - nr1);
   3254                 right->set_num_valid(vi, nr1);
   3255             }
   3256         }
   3257     }
   3258 
   3259 
   3260     // split sample indices
   3261     int *sample_idx_src_buf = temp_buf + n;
   3262     const int* sample_idx_src = data->get_sample_indices(node, sample_idx_src_buf);
   3263 
   3264     for(i = 0; i < n; i++)
   3265         temp_buf[i] = sample_idx_src[i];
   3266 
   3267     int pos = data->get_work_var_count();
   3268     if (data->is_buf_16u)
   3269     {
   3270         unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row +
   3271             pos*scount + left->offset);
   3272         unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row +
   3273             pos*scount + right->offset);
   3274         for (i = 0; i < n; i++)
   3275         {
   3276             int d = dir[i];
   3277             unsigned short idx = (unsigned short)temp_buf[i];
   3278             if (d)
   3279             {
   3280                 *rdst = idx;
   3281                 rdst++;
   3282             }
   3283             else
   3284             {
   3285                 *ldst = idx;
   3286                 ldst++;
   3287             }
   3288         }
   3289     }
   3290     else
   3291     {
   3292         int* ldst = buf->data.i + left->buf_idx*length_buf_row +
   3293             pos*scount + left->offset;
   3294         int* rdst = buf->data.i + right->buf_idx*length_buf_row +
   3295             pos*scount + right->offset;
   3296         for (i = 0; i < n; i++)
   3297         {
   3298             int d = dir[i];
   3299             int idx = temp_buf[i];
   3300             if (d)
   3301             {
   3302                 *rdst = idx;
   3303                 rdst++;
   3304             }
   3305             else
   3306             {
   3307                 *ldst = idx;
   3308                 ldst++;
   3309             }
   3310         }
   3311     }
   3312 
   3313     // deallocate the parent node data that is not needed anymore
   3314     data->free_node_data(node);
   3315 }
   3316 
   3317 float CvDTree::calc_error( CvMLData* _data, int type, std::vector<float> *resp )
   3318 {
   3319     float err = 0;
   3320     const CvMat* values = _data->get_values();
   3321     const CvMat* response = _data->get_responses();
   3322     const CvMat* missing = _data->get_missing();
   3323     const CvMat* sample_idx = (type == CV_TEST_ERROR) ? _data->get_test_sample_idx() : _data->get_train_sample_idx();
   3324     const CvMat* var_types = _data->get_var_types();
   3325     int* sidx = sample_idx ? sample_idx->data.i : 0;
   3326     int r_step = CV_IS_MAT_CONT(response->type) ?
   3327                 1 : response->step / CV_ELEM_SIZE(response->type);
   3328     bool is_classifier = var_types->data.ptr[var_types->cols-1] == CV_VAR_CATEGORICAL;
   3329     int sample_count = sample_idx ? sample_idx->cols : 0;
   3330     sample_count = (type == CV_TRAIN_ERROR && sample_count == 0) ? values->rows : sample_count;
   3331     float* pred_resp = 0;
   3332     if( resp && (sample_count > 0) )
   3333     {
   3334         resp->resize( sample_count );
   3335         pred_resp = &((*resp)[0]);
   3336     }
   3337 
   3338     if ( is_classifier )
   3339     {
   3340         for( int i = 0; i < sample_count; i++ )
   3341         {
   3342             CvMat sample, miss;
   3343             int si = sidx ? sidx[i] : i;
   3344             cvGetRow( values, &sample, si );
   3345             if( missing )
   3346                 cvGetRow( missing, &miss, si );
   3347             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
   3348             if( pred_resp )
   3349                 pred_resp[i] = r;
   3350             int d = fabs((double)r - response->data.fl[(size_t)si*r_step]) <= FLT_EPSILON ? 0 : 1;
   3351             err += d;
   3352         }
   3353         err = sample_count ? err / (float)sample_count * 100 : -FLT_MAX;
   3354     }
   3355     else
   3356     {
   3357         for( int i = 0; i < sample_count; i++ )
   3358         {
   3359             CvMat sample, miss;
   3360             int si = sidx ? sidx[i] : i;
   3361             cvGetRow( values, &sample, si );
   3362             if( missing )
   3363                 cvGetRow( missing, &miss, si );
   3364             float r = (float)predict( &sample, missing ? &miss : 0 )->value;
   3365             if( pred_resp )
   3366                 pred_resp[i] = r;
   3367             float d = r - response->data.fl[(size_t)si*r_step];
   3368             err += d*d;
   3369         }
   3370         err = sample_count ? err / (float)sample_count : -FLT_MAX;
   3371     }
   3372     return err;
   3373 }
   3374 
   3375 void CvDTree::prune_cv()
   3376 {
   3377     CvMat* ab = 0;
   3378     CvMat* temp = 0;
   3379     CvMat* err_jk = 0;
   3380 
   3381     // 1. build tree sequence for each cv fold, calculate error_{Tj,beta_k}.
   3382     // 2. choose the best tree index (if need, apply 1SE rule).
   3383     // 3. store the best index and cut the branches.
   3384 
   3385     CV_FUNCNAME( "CvDTree::prune_cv" );
   3386 
   3387     __BEGIN__;
   3388 
   3389     int ti, j, tree_count = 0, cv_n = data->params.cv_folds, n = root->sample_count;
   3390     // currently, 1SE for regression is not implemented
   3391     bool use_1se = data->params.use_1se_rule != 0 && data->is_classifier;
   3392     double* err;
   3393     double min_err = 0, min_err_se = 0;
   3394     int min_idx = -1;
   3395 
   3396     CV_CALL( ab = cvCreateMat( 1, 256, CV_64F ));
   3397 
   3398     // build the main tree sequence, calculate alpha's
   3399     for(;;tree_count++)
   3400     {
   3401         double min_alpha = update_tree_rnc(tree_count, -1);
   3402         if( cut_tree(tree_count, -1, min_alpha) )
   3403             break;
   3404 
   3405         if( ab->cols <= tree_count )
   3406         {
   3407             CV_CALL( temp = cvCreateMat( 1, ab->cols*3/2, CV_64F ));
   3408             for( ti = 0; ti < ab->cols; ti++ )
   3409                 temp->data.db[ti] = ab->data.db[ti];
   3410             cvReleaseMat( &ab );
   3411             ab = temp;
   3412             temp = 0;
   3413         }
   3414 
   3415         ab->data.db[tree_count] = min_alpha;
   3416     }
   3417 
   3418     ab->data.db[0] = 0.;
   3419 
   3420     if( tree_count > 0 )
   3421     {
   3422         for( ti = 1; ti < tree_count-1; ti++ )
   3423             ab->data.db[ti] = sqrt(ab->data.db[ti]*ab->data.db[ti+1]);
   3424         ab->data.db[tree_count-1] = DBL_MAX*0.5;
   3425 
   3426         CV_CALL( err_jk = cvCreateMat( cv_n, tree_count, CV_64F ));
   3427         err = err_jk->data.db;
   3428 
   3429         for( j = 0; j < cv_n; j++ )
   3430         {
   3431             int tj = 0, tk = 0;
   3432             for( ; tk < tree_count; tj++ )
   3433             {
   3434                 double min_alpha = update_tree_rnc(tj, j);
   3435                 if( cut_tree(tj, j, min_alpha) )
   3436                     min_alpha = DBL_MAX;
   3437 
   3438                 for( ; tk < tree_count; tk++ )
   3439                 {
   3440                     if( ab->data.db[tk] > min_alpha )
   3441                         break;
   3442                     err[j*tree_count + tk] = root->tree_error;
   3443                 }
   3444             }
   3445         }
   3446 
   3447         for( ti = 0; ti < tree_count; ti++ )
   3448         {
   3449             double sum_err = 0;
   3450             for( j = 0; j < cv_n; j++ )
   3451                 sum_err += err[j*tree_count + ti];
   3452             if( ti == 0 || sum_err < min_err )
   3453             {
   3454                 min_err = sum_err;
   3455                 min_idx = ti;
   3456                 if( use_1se )
   3457                     min_err_se = sqrt( sum_err*(n - sum_err) );
   3458             }
   3459             else if( sum_err < min_err + min_err_se )
   3460                 min_idx = ti;
   3461         }
   3462     }
   3463 
   3464     pruned_tree_idx = min_idx;
   3465     free_prune_data(data->params.truncate_pruned_tree != 0);
   3466 
   3467     __END__;
   3468 
   3469     cvReleaseMat( &err_jk );
   3470     cvReleaseMat( &ab );
   3471     cvReleaseMat( &temp );
   3472 }
   3473 
   3474 
   3475 double CvDTree::update_tree_rnc( int T, int fold )
   3476 {
   3477     CvDTreeNode* node = root;
   3478     double min_alpha = DBL_MAX;
   3479 
   3480     for(;;)
   3481     {
   3482         CvDTreeNode* parent;
   3483         for(;;)
   3484         {
   3485             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
   3486             if( t <= T || !node->left )
   3487             {
   3488                 node->complexity = 1;
   3489                 node->tree_risk = node->node_risk;
   3490                 node->tree_error = 0.;
   3491                 if( fold >= 0 )
   3492                 {
   3493                     node->tree_risk = node->cv_node_risk[fold];
   3494                     node->tree_error = node->cv_node_error[fold];
   3495                 }
   3496                 break;
   3497             }
   3498             node = node->left;
   3499         }
   3500 
   3501         for( parent = node->parent; parent && parent->right == node;
   3502             node = parent, parent = parent->parent )
   3503         {
   3504             parent->complexity += node->complexity;
   3505             parent->tree_risk += node->tree_risk;
   3506             parent->tree_error += node->tree_error;
   3507 
   3508             parent->alpha = ((fold >= 0 ? parent->cv_node_risk[fold] : parent->node_risk)
   3509                 - parent->tree_risk)/(parent->complexity - 1);
   3510             min_alpha = MIN( min_alpha, parent->alpha );
   3511         }
   3512 
   3513         if( !parent )
   3514             break;
   3515 
   3516         parent->complexity = node->complexity;
   3517         parent->tree_risk = node->tree_risk;
   3518         parent->tree_error = node->tree_error;
   3519         node = parent->right;
   3520     }
   3521 
   3522     return min_alpha;
   3523 }
   3524 
   3525 
   3526 int CvDTree::cut_tree( int T, int fold, double min_alpha )
   3527 {
   3528     CvDTreeNode* node = root;
   3529     if( !node->left )
   3530         return 1;
   3531 
   3532     for(;;)
   3533     {
   3534         CvDTreeNode* parent;
   3535         for(;;)
   3536         {
   3537             int t = fold >= 0 ? node->cv_Tn[fold] : node->Tn;
   3538             if( t <= T || !node->left )
   3539                 break;
   3540             if( node->alpha <= min_alpha + FLT_EPSILON )
   3541             {
   3542                 if( fold >= 0 )
   3543                     node->cv_Tn[fold] = T;
   3544                 else
   3545                     node->Tn = T;
   3546                 if( node == root )
   3547                     return 1;
   3548                 break;
   3549             }
   3550             node = node->left;
   3551         }
   3552 
   3553         for( parent = node->parent; parent && parent->right == node;
   3554             node = parent, parent = parent->parent )
   3555             ;
   3556 
   3557         if( !parent )
   3558             break;
   3559 
   3560         node = parent->right;
   3561     }
   3562 
   3563     return 0;
   3564 }
   3565 
   3566 
   3567 void CvDTree::free_prune_data(bool _cut_tree)
   3568 {
   3569     CvDTreeNode* node = root;
   3570 
   3571     for(;;)
   3572     {
   3573         CvDTreeNode* parent;
   3574         for(;;)
   3575         {
   3576             // do not call cvSetRemoveByPtr( cv_heap, node->cv_Tn )
   3577             // as we will clear the whole cross-validation heap at the end
   3578             node->cv_Tn = 0;
   3579             node->cv_node_error = node->cv_node_risk = 0;
   3580             if( !node->left )
   3581                 break;
   3582             node = node->left;
   3583         }
   3584 
   3585         for( parent = node->parent; parent && parent->right == node;
   3586             node = parent, parent = parent->parent )
   3587         {
   3588             if( _cut_tree && parent->Tn <= pruned_tree_idx )
   3589             {
   3590                 data->free_node( parent->left );
   3591                 data->free_node( parent->right );
   3592                 parent->left = parent->right = 0;
   3593             }
   3594         }
   3595 
   3596         if( !parent )
   3597             break;
   3598 
   3599         node = parent->right;
   3600     }
   3601 
   3602     if( data->cv_heap )
   3603         cvClearSet( data->cv_heap );
   3604 }
   3605 
   3606 
   3607 void CvDTree::free_tree()
   3608 {
   3609     if( root && data && data->shared )
   3610     {
   3611         pruned_tree_idx = INT_MIN;
   3612         free_prune_data(true);
   3613         data->free_node(root);
   3614         root = 0;
   3615     }
   3616 }
   3617 
   3618 CvDTreeNode* CvDTree::predict( const CvMat* _sample,
   3619     const CvMat* _missing, bool preprocessed_input ) const
   3620 {
   3621     cv::AutoBuffer<int> catbuf;
   3622 
   3623     int i, mstep = 0;
   3624     const uchar* m = 0;
   3625     CvDTreeNode* node = root;
   3626 
   3627     if( !node )
   3628         CV_Error( CV_StsError, "The tree has not been trained yet" );
   3629 
   3630     if( !CV_IS_MAT(_sample) || CV_MAT_TYPE(_sample->type) != CV_32FC1 ||
   3631         (_sample->cols != 1 && _sample->rows != 1) ||
   3632         (_sample->cols + _sample->rows - 1 != data->var_all && !preprocessed_input) ||
   3633         (_sample->cols + _sample->rows - 1 != data->var_count && preprocessed_input) )
   3634             CV_Error( CV_StsBadArg,
   3635         "the input sample must be 1d floating-point vector with the same "
   3636         "number of elements as the total number of variables used for training" );
   3637 
   3638     const float* sample = _sample->data.fl;
   3639     int step = CV_IS_MAT_CONT(_sample->type) ? 1 : _sample->step/sizeof(sample[0]);
   3640 
   3641     if( data->cat_count && !preprocessed_input ) // cache for categorical variables
   3642     {
   3643         int n = data->cat_count->cols;
   3644         catbuf.allocate(n);
   3645         for( i = 0; i < n; i++ )
   3646             catbuf[i] = -1;
   3647     }
   3648 
   3649     if( _missing )
   3650     {
   3651         if( !CV_IS_MAT(_missing) || !CV_IS_MASK_ARR(_missing) ||
   3652             !CV_ARE_SIZES_EQ(_missing, _sample) )
   3653             CV_Error( CV_StsBadArg,
   3654         "the missing data mask must be 8-bit vector of the same size as input sample" );
   3655         m = _missing->data.ptr;
   3656         mstep = CV_IS_MAT_CONT(_missing->type) ? 1 : _missing->step/sizeof(m[0]);
   3657     }
   3658 
   3659     const int* vtype = data->var_type->data.i;
   3660     const int* vidx = data->var_idx && !preprocessed_input ? data->var_idx->data.i : 0;
   3661     const int* cmap = data->cat_map ? data->cat_map->data.i : 0;
   3662     const int* cofs = data->cat_ofs ? data->cat_ofs->data.i : 0;
   3663 
   3664     while( node->Tn > pruned_tree_idx && node->left )
   3665     {
   3666         CvDTreeSplit* split = node->split;
   3667         int dir = 0;
   3668         for( ; !dir && split != 0; split = split->next )
   3669         {
   3670             int vi = split->var_idx;
   3671             int ci = vtype[vi];
   3672             i = vidx ? vidx[vi] : vi;
   3673             float val = sample[(size_t)i*step];
   3674             if( m && m[(size_t)i*mstep] )
   3675                 continue;
   3676             if( ci < 0 ) // ordered
   3677                 dir = val <= split->ord.c ? -1 : 1;
   3678             else // categorical
   3679             {
   3680                 int c;
   3681                 if( preprocessed_input )
   3682                     c = cvRound(val);
   3683                 else
   3684                 {
   3685                     c = catbuf[ci];
   3686                     if( c < 0 )
   3687                     {
   3688                         int a = c = cofs[ci];
   3689                         int b = (ci+1 >= data->cat_ofs->cols) ? data->cat_map->cols : cofs[ci+1];
   3690 
   3691                         int ival = cvRound(val);
   3692                         if( ival != val )
   3693                             CV_Error( CV_StsBadArg,
   3694                             "one of input categorical variable is not an integer" );
   3695 
   3696                         int sh = 0;
   3697                         while( a < b )
   3698                         {
   3699                             sh++;
   3700                             c = (a + b) >> 1;
   3701                             if( ival < cmap[c] )
   3702                                 b = c;
   3703                             else if( ival > cmap[c] )
   3704                                 a = c+1;
   3705                             else
   3706                                 break;
   3707                         }
   3708 
   3709                         if( c < 0 || ival != cmap[c] )
   3710                             continue;
   3711 
   3712                         catbuf[ci] = c -= cofs[ci];
   3713                     }
   3714                 }
   3715                 c = ( (c == 65535) && data->is_buf_16u ) ? -1 : c;
   3716                 dir = CV_DTREE_CAT_DIR(c, split->subset);
   3717             }
   3718 
   3719             if( split->inversed )
   3720                 dir = -dir;
   3721         }
   3722 
   3723         if( !dir )
   3724         {
   3725             double diff = node->right->sample_count - node->left->sample_count;
   3726             dir = diff < 0 ? -1 : 1;
   3727         }
   3728         node = dir < 0 ? node->left : node->right;
   3729     }
   3730 
   3731     return node;
   3732 }
   3733 
   3734 
   3735 CvDTreeNode* CvDTree::predict( const Mat& _sample, const Mat& _missing, bool preprocessed_input ) const
   3736 {
   3737     CvMat sample = _sample, mmask = _missing;
   3738     return predict(&sample, mmask.data.ptr ? &mmask : 0, preprocessed_input);
   3739 }
   3740 
   3741 
   3742 const CvMat* CvDTree::get_var_importance()
   3743 {
   3744     if( !var_importance )
   3745     {
   3746         CvDTreeNode* node = root;
   3747         double* importance;
   3748         if( !node )
   3749             return 0;
   3750         var_importance = cvCreateMat( 1, data->var_count, CV_64F );
   3751         cvZero( var_importance );
   3752         importance = var_importance->data.db;
   3753 
   3754         for(;;)
   3755         {
   3756             CvDTreeNode* parent;
   3757             for( ;; node = node->left )
   3758             {
   3759                 CvDTreeSplit* split = node->split;
   3760 
   3761                 if( !node->left || node->Tn <= pruned_tree_idx )
   3762                     break;
   3763 
   3764                 for( ; split != 0; split = split->next )
   3765                     importance[split->var_idx] += split->quality;
   3766             }
   3767 
   3768             for( parent = node->parent; parent && parent->right == node;
   3769                 node = parent, parent = parent->parent )
   3770                 ;
   3771 
   3772             if( !parent )
   3773                 break;
   3774 
   3775             node = parent->right;
   3776         }
   3777 
   3778         cvNormalize( var_importance, var_importance, 1., 0, CV_L1 );
   3779     }
   3780 
   3781     return var_importance;
   3782 }
   3783 
   3784 
   3785 void CvDTree::write_split( CvFileStorage* fs, CvDTreeSplit* split ) const
   3786 {
   3787     int ci;
   3788 
   3789     cvStartWriteStruct( fs, 0, CV_NODE_MAP + CV_NODE_FLOW );
   3790     cvWriteInt( fs, "var", split->var_idx );
   3791     cvWriteReal( fs, "quality", split->quality );
   3792 
   3793     ci = data->get_var_type(split->var_idx);
   3794     if( ci >= 0 ) // split on a categorical var
   3795     {
   3796         int i, n = data->cat_count->data.i[ci], to_right = 0, default_dir;
   3797         for( i = 0; i < n; i++ )
   3798             to_right += CV_DTREE_CAT_DIR(i,split->subset) > 0;
   3799 
   3800         // ad-hoc rule when to use inverse categorical split notation
   3801         // to achieve more compact and clear representation
   3802         default_dir = to_right <= 1 || to_right <= MIN(3, n/2) || to_right <= n/3 ? -1 : 1;
   3803 
   3804         cvStartWriteStruct( fs, default_dir*(split->inversed ? -1 : 1) > 0 ?
   3805                             "in" : "not_in", CV_NODE_SEQ+CV_NODE_FLOW );
   3806 
   3807         for( i = 0; i < n; i++ )
   3808         {
   3809             int dir = CV_DTREE_CAT_DIR(i,split->subset);
   3810             if( dir*default_dir < 0 )
   3811                 cvWriteInt( fs, 0, i );
   3812         }
   3813         cvEndWriteStruct( fs );
   3814     }
   3815     else
   3816         cvWriteReal( fs, !split->inversed ? "le" : "gt", split->ord.c );
   3817 
   3818     cvEndWriteStruct( fs );
   3819 }
   3820 
   3821 
   3822 void CvDTree::write_node( CvFileStorage* fs, CvDTreeNode* node ) const
   3823 {
   3824     CvDTreeSplit* split;
   3825 
   3826     cvStartWriteStruct( fs, 0, CV_NODE_MAP );
   3827 
   3828     cvWriteInt( fs, "depth", node->depth );
   3829     cvWriteInt( fs, "sample_count", node->sample_count );
   3830     cvWriteReal( fs, "value", node->value );
   3831 
   3832     if( data->is_classifier )
   3833         cvWriteInt( fs, "norm_class_idx", node->class_idx );
   3834 
   3835     cvWriteInt( fs, "Tn", node->Tn );
   3836     cvWriteInt( fs, "complexity", node->complexity );
   3837     cvWriteReal( fs, "alpha", node->alpha );
   3838     cvWriteReal( fs, "node_risk", node->node_risk );
   3839     cvWriteReal( fs, "tree_risk", node->tree_risk );
   3840     cvWriteReal( fs, "tree_error", node->tree_error );
   3841 
   3842     if( node->left )
   3843     {
   3844         cvStartWriteStruct( fs, "splits", CV_NODE_SEQ );
   3845 
   3846         for( split = node->split; split != 0; split = split->next )
   3847             write_split( fs, split );
   3848 
   3849         cvEndWriteStruct( fs );
   3850     }
   3851 
   3852     cvEndWriteStruct( fs );
   3853 }
   3854 
   3855 
   3856 void CvDTree::write_tree_nodes( CvFileStorage* fs ) const
   3857 {
   3858     //CV_FUNCNAME( "CvDTree::write_tree_nodes" );
   3859 
   3860     __BEGIN__;
   3861 
   3862     CvDTreeNode* node = root;
   3863 
   3864     // traverse the tree and save all the nodes in depth-first order
   3865     for(;;)
   3866     {
   3867         CvDTreeNode* parent;
   3868         for(;;)
   3869         {
   3870             write_node( fs, node );
   3871             if( !node->left )
   3872                 break;
   3873             node = node->left;
   3874         }
   3875 
   3876         for( parent = node->parent; parent && parent->right == node;
   3877             node = parent, parent = parent->parent )
   3878             ;
   3879 
   3880         if( !parent )
   3881             break;
   3882 
   3883         node = parent->right;
   3884     }
   3885 
   3886     __END__;
   3887 }
   3888 
   3889 
   3890 void CvDTree::write( CvFileStorage* fs, const char* name ) const
   3891 {
   3892     //CV_FUNCNAME( "CvDTree::write" );
   3893 
   3894     __BEGIN__;
   3895 
   3896     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_TREE );
   3897 
   3898     //get_var_importance();
   3899     data->write_params( fs );
   3900     //if( var_importance )
   3901     //cvWrite( fs, "var_importance", var_importance );
   3902     write( fs );
   3903 
   3904     cvEndWriteStruct( fs );
   3905 
   3906     __END__;
   3907 }
   3908 
   3909 
   3910 void CvDTree::write( CvFileStorage* fs ) const
   3911 {
   3912     //CV_FUNCNAME( "CvDTree::write" );
   3913 
   3914     __BEGIN__;
   3915 
   3916     cvWriteInt( fs, "best_tree_idx", pruned_tree_idx );
   3917 
   3918     cvStartWriteStruct( fs, "nodes", CV_NODE_SEQ );
   3919     write_tree_nodes( fs );
   3920     cvEndWriteStruct( fs );
   3921 
   3922     __END__;
   3923 }
   3924 
   3925 
   3926 CvDTreeSplit* CvDTree::read_split( CvFileStorage* fs, CvFileNode* fnode )
   3927 {
   3928     CvDTreeSplit* split = 0;
   3929 
   3930     CV_FUNCNAME( "CvDTree::read_split" );
   3931 
   3932     __BEGIN__;
   3933 
   3934     int vi, ci;
   3935 
   3936     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
   3937         CV_ERROR( CV_StsParseError, "some of the splits are not stored properly" );
   3938 
   3939     vi = cvReadIntByName( fs, fnode, "var", -1 );
   3940     if( (unsigned)vi >= (unsigned)data->var_count )
   3941         CV_ERROR( CV_StsOutOfRange, "Split variable index is out of range" );
   3942 
   3943     ci = data->get_var_type(vi);
   3944     if( ci >= 0 ) // split on categorical var
   3945     {
   3946         int i, n = data->cat_count->data.i[ci], inversed = 0, val;
   3947         CvSeqReader reader;
   3948         CvFileNode* inseq;
   3949         split = data->new_split_cat( vi, 0 );
   3950         inseq = cvGetFileNodeByName( fs, fnode, "in" );
   3951         if( !inseq )
   3952         {
   3953             inseq = cvGetFileNodeByName( fs, fnode, "not_in" );
   3954             inversed = 1;
   3955         }
   3956         if( !inseq ||
   3957             (CV_NODE_TYPE(inseq->tag) != CV_NODE_SEQ && CV_NODE_TYPE(inseq->tag) != CV_NODE_INT))
   3958             CV_ERROR( CV_StsParseError,
   3959             "Either 'in' or 'not_in' tags should be inside a categorical split data" );
   3960 
   3961         if( CV_NODE_TYPE(inseq->tag) == CV_NODE_INT )
   3962         {
   3963             val = inseq->data.i;
   3964             if( (unsigned)val >= (unsigned)n )
   3965                 CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
   3966 
   3967             split->subset[val >> 5] |= 1 << (val & 31);
   3968         }
   3969         else
   3970         {
   3971             cvStartReadSeq( inseq->data.seq, &reader );
   3972 
   3973             for( i = 0; i < reader.seq->total; i++ )
   3974             {
   3975                 CvFileNode* inode = (CvFileNode*)reader.ptr;
   3976                 val = inode->data.i;
   3977                 if( CV_NODE_TYPE(inode->tag) != CV_NODE_INT || (unsigned)val >= (unsigned)n )
   3978                     CV_ERROR( CV_StsOutOfRange, "some of in/not_in elements are out of range" );
   3979 
   3980                 split->subset[val >> 5] |= 1 << (val & 31);
   3981                 CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   3982             }
   3983         }
   3984 
   3985         // for categorical splits we do not use inversed splits,
   3986         // instead we inverse the variable set in the split
   3987         if( inversed )
   3988             for( i = 0; i < (n + 31) >> 5; i++ )
   3989                 split->subset[i] ^= -1;
   3990     }
   3991     else
   3992     {
   3993         CvFileNode* cmp_node;
   3994         split = data->new_split_ord( vi, 0, 0, 0, 0 );
   3995 
   3996         cmp_node = cvGetFileNodeByName( fs, fnode, "le" );
   3997         if( !cmp_node )
   3998         {
   3999             cmp_node = cvGetFileNodeByName( fs, fnode, "gt" );
   4000             split->inversed = 1;
   4001         }
   4002 
   4003         split->ord.c = (float)cvReadReal( cmp_node );
   4004     }
   4005 
   4006     split->quality = (float)cvReadRealByName( fs, fnode, "quality" );
   4007 
   4008     __END__;
   4009 
   4010     return split;
   4011 }
   4012 
   4013 
   4014 CvDTreeNode* CvDTree::read_node( CvFileStorage* fs, CvFileNode* fnode, CvDTreeNode* parent )
   4015 {
   4016     CvDTreeNode* node = 0;
   4017 
   4018     CV_FUNCNAME( "CvDTree::read_node" );
   4019 
   4020     __BEGIN__;
   4021 
   4022     CvFileNode* splits;
   4023     int i, depth;
   4024 
   4025     if( !fnode || CV_NODE_TYPE(fnode->tag) != CV_NODE_MAP )
   4026         CV_ERROR( CV_StsParseError, "some of the tree elements are not stored properly" );
   4027 
   4028     CV_CALL( node = data->new_node( parent, 0, 0, 0 ));
   4029     depth = cvReadIntByName( fs, fnode, "depth", -1 );
   4030     if( depth != node->depth )
   4031         CV_ERROR( CV_StsParseError, "incorrect node depth" );
   4032 
   4033     node->sample_count = cvReadIntByName( fs, fnode, "sample_count" );
   4034     node->value = cvReadRealByName( fs, fnode, "value" );
   4035     if( data->is_classifier )
   4036         node->class_idx = cvReadIntByName( fs, fnode, "norm_class_idx" );
   4037 
   4038     node->Tn = cvReadIntByName( fs, fnode, "Tn" );
   4039     node->complexity = cvReadIntByName( fs, fnode, "complexity" );
   4040     node->alpha = cvReadRealByName( fs, fnode, "alpha" );
   4041     node->node_risk = cvReadRealByName( fs, fnode, "node_risk" );
   4042     node->tree_risk = cvReadRealByName( fs, fnode, "tree_risk" );
   4043     node->tree_error = cvReadRealByName( fs, fnode, "tree_error" );
   4044 
   4045     splits = cvGetFileNodeByName( fs, fnode, "splits" );
   4046     if( splits )
   4047     {
   4048         CvSeqReader reader;
   4049         CvDTreeSplit* last_split = 0;
   4050 
   4051         if( CV_NODE_TYPE(splits->tag) != CV_NODE_SEQ )
   4052             CV_ERROR( CV_StsParseError, "splits tag must stored as a sequence" );
   4053 
   4054         cvStartReadSeq( splits->data.seq, &reader );
   4055         for( i = 0; i < reader.seq->total; i++ )
   4056         {
   4057             CvDTreeSplit* split;
   4058             CV_CALL( split = read_split( fs, (CvFileNode*)reader.ptr ));
   4059             if( !last_split )
   4060                 node->split = last_split = split;
   4061             else
   4062                 last_split = last_split->next = split;
   4063 
   4064             CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   4065         }
   4066     }
   4067 
   4068     __END__;
   4069 
   4070     return node;
   4071 }
   4072 
   4073 
   4074 void CvDTree::read_tree_nodes( CvFileStorage* fs, CvFileNode* fnode )
   4075 {
   4076     CV_FUNCNAME( "CvDTree::read_tree_nodes" );
   4077 
   4078     __BEGIN__;
   4079 
   4080     CvSeqReader reader;
   4081     CvDTreeNode _root;
   4082     CvDTreeNode* parent = &_root;
   4083     int i;
   4084     parent->left = parent->right = parent->parent = 0;
   4085 
   4086     cvStartReadSeq( fnode->data.seq, &reader );
   4087 
   4088     for( i = 0; i < reader.seq->total; i++ )
   4089     {
   4090         CvDTreeNode* node;
   4091 
   4092         CV_CALL( node = read_node( fs, (CvFileNode*)reader.ptr, parent != &_root ? parent : 0 ));
   4093         if( !parent->left )
   4094             parent->left = node;
   4095         else
   4096             parent->right = node;
   4097         if( node->split )
   4098             parent = node;
   4099         else
   4100         {
   4101             while( parent && parent->right )
   4102                 parent = parent->parent;
   4103         }
   4104 
   4105         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
   4106     }
   4107 
   4108     root = _root.left;
   4109 
   4110     __END__;
   4111 }
   4112 
   4113 
   4114 void CvDTree::read( CvFileStorage* fs, CvFileNode* fnode )
   4115 {
   4116     CvDTreeTrainData* _data = new CvDTreeTrainData();
   4117     _data->read_params( fs, fnode );
   4118 
   4119     read( fs, fnode, _data );
   4120     get_var_importance();
   4121 }
   4122 
   4123 
   4124 // a special entry point for reading weak decision trees from the tree ensembles
   4125 void CvDTree::read( CvFileStorage* fs, CvFileNode* node, CvDTreeTrainData* _data )
   4126 {
   4127     CV_FUNCNAME( "CvDTree::read" );
   4128 
   4129     __BEGIN__;
   4130 
   4131     CvFileNode* tree_nodes;
   4132 
   4133     clear();
   4134     data = _data;
   4135 
   4136     tree_nodes = cvGetFileNodeByName( fs, node, "nodes" );
   4137     if( !tree_nodes || CV_NODE_TYPE(tree_nodes->tag) != CV_NODE_SEQ )
   4138         CV_ERROR( CV_StsParseError, "nodes tag is missing" );
   4139 
   4140     pruned_tree_idx = cvReadIntByName( fs, node, "best_tree_idx", -1 );
   4141     read_tree_nodes( fs, tree_nodes );
   4142 
   4143     __END__;
   4144 }
   4145 
   4146 Mat CvDTree::getVarImportance()
   4147 {
   4148     return cvarrToMat(get_var_importance());
   4149 }
   4150 
   4151 /* End of file. */
   4152