1 /*M/////////////////////////////////////////////////////////////////////////////////////// 2 // 3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4 // 5 // By downloading, copying, installing or using the software you agree to this license. 6 // If you do not agree to this license, do not download, install, 7 // copy or use the software. 8 // 9 // 10 // Intel License Agreement 11 // 12 // Copyright (C) 2000, Intel Corporation, all rights reserved. 13 // Third party copyrights are property of their respective owners. 14 // 15 // Redistribution and use in source and binary forms, with or without modification, 16 // are permitted provided that the following conditions are met: 17 // 18 // * Redistribution's of source code must retain the above copyright notice, 19 // this list of conditions and the following disclaimer. 20 // 21 // * Redistribution's in binary form must reproduce the above copyright notice, 22 // this list of conditions and the following disclaimer in the documentation 23 // and/or other materials provided with the distribution. 24 // 25 // * The name of Intel Corporation may not be used to endorse or promote products 26 // derived from this software without specific prior written permission. 27 // 28 // This software is provided by the copyright holders and contributors "as is" and 29 // any express or implied warranties, including, but not limited to, the implied 30 // warranties of merchantability and fitness for a particular purpose are disclaimed. 31 // In no event shall the Intel Corporation or contributors be liable for any direct, 32 // indirect, incidental, special, exemplary, or consequential damages 33 // (including, but not limited to, procurement of substitute goods or services; 34 // loss of use, data, or profits; or business interruption) however caused 35 // and on any theory of liability, whether in contract, strict liability, 36 // or tort (including negligence or otherwise) arising in any way out of 37 // the use of this software, even if advised of the possibility of such damage. 38 // 39 //M*/ 40 41 #include "_ml.h" 42 43 CvNormalBayesClassifier::CvNormalBayesClassifier() 44 { 45 var_count = var_all = 0; 46 var_idx = 0; 47 cls_labels = 0; 48 count = 0; 49 sum = 0; 50 productsum = 0; 51 avg = 0; 52 inv_eigen_values = 0; 53 cov_rotate_mats = 0; 54 c = 0; 55 default_model_name = "my_nb"; 56 } 57 58 59 void CvNormalBayesClassifier::clear() 60 { 61 if( cls_labels ) 62 { 63 for( int cls = 0; cls < cls_labels->cols; cls++ ) 64 { 65 cvReleaseMat( &count[cls] ); 66 cvReleaseMat( &sum[cls] ); 67 cvReleaseMat( &productsum[cls] ); 68 cvReleaseMat( &avg[cls] ); 69 cvReleaseMat( &inv_eigen_values[cls] ); 70 cvReleaseMat( &cov_rotate_mats[cls] ); 71 } 72 } 73 74 cvReleaseMat( &cls_labels ); 75 cvReleaseMat( &var_idx ); 76 cvReleaseMat( &c ); 77 cvFree( &count ); 78 } 79 80 81 CvNormalBayesClassifier::~CvNormalBayesClassifier() 82 { 83 clear(); 84 } 85 86 87 CvNormalBayesClassifier::CvNormalBayesClassifier( 88 const CvMat* _train_data, const CvMat* _responses, 89 const CvMat* _var_idx, const CvMat* _sample_idx ) 90 { 91 var_count = var_all = 0; 92 var_idx = 0; 93 cls_labels = 0; 94 count = 0; 95 sum = 0; 96 productsum = 0; 97 avg = 0; 98 inv_eigen_values = 0; 99 cov_rotate_mats = 0; 100 c = 0; 101 default_model_name = "my_nb"; 102 103 train( _train_data, _responses, _var_idx, _sample_idx ); 104 } 105 106 107 bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _responses, 108 const CvMat* _var_idx, const CvMat* _sample_idx, bool update ) 109 { 110 const float min_variation = FLT_EPSILON; 111 bool result = false; 112 CvMat* responses = 0; 113 const float** train_data = 0; 114 CvMat* __cls_labels = 0; 115 CvMat* __var_idx = 0; 116 CvMat* cov = 0; 117 118 CV_FUNCNAME( "CvNormalBayesClassifier::train" ); 119 120 __BEGIN__; 121 122 int cls, nsamples = 0, _var_count = 0, _var_all = 0, nclasses = 0; 123 int s, c1, c2; 124 const int* responses_data; 125 126 CV_CALL( cvPrepareTrainData( 0, 127 _train_data, CV_ROW_SAMPLE, _responses, CV_VAR_CATEGORICAL, 128 _var_idx, _sample_idx, false, &train_data, 129 &nsamples, &_var_count, &_var_all, &responses, 130 &__cls_labels, &__var_idx )); 131 132 if( !update ) 133 { 134 const size_t mat_size = sizeof(CvMat*); 135 size_t data_size; 136 137 clear(); 138 139 var_idx = __var_idx; 140 cls_labels = __cls_labels; 141 __var_idx = __cls_labels = 0; 142 var_count = _var_count; 143 var_all = _var_all; 144 145 nclasses = cls_labels->cols; 146 data_size = nclasses*6*mat_size; 147 148 CV_CALL( count = (CvMat**)cvAlloc( data_size )); 149 memset( count, 0, data_size ); 150 151 sum = count + nclasses; 152 productsum = sum + nclasses; 153 avg = productsum + nclasses; 154 inv_eigen_values= avg + nclasses; 155 cov_rotate_mats = inv_eigen_values + nclasses; 156 157 CV_CALL( c = cvCreateMat( 1, nclasses, CV_64FC1 )); 158 159 for( cls = 0; cls < nclasses; cls++ ) 160 { 161 CV_CALL(count[cls] = cvCreateMat( 1, var_count, CV_32SC1 )); 162 CV_CALL(sum[cls] = cvCreateMat( 1, var_count, CV_64FC1 )); 163 CV_CALL(productsum[cls] = cvCreateMat( var_count, var_count, CV_64FC1 )); 164 CV_CALL(avg[cls] = cvCreateMat( 1, var_count, CV_64FC1 )); 165 CV_CALL(inv_eigen_values[cls] = cvCreateMat( 1, var_count, CV_64FC1 )); 166 CV_CALL(cov_rotate_mats[cls] = cvCreateMat( var_count, var_count, CV_64FC1 )); 167 CV_CALL(cvZero( count[cls] )); 168 CV_CALL(cvZero( sum[cls] )); 169 CV_CALL(cvZero( productsum[cls] )); 170 CV_CALL(cvZero( avg[cls] )); 171 CV_CALL(cvZero( inv_eigen_values[cls] )); 172 CV_CALL(cvZero( cov_rotate_mats[cls] )); 173 } 174 } 175 else 176 { 177 // check that the new training data has the same dimensionality etc. 178 if( _var_count != var_count || _var_all != var_all || !(!_var_idx && !var_idx || 179 _var_idx && var_idx && cvNorm(_var_idx,var_idx,CV_C) < DBL_EPSILON) ) 180 CV_ERROR( CV_StsBadArg, 181 "The new training data is inconsistent with the original training data" ); 182 183 if( cls_labels->cols != __cls_labels->cols || 184 cvNorm(cls_labels, __cls_labels, CV_C) > DBL_EPSILON ) 185 CV_ERROR( CV_StsNotImplemented, 186 "In the current implementation the new training data must have absolutely " 187 "the same set of class labels as used in the original training data" ); 188 189 nclasses = cls_labels->cols; 190 } 191 192 responses_data = responses->data.i; 193 CV_CALL( cov = cvCreateMat( _var_count, _var_count, CV_64FC1 )); 194 195 /* process train data (count, sum , productsum) */ 196 for( s = 0; s < nsamples; s++ ) 197 { 198 cls = responses_data[s]; 199 int* count_data = count[cls]->data.i; 200 double* sum_data = sum[cls]->data.db; 201 double* prod_data = productsum[cls]->data.db; 202 const float* train_vec = train_data[s]; 203 204 for( c1 = 0; c1 < _var_count; c1++, prod_data += _var_count ) 205 { 206 double val1 = train_vec[c1]; 207 sum_data[c1] += val1; 208 count_data[c1]++; 209 for( c2 = c1; c2 < _var_count; c2++ ) 210 prod_data[c2] += train_vec[c2]*val1; 211 } 212 } 213 214 /* calculate avg, covariance matrix, c */ 215 for( cls = 0; cls < nclasses; cls++ ) 216 { 217 double det = 1; 218 int i, j; 219 CvMat* w = inv_eigen_values[cls]; 220 int* count_data = count[cls]->data.i; 221 double* avg_data = avg[cls]->data.db; 222 double* sum1 = sum[cls]->data.db; 223 224 cvCompleteSymm( productsum[cls], 0 ); 225 226 for( j = 0; j < _var_count; j++ ) 227 { 228 int n = count_data[j]; 229 avg_data[j] = n ? sum1[j] / n : 0.; 230 } 231 232 count_data = count[cls]->data.i; 233 avg_data = avg[cls]->data.db; 234 sum1 = sum[cls]->data.db; 235 236 for( i = 0; i < _var_count; i++ ) 237 { 238 double* avg2_data = avg[cls]->data.db; 239 double* sum2 = sum[cls]->data.db; 240 double* prod_data = productsum[cls]->data.db + i*_var_count; 241 double* cov_data = cov->data.db + i*_var_count; 242 double s1val = sum1[j]; 243 double avg1 = avg_data[i]; 244 int count = count_data[i]; 245 246 for( j = 0; j <= i; j++ ) 247 { 248 double avg2 = avg2_data[j]; 249 double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * count; 250 cov_val = (count > 1) ? cov_val / (count - 1) : cov_val; 251 cov_data[j] = cov_val; 252 } 253 } 254 255 CV_CALL( cvCompleteSymm( cov, 1 )); 256 CV_CALL( cvSVD( cov, w, cov_rotate_mats[cls], 0, CV_SVD_U_T )); 257 CV_CALL( cvMaxS( w, min_variation, w )); 258 for( j = 0; j < _var_count; j++ ) 259 det *= w->data.db[j]; 260 261 CV_CALL( cvDiv( NULL, w, w )); 262 c->data.db[cls] = log( det ); 263 } 264 265 result = true; 266 267 __END__; 268 269 if( !result || cvGetErrStatus() < 0 ) 270 clear(); 271 272 cvReleaseMat( &cov ); 273 cvReleaseMat( &__cls_labels ); 274 cvReleaseMat( &__var_idx ); 275 cvFree( &train_data ); 276 277 return result; 278 } 279 280 281 float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const 282 { 283 float value = 0; 284 void* buffer = 0; 285 int allocated_buffer = 0; 286 287 CV_FUNCNAME( "CvNormalBayesClassifier::predict" ); 288 289 __BEGIN__; 290 291 int i, j, k, cls = -1, _var_count, nclasses; 292 double opt = FLT_MAX; 293 CvMat diff; 294 int rtype = 0, rstep = 0, size; 295 const int* vidx = 0; 296 297 nclasses = cls_labels->cols; 298 _var_count = avg[0]->cols; 299 300 if( !CV_IS_MAT(samples) || CV_MAT_TYPE(samples->type) != CV_32FC1 || samples->cols != var_all ) 301 CV_ERROR( CV_StsBadArg, 302 "The input samples must be 32f matrix with the number of columns = var_all" ); 303 304 if( samples->rows > 1 && !results ) 305 CV_ERROR( CV_StsNullPtr, 306 "When the number of input samples is >1, the output vector of results must be passed" ); 307 308 if( results ) 309 { 310 if( !CV_IS_MAT(results) || CV_MAT_TYPE(results->type) != CV_32FC1 && 311 CV_MAT_TYPE(results->type) != CV_32SC1 || 312 results->cols != 1 && results->rows != 1 || 313 results->cols + results->rows - 1 != samples->rows ) 314 CV_ERROR( CV_StsBadArg, "The output array must be integer or floating-point vector " 315 "with the number of elements = number of rows in the input matrix" ); 316 317 rtype = CV_MAT_TYPE(results->type); 318 rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype); 319 } 320 321 if( var_idx ) 322 vidx = var_idx->data.i; 323 324 // allocate memory and initializing headers for calculating 325 size = sizeof(double) * (nclasses + var_count); 326 if( size <= CV_MAX_LOCAL_SIZE ) 327 buffer = cvStackAlloc( size ); 328 else 329 { 330 CV_CALL( buffer = cvAlloc( size )); 331 allocated_buffer = 1; 332 } 333 334 diff = cvMat( 1, var_count, CV_64FC1, buffer ); 335 336 for( k = 0; k < samples->rows; k++ ) 337 { 338 int ival; 339 340 for( i = 0; i < nclasses; i++ ) 341 { 342 double cur = c->data.db[i]; 343 CvMat* u = cov_rotate_mats[i]; 344 CvMat* w = inv_eigen_values[i]; 345 const double* avg_data = avg[i]->data.db; 346 const float* x = (const float*)(samples->data.ptr + samples->step*k); 347 348 // cov = u w u' --> cov^(-1) = u w^(-1) u' 349 for( j = 0; j < _var_count; j++ ) 350 diff.data.db[j] = avg_data[j] - x[vidx ? vidx[j] : j]; 351 352 CV_CALL(cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T )); 353 for( j = 0; j < _var_count; j++ ) 354 { 355 double d = diff.data.db[j]; 356 cur += d*d*w->data.db[j]; 357 } 358 359 if( cur < opt ) 360 { 361 cls = i; 362 opt = cur; 363 } 364 /* probability = exp( -0.5 * cur ) */ 365 } 366 367 ival = cls_labels->data.i[cls]; 368 if( results ) 369 { 370 if( rtype == CV_32SC1 ) 371 results->data.i[k*rstep] = ival; 372 else 373 results->data.fl[k*rstep] = (float)ival; 374 } 375 if( k == 0 ) 376 value = (float)ival; 377 378 /*if( _probs ) 379 { 380 CV_CALL( cvConvertScale( &expo, &expo, -0.5 )); 381 CV_CALL( cvExp( &expo, &expo )); 382 if( _probs->cols == 1 ) 383 CV_CALL( cvReshape( &expo, &expo, 1, nclasses )); 384 CV_CALL( cvConvertScale( &expo, _probs, 1./cvSum( &expo ).val[0] )); 385 }*/ 386 } 387 388 __END__; 389 390 if( allocated_buffer ) 391 cvFree( &buffer ); 392 393 return value; 394 } 395 396 397 void CvNormalBayesClassifier::write( CvFileStorage* fs, const char* name ) 398 { 399 CV_FUNCNAME( "CvNormalBayesClassifier::write" ); 400 401 __BEGIN__; 402 403 int nclasses, i; 404 405 nclasses = cls_labels->cols; 406 407 cvStartWriteStruct( fs, name, CV_NODE_MAP, CV_TYPE_NAME_ML_NBAYES ); 408 409 CV_CALL( cvWriteInt( fs, "var_count", var_count )); 410 CV_CALL( cvWriteInt( fs, "var_all", var_all )); 411 412 if( var_idx ) 413 CV_CALL( cvWrite( fs, "var_idx", var_idx )); 414 CV_CALL( cvWrite( fs, "cls_labels", cls_labels )); 415 416 CV_CALL( cvStartWriteStruct( fs, "count", CV_NODE_SEQ )); 417 for( i = 0; i < nclasses; i++ ) 418 CV_CALL( cvWrite( fs, NULL, count[i] )); 419 CV_CALL( cvEndWriteStruct( fs )); 420 421 CV_CALL( cvStartWriteStruct( fs, "sum", CV_NODE_SEQ )); 422 for( i = 0; i < nclasses; i++ ) 423 CV_CALL( cvWrite( fs, NULL, sum[i] )); 424 CV_CALL( cvEndWriteStruct( fs )); 425 426 CV_CALL( cvStartWriteStruct( fs, "productsum", CV_NODE_SEQ )); 427 for( i = 0; i < nclasses; i++ ) 428 CV_CALL( cvWrite( fs, NULL, productsum[i] )); 429 CV_CALL( cvEndWriteStruct( fs )); 430 431 CV_CALL( cvStartWriteStruct( fs, "avg", CV_NODE_SEQ )); 432 for( i = 0; i < nclasses; i++ ) 433 CV_CALL( cvWrite( fs, NULL, avg[i] )); 434 CV_CALL( cvEndWriteStruct( fs )); 435 436 CV_CALL( cvStartWriteStruct( fs, "inv_eigen_values", CV_NODE_SEQ )); 437 for( i = 0; i < nclasses; i++ ) 438 CV_CALL( cvWrite( fs, NULL, inv_eigen_values[i] )); 439 CV_CALL( cvEndWriteStruct( fs )); 440 441 CV_CALL( cvStartWriteStruct( fs, "cov_rotate_mats", CV_NODE_SEQ )); 442 for( i = 0; i < nclasses; i++ ) 443 CV_CALL( cvWrite( fs, NULL, cov_rotate_mats[i] )); 444 CV_CALL( cvEndWriteStruct( fs )); 445 446 CV_CALL( cvWrite( fs, "c", c )); 447 448 cvEndWriteStruct( fs ); 449 450 __END__; 451 } 452 453 454 void CvNormalBayesClassifier::read( CvFileStorage* fs, CvFileNode* root_node ) 455 { 456 bool ok = false; 457 CV_FUNCNAME( "CvNormalBayesClassifier::read" ); 458 459 __BEGIN__; 460 461 int nclasses, i; 462 size_t data_size; 463 CvFileNode* node; 464 CvSeq* seq; 465 CvSeqReader reader; 466 467 clear(); 468 469 CV_CALL( var_count = cvReadIntByName( fs, root_node, "var_count", -1 )); 470 CV_CALL( var_all = cvReadIntByName( fs, root_node, "var_all", -1 )); 471 CV_CALL( var_idx = (CvMat*)cvReadByName( fs, root_node, "var_idx" )); 472 CV_CALL( cls_labels = (CvMat*)cvReadByName( fs, root_node, "cls_labels" )); 473 if( !cls_labels ) 474 CV_ERROR( CV_StsParseError, "No \"cls_labels\" in NBayes classifier" ); 475 if( cls_labels->cols < 1 ) 476 CV_ERROR( CV_StsBadArg, "Number of classes is less 1" ); 477 if( var_count <= 0 ) 478 CV_ERROR( CV_StsParseError, 479 "The field \"var_count\" of NBayes classifier is missing" ); 480 nclasses = cls_labels->cols; 481 482 data_size = nclasses*6*sizeof(CvMat*); 483 CV_CALL( count = (CvMat**)cvAlloc( data_size )); 484 memset( count, 0, data_size ); 485 486 sum = count + nclasses; 487 productsum = sum + nclasses; 488 avg = productsum + nclasses; 489 inv_eigen_values = avg + nclasses; 490 cov_rotate_mats = inv_eigen_values + nclasses; 491 492 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "count" )); 493 seq = node->data.seq; 494 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses) 495 CV_ERROR( CV_StsBadArg, "" ); 496 CV_CALL( cvStartReadSeq( seq, &reader, 0 )); 497 for( i = 0; i < nclasses; i++ ) 498 { 499 CV_CALL( count[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr )); 500 CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); 501 } 502 503 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "sum" )); 504 seq = node->data.seq; 505 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses) 506 CV_ERROR( CV_StsBadArg, "" ); 507 CV_CALL( cvStartReadSeq( seq, &reader, 0 )); 508 for( i = 0; i < nclasses; i++ ) 509 { 510 CV_CALL( sum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr )); 511 CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); 512 } 513 514 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "productsum" )); 515 seq = node->data.seq; 516 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses) 517 CV_ERROR( CV_StsBadArg, "" ); 518 CV_CALL( cvStartReadSeq( seq, &reader, 0 )); 519 for( i = 0; i < nclasses; i++ ) 520 { 521 CV_CALL( productsum[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr )); 522 CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); 523 } 524 525 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "avg" )); 526 seq = node->data.seq; 527 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses) 528 CV_ERROR( CV_StsBadArg, "" ); 529 CV_CALL( cvStartReadSeq( seq, &reader, 0 )); 530 for( i = 0; i < nclasses; i++ ) 531 { 532 CV_CALL( avg[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr )); 533 CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); 534 } 535 536 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "inv_eigen_values" )); 537 seq = node->data.seq; 538 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses) 539 CV_ERROR( CV_StsBadArg, "" ); 540 CV_CALL( cvStartReadSeq( seq, &reader, 0 )); 541 for( i = 0; i < nclasses; i++ ) 542 { 543 CV_CALL( inv_eigen_values[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr )); 544 CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); 545 } 546 547 CV_CALL( node = cvGetFileNodeByName( fs, root_node, "cov_rotate_mats" )); 548 seq = node->data.seq; 549 if( !CV_NODE_IS_SEQ(node->tag) || seq->total != nclasses) 550 CV_ERROR( CV_StsBadArg, "" ); 551 CV_CALL( cvStartReadSeq( seq, &reader, 0 )); 552 for( i = 0; i < nclasses; i++ ) 553 { 554 CV_CALL( cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr )); 555 CV_NEXT_SEQ_ELEM( seq->elem_size, reader ); 556 } 557 558 CV_CALL( c = (CvMat*)cvReadByName( fs, root_node, "c" )); 559 560 ok = true; 561 562 __END__; 563 564 if( !ok ) 565 clear(); 566 } 567 568 /* End of file. */ 569 570