Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "_ml.h"
     42 
     43 CvForestTree::CvForestTree()
     44 {
     45     forest = NULL;
     46 }
     47 
     48 
     49 CvForestTree::~CvForestTree()
     50 {
     51     clear();
     52 }
     53 
     54 
     55 bool CvForestTree::train( CvDTreeTrainData* _data,
     56                           const CvMat* _subsample_idx,
     57                           CvRTrees* _forest )
     58 {
     59     bool result = false;
     60 
     61     CV_FUNCNAME( "CvForestTree::train" );
     62 
     63     __BEGIN__;
     64 
     65 
     66     clear();
     67     forest = _forest;
     68 
     69     data = _data;
     70     data->shared = true;
     71     CV_CALL(result = do_train(_subsample_idx));
     72 
     73     __END__;
     74 
     75     return result;
     76 }
     77 
     78 
     79 bool
     80 CvForestTree::train( const CvMat*, int, const CvMat*, const CvMat*,
     81                     const CvMat*, const CvMat*, const CvMat*, CvDTreeParams )
     82 {
     83     assert(0);
     84     return false;
     85 }
     86 
     87 
     88 bool
     89 CvForestTree::train( CvDTreeTrainData*, const CvMat* )
     90 {
     91     assert(0);
     92     return false;
     93 }
     94 
     95 
     96 CvDTreeSplit* CvForestTree::find_best_split( CvDTreeNode* node )
     97 {
     98     int vi;
     99     CvDTreeSplit *best_split = 0, *split = 0, *t;
    100 
    101     CV_FUNCNAME("CvForestTree::find_best_split");
    102     __BEGIN__;
    103 
    104     CvMat* active_var_mask = 0;
    105     if( forest )
    106     {
    107         int var_count;
    108         CvRNG* rng = forest->get_rng();
    109 
    110         active_var_mask = forest->get_active_var_mask();
    111         var_count = active_var_mask->cols;
    112 
    113         CV_ASSERT( var_count == data->var_count );
    114 
    115         for( vi = 0; vi < var_count; vi++ )
    116         {
    117             uchar temp;
    118             int i1 = cvRandInt(rng) % var_count;
    119             int i2 = cvRandInt(rng) % var_count;
    120             CV_SWAP( active_var_mask->data.ptr[i1],
    121                 active_var_mask->data.ptr[i2], temp );
    122         }
    123     }
    124     for( vi = 0; vi < data->var_count; vi++ )
    125     {
    126         int ci = data->var_type->data.i[vi];
    127         if( node->num_valid[vi] <= 1
    128             || (active_var_mask && !active_var_mask->data.ptr[vi]) )
    129             continue;
    130 
    131         if( data->is_classifier )
    132         {
    133             if( ci >= 0 )
    134                 split = find_split_cat_class( node, vi );
    135             else
    136                 split = find_split_ord_class( node, vi );
    137         }
    138         else
    139         {
    140             if( ci >= 0 )
    141                 split = find_split_cat_reg( node, vi );
    142             else
    143                 split = find_split_ord_reg( node, vi );
    144         }
    145 
    146         if( split )
    147         {
    148             if( !best_split || best_split->quality < split->quality )
    149                 CV_SWAP( best_split, split, t );
    150             if( split )
    151                 cvSetRemoveByPtr( data->split_heap, split );
    152         }
    153     }
    154 
    155     __END__;
    156 
    157     return best_split;
    158 }
    159 
    160 
    161 void CvForestTree::read( CvFileStorage* fs, CvFileNode* fnode, CvRTrees* _forest, CvDTreeTrainData* _data )
    162 {
    163     CvDTree::read( fs, fnode, _data );
    164     forest = _forest;
    165 }
    166 
    167 
    168 void CvForestTree::read( CvFileStorage*, CvFileNode* )
    169 {
    170     assert(0);
    171 }
    172 
    173 void CvForestTree::read( CvFileStorage* _fs, CvFileNode* _node,
    174                          CvDTreeTrainData* _data )
    175 {
    176     CvDTree::read( _fs, _node, _data );
    177 }
    178 
    179 
    180 //////////////////////////////////////////////////////////////////////////////////////////
    181 //                                  Random trees                                        //
    182 //////////////////////////////////////////////////////////////////////////////////////////
    183 
    184 CvRTrees::CvRTrees()
    185 {
    186     nclasses         = 0;
    187     oob_error        = 0;
    188     ntrees           = 0;
    189     trees            = NULL;
    190     data             = NULL;
    191     active_var_mask  = NULL;
    192     var_importance   = NULL;
    193     rng = cvRNG(0xffffffff);
    194     default_model_name = "my_random_trees";
    195 }
    196 
    197 
    198 void CvRTrees::clear()
    199 {
    200     int k;
    201     for( k = 0; k < ntrees; k++ )
    202         delete trees[k];
    203     cvFree( &trees );
    204 
    205     delete data;
    206     data = 0;
    207 
    208     cvReleaseMat( &active_var_mask );
    209     cvReleaseMat( &var_importance );
    210     ntrees = 0;
    211 }
    212 
    213 
    214 CvRTrees::~CvRTrees()
    215 {
    216     clear();
    217 }
    218 
    219 
    220 CvMat* CvRTrees::get_active_var_mask()
    221 {
    222     return active_var_mask;
    223 }
    224 
    225 
    226 CvRNG* CvRTrees::get_rng()
    227 {
    228     return &rng;
    229 }
    230 
    231 bool CvRTrees::train( const CvMat* _train_data, int _tflag,
    232                         const CvMat* _responses, const CvMat* _var_idx,
    233                         const CvMat* _sample_idx, const CvMat* _var_type,
    234                         const CvMat* _missing_mask, CvRTParams params )
    235 {
    236     bool result = false;
    237 
    238     CV_FUNCNAME("CvRTrees::train");
    239     __BEGIN__;
    240 
    241     int var_count = 0;
    242 
    243     clear();
    244 
    245     CvDTreeParams tree_params( params.max_depth, params.min_sample_count,
    246         params.regression_accuracy, params.use_surrogates, params.max_categories,
    247         params.cv_folds, params.use_1se_rule, false, params.priors );
    248 
    249     data = new CvDTreeTrainData();
    250     CV_CALL(data->set_data( _train_data, _tflag, _responses, _var_idx,
    251         _sample_idx, _var_type, _missing_mask, tree_params, true));
    252 
    253     var_count = data->var_count;
    254     if( params.nactive_vars > var_count )
    255         params.nactive_vars = var_count;
    256     else if( params.nactive_vars == 0 )
    257         params.nactive_vars = (int)sqrt((double)var_count);
    258     else if( params.nactive_vars < 0 )
    259         CV_ERROR( CV_StsBadArg, "<nactive_vars> must be non-negative" );
    260     params.term_crit = cvCheckTermCriteria( params.term_crit, 0.1, 1000 );
    261 
    262     // Create mask of active variables at the tree nodes
    263     CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
    264     if( params.calc_var_importance )
    265     {
    266         CV_CALL(var_importance  = cvCreateMat( 1, var_count, CV_32FC1 ));
    267         cvZero(var_importance);
    268     }
    269     { // initialize active variables mask
    270         CvMat submask1, submask2;
    271         cvGetCols( active_var_mask, &submask1, 0, params.nactive_vars );
    272         cvGetCols( active_var_mask, &submask2, params.nactive_vars, var_count );
    273         cvSet( &submask1, cvScalar(1) );
    274         cvZero( &submask2 );
    275     }
    276 
    277     CV_CALL(result = grow_forest( params.term_crit ));
    278 
    279     result = true;
    280 
    281     __END__;
    282 
    283     return result;
    284 }
    285 
    286 
    287 bool CvRTrees::grow_forest( const CvTermCriteria term_crit )
    288 {
    289     bool result = false;
    290 
    291     CvMat* sample_idx_mask_for_tree = 0;
    292     CvMat* sample_idx_for_tree      = 0;
    293 
    294     CvMat* oob_sample_votes	   = 0;
    295     CvMat* oob_responses       = 0;
    296 
    297     float* oob_samples_perm_ptr= 0;
    298 
    299     float* samples_ptr     = 0;
    300     uchar* missing_ptr     = 0;
    301     float* true_resp_ptr   = 0;
    302 
    303     CV_FUNCNAME("CvRTrees::grow_forest");
    304     __BEGIN__;
    305 
    306     const int max_ntrees = term_crit.max_iter;
    307     const double max_oob_err = term_crit.epsilon;
    308 
    309     const int dims = data->var_count;
    310     float maximal_response = 0;
    311 
    312     // oob_predictions_sum[i] = sum of predicted values for the i-th sample
    313     // oob_num_of_predictions[i] = number of summands
    314     //                            (number of predictions for the i-th sample)
    315     // initialize these variable to avoid warning C4701
    316     CvMat oob_predictions_sum = cvMat( 1, 1, CV_32FC1 );
    317     CvMat oob_num_of_predictions = cvMat( 1, 1, CV_32FC1 );
    318 
    319     nsamples = data->sample_count;
    320     nclasses = data->get_num_classes();
    321 
    322     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*max_ntrees );
    323     memset( trees, 0, sizeof(trees[0])*max_ntrees );
    324 
    325     if( data->is_classifier )
    326     {
    327         CV_CALL(oob_sample_votes = cvCreateMat( nsamples, nclasses, CV_32SC1 ));
    328         cvZero(oob_sample_votes);
    329     }
    330     else
    331     {
    332         // oob_responses[0,i] = oob_predictions_sum[i]
    333         //    = sum of predicted values for the i-th sample
    334         // oob_responses[1,i] = oob_num_of_predictions[i]
    335         //    = number of summands (number of predictions for the i-th sample)
    336         CV_CALL(oob_responses = cvCreateMat( 2, nsamples, CV_32FC1 ));
    337         cvZero(oob_responses);
    338         cvGetRow( oob_responses, &oob_predictions_sum, 0 );
    339         cvGetRow( oob_responses, &oob_num_of_predictions, 1 );
    340     }
    341     CV_CALL(sample_idx_mask_for_tree = cvCreateMat( 1, nsamples, CV_8UC1 ));
    342     CV_CALL(sample_idx_for_tree      = cvCreateMat( 1, nsamples, CV_32SC1 ));
    343     CV_CALL(oob_samples_perm_ptr     = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
    344     CV_CALL(samples_ptr              = (float*)cvAlloc( sizeof(float)*nsamples*dims ));
    345     CV_CALL(missing_ptr              = (uchar*)cvAlloc( sizeof(uchar)*nsamples*dims ));
    346     CV_CALL(true_resp_ptr            = (float*)cvAlloc( sizeof(float)*nsamples ));
    347 
    348     CV_CALL(data->get_vectors( 0, samples_ptr, missing_ptr, true_resp_ptr ));
    349     {
    350         double minval, maxval;
    351         CvMat responses = cvMat(1, nsamples, CV_32FC1, true_resp_ptr);
    352         cvMinMaxLoc( &responses, &minval, &maxval );
    353         maximal_response = (float)MAX( MAX( fabs(minval), fabs(maxval) ), 0 );
    354     }
    355 
    356     ntrees = 0;
    357     while( ntrees < max_ntrees )
    358     {
    359         int i, oob_samples_count = 0;
    360         double ncorrect_responses = 0; // used for estimation of variable importance
    361         CvMat sample, missing;
    362         CvForestTree* tree = 0;
    363 
    364         cvZero( sample_idx_mask_for_tree );
    365         for( i = 0; i < nsamples; i++ ) //form sample for creation one tree
    366         {
    367             int idx = cvRandInt( &rng ) % nsamples;
    368             sample_idx_for_tree->data.i[i] = idx;
    369             sample_idx_mask_for_tree->data.ptr[idx] = 0xFF;
    370         }
    371 
    372         trees[ntrees] = new CvForestTree();
    373         tree = trees[ntrees];
    374         CV_CALL(tree->train( data, sample_idx_for_tree, this ));
    375 
    376         // form array of OOB samples indices and get these samples
    377         sample   = cvMat( 1, dims, CV_32FC1, samples_ptr );
    378         missing  = cvMat( 1, dims, CV_8UC1,  missing_ptr );
    379 
    380         oob_error = 0;
    381         for( i = 0; i < nsamples; i++,
    382             sample.data.fl += dims, missing.data.ptr += dims )
    383         {
    384             CvDTreeNode* predicted_node = 0;
    385             // check if the sample is OOB
    386             if( sample_idx_mask_for_tree->data.ptr[i] )
    387                 continue;
    388 
    389             // predict oob samples
    390             if( !predicted_node )
    391                 CV_CALL(predicted_node = tree->predict(&sample, &missing, true));
    392 
    393             if( !data->is_classifier ) //regression
    394             {
    395                 double avg_resp, resp = predicted_node->value;
    396                 oob_predictions_sum.data.fl[i] += (float)resp;
    397                 oob_num_of_predictions.data.fl[i] += 1;
    398 
    399                 // compute oob error
    400                 avg_resp = oob_predictions_sum.data.fl[i]/oob_num_of_predictions.data.fl[i];
    401                 avg_resp -= true_resp_ptr[i];
    402                 oob_error += avg_resp*avg_resp;
    403                 resp = (resp - true_resp_ptr[i])/maximal_response;
    404                 ncorrect_responses += exp( -resp*resp );
    405             }
    406             else //classification
    407             {
    408                 double prdct_resp;
    409                 CvPoint max_loc;
    410                 CvMat votes;
    411 
    412                 cvGetRow(oob_sample_votes, &votes, i);
    413                 votes.data.i[predicted_node->class_idx]++;
    414 
    415                 // compute oob error
    416                 cvMinMaxLoc( &votes, 0, 0, 0, &max_loc );
    417 
    418                 prdct_resp = data->cat_map->data.i[max_loc.x];
    419                 oob_error += (fabs(prdct_resp - true_resp_ptr[i]) < FLT_EPSILON) ? 0 : 1;
    420 
    421                 ncorrect_responses += cvRound(predicted_node->value - true_resp_ptr[i]) == 0;
    422             }
    423             oob_samples_count++;
    424         }
    425         if( oob_samples_count > 0 )
    426             oob_error /= (double)oob_samples_count;
    427 
    428         // estimate variable importance
    429         if( var_importance && oob_samples_count > 0 )
    430         {
    431             int m;
    432 
    433             memcpy( oob_samples_perm_ptr, samples_ptr, dims*nsamples*sizeof(float));
    434             for( m = 0; m < dims; m++ )
    435             {
    436                 double ncorrect_responses_permuted = 0;
    437                 // randomly permute values of the m-th variable in the oob samples
    438                 float* mth_var_ptr = oob_samples_perm_ptr + m;
    439 
    440                 for( i = 0; i < nsamples; i++ )
    441                 {
    442                     int i1, i2;
    443                     float temp;
    444 
    445                     if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
    446                         continue;
    447                     i1 = cvRandInt( &rng ) % nsamples;
    448                     i2 = cvRandInt( &rng ) % nsamples;
    449                     CV_SWAP( mth_var_ptr[i1*dims], mth_var_ptr[i2*dims], temp );
    450 
    451                     // turn values of (m-1)-th variable, that were permuted
    452                     // at the previous iteration, untouched
    453                     if( m > 1 )
    454                         oob_samples_perm_ptr[i*dims+m-1] = samples_ptr[i*dims+m-1];
    455                 }
    456 
    457                 // predict "permuted" cases and calculate the number of votes for the
    458                 // correct class in the variable-m-permuted oob data
    459                 sample  = cvMat( 1, dims, CV_32FC1, oob_samples_perm_ptr );
    460                 missing = cvMat( 1, dims, CV_8UC1, missing_ptr );
    461                 for( i = 0; i < nsamples; i++,
    462                     sample.data.fl += dims, missing.data.ptr += dims )
    463                 {
    464                     double predct_resp, true_resp;
    465 
    466                     if( sample_idx_mask_for_tree->data.ptr[i] ) //the sample is not OOB
    467                         continue;
    468 
    469                     predct_resp = tree->predict(&sample, &missing, true)->value;
    470                     true_resp   = true_resp_ptr[i];
    471                     if( data->is_classifier )
    472                         ncorrect_responses_permuted += cvRound(true_resp - predct_resp) == 0;
    473                     else
    474                     {
    475                         true_resp = (true_resp - predct_resp)/maximal_response;
    476                         ncorrect_responses_permuted += exp( -true_resp*true_resp );
    477                     }
    478                 }
    479                 var_importance->data.fl[m] += (float)(ncorrect_responses
    480                     - ncorrect_responses_permuted);
    481             }
    482         }
    483         ntrees++;
    484         if( term_crit.type != CV_TERMCRIT_ITER && oob_error < max_oob_err )
    485             break;
    486     }
    487     if( var_importance )
    488         CV_CALL(cvConvertScale( var_importance, var_importance, 1./ntrees/nsamples ));
    489 
    490     result = true;
    491 
    492     __END__;
    493 
    494     cvReleaseMat( &sample_idx_mask_for_tree );
    495     cvReleaseMat( &sample_idx_for_tree );
    496     cvReleaseMat( &oob_sample_votes );
    497     cvReleaseMat( &oob_responses );
    498 
    499     cvFree( &oob_samples_perm_ptr );
    500     cvFree( &samples_ptr );
    501     cvFree( &missing_ptr );
    502     cvFree( &true_resp_ptr );
    503 
    504     return result;
    505 }
    506 
    507 
    508 const CvMat* CvRTrees::get_var_importance()
    509 {
    510     return var_importance;
    511 }
    512 
    513 
    514 float CvRTrees::get_proximity( const CvMat* sample1, const CvMat* sample2,
    515                               const CvMat* missing1, const CvMat* missing2 ) const
    516 {
    517     float result = 0;
    518 
    519     CV_FUNCNAME( "CvRTrees::get_proximity" );
    520 
    521     __BEGIN__;
    522 
    523     int i;
    524     for( i = 0; i < ntrees; i++ )
    525         result += trees[i]->predict( sample1, missing1 ) ==
    526         trees[i]->predict( sample2, missing2 ) ?  1 : 0;
    527     result = result/(float)ntrees;
    528 
    529     __END__;
    530 
    531     return result;
    532 }
    533 
    534 
    535 float CvRTrees::predict( const CvMat* sample, const CvMat* missing ) const
    536 {
    537     double result = -1;
    538 
    539     CV_FUNCNAME("CvRTrees::predict");
    540     __BEGIN__;
    541 
    542     int k;
    543 
    544     if( nclasses > 0 ) //classification
    545     {
    546         int max_nvotes = 0;
    547         int* votes = (int*)alloca( sizeof(int)*nclasses );
    548         memset( votes, 0, sizeof(*votes)*nclasses );
    549         for( k = 0; k < ntrees; k++ )
    550         {
    551             CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
    552             int nvotes;
    553             int class_idx = predicted_node->class_idx;
    554             CV_ASSERT( 0 <= class_idx && class_idx < nclasses );
    555 
    556             nvotes = ++votes[class_idx];
    557             if( nvotes > max_nvotes )
    558             {
    559                 max_nvotes = nvotes;
    560                 result = predicted_node->value;
    561             }
    562         }
    563     }
    564     else // regression
    565     {
    566         result = 0;
    567         for( k = 0; k < ntrees; k++ )
    568             result += trees[k]->predict( sample, missing )->value;
    569         result /= (double)ntrees;
    570     }
    571 
    572     __END__;
    573 
    574     return (float)result;
    575 }
    576 
    577 
    578 void CvRTrees::write( CvFileStorage* fs, const char* name )
    579 {
    580     CV_FUNCNAME( "CvRTrees::write" );
    581 
    582     __BEGIN__;
    583 
    584     int k;
    585 
    586     if( ntrees < 1 || !trees || nsamples < 1 )
    587         CV_ERROR( CV_StsBadArg, "Invalid CvRTrees object" );
    588 
    589     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_RTREES );
    590 
    591     cvWriteInt( fs, "nclasses", nclasses );
    592     cvWriteInt( fs, "nsamples", nsamples );
    593     cvWriteInt( fs, "nactive_vars", (int)cvSum(active_var_mask).val[0] );
    594     cvWriteReal( fs, "oob_error", oob_error );
    595 
    596     if( var_importance )
    597         cvWrite( fs, "var_importance", var_importance );
    598 
    599     cvWriteInt( fs, "ntrees", ntrees );
    600 
    601     CV_CALL(data->write_params( fs ));
    602 
    603     cvStartWriteStruct( fs, "trees", CV_NODE_SEQ );
    604 
    605     for( k = 0; k < ntrees; k++ )
    606     {
    607         cvStartWriteStruct( fs, 0, CV_NODE_MAP );
    608         CV_CALL( trees[k]->write( fs ));
    609         cvEndWriteStruct( fs );
    610     }
    611 
    612     cvEndWriteStruct( fs ); //trees
    613     cvEndWriteStruct( fs ); //CV_TYPE_NAME_ML_RTREES
    614 
    615     __END__;
    616 }
    617 
    618 
    619 void CvRTrees::read( CvFileStorage* fs, CvFileNode* fnode )
    620 {
    621     CV_FUNCNAME( "CvRTrees::read" );
    622 
    623     __BEGIN__;
    624 
    625     int nactive_vars, var_count, k;
    626     CvSeqReader reader;
    627     CvFileNode* trees_fnode = 0;
    628 
    629     clear();
    630 
    631     nclasses     = cvReadIntByName( fs, fnode, "nclasses", -1 );
    632     nsamples     = cvReadIntByName( fs, fnode, "nsamples" );
    633     nactive_vars = cvReadIntByName( fs, fnode, "nactive_vars", -1 );
    634     oob_error    = cvReadRealByName(fs, fnode, "oob_error", -1 );
    635     ntrees       = cvReadIntByName( fs, fnode, "ntrees", -1 );
    636 
    637     var_importance = (CvMat*)cvReadByName( fs, fnode, "var_importance" );
    638 
    639     if( nclasses < 0 || nsamples <= 0 || nactive_vars < 0 || oob_error < 0 || ntrees <= 0)
    640         CV_ERROR( CV_StsParseError, "Some <nclasses>, <nsamples>, <var_count>, "
    641         "<nactive_vars>, <oob_error>, <ntrees> of tags are missing" );
    642 
    643     rng = CvRNG( -1 );
    644 
    645     trees = (CvForestTree**)cvAlloc( sizeof(trees[0])*ntrees );
    646     memset( trees, 0, sizeof(trees[0])*ntrees );
    647 
    648     data = new CvDTreeTrainData();
    649     data->read_params( fs, fnode );
    650     data->shared = true;
    651 
    652     trees_fnode = cvGetFileNodeByName( fs, fnode, "trees" );
    653     if( !trees_fnode || !CV_NODE_IS_SEQ(trees_fnode->tag) )
    654         CV_ERROR( CV_StsParseError, "<trees> tag is missing" );
    655 
    656     cvStartReadSeq( trees_fnode->data.seq, &reader );
    657     if( reader.seq->total != ntrees )
    658         CV_ERROR( CV_StsParseError,
    659         "<ntrees> is not equal to the number of trees saved in file" );
    660 
    661     for( k = 0; k < ntrees; k++ )
    662     {
    663         trees[k] = new CvForestTree();
    664         CV_CALL(trees[k]->read( fs, (CvFileNode*)reader.ptr, this, data ));
    665         CV_NEXT_SEQ_ELEM( reader.seq->elem_size, reader );
    666     }
    667 
    668     var_count = data->var_count;
    669     CV_CALL(active_var_mask = cvCreateMat( 1, var_count, CV_8UC1 ));
    670     {
    671         // initialize active variables mask
    672         CvMat submask1, submask2;
    673         cvGetCols( active_var_mask, &submask1, 0, nactive_vars );
    674         cvGetCols( active_var_mask, &submask2, nactive_vars, var_count );
    675         cvSet( &submask1, cvScalar(1) );
    676         cvZero( &submask2 );
    677     }
    678 
    679     __END__;
    680 }
    681 
    682 
    683 int CvRTrees::get_tree_count() const
    684 {
    685     return ntrees;
    686 }
    687 
    688 CvForestTree* CvRTrees::get_tree(int i) const
    689 {
    690     return (unsigned)i < (unsigned)ntrees ? trees[i] : 0;
    691 }
    692 
    693 // End of file.
    694