Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright( C) 2000, Intel Corporation, all rights reserved.
     14 // Third party copyrights are property of their respective owners.
     15 //
     16 // Redistribution and use in source and binary forms, with or without modification,
     17 // are permitted provided that the following conditions are met:
     18 //
     19 //   * Redistribution's of source code must retain the above copyright notice,
     20 //     this list of conditions and the following disclaimer.
     21 //
     22 //   * Redistribution's in binary form must reproduce the above copyright notice,
     23 //     this list of conditions and the following disclaimer in the documentation
     24 //     and/or other materials provided with the distribution.
     25 //
     26 //   * The name of Intel Corporation may not be used to endorse or promote products
     27 //     derived from this software without specific prior written permission.
     28 //
     29 // This software is provided by the copyright holders and contributors "as is" and
     30 // any express or implied warranties, including, but not limited to, the implied
     31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     32 // In no event shall the Intel Corporation or contributors be liable for any direct,
     33 // indirect, incidental, special, exemplary, or consequential damages
     34 //(including, but not limited to, procurement of substitute goods or services;
     35 // loss of use, data, or profits; or business interruption) however caused
     36 // and on any theory of liability, whether in contract, strict liability,
     37 // or tort(including negligence or otherwise) arising in any way out of
     38 // the use of this software, even ifadvised of the possibility of such damage.
     39 //
     40 //M*/
     41 
     42 #include "precomp.hpp"
     43 
     44 namespace cv
     45 {
     46 namespace ml
     47 {
     48 
     49 const double minEigenValue = DBL_EPSILON;
     50 
     51 class CV_EXPORTS EMImpl : public EM
     52 {
     53 public:
     54 
     55     int nclusters;
     56     int covMatType;
     57     TermCriteria termCrit;
     58 
     59     CV_IMPL_PROPERTY_S(TermCriteria, TermCriteria, termCrit)
     60 
     61     void setClustersNumber(int val)
     62     {
     63         nclusters = val;
     64         CV_Assert(nclusters > 1);
     65     }
     66 
     67     int getClustersNumber() const
     68     {
     69         return nclusters;
     70     }
     71 
     72     void setCovarianceMatrixType(int val)
     73     {
     74         covMatType = val;
     75         CV_Assert(covMatType == COV_MAT_SPHERICAL ||
     76                   covMatType == COV_MAT_DIAGONAL ||
     77                   covMatType == COV_MAT_GENERIC);
     78     }
     79 
     80     int getCovarianceMatrixType() const
     81     {
     82         return covMatType;
     83     }
     84 
     85     EMImpl()
     86     {
     87         nclusters = DEFAULT_NCLUSTERS;
     88         covMatType=EM::COV_MAT_DIAGONAL;
     89         termCrit = TermCriteria(TermCriteria::COUNT+TermCriteria::EPS, EM::DEFAULT_MAX_ITERS, 1e-6);
     90     }
     91 
     92     virtual ~EMImpl() {}
     93 
     94     void clear()
     95     {
     96         trainSamples.release();
     97         trainProbs.release();
     98         trainLogLikelihoods.release();
     99         trainLabels.release();
    100 
    101         weights.release();
    102         means.release();
    103         covs.clear();
    104 
    105         covsEigenValues.clear();
    106         invCovsEigenValues.clear();
    107         covsRotateMats.clear();
    108 
    109         logWeightDivDet.release();
    110     }
    111 
    112     bool train(const Ptr<TrainData>& data, int)
    113     {
    114         Mat samples = data->getTrainSamples(), labels;
    115         return trainEM(samples, labels, noArray(), noArray());
    116     }
    117 
    118     bool trainEM(InputArray samples,
    119                OutputArray logLikelihoods,
    120                OutputArray labels,
    121                OutputArray probs)
    122     {
    123         Mat samplesMat = samples.getMat();
    124         setTrainData(START_AUTO_STEP, samplesMat, 0, 0, 0, 0);
    125         return doTrain(START_AUTO_STEP, logLikelihoods, labels, probs);
    126     }
    127 
    128     bool trainE(InputArray samples,
    129                 InputArray _means0,
    130                 InputArray _covs0,
    131                 InputArray _weights0,
    132                 OutputArray logLikelihoods,
    133                 OutputArray labels,
    134                 OutputArray probs)
    135     {
    136         Mat samplesMat = samples.getMat();
    137         std::vector<Mat> covs0;
    138         _covs0.getMatVector(covs0);
    139 
    140         Mat means0 = _means0.getMat(), weights0 = _weights0.getMat();
    141 
    142         setTrainData(START_E_STEP, samplesMat, 0, !_means0.empty() ? &means0 : 0,
    143                      !_covs0.empty() ? &covs0 : 0, !_weights0.empty() ? &weights0 : 0);
    144         return doTrain(START_E_STEP, logLikelihoods, labels, probs);
    145     }
    146 
    147     bool trainM(InputArray samples,
    148                 InputArray _probs0,
    149                 OutputArray logLikelihoods,
    150                 OutputArray labels,
    151                 OutputArray probs)
    152     {
    153         Mat samplesMat = samples.getMat();
    154         Mat probs0 = _probs0.getMat();
    155 
    156         setTrainData(START_M_STEP, samplesMat, !_probs0.empty() ? &probs0 : 0, 0, 0, 0);
    157         return doTrain(START_M_STEP, logLikelihoods, labels, probs);
    158     }
    159 
    160     float predict(InputArray _inputs, OutputArray _outputs, int) const
    161     {
    162         bool needprobs = _outputs.needed();
    163         Mat samples = _inputs.getMat(), probs, probsrow;
    164         int ptype = CV_32F;
    165         float firstres = 0.f;
    166         int i, nsamples = samples.rows;
    167 
    168         if( needprobs )
    169         {
    170             if( _outputs.fixedType() )
    171                 ptype = _outputs.type();
    172             _outputs.create(samples.rows, nclusters, ptype);
    173         }
    174         else
    175             nsamples = std::min(nsamples, 1);
    176 
    177         for( i = 0; i < nsamples; i++ )
    178         {
    179             if( needprobs )
    180                 probsrow = probs.row(i);
    181             Vec2d res = computeProbabilities(samples.row(i), needprobs ? &probsrow : 0, ptype);
    182             if( i == 0 )
    183                 firstres = (float)res[1];
    184         }
    185         return firstres;
    186     }
    187 
    188     Vec2d predict2(InputArray _sample, OutputArray _probs) const
    189     {
    190         int ptype = CV_32F;
    191         Mat sample = _sample.getMat();
    192         CV_Assert(isTrained());
    193 
    194         CV_Assert(!sample.empty());
    195         if(sample.type() != CV_64FC1)
    196         {
    197             Mat tmp;
    198             sample.convertTo(tmp, CV_64FC1);
    199             sample = tmp;
    200         }
    201         sample.reshape(1, 1);
    202 
    203         Mat probs;
    204         if( _probs.needed() )
    205         {
    206             if( _probs.fixedType() )
    207                 ptype = _probs.type();
    208             _probs.create(1, nclusters, ptype);
    209             probs = _probs.getMat();
    210         }
    211 
    212         return computeProbabilities(sample, !probs.empty() ? &probs : 0, ptype);
    213     }
    214 
    215     bool isTrained() const
    216     {
    217         return !means.empty();
    218     }
    219 
    220     bool isClassifier() const
    221     {
    222         return true;
    223     }
    224 
    225     int getVarCount() const
    226     {
    227         return means.cols;
    228     }
    229 
    230     String getDefaultName() const
    231     {
    232         return "opencv_ml_em";
    233     }
    234 
    235     static void checkTrainData(int startStep, const Mat& samples,
    236                                int nclusters, int covMatType, const Mat* probs, const Mat* means,
    237                                const std::vector<Mat>* covs, const Mat* weights)
    238     {
    239         // Check samples.
    240         CV_Assert(!samples.empty());
    241         CV_Assert(samples.channels() == 1);
    242 
    243         int nsamples = samples.rows;
    244         int dim = samples.cols;
    245 
    246         // Check training params.
    247         CV_Assert(nclusters > 0);
    248         CV_Assert(nclusters <= nsamples);
    249         CV_Assert(startStep == START_AUTO_STEP ||
    250                   startStep == START_E_STEP ||
    251                   startStep == START_M_STEP);
    252         CV_Assert(covMatType == COV_MAT_GENERIC ||
    253                   covMatType == COV_MAT_DIAGONAL ||
    254                   covMatType == COV_MAT_SPHERICAL);
    255 
    256         CV_Assert(!probs ||
    257             (!probs->empty() &&
    258              probs->rows == nsamples && probs->cols == nclusters &&
    259              (probs->type() == CV_32FC1 || probs->type() == CV_64FC1)));
    260 
    261         CV_Assert(!weights ||
    262             (!weights->empty() &&
    263              (weights->cols == 1 || weights->rows == 1) && static_cast<int>(weights->total()) == nclusters &&
    264              (weights->type() == CV_32FC1 || weights->type() == CV_64FC1)));
    265 
    266         CV_Assert(!means ||
    267             (!means->empty() &&
    268              means->rows == nclusters && means->cols == dim &&
    269              means->channels() == 1));
    270 
    271         CV_Assert(!covs ||
    272             (!covs->empty() &&
    273              static_cast<int>(covs->size()) == nclusters));
    274         if(covs)
    275         {
    276             const Size covSize(dim, dim);
    277             for(size_t i = 0; i < covs->size(); i++)
    278             {
    279                 const Mat& m = (*covs)[i];
    280                 CV_Assert(!m.empty() && m.size() == covSize && (m.channels() == 1));
    281             }
    282         }
    283 
    284         if(startStep == START_E_STEP)
    285         {
    286             CV_Assert(means);
    287         }
    288         else if(startStep == START_M_STEP)
    289         {
    290             CV_Assert(probs);
    291         }
    292     }
    293 
    294     static void preprocessSampleData(const Mat& src, Mat& dst, int dstType, bool isAlwaysClone)
    295     {
    296         if(src.type() == dstType && !isAlwaysClone)
    297             dst = src;
    298         else
    299             src.convertTo(dst, dstType);
    300     }
    301 
    302     static void preprocessProbability(Mat& probs)
    303     {
    304         max(probs, 0., probs);
    305 
    306         const double uniformProbability = (double)(1./probs.cols);
    307         for(int y = 0; y < probs.rows; y++)
    308         {
    309             Mat sampleProbs = probs.row(y);
    310 
    311             double maxVal = 0;
    312             minMaxLoc(sampleProbs, 0, &maxVal);
    313             if(maxVal < FLT_EPSILON)
    314                 sampleProbs.setTo(uniformProbability);
    315             else
    316                 normalize(sampleProbs, sampleProbs, 1, 0, NORM_L1);
    317         }
    318     }
    319 
    320     void setTrainData(int startStep, const Mat& samples,
    321                       const Mat* probs0,
    322                       const Mat* means0,
    323                       const std::vector<Mat>* covs0,
    324                       const Mat* weights0)
    325     {
    326         clear();
    327 
    328         checkTrainData(startStep, samples, nclusters, covMatType, probs0, means0, covs0, weights0);
    329 
    330         bool isKMeansInit = (startStep == START_AUTO_STEP) || (startStep == START_E_STEP && (covs0 == 0 || weights0 == 0));
    331         // Set checked data
    332         preprocessSampleData(samples, trainSamples, isKMeansInit ? CV_32FC1 : CV_64FC1, false);
    333 
    334         // set probs
    335         if(probs0 && startStep == START_M_STEP)
    336         {
    337             preprocessSampleData(*probs0, trainProbs, CV_64FC1, true);
    338             preprocessProbability(trainProbs);
    339         }
    340 
    341         // set weights
    342         if(weights0 && (startStep == START_E_STEP && covs0))
    343         {
    344             weights0->convertTo(weights, CV_64FC1);
    345             weights.reshape(1,1);
    346             preprocessProbability(weights);
    347         }
    348 
    349         // set means
    350         if(means0 && (startStep == START_E_STEP/* || startStep == START_AUTO_STEP*/))
    351             means0->convertTo(means, isKMeansInit ? CV_32FC1 : CV_64FC1);
    352 
    353         // set covs
    354         if(covs0 && (startStep == START_E_STEP && weights0))
    355         {
    356             covs.resize(nclusters);
    357             for(size_t i = 0; i < covs0->size(); i++)
    358                 (*covs0)[i].convertTo(covs[i], CV_64FC1);
    359         }
    360     }
    361 
    362     void decomposeCovs()
    363     {
    364         CV_Assert(!covs.empty());
    365         covsEigenValues.resize(nclusters);
    366         if(covMatType == COV_MAT_GENERIC)
    367             covsRotateMats.resize(nclusters);
    368         invCovsEigenValues.resize(nclusters);
    369         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    370         {
    371             CV_Assert(!covs[clusterIndex].empty());
    372 
    373             SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
    374 
    375             if(covMatType == COV_MAT_SPHERICAL)
    376             {
    377                 double maxSingularVal = svd.w.at<double>(0);
    378                 covsEigenValues[clusterIndex] = Mat(1, 1, CV_64FC1, Scalar(maxSingularVal));
    379             }
    380             else if(covMatType == COV_MAT_DIAGONAL)
    381             {
    382                 covsEigenValues[clusterIndex] = svd.w;
    383             }
    384             else //COV_MAT_GENERIC
    385             {
    386                 covsEigenValues[clusterIndex] = svd.w;
    387                 covsRotateMats[clusterIndex] = svd.u;
    388             }
    389             max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
    390             invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
    391         }
    392     }
    393 
    394     void clusterTrainSamples()
    395     {
    396         int nsamples = trainSamples.rows;
    397 
    398         // Cluster samples, compute/update means
    399 
    400         // Convert samples and means to 32F, because kmeans requires this type.
    401         Mat trainSamplesFlt, meansFlt;
    402         if(trainSamples.type() != CV_32FC1)
    403             trainSamples.convertTo(trainSamplesFlt, CV_32FC1);
    404         else
    405             trainSamplesFlt = trainSamples;
    406         if(!means.empty())
    407         {
    408             if(means.type() != CV_32FC1)
    409                 means.convertTo(meansFlt, CV_32FC1);
    410             else
    411                 meansFlt = means;
    412         }
    413 
    414         Mat labels;
    415         kmeans(trainSamplesFlt, nclusters, labels,
    416                TermCriteria(TermCriteria::COUNT, means.empty() ? 10 : 1, 0.5),
    417                10, KMEANS_PP_CENTERS, meansFlt);
    418 
    419         // Convert samples and means back to 64F.
    420         CV_Assert(meansFlt.type() == CV_32FC1);
    421         if(trainSamples.type() != CV_64FC1)
    422         {
    423             Mat trainSamplesBuffer;
    424             trainSamplesFlt.convertTo(trainSamplesBuffer, CV_64FC1);
    425             trainSamples = trainSamplesBuffer;
    426         }
    427         meansFlt.convertTo(means, CV_64FC1);
    428 
    429         // Compute weights and covs
    430         weights = Mat(1, nclusters, CV_64FC1, Scalar(0));
    431         covs.resize(nclusters);
    432         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    433         {
    434             Mat clusterSamples;
    435             for(int sampleIndex = 0; sampleIndex < nsamples; sampleIndex++)
    436             {
    437                 if(labels.at<int>(sampleIndex) == clusterIndex)
    438                 {
    439                     const Mat sample = trainSamples.row(sampleIndex);
    440                     clusterSamples.push_back(sample);
    441                 }
    442             }
    443             CV_Assert(!clusterSamples.empty());
    444 
    445             calcCovarMatrix(clusterSamples, covs[clusterIndex], means.row(clusterIndex),
    446                 CV_COVAR_NORMAL + CV_COVAR_ROWS + CV_COVAR_USE_AVG + CV_COVAR_SCALE, CV_64FC1);
    447             weights.at<double>(clusterIndex) = static_cast<double>(clusterSamples.rows)/static_cast<double>(nsamples);
    448         }
    449 
    450         decomposeCovs();
    451     }
    452 
    453     void computeLogWeightDivDet()
    454     {
    455         CV_Assert(!covsEigenValues.empty());
    456 
    457         Mat logWeights;
    458         cv::max(weights, DBL_MIN, weights);
    459         log(weights, logWeights);
    460 
    461         logWeightDivDet.create(1, nclusters, CV_64FC1);
    462         // note: logWeightDivDet = log(weight_k) - 0.5 * log(|det(cov_k)|)
    463 
    464         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    465         {
    466             double logDetCov = 0.;
    467             const int evalCount = static_cast<int>(covsEigenValues[clusterIndex].total());
    468             for(int di = 0; di < evalCount; di++)
    469                 logDetCov += std::log(covsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0));
    470 
    471             logWeightDivDet.at<double>(clusterIndex) = logWeights.at<double>(clusterIndex) - 0.5 * logDetCov;
    472         }
    473     }
    474 
    475     bool doTrain(int startStep, OutputArray logLikelihoods, OutputArray labels, OutputArray probs)
    476     {
    477         int dim = trainSamples.cols;
    478         // Precompute the empty initial train data in the cases of START_E_STEP and START_AUTO_STEP
    479         if(startStep != START_M_STEP)
    480         {
    481             if(covs.empty())
    482             {
    483                 CV_Assert(weights.empty());
    484                 clusterTrainSamples();
    485             }
    486         }
    487 
    488         if(!covs.empty() && covsEigenValues.empty() )
    489         {
    490             CV_Assert(invCovsEigenValues.empty());
    491             decomposeCovs();
    492         }
    493 
    494         if(startStep == START_M_STEP)
    495             mStep();
    496 
    497         double trainLogLikelihood, prevTrainLogLikelihood = 0.;
    498         int maxIters = (termCrit.type & TermCriteria::MAX_ITER) ?
    499             termCrit.maxCount : DEFAULT_MAX_ITERS;
    500         double epsilon = (termCrit.type & TermCriteria::EPS) ? termCrit.epsilon : 0.;
    501 
    502         for(int iter = 0; ; iter++)
    503         {
    504             eStep();
    505             trainLogLikelihood = sum(trainLogLikelihoods)[0];
    506 
    507             if(iter >= maxIters - 1)
    508                 break;
    509 
    510             double trainLogLikelihoodDelta = trainLogLikelihood - prevTrainLogLikelihood;
    511             if( iter != 0 &&
    512                 (trainLogLikelihoodDelta < -DBL_EPSILON ||
    513                  trainLogLikelihoodDelta < epsilon * std::fabs(trainLogLikelihood)))
    514                 break;
    515 
    516             mStep();
    517 
    518             prevTrainLogLikelihood = trainLogLikelihood;
    519         }
    520 
    521         if( trainLogLikelihood <= -DBL_MAX/10000. )
    522         {
    523             clear();
    524             return false;
    525         }
    526 
    527         // postprocess covs
    528         covs.resize(nclusters);
    529         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    530         {
    531             if(covMatType == COV_MAT_SPHERICAL)
    532             {
    533                 covs[clusterIndex].create(dim, dim, CV_64FC1);
    534                 setIdentity(covs[clusterIndex], Scalar(covsEigenValues[clusterIndex].at<double>(0)));
    535             }
    536             else if(covMatType == COV_MAT_DIAGONAL)
    537             {
    538                 covs[clusterIndex] = Mat::diag(covsEigenValues[clusterIndex]);
    539             }
    540         }
    541 
    542         if(labels.needed())
    543             trainLabels.copyTo(labels);
    544         if(probs.needed())
    545             trainProbs.copyTo(probs);
    546         if(logLikelihoods.needed())
    547             trainLogLikelihoods.copyTo(logLikelihoods);
    548 
    549         trainSamples.release();
    550         trainProbs.release();
    551         trainLabels.release();
    552         trainLogLikelihoods.release();
    553 
    554         return true;
    555     }
    556 
    557     Vec2d computeProbabilities(const Mat& sample, Mat* probs, int ptype) const
    558     {
    559         // L_ik = log(weight_k) - 0.5 * log(|det(cov_k)|) - 0.5 *(x_i - mean_k)' cov_k^(-1) (x_i - mean_k)]
    560         // q = arg(max_k(L_ik))
    561         // probs_ik = exp(L_ik - L_iq) / (1 + sum_j!=q (exp(L_ij - L_iq))
    562         // see Alex Smola's blog http://blog.smola.org/page/2 for
    563         // details on the log-sum-exp trick
    564 
    565         int stype = sample.type();
    566         CV_Assert(!means.empty());
    567         CV_Assert((stype == CV_32F || stype == CV_64F) && (ptype == CV_32F || ptype == CV_64F));
    568         CV_Assert(sample.size() == Size(means.cols, 1));
    569 
    570         int dim = sample.cols;
    571 
    572         Mat L(1, nclusters, CV_64FC1), centeredSample(1, dim, CV_64F);
    573         int i, label = 0;
    574         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    575         {
    576             const double* mptr = means.ptr<double>(clusterIndex);
    577             double* dptr = centeredSample.ptr<double>();
    578             if( stype == CV_32F )
    579             {
    580                 const float* sptr = sample.ptr<float>();
    581                 for( i = 0; i < dim; i++ )
    582                     dptr[i] = sptr[i] - mptr[i];
    583             }
    584             else
    585             {
    586                 const double* sptr = sample.ptr<double>();
    587                 for( i = 0; i < dim; i++ )
    588                     dptr[i] = sptr[i] - mptr[i];
    589             }
    590 
    591             Mat rotatedCenteredSample = covMatType != COV_MAT_GENERIC ?
    592                     centeredSample : centeredSample * covsRotateMats[clusterIndex];
    593 
    594             double Lval = 0;
    595             for(int di = 0; di < dim; di++)
    596             {
    597                 double w = invCovsEigenValues[clusterIndex].at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0);
    598                 double val = rotatedCenteredSample.at<double>(di);
    599                 Lval += w * val * val;
    600             }
    601             CV_DbgAssert(!logWeightDivDet.empty());
    602             L.at<double>(clusterIndex) = logWeightDivDet.at<double>(clusterIndex) - 0.5 * Lval;
    603 
    604             if(L.at<double>(clusterIndex) > L.at<double>(label))
    605                 label = clusterIndex;
    606         }
    607 
    608         double maxLVal = L.at<double>(label);
    609         double expDiffSum = 0;
    610         for( i = 0; i < L.cols; i++ )
    611         {
    612             double v = std::exp(L.at<double>(i) - maxLVal);
    613             L.at<double>(i) = v;
    614             expDiffSum += v; // sum_j(exp(L_ij - L_iq))
    615         }
    616 
    617         if(probs)
    618             L.convertTo(*probs, ptype, 1./expDiffSum);
    619 
    620         Vec2d res;
    621         res[0] = std::log(expDiffSum)  + maxLVal - 0.5 * dim * CV_LOG2PI;
    622         res[1] = label;
    623 
    624         return res;
    625     }
    626 
    627     void eStep()
    628     {
    629         // Compute probs_ik from means_k, covs_k and weights_k.
    630         trainProbs.create(trainSamples.rows, nclusters, CV_64FC1);
    631         trainLabels.create(trainSamples.rows, 1, CV_32SC1);
    632         trainLogLikelihoods.create(trainSamples.rows, 1, CV_64FC1);
    633 
    634         computeLogWeightDivDet();
    635 
    636         CV_DbgAssert(trainSamples.type() == CV_64FC1);
    637         CV_DbgAssert(means.type() == CV_64FC1);
    638 
    639         for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
    640         {
    641             Mat sampleProbs = trainProbs.row(sampleIndex);
    642             Vec2d res = computeProbabilities(trainSamples.row(sampleIndex), &sampleProbs, CV_64F);
    643             trainLogLikelihoods.at<double>(sampleIndex) = res[0];
    644             trainLabels.at<int>(sampleIndex) = static_cast<int>(res[1]);
    645         }
    646     }
    647 
    648     void mStep()
    649     {
    650         // Update means_k, covs_k and weights_k from probs_ik
    651         int dim = trainSamples.cols;
    652 
    653         // Update weights
    654         // not normalized first
    655         reduce(trainProbs, weights, 0, CV_REDUCE_SUM);
    656 
    657         // Update means
    658         means.create(nclusters, dim, CV_64FC1);
    659         means = Scalar(0);
    660 
    661         const double minPosWeight = trainSamples.rows * DBL_EPSILON;
    662         double minWeight = DBL_MAX;
    663         int minWeightClusterIndex = -1;
    664         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    665         {
    666             if(weights.at<double>(clusterIndex) <= minPosWeight)
    667                 continue;
    668 
    669             if(weights.at<double>(clusterIndex) < minWeight)
    670             {
    671                 minWeight = weights.at<double>(clusterIndex);
    672                 minWeightClusterIndex = clusterIndex;
    673             }
    674 
    675             Mat clusterMean = means.row(clusterIndex);
    676             for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
    677                 clusterMean += trainProbs.at<double>(sampleIndex, clusterIndex) * trainSamples.row(sampleIndex);
    678             clusterMean /= weights.at<double>(clusterIndex);
    679         }
    680 
    681         // Update covsEigenValues and invCovsEigenValues
    682         covs.resize(nclusters);
    683         covsEigenValues.resize(nclusters);
    684         if(covMatType == COV_MAT_GENERIC)
    685             covsRotateMats.resize(nclusters);
    686         invCovsEigenValues.resize(nclusters);
    687         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    688         {
    689             if(weights.at<double>(clusterIndex) <= minPosWeight)
    690                 continue;
    691 
    692             if(covMatType != COV_MAT_SPHERICAL)
    693                 covsEigenValues[clusterIndex].create(1, dim, CV_64FC1);
    694             else
    695                 covsEigenValues[clusterIndex].create(1, 1, CV_64FC1);
    696 
    697             if(covMatType == COV_MAT_GENERIC)
    698                 covs[clusterIndex].create(dim, dim, CV_64FC1);
    699 
    700             Mat clusterCov = covMatType != COV_MAT_GENERIC ?
    701                 covsEigenValues[clusterIndex] : covs[clusterIndex];
    702 
    703             clusterCov = Scalar(0);
    704 
    705             Mat centeredSample;
    706             for(int sampleIndex = 0; sampleIndex < trainSamples.rows; sampleIndex++)
    707             {
    708                 centeredSample = trainSamples.row(sampleIndex) - means.row(clusterIndex);
    709 
    710                 if(covMatType == COV_MAT_GENERIC)
    711                     clusterCov += trainProbs.at<double>(sampleIndex, clusterIndex) * centeredSample.t() * centeredSample;
    712                 else
    713                 {
    714                     double p = trainProbs.at<double>(sampleIndex, clusterIndex);
    715                     for(int di = 0; di < dim; di++ )
    716                     {
    717                         double val = centeredSample.at<double>(di);
    718                         clusterCov.at<double>(covMatType != COV_MAT_SPHERICAL ? di : 0) += p*val*val;
    719                     }
    720                 }
    721             }
    722 
    723             if(covMatType == COV_MAT_SPHERICAL)
    724                 clusterCov /= dim;
    725 
    726             clusterCov /= weights.at<double>(clusterIndex);
    727 
    728             // Update covsRotateMats for COV_MAT_GENERIC only
    729             if(covMatType == COV_MAT_GENERIC)
    730             {
    731                 SVD svd(covs[clusterIndex], SVD::MODIFY_A + SVD::FULL_UV);
    732                 covsEigenValues[clusterIndex] = svd.w;
    733                 covsRotateMats[clusterIndex] = svd.u;
    734             }
    735 
    736             max(covsEigenValues[clusterIndex], minEigenValue, covsEigenValues[clusterIndex]);
    737 
    738             // update invCovsEigenValues
    739             invCovsEigenValues[clusterIndex] = 1./covsEigenValues[clusterIndex];
    740         }
    741 
    742         for(int clusterIndex = 0; clusterIndex < nclusters; clusterIndex++)
    743         {
    744             if(weights.at<double>(clusterIndex) <= minPosWeight)
    745             {
    746                 Mat clusterMean = means.row(clusterIndex);
    747                 means.row(minWeightClusterIndex).copyTo(clusterMean);
    748                 covs[minWeightClusterIndex].copyTo(covs[clusterIndex]);
    749                 covsEigenValues[minWeightClusterIndex].copyTo(covsEigenValues[clusterIndex]);
    750                 if(covMatType == COV_MAT_GENERIC)
    751                     covsRotateMats[minWeightClusterIndex].copyTo(covsRotateMats[clusterIndex]);
    752                 invCovsEigenValues[minWeightClusterIndex].copyTo(invCovsEigenValues[clusterIndex]);
    753             }
    754         }
    755 
    756         // Normalize weights
    757         weights /= trainSamples.rows;
    758     }
    759 
    760     void write_params(FileStorage& fs) const
    761     {
    762         fs << "nclusters" << nclusters;
    763         fs << "cov_mat_type" << (covMatType == COV_MAT_SPHERICAL ? String("spherical") :
    764                                  covMatType == COV_MAT_DIAGONAL ? String("diagonal") :
    765                                  covMatType == COV_MAT_GENERIC ? String("generic") :
    766                                  format("unknown_%d", covMatType));
    767         writeTermCrit(fs, termCrit);
    768     }
    769 
    770     void write(FileStorage& fs) const
    771     {
    772         fs << "training_params" << "{";
    773         write_params(fs);
    774         fs << "}";
    775         fs << "weights" << weights;
    776         fs << "means" << means;
    777 
    778         size_t i, n = covs.size();
    779 
    780         fs << "covs" << "[";
    781         for( i = 0; i < n; i++ )
    782             fs << covs[i];
    783         fs << "]";
    784     }
    785 
    786     void read_params(const FileNode& fn)
    787     {
    788         nclusters = (int)fn["nclusters"];
    789         String s = (String)fn["cov_mat_type"];
    790         covMatType = s == "spherical" ? COV_MAT_SPHERICAL :
    791                              s == "diagonal" ? COV_MAT_DIAGONAL :
    792                              s == "generic" ? COV_MAT_GENERIC : -1;
    793         CV_Assert(covMatType >= 0);
    794         termCrit = readTermCrit(fn);
    795     }
    796 
    797     void read(const FileNode& fn)
    798     {
    799         clear();
    800         read_params(fn["training_params"]);
    801 
    802         fn["weights"] >> weights;
    803         fn["means"] >> means;
    804 
    805         FileNode cfn = fn["covs"];
    806         FileNodeIterator cfn_it = cfn.begin();
    807         int i, n = (int)cfn.size();
    808         covs.resize(n);
    809 
    810         for( i = 0; i < n; i++, ++cfn_it )
    811             (*cfn_it) >> covs[i];
    812 
    813         decomposeCovs();
    814         computeLogWeightDivDet();
    815     }
    816 
    817     Mat getWeights() const { return weights; }
    818     Mat getMeans() const { return means; }
    819     void getCovs(std::vector<Mat>& _covs) const
    820     {
    821         _covs.resize(covs.size());
    822         std::copy(covs.begin(), covs.end(), _covs.begin());
    823     }
    824 
    825     // all inner matrices have type CV_64FC1
    826     Mat trainSamples;
    827     Mat trainProbs;
    828     Mat trainLogLikelihoods;
    829     Mat trainLabels;
    830 
    831     Mat weights;
    832     Mat means;
    833     std::vector<Mat> covs;
    834 
    835     std::vector<Mat> covsEigenValues;
    836     std::vector<Mat> covsRotateMats;
    837     std::vector<Mat> invCovsEigenValues;
    838     Mat logWeightDivDet;
    839 };
    840 
    841 Ptr<EM> EM::create()
    842 {
    843     return makePtr<EMImpl>();
    844 }
    845 
    846 }
    847 } // namespace cv
    848 
    849 /* End of file. */
    850