Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //            Intel License Agreement
     11 //
     12 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     13 // Third party copyrights are property of their respective owners.
     14 //
     15 // Redistribution and use in source and binary forms, with or without modification,
     16 // are permitted provided that the following conditions are met:
     17 //
     18 //   * Redistribution's of source code must retain the above copyright notice,
     19 //     this list of conditions and the following disclaimer.
     20 //
     21 //   * Redistribution's in binary form must reproduce the above copyright notice,
     22 //     this list of conditions and the following disclaimer in the documentation
     23 //     and/or other materials provided with the distribution.
     24 //
     25 //   * The name of Intel Corporation may not be used to endorse or promote products
     26 //     derived from this software without specific prior written permission.
     27 //
     28 // This software is provided by the copyright holders and contributors "as is" and
     29 // any express or implied warranties, including, but not limited to, the implied
     30 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     31 // In no event shall the Intel Corporation or contributors be liable for any direct,
     32 // indirect, incidental, special, exemplary, or consequential damages
     33 // (including, but not limited to, procurement of substitute goods or services;
     34 // loss of use, data, or profits; or business interruption) however caused
     35 // and on any theory of liability, whether in contract, strict liability,
     36 // or tort (including negligence or otherwise) arising in any way out of
     37 // the use of this software, even if advised of the possibility of such damage.
     38 //
     39 //M*/
     40 
     41 #include "_ml.h"
     42 
     43 CvNormalBayesClassifier::CvNormalBayesClassifier()
     44 {
     45     var_count = var_all = 0;
     46     var_idx = 0;
     47     cls_labels = 0;
     48     count = 0;
     49     sum = 0;
     50     productsum = 0;
     51     avg = 0;
     52     inv_eigen_values = 0;
     53     cov_rotate_mats = 0;
     54     c = 0;
     55     default_model_name = "my_nb";
     56 }
     57 
     58 
     59 void CvNormalBayesClassifier::clear()
     60 {
     61     if( cls_labels )
     62     {
     63         for( int cls = 0; cls < cls_labels->cols; cls++ )
     64         {
     65             cvReleaseMat( &count[cls] );
     66             cvReleaseMat( &sum[cls] );
     67             cvReleaseMat( &productsum[cls] );
     68             cvReleaseMat( &avg[cls] );
     69             cvReleaseMat( &inv_eigen_values[cls] );
     70             cvReleaseMat( &cov_rotate_mats[cls] );
     71         }
     72     }
     73 
     74     cvReleaseMat( &cls_labels );
     75     cvReleaseMat( &var_idx );
     76     cvReleaseMat( &c );
     77     cvFree( &count );
     78 }
     79 
     80 
     81 CvNormalBayesClassifier::~CvNormalBayesClassifier()
     82 {
     83     clear();
     84 }
     85 
     86 
     87 CvNormalBayesClassifier::CvNormalBayesClassifier(
     88     const CvMat* _train_data, const CvMat* _responses,
     89     const CvMat* _var_idx, const CvMat* _sample_idx )
     90 {
     91     var_count = var_all = 0;
     92     var_idx = 0;
     93     cls_labels = 0;
     94     count = 0;
     95     sum = 0;
     96     productsum = 0;
     97     avg = 0;
     98     inv_eigen_values = 0;
     99     cov_rotate_mats = 0;
    100     c = 0;
    101     default_model_name = "my_nb";
    102 
    103     train( _train_data, _responses, _var_idx, _sample_idx );
    104 }
    105 
    106 
    107 bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _responses,
    108                             const CvMat* _var_idx, const CvMat* _sample_idx, bool update )
    109 {
    110     const float min_variation = FLT_EPSILON;
    111     bool result = false;
    112     CvMat* responses   = 0;
    113     const float** train_data = 0;
    114     CvMat* __cls_labels = 0;
    115     CvMat* __var_idx = 0;
    116     CvMat* cov = 0;
    117 
    118     CV_FUNCNAME( "CvNormalBayesClassifier::train" );
    119 
    120     __BEGIN__;
    121 
    122     int cls, nsamples = 0, _var_count = 0, _var_all = 0, nclasses = 0;
    123     int s, c1, c2;
    124     const int* responses_data;
    125 
    126     CV_CALL( cvPrepareTrainData( 0,
    127         _train_data, CV_ROW_SAMPLE, _responses, CV_VAR_CATEGORICAL,
    128         _var_idx, _sample_idx, false, &train_data,
    129         &nsamples, &_var_count, &_var_all, &responses,
    130         &__cls_labels, &__var_idx ));
    131 
    132     if( !update )
    133     {
    134         const size_t mat_size = sizeof(CvMat*);
    135         size_t data_size;
    136 
    137         clear();
    138 
    139         var_idx = __var_idx;
    140         cls_labels = __cls_labels;
    141         __var_idx = __cls_labels = 0;
    142         var_count = _var_count;
    143         var_all = _var_all;
    144 
    145         nclasses = cls_labels->cols;
    146         data_size = nclasses*6*mat_size;
    147 
    148         CV_CALL( count = (CvMat**)cvAlloc( data_size ));
    149         memset( count, 0, data_size );
    150 
    151         sum             = count      + nclasses;
    152         productsum      = sum        + nclasses;
    153         avg             = productsum + nclasses;
    154         inv_eigen_values= avg        + nclasses;
    155         cov_rotate_mats = inv_eigen_values         + nclasses;
    156 
    157         CV_CALL( c = cvCreateMat( 1, nclasses, CV_64FC1 ));
    158 
    159         for( cls = 0; cls < nclasses; cls++ )
    160         {
    161             CV_CALL(count[cls]            = cvCreateMat( 1, var_count, CV_32SC1 ));
    162             CV_CALL(sum[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));
    163             CV_CALL(productsum[cls]       = cvCreateMat( var_count, var_count, CV_64FC1 ));
    164             CV_CALL(avg[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));
    165             CV_CALL(inv_eigen_values[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));
    166             CV_CALL(cov_rotate_mats[cls]  = cvCreateMat( var_count, var_count, CV_64FC1 ));
    167             CV_CALL(cvZero( count[cls] ));
    168             CV_CALL(cvZero( sum[cls] ));
    169             CV_CALL(cvZero( productsum[cls] ));
    170             CV_CALL(cvZero( avg[cls] ));
    171             CV_CALL(cvZero( inv_eigen_values[cls] ));
    172             CV_CALL(cvZero( cov_rotate_mats[cls] ));
    173         }
    174     }
    175     else
    176     {
    177         // check that the new training data has the same dimensionality etc.
    178         if( _var_count != var_count || _var_all != var_all || !(!_var_idx && !var_idx ||
    179             _var_idx && var_idx && cvNorm(_var_idx,var_idx,CV_C) < DBL_EPSILON) )
    180             CV_ERROR( CV_StsBadArg,
    181             "The new training data is inconsistent with the original training data" );
    182 
    183         if( cls_labels->cols != __cls_labels->cols ||
    184             cvNorm(cls_labels, __cls_labels, CV_C) > DBL_EPSILON )
    185             CV_ERROR( CV_StsNotImplemented,
    186             "In the current implementation the new training data must have absolutely "
    187             "the same set of class labels as used in the original training data" );
    188 
    189         nclasses = cls_labels->cols;
    190     }
    191 
    192     responses_data = responses->data.i;
    193     CV_CALL( cov = cvCreateMat( _var_count, _var_count, CV_64FC1 ));
    194 
    195     /* process train data (count, sum , productsum) */
    196     for( s = 0; s < nsamples; s++ )
    197     {
    198         cls = responses_data[s];
    199         int* count_data = count[cls]->data.i;
    200         double* sum_data = sum[cls]->data.db;
    201         double* prod_data = productsum[cls]->data.db;
    202         const float* train_vec = train_data[s];
    203 
    204         for( c1 = 0; c1 < _var_count; c1++, prod_data += _var_count )
    205         {
    206             double val1 = train_vec[c1];
    207             sum_data[c1] += val1;
    208             count_data[c1]++;
    209             for( c2 = c1; c2 < _var_count; c2++ )
    210                 prod_data[c2] += train_vec[c2]*val1;
    211         }
    212     }
    213 
    214     /* calculate avg, covariance matrix, c */
    215     for( cls = 0; cls < nclasses; cls++ )
    216     {
    217         double det = 1;
    218         int i, j;
    219         CvMat* w = inv_eigen_values[cls];
    220         int* count_data = count[cls]->data.i;
    221         double* avg_data = avg[cls]->data.db;
    222         double* sum1 = sum[cls]->data.db;
    223 
    224         cvCompleteSymm( productsum[cls], 0 );
    225 
    226         for( j = 0; j < _var_count; j++ )
    227         {
    228             int n = count_data[j];
    229             avg_data[j] = n ? sum1[j] / n : 0.;
    230         }
    231 
    232         count_data = count[cls]->data.i;
    233         avg_data = avg[cls]->data.db;
    234         sum1 = sum[cls]->data.db;
    235 
    236         for( i = 0; i < _var_count; i++ )
    237         {
    238             double* avg2_data = avg[cls]->data.db;
    239             double* sum2 = sum[cls]->data.db;
    240             double* prod_data = productsum[cls]->data.db + i*_var_count;
    241             double* cov_data = cov->data.db + i*_var_count;
    242             double s1val = sum1[j];
    243             double avg1 = avg_data[i];
    244             int count = count_data[i];
    245 
    246             for( j = 0; j <= i; j++ )
    247             {
    248                 double avg2 = avg2_data[j];
    249                 double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * count;
    250                 cov_val = (count > 1) ? cov_val / (count - 1) : cov_val;
    251                 cov_data[j] = cov_val;
    252             }
    253         }
    254 
    255         CV_CALL( cvCompleteSymm( cov, 1 ));
    256         CV_CALL( cvSVD( cov, w, cov_rotate_mats[cls], 0, CV_SVD_U_T ));
    257         CV_CALL( cvMaxS( w, min_variation, w ));
    258         for( j = 0; j < _var_count; j++ )
    259             det *= w->data.db[j];
    260 
    261         CV_CALL( cvDiv( NULL, w, w ));
    262         c->data.db[cls] = log( det );
    263     }
    264 
    265     result = true;
    266 
    267     __END__;
    268 
    269     if( !result || cvGetErrStatus() < 0 )
    270         clear();
    271 
    272     cvReleaseMat( &cov );
    273     cvReleaseMat( &__cls_labels );
    274     cvReleaseMat( &__var_idx );
    275     cvFree( &train_data );
    276 
    277     return result;
    278 }
    279 
    280 
    281 float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const
    282 {
    283     float value = 0;
    284     void* buffer = 0;
    285     int allocated_buffer = 0;
    286 
    287     CV_FUNCNAME( "CvNormalBayesClassifier::predict" );
    288 
    289     __BEGIN__;
    290 
    291     int i, j, k, cls = -1, _var_count, nclasses;
    292     double opt = FLT_MAX;
    293     CvMat diff;
    294     int rtype = 0, rstep = 0, size;
    295     const int* vidx = 0;
    296 
    297     nclasses = cls_labels->cols;
    298     _var_count = avg[0]->cols;
    299 
    300     if( !CV_IS_MAT(samples) || CV_MAT_TYPE(samples->type) != CV_32FC1 || samples->cols != var_all )
    301         CV_ERROR( CV_StsBadArg,
    302         "The input samples must be 32f matrix with the number of columns = var_all" );
    303 
    304     if( samples->rows > 1 && !results )
    305         CV_ERROR( CV_StsNullPtr,
    306         "When the number of input samples is >1, the output vector of results must be passed" );
    307 
    308     if( results )
    309     {
    310         if( !CV_IS_MAT(results) || CV_MAT_TYPE(results->type) != CV_32FC1 &&
    311         CV_MAT_TYPE(results->type) != CV_32SC1 ||
    312         results->cols != 1 && results->rows != 1 ||
    313         results->cols + results->rows - 1 != samples->rows )
    314         CV_ERROR( CV_StsBadArg, "The output array must be integer or floating-point vector "
    315         "with the number of elements = number of rows in the input matrix" );
    316 
    317         rtype = CV_MAT_TYPE(results->type);
    318         rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);
    319     }
    320 
    321     if( var_idx )
    322         vidx = var_idx->data.i;
    323 
    324 // allocate memory and initializing headers for calculating
    325     size = sizeof(double) * (nclasses + var_count);
    326     if( size <= CV_MAX_LOCAL_SIZE )
    327         buffer = cvStackAlloc( size );
    328     else
    329     {
    330         CV_CALL( buffer = cvAlloc( size ));
    331         allocated_buffer = 1;
    332     }
    333 
    334     diff = cvMat( 1, var_count, CV_64FC1, buffer );
    335 
    336     for( k = 0; k < samples->rows; k++ )
    337     {
    338         int ival;
    339 
    340         for( i = 0; i < nclasses; i++ )
    341         {
    342             double cur = c->data.db[i];
    343             CvMat* u = cov_rotate_mats[i];
    344             CvMat* w = inv_eigen_values[i];
    345             const double* avg_data = avg[i]->data.db;
    346             const float* x = (const float*)(samples->data.ptr + samples->step*k);
    347 
    348             // cov = u w u'  -->  cov^(-1) = u w^(-1) u'
    349             for( j = 0; j < _var_count; j++ )
    350                 diff.data.db[j] = avg_data[j] - x[vidx ? vidx[j] : j];
    351 
    352             CV_CALL(cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T ));
    353             for( j = 0; j < _var_count; j++ )
    354             {
    355                 double d = diff.data.db[j];
    356                 cur += d*d*w->data.db[j];
    357             }
    358 
    359             if( cur < opt )
    360             {
    361                 cls = i;
    362                 opt = cur;
    363             }
    364             /* probability = exp( -0.5 * cur ) */
    365         }
    366 
    367         ival = cls_labels->data.i[cls];
    368         if( results )
    369         {
    370             if( rtype == CV_32SC1 )
    371                 results->data.i[k*rstep] = ival;
    372             else
    373                 results->data.fl[k*rstep] = (float)ival;
    374         }
    375         if( k == 0 )
    376             value = (float)ival;
    377 
    378         /*if( _probs )
    379         {
    380             CV_CALL( cvConvertScale( &expo, &expo, -0.5 ));
    381             CV_CALL( cvExp( &expo, &expo ));
    382             if( _probs->cols == 1 )
    383                 CV_CALL( cvReshape( &expo, &expo, 1, nclasses ));
    384             CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] ));
    385         }*/
    386     }
    387 
    388     __END__;
    389 
    390     if( allocated_buffer )
    391         cvFree( &buffer );
    392 
    393     return value;
    394 }
    395 
    396 
    397 void CvNormalBayesClassifier::write( CvFileStorage* fs, const char* name )
    398 {
    399     CV_FUNCNAME( "CvNormalBayesClassifier::write" );
    400 
    401     __BEGIN__;
    402 
    403     int nclasses, i;
    404 
    405     nclasses = cls_labels->cols;
    406 
    407     cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_NBAYES );
    408 
    409     CV_CALL( cvWriteInt( fs, "var_count", var_count ));
    410     CV_CALL( cvWriteInt( fs, "var_all", var_all ));
    411 
    412     if( var_idx )
    413         CV_CALL( cvWrite( fs, "var_idx", var_idx ));
    414     CV_CALL( cvWrite( fs, "cls_labels", cls_labels ));
    415 
    416     CV_CALL( cvStartWriteStruct( fs, "count", CV_NODE_SEQ ));
    417     for( i = 0; i < nclasses; i++ )
    418         CV_CALL( cvWrite( fs, NULL, count[i] ));
    419     CV_CALL( cvEndWriteStruct( fs ));
    420 
    421     CV_CALL( cvStartWriteStruct( fs, "sum", CV_NODE_SEQ ));
    422     for( i = 0; i < nclasses; i++ )
    423         CV_CALL( cvWrite( fs, NULL, sum[i] ));
    424     CV_CALL( cvEndWriteStruct( fs ));
    425 
    426     CV_CALL( cvStartWriteStruct( fs, "productsum", CV_NODE_SEQ ));
    427     for( i = 0; i < nclasses; i++ )
    428         CV_CALL( cvWrite( fs, NULL, productsum[i] ));
    429     CV_CALL( cvEndWriteStruct( fs ));
    430 
    431     CV_CALL( cvStartWriteStruct( fs, "avg", CV_NODE_SEQ ));
    432     for( i = 0; i < nclasses; i++ )
    433         CV_CALL( cvWrite( fs, NULL, avg[i] ));
    434     CV_CALL( cvEndWriteStruct( fs ));
    435 
    436     CV_CALL( cvStartWriteStruct( fs, "inv_eigen_values", CV_NODE_SEQ ));
    437     for( i = 0; i < nclasses; i++ )
    438         CV_CALL( cvWrite( fs, NULL, inv_eigen_values[i] ));
    439     CV_CALL( cvEndWriteStruct( fs ));
    440 
    441     CV_CALL( cvStartWriteStruct( fs, "cov_rotate_mats", CV_NODE_SEQ ));
    442     for( i = 0; i < nclasses; i++ )
    443         CV_CALL( cvWrite( fs, NULL, cov_rotate_mats[i] ));
    444     CV_CALL( cvEndWriteStruct( fs ));
    445 
    446     CV_CALL( cvWrite( fs, "c", c ));
    447 
    448     cvEndWriteStruct( fs );
    449 
    450     __END__;
    451 }
    452 
    453 
    454 void CvNormalBayesClassifier::read( CvFileStorage* fs, CvFileNode* root_node )
    455 {
    456     bool ok = false;
    457     CV_FUNCNAME( "CvNormalBayesClassifier::read" );
    458 
    459     __BEGIN__;
    460 
    461     int nclasses, i;
    462     size_t data_size;
    463     CvFileNode* node;
    464     CvSeq* seq;
    465     CvSeqReader reader;
    466 
    467     clear();
    468 
    469     CV_CALL( var_count = cvReadIntByName( fs, root_node, "var_count", -1 ));
    470     CV_CALL( var_all = cvReadIntByName( fs, root_node, "var_all", -1 ));
    471     CV_CALL( var_idx = (CvMat*)cvReadByName( fs, root_node, "var_idx" ));
    472     CV_CALL( cls_labels = (CvMat*)cvReadByName( fs, root_node, "cls_labels" ));
    473     if( !cls_labels )
    474         CV_ERROR( CV_StsParseError, "No \"cls_labels\" in NBayes classifier" );
    475     if( cls_labels->cols < 1 )
    476         CV_ERROR( CV_StsBadArg, "Number of classes is less 1" );
    477     if( var_count <= 0 )
    478         CV_ERROR( CV_StsParseError,
    479         "The field \"var_count\" of NBayes classifier is missing" );
    480     nclasses = cls_labels->cols;
    481 
    482     data_size = nclasses*6*sizeof(CvMat*);
    483     CV_CALL( count = (CvMat**)cvAlloc( data_size ));
    484     memset( count, 0, data_size );
    485 
    486     sum = count + nclasses;
    487     productsum  = sum  + nclasses;
    488     avg = productsum + nclasses;
    489     inv_eigen_values = avg + nclasses;
    490     cov_rotate_mats = inv_eigen_values + nclasses;
    491 
    492     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "count" ));
    493     seq = node->data.seq;
    494     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
    495         CV_ERROR( CV_StsBadArg, "" );
    496     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
    497     for( i = 0; i < nclasses; i++ )
    498     {
    499         CV_CALL( count[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
    500         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    501     }
    502 
    503     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "sum" ));
    504     seq = node->data.seq;
    505     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
    506         CV_ERROR( CV_StsBadArg, "" );
    507     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
    508     for( i = 0; i < nclasses; i++ )
    509     {
    510         CV_CALL( sum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
    511         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    512     }
    513 
    514     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "productsum" ));
    515     seq = node->data.seq;
    516     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
    517         CV_ERROR( CV_StsBadArg, "" );
    518     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
    519     for( i = 0; i < nclasses; i++ )
    520     {
    521         CV_CALL( productsum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
    522         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    523     }
    524 
    525     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "avg" ));
    526     seq = node->data.seq;
    527     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
    528         CV_ERROR( CV_StsBadArg, "" );
    529     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
    530     for( i = 0; i < nclasses; i++ )
    531     {
    532         CV_CALL( avg[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
    533         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    534     }
    535 
    536     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "inv_eigen_values" ));
    537     seq = node->data.seq;
    538     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
    539         CV_ERROR( CV_StsBadArg, "" );
    540     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
    541     for( i = 0; i < nclasses; i++ )
    542     {
    543         CV_CALL( inv_eigen_values[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
    544         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    545     }
    546 
    547     CV_CALL( node = cvGetFileNodeByName( fs, root_node, "cov_rotate_mats" ));
    548     seq = node->data.seq;
    549     if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses)
    550         CV_ERROR( CV_StsBadArg, "" );
    551     CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
    552     for( i = 0; i < nclasses; i++ )
    553     {
    554         CV_CALL( cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
    555         CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    556     }
    557 
    558     CV_CALL( c = (CvMat*)cvReadByName( fs, root_node, "c" ));
    559 
    560     ok = true;
    561 
    562     __END__;
    563 
    564     if( !ok )
    565         clear();
    566 }
    567 
    568 /* End of file. */
    569 
    570