Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Third party copyrights are property of their respective owners.
     15 //
     16 // Redistribution and use in source and binary forms, with or without modification,
     17 // are permitted provided that the following conditions are met:
     18 //
     19 //   * Redistribution's of source code must retain the above copyright notice,
     20 //     this list of conditions and the following disclaimer.
     21 //
     22 //   * Redistribution's in binary form must reproduce the above copyright notice,
     23 //     this list of conditions and the following disclaimer in the documentation
     24 //     and/or other materials provided with the distribution.
     25 //
     26 //   * The name of Intel Corporation may not be used to endorse or promote products
     27 //     derived from this software without specific prior written permission.
     28 //
     29 // This software is provided by the copyright holders and contributors "as is" and
     30 // any express or implied warranties, including, but not limited to, the implied
     31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     32 // In no event shall the Intel Corporation or contributors be liable for any direct,
     33 // indirect, incidental, special, exemplary, or consequential damages
     34 // (including, but not limited to, procurement of substitute goods or services;
     35 // loss of use, data, or profits; or business interruption) however caused
     36 // and on any theory of liability, whether in contract, strict liability,
     37 // or tort (including negligence or otherwise) arising in any way out of
     38 // the use of this software, even if advised of the possibility of such damage.
     39 //
     40 //M*/
     41 
     42 #include "precomp.hpp"
     43 #include "gcgraph.hpp"
     44 #include <limits>
     45 
     46 using namespace cv;
     47 
     48 /*
     49 This is implementation of image segmentation algorithm GrabCut described in
     50 "GrabCut  Interactive Foreground Extraction using Iterated Graph Cuts".
     51 Carsten Rother, Vladimir Kolmogorov, Andrew Blake.
     52  */
     53 
     54 /*
     55  GMM - Gaussian Mixture Model
     56 */
     57 class GMM
     58 {
     59 public:
     60     static const int componentsCount = 5;
     61 
     62     GMM( Mat& _model );
     63     double operator()( const Vec3d color ) const;
     64     double operator()( int ci, const Vec3d color ) const;
     65     int whichComponent( const Vec3d color ) const;
     66 
     67     void initLearning();
     68     void addSample( int ci, const Vec3d color );
     69     void endLearning();
     70 
     71 private:
     72     void calcInverseCovAndDeterm( int ci );
     73     Mat model;
     74     double* coefs;
     75     double* mean;
     76     double* cov;
     77 
     78     double inverseCovs[componentsCount][3][3];
     79     double covDeterms[componentsCount];
     80 
     81     double sums[componentsCount][3];
     82     double prods[componentsCount][3][3];
     83     int sampleCounts[componentsCount];
     84     int totalSampleCount;
     85 };
     86 
     87 GMM::GMM( Mat& _model )
     88 {
     89     const int modelSize = 3/*mean*/ + 9/*covariance*/ + 1/*component weight*/;
     90     if( _model.empty() )
     91     {
     92         _model.create( 1, modelSize*componentsCount, CV_64FC1 );
     93         _model.setTo(Scalar(0));
     94     }
     95     else if( (_model.type() != CV_64FC1) || (_model.rows != 1) || (_model.cols != modelSize*componentsCount) )
     96         CV_Error( CV_StsBadArg, "_model must have CV_64FC1 type, rows == 1 and cols == 13*componentsCount" );
     97 
     98     model = _model;
     99 
    100     coefs = model.ptr<double>(0);
    101     mean = coefs + componentsCount;
    102     cov = mean + 3*componentsCount;
    103 
    104     for( int ci = 0; ci < componentsCount; ci++ )
    105         if( coefs[ci] > 0 )
    106              calcInverseCovAndDeterm( ci );
    107 }
    108 
    109 double GMM::operator()( const Vec3d color ) const
    110 {
    111     double res = 0;
    112     for( int ci = 0; ci < componentsCount; ci++ )
    113         res += coefs[ci] * (*this)(ci, color );
    114     return res;
    115 }
    116 
    117 double GMM::operator()( int ci, const Vec3d color ) const
    118 {
    119     double res = 0;
    120     if( coefs[ci] > 0 )
    121     {
    122         CV_Assert( covDeterms[ci] > std::numeric_limits<double>::epsilon() );
    123         Vec3d diff = color;
    124         double* m = mean + 3*ci;
    125         diff[0] -= m[0]; diff[1] -= m[1]; diff[2] -= m[2];
    126         double mult = diff[0]*(diff[0]*inverseCovs[ci][0][0] + diff[1]*inverseCovs[ci][1][0] + diff[2]*inverseCovs[ci][2][0])
    127                    + diff[1]*(diff[0]*inverseCovs[ci][0][1] + diff[1]*inverseCovs[ci][1][1] + diff[2]*inverseCovs[ci][2][1])
    128                    + diff[2]*(diff[0]*inverseCovs[ci][0][2] + diff[1]*inverseCovs[ci][1][2] + diff[2]*inverseCovs[ci][2][2]);
    129         res = 1.0f/sqrt(covDeterms[ci]) * exp(-0.5f*mult);
    130     }
    131     return res;
    132 }
    133 
    134 int GMM::whichComponent( const Vec3d color ) const
    135 {
    136     int k = 0;
    137     double max = 0;
    138 
    139     for( int ci = 0; ci < componentsCount; ci++ )
    140     {
    141         double p = (*this)( ci, color );
    142         if( p > max )
    143         {
    144             k = ci;
    145             max = p;
    146         }
    147     }
    148     return k;
    149 }
    150 
    151 void GMM::initLearning()
    152 {
    153     for( int ci = 0; ci < componentsCount; ci++)
    154     {
    155         sums[ci][0] = sums[ci][1] = sums[ci][2] = 0;
    156         prods[ci][0][0] = prods[ci][0][1] = prods[ci][0][2] = 0;
    157         prods[ci][1][0] = prods[ci][1][1] = prods[ci][1][2] = 0;
    158         prods[ci][2][0] = prods[ci][2][1] = prods[ci][2][2] = 0;
    159         sampleCounts[ci] = 0;
    160     }
    161     totalSampleCount = 0;
    162 }
    163 
    164 void GMM::addSample( int ci, const Vec3d color )
    165 {
    166     sums[ci][0] += color[0]; sums[ci][1] += color[1]; sums[ci][2] += color[2];
    167     prods[ci][0][0] += color[0]*color[0]; prods[ci][0][1] += color[0]*color[1]; prods[ci][0][2] += color[0]*color[2];
    168     prods[ci][1][0] += color[1]*color[0]; prods[ci][1][1] += color[1]*color[1]; prods[ci][1][2] += color[1]*color[2];
    169     prods[ci][2][0] += color[2]*color[0]; prods[ci][2][1] += color[2]*color[1]; prods[ci][2][2] += color[2]*color[2];
    170     sampleCounts[ci]++;
    171     totalSampleCount++;
    172 }
    173 
    174 void GMM::endLearning()
    175 {
    176     const double variance = 0.01;
    177     for( int ci = 0; ci < componentsCount; ci++ )
    178     {
    179         int n = sampleCounts[ci];
    180         if( n == 0 )
    181             coefs[ci] = 0;
    182         else
    183         {
    184             coefs[ci] = (double)n/totalSampleCount;
    185 
    186             double* m = mean + 3*ci;
    187             m[0] = sums[ci][0]/n; m[1] = sums[ci][1]/n; m[2] = sums[ci][2]/n;
    188 
    189             double* c = cov + 9*ci;
    190             c[0] = prods[ci][0][0]/n - m[0]*m[0]; c[1] = prods[ci][0][1]/n - m[0]*m[1]; c[2] = prods[ci][0][2]/n - m[0]*m[2];
    191             c[3] = prods[ci][1][0]/n - m[1]*m[0]; c[4] = prods[ci][1][1]/n - m[1]*m[1]; c[5] = prods[ci][1][2]/n - m[1]*m[2];
    192             c[6] = prods[ci][2][0]/n - m[2]*m[0]; c[7] = prods[ci][2][1]/n - m[2]*m[1]; c[8] = prods[ci][2][2]/n - m[2]*m[2];
    193 
    194             double dtrm = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) + c[2]*(c[3]*c[7]-c[4]*c[6]);
    195             if( dtrm <= std::numeric_limits<double>::epsilon() )
    196             {
    197                 // Adds the white noise to avoid singular covariance matrix.
    198                 c[0] += variance;
    199                 c[4] += variance;
    200                 c[8] += variance;
    201             }
    202 
    203             calcInverseCovAndDeterm(ci);
    204         }
    205     }
    206 }
    207 
    208 void GMM::calcInverseCovAndDeterm( int ci )
    209 {
    210     if( coefs[ci] > 0 )
    211     {
    212         double *c = cov + 9*ci;
    213         double dtrm =
    214               covDeterms[ci] = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) + c[2]*(c[3]*c[7]-c[4]*c[6]);
    215 
    216         CV_Assert( dtrm > std::numeric_limits<double>::epsilon() );
    217         inverseCovs[ci][0][0] =  (c[4]*c[8] - c[5]*c[7]) / dtrm;
    218         inverseCovs[ci][1][0] = -(c[3]*c[8] - c[5]*c[6]) / dtrm;
    219         inverseCovs[ci][2][0] =  (c[3]*c[7] - c[4]*c[6]) / dtrm;
    220         inverseCovs[ci][0][1] = -(c[1]*c[8] - c[2]*c[7]) / dtrm;
    221         inverseCovs[ci][1][1] =  (c[0]*c[8] - c[2]*c[6]) / dtrm;
    222         inverseCovs[ci][2][1] = -(c[0]*c[7] - c[1]*c[6]) / dtrm;
    223         inverseCovs[ci][0][2] =  (c[1]*c[5] - c[2]*c[4]) / dtrm;
    224         inverseCovs[ci][1][2] = -(c[0]*c[5] - c[2]*c[3]) / dtrm;
    225         inverseCovs[ci][2][2] =  (c[0]*c[4] - c[1]*c[3]) / dtrm;
    226     }
    227 }
    228 
    229 /*
    230   Calculate beta - parameter of GrabCut algorithm.
    231   beta = 1/(2*avg(sqr(||color[i] - color[j]||)))
    232 */
    233 static double calcBeta( const Mat& img )
    234 {
    235     double beta = 0;
    236     for( int y = 0; y < img.rows; y++ )
    237     {
    238         for( int x = 0; x < img.cols; x++ )
    239         {
    240             Vec3d color = img.at<Vec3b>(y,x);
    241             if( x>0 ) // left
    242             {
    243                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);
    244                 beta += diff.dot(diff);
    245             }
    246             if( y>0 && x>0 ) // upleft
    247             {
    248                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);
    249                 beta += diff.dot(diff);
    250             }
    251             if( y>0 ) // up
    252             {
    253                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);
    254                 beta += diff.dot(diff);
    255             }
    256             if( y>0 && x<img.cols-1) // upright
    257             {
    258                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);
    259                 beta += diff.dot(diff);
    260             }
    261         }
    262     }
    263     if( beta <= std::numeric_limits<double>::epsilon() )
    264         beta = 0;
    265     else
    266         beta = 1.f / (2 * beta/(4*img.cols*img.rows - 3*img.cols - 3*img.rows + 2) );
    267 
    268     return beta;
    269 }
    270 
    271 /*
    272   Calculate weights of noterminal vertices of graph.
    273   beta and gamma - parameters of GrabCut algorithm.
    274  */
    275 static void calcNWeights( const Mat& img, Mat& leftW, Mat& upleftW, Mat& upW, Mat& uprightW, double beta, double gamma )
    276 {
    277     const double gammaDivSqrt2 = gamma / std::sqrt(2.0f);
    278     leftW.create( img.rows, img.cols, CV_64FC1 );
    279     upleftW.create( img.rows, img.cols, CV_64FC1 );
    280     upW.create( img.rows, img.cols, CV_64FC1 );
    281     uprightW.create( img.rows, img.cols, CV_64FC1 );
    282     for( int y = 0; y < img.rows; y++ )
    283     {
    284         for( int x = 0; x < img.cols; x++ )
    285         {
    286             Vec3d color = img.at<Vec3b>(y,x);
    287             if( x-1>=0 ) // left
    288             {
    289                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1);
    290                 leftW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));
    291             }
    292             else
    293                 leftW.at<double>(y,x) = 0;
    294             if( x-1>=0 && y-1>=0 ) // upleft
    295             {
    296                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1);
    297                 upleftW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));
    298             }
    299             else
    300                 upleftW.at<double>(y,x) = 0;
    301             if( y-1>=0 ) // up
    302             {
    303                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x);
    304                 upW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff));
    305             }
    306             else
    307                 upW.at<double>(y,x) = 0;
    308             if( x+1<img.cols && y-1>=0 ) // upright
    309             {
    310                 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1);
    311                 uprightW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff));
    312             }
    313             else
    314                 uprightW.at<double>(y,x) = 0;
    315         }
    316     }
    317 }
    318 
    319 /*
    320   Check size, type and element values of mask matrix.
    321  */
    322 static void checkMask( const Mat& img, const Mat& mask )
    323 {
    324     if( mask.empty() )
    325         CV_Error( CV_StsBadArg, "mask is empty" );
    326     if( mask.type() != CV_8UC1 )
    327         CV_Error( CV_StsBadArg, "mask must have CV_8UC1 type" );
    328     if( mask.cols != img.cols || mask.rows != img.rows )
    329         CV_Error( CV_StsBadArg, "mask must have as many rows and cols as img" );
    330     for( int y = 0; y < mask.rows; y++ )
    331     {
    332         for( int x = 0; x < mask.cols; x++ )
    333         {
    334             uchar val = mask.at<uchar>(y,x);
    335             if( val!=GC_BGD && val!=GC_FGD && val!=GC_PR_BGD && val!=GC_PR_FGD )
    336                 CV_Error( CV_StsBadArg, "mask element value must be equel"
    337                     "GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD" );
    338         }
    339     }
    340 }
    341 
    342 /*
    343   Initialize mask using rectangular.
    344 */
    345 static void initMaskWithRect( Mat& mask, Size imgSize, Rect rect )
    346 {
    347     mask.create( imgSize, CV_8UC1 );
    348     mask.setTo( GC_BGD );
    349 
    350     rect.x = std::max(0, rect.x);
    351     rect.y = std::max(0, rect.y);
    352     rect.width = std::min(rect.width, imgSize.width-rect.x);
    353     rect.height = std::min(rect.height, imgSize.height-rect.y);
    354 
    355     (mask(rect)).setTo( Scalar(GC_PR_FGD) );
    356 }
    357 
    358 /*
    359   Initialize GMM background and foreground models using kmeans algorithm.
    360 */
    361 static void initGMMs( const Mat& img, const Mat& mask, GMM& bgdGMM, GMM& fgdGMM )
    362 {
    363     const int kMeansItCount = 10;
    364     const int kMeansType = KMEANS_PP_CENTERS;
    365 
    366     Mat bgdLabels, fgdLabels;
    367     std::vector<Vec3f> bgdSamples, fgdSamples;
    368     Point p;
    369     for( p.y = 0; p.y < img.rows; p.y++ )
    370     {
    371         for( p.x = 0; p.x < img.cols; p.x++ )
    372         {
    373             if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )
    374                 bgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );
    375             else // GC_FGD | GC_PR_FGD
    376                 fgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) );
    377         }
    378     }
    379     CV_Assert( !bgdSamples.empty() && !fgdSamples.empty() );
    380     Mat _bgdSamples( (int)bgdSamples.size(), 3, CV_32FC1, &bgdSamples[0][0] );
    381     kmeans( _bgdSamples, GMM::componentsCount, bgdLabels,
    382             TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );
    383     Mat _fgdSamples( (int)fgdSamples.size(), 3, CV_32FC1, &fgdSamples[0][0] );
    384     kmeans( _fgdSamples, GMM::componentsCount, fgdLabels,
    385             TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType );
    386 
    387     bgdGMM.initLearning();
    388     for( int i = 0; i < (int)bgdSamples.size(); i++ )
    389         bgdGMM.addSample( bgdLabels.at<int>(i,0), bgdSamples[i] );
    390     bgdGMM.endLearning();
    391 
    392     fgdGMM.initLearning();
    393     for( int i = 0; i < (int)fgdSamples.size(); i++ )
    394         fgdGMM.addSample( fgdLabels.at<int>(i,0), fgdSamples[i] );
    395     fgdGMM.endLearning();
    396 }
    397 
    398 /*
    399   Assign GMMs components for each pixel.
    400 */
    401 static void assignGMMsComponents( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, Mat& compIdxs )
    402 {
    403     Point p;
    404     for( p.y = 0; p.y < img.rows; p.y++ )
    405     {
    406         for( p.x = 0; p.x < img.cols; p.x++ )
    407         {
    408             Vec3d color = img.at<Vec3b>(p);
    409             compIdxs.at<int>(p) = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ?
    410                 bgdGMM.whichComponent(color) : fgdGMM.whichComponent(color);
    411         }
    412     }
    413 }
    414 
    415 /*
    416   Learn GMMs parameters.
    417 */
    418 static void learnGMMs( const Mat& img, const Mat& mask, const Mat& compIdxs, GMM& bgdGMM, GMM& fgdGMM )
    419 {
    420     bgdGMM.initLearning();
    421     fgdGMM.initLearning();
    422     Point p;
    423     for( int ci = 0; ci < GMM::componentsCount; ci++ )
    424     {
    425         for( p.y = 0; p.y < img.rows; p.y++ )
    426         {
    427             for( p.x = 0; p.x < img.cols; p.x++ )
    428             {
    429                 if( compIdxs.at<int>(p) == ci )
    430                 {
    431                     if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD )
    432                         bgdGMM.addSample( ci, img.at<Vec3b>(p) );
    433                     else
    434                         fgdGMM.addSample( ci, img.at<Vec3b>(p) );
    435                 }
    436             }
    437         }
    438     }
    439     bgdGMM.endLearning();
    440     fgdGMM.endLearning();
    441 }
    442 
    443 /*
    444   Construct GCGraph
    445 */
    446 static void constructGCGraph( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, double lambda,
    447                        const Mat& leftW, const Mat& upleftW, const Mat& upW, const Mat& uprightW,
    448                        GCGraph<double>& graph )
    449 {
    450     int vtxCount = img.cols*img.rows,
    451         edgeCount = 2*(4*img.cols*img.rows - 3*(img.cols + img.rows) + 2);
    452     graph.create(vtxCount, edgeCount);
    453     Point p;
    454     for( p.y = 0; p.y < img.rows; p.y++ )
    455     {
    456         for( p.x = 0; p.x < img.cols; p.x++)
    457         {
    458             // add node
    459             int vtxIdx = graph.addVtx();
    460             Vec3b color = img.at<Vec3b>(p);
    461 
    462             // set t-weights
    463             double fromSource, toSink;
    464             if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )
    465             {
    466                 fromSource = -log( bgdGMM(color) );
    467                 toSink = -log( fgdGMM(color) );
    468             }
    469             else if( mask.at<uchar>(p) == GC_BGD )
    470             {
    471                 fromSource = 0;
    472                 toSink = lambda;
    473             }
    474             else // GC_FGD
    475             {
    476                 fromSource = lambda;
    477                 toSink = 0;
    478             }
    479             graph.addTermWeights( vtxIdx, fromSource, toSink );
    480 
    481             // set n-weights
    482             if( p.x>0 )
    483             {
    484                 double w = leftW.at<double>(p);
    485                 graph.addEdges( vtxIdx, vtxIdx-1, w, w );
    486             }
    487             if( p.x>0 && p.y>0 )
    488             {
    489                 double w = upleftW.at<double>(p);
    490                 graph.addEdges( vtxIdx, vtxIdx-img.cols-1, w, w );
    491             }
    492             if( p.y>0 )
    493             {
    494                 double w = upW.at<double>(p);
    495                 graph.addEdges( vtxIdx, vtxIdx-img.cols, w, w );
    496             }
    497             if( p.x<img.cols-1 && p.y>0 )
    498             {
    499                 double w = uprightW.at<double>(p);
    500                 graph.addEdges( vtxIdx, vtxIdx-img.cols+1, w, w );
    501             }
    502         }
    503     }
    504 }
    505 
    506 /*
    507   Estimate segmentation using MaxFlow algorithm
    508 */
    509 static void estimateSegmentation( GCGraph<double>& graph, Mat& mask )
    510 {
    511     graph.maxFlow();
    512     Point p;
    513     for( p.y = 0; p.y < mask.rows; p.y++ )
    514     {
    515         for( p.x = 0; p.x < mask.cols; p.x++ )
    516         {
    517             if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD )
    518             {
    519                 if( graph.inSourceSegment( p.y*mask.cols+p.x /*vertex index*/ ) )
    520                     mask.at<uchar>(p) = GC_PR_FGD;
    521                 else
    522                     mask.at<uchar>(p) = GC_PR_BGD;
    523             }
    524         }
    525     }
    526 }
    527 
    528 void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect,
    529                   InputOutputArray _bgdModel, InputOutputArray _fgdModel,
    530                   int iterCount, int mode )
    531 {
    532     Mat img = _img.getMat();
    533     Mat& mask = _mask.getMatRef();
    534     Mat& bgdModel = _bgdModel.getMatRef();
    535     Mat& fgdModel = _fgdModel.getMatRef();
    536 
    537     if( img.empty() )
    538         CV_Error( CV_StsBadArg, "image is empty" );
    539     if( img.type() != CV_8UC3 )
    540         CV_Error( CV_StsBadArg, "image mush have CV_8UC3 type" );
    541 
    542     GMM bgdGMM( bgdModel ), fgdGMM( fgdModel );
    543     Mat compIdxs( img.size(), CV_32SC1 );
    544 
    545     if( mode == GC_INIT_WITH_RECT || mode == GC_INIT_WITH_MASK )
    546     {
    547         if( mode == GC_INIT_WITH_RECT )
    548             initMaskWithRect( mask, img.size(), rect );
    549         else // flag == GC_INIT_WITH_MASK
    550             checkMask( img, mask );
    551         initGMMs( img, mask, bgdGMM, fgdGMM );
    552     }
    553 
    554     if( iterCount <= 0)
    555         return;
    556 
    557     if( mode == GC_EVAL )
    558         checkMask( img, mask );
    559 
    560     const double gamma = 50;
    561     const double lambda = 9*gamma;
    562     const double beta = calcBeta( img );
    563 
    564     Mat leftW, upleftW, upW, uprightW;
    565     calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma );
    566 
    567     for( int i = 0; i < iterCount; i++ )
    568     {
    569         GCGraph<double> graph;
    570         assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs );
    571         learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM );
    572         constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph );
    573         estimateSegmentation( graph, mask );
    574     }
    575 }
    576