1 /*M/////////////////////////////////////////////////////////////////////////////////////// 2 // 3 // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING. 4 // 5 // By downloading, copying, installing or using the software you agree to this license. 6 // If you do not agree to this license, do not download, install, 7 // copy or use the software. 8 // 9 // 10 // Intel License Agreement 11 // For Open Source Computer Vision Library 12 // 13 // Copyright (C) 2000, Intel Corporation, all rights reserved. 14 // Third party copyrights are property of their respective owners. 15 // 16 // Redistribution and use in source and binary forms, with or without modification, 17 // are permitted provided that the following conditions are met: 18 // 19 // * Redistribution's of source code must retain the above copyright notice, 20 // this list of conditions and the following disclaimer. 21 // 22 // * Redistribution's in binary form must reproduce the above copyright notice, 23 // this list of conditions and the following disclaimer in the documentation 24 // and/or other materials provided with the distribution. 25 // 26 // * The name of Intel Corporation may not be used to endorse or promote products 27 // derived from this software without specific prior written permission. 28 // 29 // This software is provided by the copyright holders and contributors "as is" and 30 // any express or implied warranties, including, but not limited to, the implied 31 // warranties of merchantability and fitness for a particular purpose are disclaimed. 32 // In no event shall the Intel Corporation or contributors be liable for any direct, 33 // indirect, incidental, special, exemplary, or consequential damages 34 // (including, but not limited to, procurement of substitute goods or services; 35 // loss of use, data, or profits; or business interruption) however caused 36 // and on any theory of liability, whether in contract, strict liability, 37 // or tort (including negligence or otherwise) arising in any way out of 38 // the use of this software, even if advised of the possibility of such damage. 39 // 40 //M*/ 41 42 #include "precomp.hpp" 43 #include "gcgraph.hpp" 44 #include <limits> 45 46 using namespace cv; 47 48 /* 49 This is implementation of image segmentation algorithm GrabCut described in 50 "GrabCut Interactive Foreground Extraction using Iterated Graph Cuts". 51 Carsten Rother, Vladimir Kolmogorov, Andrew Blake. 52 */ 53 54 /* 55 GMM - Gaussian Mixture Model 56 */ 57 class GMM 58 { 59 public: 60 static const int componentsCount = 5; 61 62 GMM( Mat& _model ); 63 double operator()( const Vec3d color ) const; 64 double operator()( int ci, const Vec3d color ) const; 65 int whichComponent( const Vec3d color ) const; 66 67 void initLearning(); 68 void addSample( int ci, const Vec3d color ); 69 void endLearning(); 70 71 private: 72 void calcInverseCovAndDeterm( int ci ); 73 Mat model; 74 double* coefs; 75 double* mean; 76 double* cov; 77 78 double inverseCovs[componentsCount][3][3]; 79 double covDeterms[componentsCount]; 80 81 double sums[componentsCount][3]; 82 double prods[componentsCount][3][3]; 83 int sampleCounts[componentsCount]; 84 int totalSampleCount; 85 }; 86 87 GMM::GMM( Mat& _model ) 88 { 89 const int modelSize = 3/*mean*/ + 9/*covariance*/ + 1/*component weight*/; 90 if( _model.empty() ) 91 { 92 _model.create( 1, modelSize*componentsCount, CV_64FC1 ); 93 _model.setTo(Scalar(0)); 94 } 95 else if( (_model.type() != CV_64FC1) || (_model.rows != 1) || (_model.cols != modelSize*componentsCount) ) 96 CV_Error( CV_StsBadArg, "_model must have CV_64FC1 type, rows == 1 and cols == 13*componentsCount" ); 97 98 model = _model; 99 100 coefs = model.ptr<double>(0); 101 mean = coefs + componentsCount; 102 cov = mean + 3*componentsCount; 103 104 for( int ci = 0; ci < componentsCount; ci++ ) 105 if( coefs[ci] > 0 ) 106 calcInverseCovAndDeterm( ci ); 107 } 108 109 double GMM::operator()( const Vec3d color ) const 110 { 111 double res = 0; 112 for( int ci = 0; ci < componentsCount; ci++ ) 113 res += coefs[ci] * (*this)(ci, color ); 114 return res; 115 } 116 117 double GMM::operator()( int ci, const Vec3d color ) const 118 { 119 double res = 0; 120 if( coefs[ci] > 0 ) 121 { 122 CV_Assert( covDeterms[ci] > std::numeric_limits<double>::epsilon() ); 123 Vec3d diff = color; 124 double* m = mean + 3*ci; 125 diff[0] -= m[0]; diff[1] -= m[1]; diff[2] -= m[2]; 126 double mult = diff[0]*(diff[0]*inverseCovs[ci][0][0] + diff[1]*inverseCovs[ci][1][0] + diff[2]*inverseCovs[ci][2][0]) 127 + diff[1]*(diff[0]*inverseCovs[ci][0][1] + diff[1]*inverseCovs[ci][1][1] + diff[2]*inverseCovs[ci][2][1]) 128 + diff[2]*(diff[0]*inverseCovs[ci][0][2] + diff[1]*inverseCovs[ci][1][2] + diff[2]*inverseCovs[ci][2][2]); 129 res = 1.0f/sqrt(covDeterms[ci]) * exp(-0.5f*mult); 130 } 131 return res; 132 } 133 134 int GMM::whichComponent( const Vec3d color ) const 135 { 136 int k = 0; 137 double max = 0; 138 139 for( int ci = 0; ci < componentsCount; ci++ ) 140 { 141 double p = (*this)( ci, color ); 142 if( p > max ) 143 { 144 k = ci; 145 max = p; 146 } 147 } 148 return k; 149 } 150 151 void GMM::initLearning() 152 { 153 for( int ci = 0; ci < componentsCount; ci++) 154 { 155 sums[ci][0] = sums[ci][1] = sums[ci][2] = 0; 156 prods[ci][0][0] = prods[ci][0][1] = prods[ci][0][2] = 0; 157 prods[ci][1][0] = prods[ci][1][1] = prods[ci][1][2] = 0; 158 prods[ci][2][0] = prods[ci][2][1] = prods[ci][2][2] = 0; 159 sampleCounts[ci] = 0; 160 } 161 totalSampleCount = 0; 162 } 163 164 void GMM::addSample( int ci, const Vec3d color ) 165 { 166 sums[ci][0] += color[0]; sums[ci][1] += color[1]; sums[ci][2] += color[2]; 167 prods[ci][0][0] += color[0]*color[0]; prods[ci][0][1] += color[0]*color[1]; prods[ci][0][2] += color[0]*color[2]; 168 prods[ci][1][0] += color[1]*color[0]; prods[ci][1][1] += color[1]*color[1]; prods[ci][1][2] += color[1]*color[2]; 169 prods[ci][2][0] += color[2]*color[0]; prods[ci][2][1] += color[2]*color[1]; prods[ci][2][2] += color[2]*color[2]; 170 sampleCounts[ci]++; 171 totalSampleCount++; 172 } 173 174 void GMM::endLearning() 175 { 176 const double variance = 0.01; 177 for( int ci = 0; ci < componentsCount; ci++ ) 178 { 179 int n = sampleCounts[ci]; 180 if( n == 0 ) 181 coefs[ci] = 0; 182 else 183 { 184 coefs[ci] = (double)n/totalSampleCount; 185 186 double* m = mean + 3*ci; 187 m[0] = sums[ci][0]/n; m[1] = sums[ci][1]/n; m[2] = sums[ci][2]/n; 188 189 double* c = cov + 9*ci; 190 c[0] = prods[ci][0][0]/n - m[0]*m[0]; c[1] = prods[ci][0][1]/n - m[0]*m[1]; c[2] = prods[ci][0][2]/n - m[0]*m[2]; 191 c[3] = prods[ci][1][0]/n - m[1]*m[0]; c[4] = prods[ci][1][1]/n - m[1]*m[1]; c[5] = prods[ci][1][2]/n - m[1]*m[2]; 192 c[6] = prods[ci][2][0]/n - m[2]*m[0]; c[7] = prods[ci][2][1]/n - m[2]*m[1]; c[8] = prods[ci][2][2]/n - m[2]*m[2]; 193 194 double dtrm = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) + c[2]*(c[3]*c[7]-c[4]*c[6]); 195 if( dtrm <= std::numeric_limits<double>::epsilon() ) 196 { 197 // Adds the white noise to avoid singular covariance matrix. 198 c[0] += variance; 199 c[4] += variance; 200 c[8] += variance; 201 } 202 203 calcInverseCovAndDeterm(ci); 204 } 205 } 206 } 207 208 void GMM::calcInverseCovAndDeterm( int ci ) 209 { 210 if( coefs[ci] > 0 ) 211 { 212 double *c = cov + 9*ci; 213 double dtrm = 214 covDeterms[ci] = c[0]*(c[4]*c[8]-c[5]*c[7]) - c[1]*(c[3]*c[8]-c[5]*c[6]) + c[2]*(c[3]*c[7]-c[4]*c[6]); 215 216 CV_Assert( dtrm > std::numeric_limits<double>::epsilon() ); 217 inverseCovs[ci][0][0] = (c[4]*c[8] - c[5]*c[7]) / dtrm; 218 inverseCovs[ci][1][0] = -(c[3]*c[8] - c[5]*c[6]) / dtrm; 219 inverseCovs[ci][2][0] = (c[3]*c[7] - c[4]*c[6]) / dtrm; 220 inverseCovs[ci][0][1] = -(c[1]*c[8] - c[2]*c[7]) / dtrm; 221 inverseCovs[ci][1][1] = (c[0]*c[8] - c[2]*c[6]) / dtrm; 222 inverseCovs[ci][2][1] = -(c[0]*c[7] - c[1]*c[6]) / dtrm; 223 inverseCovs[ci][0][2] = (c[1]*c[5] - c[2]*c[4]) / dtrm; 224 inverseCovs[ci][1][2] = -(c[0]*c[5] - c[2]*c[3]) / dtrm; 225 inverseCovs[ci][2][2] = (c[0]*c[4] - c[1]*c[3]) / dtrm; 226 } 227 } 228 229 /* 230 Calculate beta - parameter of GrabCut algorithm. 231 beta = 1/(2*avg(sqr(||color[i] - color[j]||))) 232 */ 233 static double calcBeta( const Mat& img ) 234 { 235 double beta = 0; 236 for( int y = 0; y < img.rows; y++ ) 237 { 238 for( int x = 0; x < img.cols; x++ ) 239 { 240 Vec3d color = img.at<Vec3b>(y,x); 241 if( x>0 ) // left 242 { 243 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1); 244 beta += diff.dot(diff); 245 } 246 if( y>0 && x>0 ) // upleft 247 { 248 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1); 249 beta += diff.dot(diff); 250 } 251 if( y>0 ) // up 252 { 253 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x); 254 beta += diff.dot(diff); 255 } 256 if( y>0 && x<img.cols-1) // upright 257 { 258 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1); 259 beta += diff.dot(diff); 260 } 261 } 262 } 263 if( beta <= std::numeric_limits<double>::epsilon() ) 264 beta = 0; 265 else 266 beta = 1.f / (2 * beta/(4*img.cols*img.rows - 3*img.cols - 3*img.rows + 2) ); 267 268 return beta; 269 } 270 271 /* 272 Calculate weights of noterminal vertices of graph. 273 beta and gamma - parameters of GrabCut algorithm. 274 */ 275 static void calcNWeights( const Mat& img, Mat& leftW, Mat& upleftW, Mat& upW, Mat& uprightW, double beta, double gamma ) 276 { 277 const double gammaDivSqrt2 = gamma / std::sqrt(2.0f); 278 leftW.create( img.rows, img.cols, CV_64FC1 ); 279 upleftW.create( img.rows, img.cols, CV_64FC1 ); 280 upW.create( img.rows, img.cols, CV_64FC1 ); 281 uprightW.create( img.rows, img.cols, CV_64FC1 ); 282 for( int y = 0; y < img.rows; y++ ) 283 { 284 for( int x = 0; x < img.cols; x++ ) 285 { 286 Vec3d color = img.at<Vec3b>(y,x); 287 if( x-1>=0 ) // left 288 { 289 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y,x-1); 290 leftW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff)); 291 } 292 else 293 leftW.at<double>(y,x) = 0; 294 if( x-1>=0 && y-1>=0 ) // upleft 295 { 296 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x-1); 297 upleftW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff)); 298 } 299 else 300 upleftW.at<double>(y,x) = 0; 301 if( y-1>=0 ) // up 302 { 303 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x); 304 upW.at<double>(y,x) = gamma * exp(-beta*diff.dot(diff)); 305 } 306 else 307 upW.at<double>(y,x) = 0; 308 if( x+1<img.cols && y-1>=0 ) // upright 309 { 310 Vec3d diff = color - (Vec3d)img.at<Vec3b>(y-1,x+1); 311 uprightW.at<double>(y,x) = gammaDivSqrt2 * exp(-beta*diff.dot(diff)); 312 } 313 else 314 uprightW.at<double>(y,x) = 0; 315 } 316 } 317 } 318 319 /* 320 Check size, type and element values of mask matrix. 321 */ 322 static void checkMask( const Mat& img, const Mat& mask ) 323 { 324 if( mask.empty() ) 325 CV_Error( CV_StsBadArg, "mask is empty" ); 326 if( mask.type() != CV_8UC1 ) 327 CV_Error( CV_StsBadArg, "mask must have CV_8UC1 type" ); 328 if( mask.cols != img.cols || mask.rows != img.rows ) 329 CV_Error( CV_StsBadArg, "mask must have as many rows and cols as img" ); 330 for( int y = 0; y < mask.rows; y++ ) 331 { 332 for( int x = 0; x < mask.cols; x++ ) 333 { 334 uchar val = mask.at<uchar>(y,x); 335 if( val!=GC_BGD && val!=GC_FGD && val!=GC_PR_BGD && val!=GC_PR_FGD ) 336 CV_Error( CV_StsBadArg, "mask element value must be equel" 337 "GC_BGD or GC_FGD or GC_PR_BGD or GC_PR_FGD" ); 338 } 339 } 340 } 341 342 /* 343 Initialize mask using rectangular. 344 */ 345 static void initMaskWithRect( Mat& mask, Size imgSize, Rect rect ) 346 { 347 mask.create( imgSize, CV_8UC1 ); 348 mask.setTo( GC_BGD ); 349 350 rect.x = std::max(0, rect.x); 351 rect.y = std::max(0, rect.y); 352 rect.width = std::min(rect.width, imgSize.width-rect.x); 353 rect.height = std::min(rect.height, imgSize.height-rect.y); 354 355 (mask(rect)).setTo( Scalar(GC_PR_FGD) ); 356 } 357 358 /* 359 Initialize GMM background and foreground models using kmeans algorithm. 360 */ 361 static void initGMMs( const Mat& img, const Mat& mask, GMM& bgdGMM, GMM& fgdGMM ) 362 { 363 const int kMeansItCount = 10; 364 const int kMeansType = KMEANS_PP_CENTERS; 365 366 Mat bgdLabels, fgdLabels; 367 std::vector<Vec3f> bgdSamples, fgdSamples; 368 Point p; 369 for( p.y = 0; p.y < img.rows; p.y++ ) 370 { 371 for( p.x = 0; p.x < img.cols; p.x++ ) 372 { 373 if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ) 374 bgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) ); 375 else // GC_FGD | GC_PR_FGD 376 fgdSamples.push_back( (Vec3f)img.at<Vec3b>(p) ); 377 } 378 } 379 CV_Assert( !bgdSamples.empty() && !fgdSamples.empty() ); 380 Mat _bgdSamples( (int)bgdSamples.size(), 3, CV_32FC1, &bgdSamples[0][0] ); 381 kmeans( _bgdSamples, GMM::componentsCount, bgdLabels, 382 TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType ); 383 Mat _fgdSamples( (int)fgdSamples.size(), 3, CV_32FC1, &fgdSamples[0][0] ); 384 kmeans( _fgdSamples, GMM::componentsCount, fgdLabels, 385 TermCriteria( CV_TERMCRIT_ITER, kMeansItCount, 0.0), 0, kMeansType ); 386 387 bgdGMM.initLearning(); 388 for( int i = 0; i < (int)bgdSamples.size(); i++ ) 389 bgdGMM.addSample( bgdLabels.at<int>(i,0), bgdSamples[i] ); 390 bgdGMM.endLearning(); 391 392 fgdGMM.initLearning(); 393 for( int i = 0; i < (int)fgdSamples.size(); i++ ) 394 fgdGMM.addSample( fgdLabels.at<int>(i,0), fgdSamples[i] ); 395 fgdGMM.endLearning(); 396 } 397 398 /* 399 Assign GMMs components for each pixel. 400 */ 401 static void assignGMMsComponents( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, Mat& compIdxs ) 402 { 403 Point p; 404 for( p.y = 0; p.y < img.rows; p.y++ ) 405 { 406 for( p.x = 0; p.x < img.cols; p.x++ ) 407 { 408 Vec3d color = img.at<Vec3b>(p); 409 compIdxs.at<int>(p) = mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ? 410 bgdGMM.whichComponent(color) : fgdGMM.whichComponent(color); 411 } 412 } 413 } 414 415 /* 416 Learn GMMs parameters. 417 */ 418 static void learnGMMs( const Mat& img, const Mat& mask, const Mat& compIdxs, GMM& bgdGMM, GMM& fgdGMM ) 419 { 420 bgdGMM.initLearning(); 421 fgdGMM.initLearning(); 422 Point p; 423 for( int ci = 0; ci < GMM::componentsCount; ci++ ) 424 { 425 for( p.y = 0; p.y < img.rows; p.y++ ) 426 { 427 for( p.x = 0; p.x < img.cols; p.x++ ) 428 { 429 if( compIdxs.at<int>(p) == ci ) 430 { 431 if( mask.at<uchar>(p) == GC_BGD || mask.at<uchar>(p) == GC_PR_BGD ) 432 bgdGMM.addSample( ci, img.at<Vec3b>(p) ); 433 else 434 fgdGMM.addSample( ci, img.at<Vec3b>(p) ); 435 } 436 } 437 } 438 } 439 bgdGMM.endLearning(); 440 fgdGMM.endLearning(); 441 } 442 443 /* 444 Construct GCGraph 445 */ 446 static void constructGCGraph( const Mat& img, const Mat& mask, const GMM& bgdGMM, const GMM& fgdGMM, double lambda, 447 const Mat& leftW, const Mat& upleftW, const Mat& upW, const Mat& uprightW, 448 GCGraph<double>& graph ) 449 { 450 int vtxCount = img.cols*img.rows, 451 edgeCount = 2*(4*img.cols*img.rows - 3*(img.cols + img.rows) + 2); 452 graph.create(vtxCount, edgeCount); 453 Point p; 454 for( p.y = 0; p.y < img.rows; p.y++ ) 455 { 456 for( p.x = 0; p.x < img.cols; p.x++) 457 { 458 // add node 459 int vtxIdx = graph.addVtx(); 460 Vec3b color = img.at<Vec3b>(p); 461 462 // set t-weights 463 double fromSource, toSink; 464 if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD ) 465 { 466 fromSource = -log( bgdGMM(color) ); 467 toSink = -log( fgdGMM(color) ); 468 } 469 else if( mask.at<uchar>(p) == GC_BGD ) 470 { 471 fromSource = 0; 472 toSink = lambda; 473 } 474 else // GC_FGD 475 { 476 fromSource = lambda; 477 toSink = 0; 478 } 479 graph.addTermWeights( vtxIdx, fromSource, toSink ); 480 481 // set n-weights 482 if( p.x>0 ) 483 { 484 double w = leftW.at<double>(p); 485 graph.addEdges( vtxIdx, vtxIdx-1, w, w ); 486 } 487 if( p.x>0 && p.y>0 ) 488 { 489 double w = upleftW.at<double>(p); 490 graph.addEdges( vtxIdx, vtxIdx-img.cols-1, w, w ); 491 } 492 if( p.y>0 ) 493 { 494 double w = upW.at<double>(p); 495 graph.addEdges( vtxIdx, vtxIdx-img.cols, w, w ); 496 } 497 if( p.x<img.cols-1 && p.y>0 ) 498 { 499 double w = uprightW.at<double>(p); 500 graph.addEdges( vtxIdx, vtxIdx-img.cols+1, w, w ); 501 } 502 } 503 } 504 } 505 506 /* 507 Estimate segmentation using MaxFlow algorithm 508 */ 509 static void estimateSegmentation( GCGraph<double>& graph, Mat& mask ) 510 { 511 graph.maxFlow(); 512 Point p; 513 for( p.y = 0; p.y < mask.rows; p.y++ ) 514 { 515 for( p.x = 0; p.x < mask.cols; p.x++ ) 516 { 517 if( mask.at<uchar>(p) == GC_PR_BGD || mask.at<uchar>(p) == GC_PR_FGD ) 518 { 519 if( graph.inSourceSegment( p.y*mask.cols+p.x /*vertex index*/ ) ) 520 mask.at<uchar>(p) = GC_PR_FGD; 521 else 522 mask.at<uchar>(p) = GC_PR_BGD; 523 } 524 } 525 } 526 } 527 528 void cv::grabCut( InputArray _img, InputOutputArray _mask, Rect rect, 529 InputOutputArray _bgdModel, InputOutputArray _fgdModel, 530 int iterCount, int mode ) 531 { 532 Mat img = _img.getMat(); 533 Mat& mask = _mask.getMatRef(); 534 Mat& bgdModel = _bgdModel.getMatRef(); 535 Mat& fgdModel = _fgdModel.getMatRef(); 536 537 if( img.empty() ) 538 CV_Error( CV_StsBadArg, "image is empty" ); 539 if( img.type() != CV_8UC3 ) 540 CV_Error( CV_StsBadArg, "image mush have CV_8UC3 type" ); 541 542 GMM bgdGMM( bgdModel ), fgdGMM( fgdModel ); 543 Mat compIdxs( img.size(), CV_32SC1 ); 544 545 if( mode == GC_INIT_WITH_RECT || mode == GC_INIT_WITH_MASK ) 546 { 547 if( mode == GC_INIT_WITH_RECT ) 548 initMaskWithRect( mask, img.size(), rect ); 549 else // flag == GC_INIT_WITH_MASK 550 checkMask( img, mask ); 551 initGMMs( img, mask, bgdGMM, fgdGMM ); 552 } 553 554 if( iterCount <= 0) 555 return; 556 557 if( mode == GC_EVAL ) 558 checkMask( img, mask ); 559 560 const double gamma = 50; 561 const double lambda = 9*gamma; 562 const double beta = calcBeta( img ); 563 564 Mat leftW, upleftW, upW, uprightW; 565 calcNWeights( img, leftW, upleftW, upW, uprightW, beta, gamma ); 566 567 for( int i = 0; i < iterCount; i++ ) 568 { 569 GCGraph<double> graph; 570 assignGMMsComponents( img, mask, bgdGMM, fgdGMM, compIdxs ); 571 learnGMMs( img, mask, compIdxs, bgdGMM, fgdGMM ); 572 constructGCGraph(img, mask, bgdGMM, fgdGMM, lambda, leftW, upleftW, upW, uprightW, graph ); 573 estimateSegmentation( graph, mask ); 574 } 575 } 576