Home | History | Annotate | Download | only in cpp
      1 #include "opencv2/imgcodecs.hpp"
      2 #include "opencv2/highgui/highgui.hpp"
      3 #include "opencv2/imgproc/imgproc.hpp"
      4 
      5 #include <iostream>
      6 
      7 using namespace std;
      8 using namespace cv;
      9 
     10 static void help()
     11 {
     12     cout << "\nThis program demonstrates GrabCut segmentation -- select an object in a region\n"
     13             "and then grabcut will attempt to segment it out.\n"
     14             "Call:\n"
     15             "./grabcut <image_name>\n"
     16         "\nSelect a rectangular area around the object you want to segment\n" <<
     17         "\nHot keys: \n"
     18         "\tESC - quit the program\n"
     19         "\tr - restore the original image\n"
     20         "\tn - next iteration\n"
     21         "\n"
     22         "\tleft mouse button - set rectangle\n"
     23         "\n"
     24         "\tCTRL+left mouse button - set GC_BGD pixels\n"
     25         "\tSHIFT+left mouse button - set GC_FGD pixels\n"
     26         "\n"
     27         "\tCTRL+right mouse button - set GC_PR_BGD pixels\n"
     28         "\tSHIFT+right mouse button - set GC_PR_FGD pixels\n" << endl;
     29 }
     30 
     31 const Scalar RED = Scalar(0,0,255);
     32 const Scalar PINK = Scalar(230,130,255);
     33 const Scalar BLUE = Scalar(255,0,0);
     34 const Scalar LIGHTBLUE = Scalar(255,255,160);
     35 const Scalar GREEN = Scalar(0,255,0);
     36 
     37 const int BGD_KEY = EVENT_FLAG_CTRLKEY;
     38 const int FGD_KEY = EVENT_FLAG_SHIFTKEY;
     39 
     40 static void getBinMask( const Mat& comMask, Mat& binMask )
     41 {
     42     if( comMask.empty() || comMask.type()!=CV_8UC1 )
     43         CV_Error( Error::StsBadArg, "comMask is empty or has incorrect type (not CV_8UC1)" );
     44     if( binMask.empty() || binMask.rows!=comMask.rows || binMask.cols!=comMask.cols )
     45         binMask.create( comMask.size(), CV_8UC1 );
     46     binMask = comMask & 1;
     47 }
     48 
     49 class GCApplication
     50 {
     51 public:
     52     enum{ NOT_SET = 0, IN_PROCESS = 1, SET = 2 };
     53     static const int radius = 2;
     54     static const int thickness = -1;
     55 
     56     void reset();
     57     void setImageAndWinName( const Mat& _image, const string& _winName );
     58     void showImage() const;
     59     void mouseClick( int event, int x, int y, int flags, void* param );
     60     int nextIter();
     61     int getIterCount() const { return iterCount; }
     62 private:
     63     void setRectInMask();
     64     void setLblsInMask( int flags, Point p, bool isPr );
     65 
     66     const string* winName;
     67     const Mat* image;
     68     Mat mask;
     69     Mat bgdModel, fgdModel;
     70 
     71     uchar rectState, lblsState, prLblsState;
     72     bool isInitialized;
     73 
     74     Rect rect;
     75     vector<Point> fgdPxls, bgdPxls, prFgdPxls, prBgdPxls;
     76     int iterCount;
     77 };
     78 
     79 void GCApplication::reset()
     80 {
     81     if( !mask.empty() )
     82         mask.setTo(Scalar::all(GC_BGD));
     83     bgdPxls.clear(); fgdPxls.clear();
     84     prBgdPxls.clear();  prFgdPxls.clear();
     85 
     86     isInitialized = false;
     87     rectState = NOT_SET;
     88     lblsState = NOT_SET;
     89     prLblsState = NOT_SET;
     90     iterCount = 0;
     91 }
     92 
     93 void GCApplication::setImageAndWinName( const Mat& _image, const string& _winName  )
     94 {
     95     if( _image.empty() || _winName.empty() )
     96         return;
     97     image = &_image;
     98     winName = &_winName;
     99     mask.create( image->size(), CV_8UC1);
    100     reset();
    101 }
    102 
    103 void GCApplication::showImage() const
    104 {
    105     if( image->empty() || winName->empty() )
    106         return;
    107 
    108     Mat res;
    109     Mat binMask;
    110     if( !isInitialized )
    111         image->copyTo( res );
    112     else
    113     {
    114         getBinMask( mask, binMask );
    115         image->copyTo( res, binMask );
    116     }
    117 
    118     vector<Point>::const_iterator it;
    119     for( it = bgdPxls.begin(); it != bgdPxls.end(); ++it )
    120         circle( res, *it, radius, BLUE, thickness );
    121     for( it = fgdPxls.begin(); it != fgdPxls.end(); ++it )
    122         circle( res, *it, radius, RED, thickness );
    123     for( it = prBgdPxls.begin(); it != prBgdPxls.end(); ++it )
    124         circle( res, *it, radius, LIGHTBLUE, thickness );
    125     for( it = prFgdPxls.begin(); it != prFgdPxls.end(); ++it )
    126         circle( res, *it, radius, PINK, thickness );
    127 
    128     if( rectState == IN_PROCESS || rectState == SET )
    129         rectangle( res, Point( rect.x, rect.y ), Point(rect.x + rect.width, rect.y + rect.height ), GREEN, 2);
    130 
    131     imshow( *winName, res );
    132 }
    133 
    134 void GCApplication::setRectInMask()
    135 {
    136     CV_Assert( !mask.empty() );
    137     mask.setTo( GC_BGD );
    138     rect.x = max(0, rect.x);
    139     rect.y = max(0, rect.y);
    140     rect.width = min(rect.width, image->cols-rect.x);
    141     rect.height = min(rect.height, image->rows-rect.y);
    142     (mask(rect)).setTo( Scalar(GC_PR_FGD) );
    143 }
    144 
    145 void GCApplication::setLblsInMask( int flags, Point p, bool isPr )
    146 {
    147     vector<Point> *bpxls, *fpxls;
    148     uchar bvalue, fvalue;
    149     if( !isPr )
    150     {
    151         bpxls = &bgdPxls;
    152         fpxls = &fgdPxls;
    153         bvalue = GC_BGD;
    154         fvalue = GC_FGD;
    155     }
    156     else
    157     {
    158         bpxls = &prBgdPxls;
    159         fpxls = &prFgdPxls;
    160         bvalue = GC_PR_BGD;
    161         fvalue = GC_PR_FGD;
    162     }
    163     if( flags & BGD_KEY )
    164     {
    165         bpxls->push_back(p);
    166         circle( mask, p, radius, bvalue, thickness );
    167     }
    168     if( flags & FGD_KEY )
    169     {
    170         fpxls->push_back(p);
    171         circle( mask, p, radius, fvalue, thickness );
    172     }
    173 }
    174 
    175 void GCApplication::mouseClick( int event, int x, int y, int flags, void* )
    176 {
    177     // TODO add bad args check
    178     switch( event )
    179     {
    180     case EVENT_LBUTTONDOWN: // set rect or GC_BGD(GC_FGD) labels
    181         {
    182             bool isb = (flags & BGD_KEY) != 0,
    183                  isf = (flags & FGD_KEY) != 0;
    184             if( rectState == NOT_SET && !isb && !isf )
    185             {
    186                 rectState = IN_PROCESS;
    187                 rect = Rect( x, y, 1, 1 );
    188             }
    189             if ( (isb || isf) && rectState == SET )
    190                 lblsState = IN_PROCESS;
    191         }
    192         break;
    193     case EVENT_RBUTTONDOWN: // set GC_PR_BGD(GC_PR_FGD) labels
    194         {
    195             bool isb = (flags & BGD_KEY) != 0,
    196                  isf = (flags & FGD_KEY) != 0;
    197             if ( (isb || isf) && rectState == SET )
    198                 prLblsState = IN_PROCESS;
    199         }
    200         break;
    201     case EVENT_LBUTTONUP:
    202         if( rectState == IN_PROCESS )
    203         {
    204             rect = Rect( Point(rect.x, rect.y), Point(x,y) );
    205             rectState = SET;
    206             setRectInMask();
    207             CV_Assert( bgdPxls.empty() && fgdPxls.empty() && prBgdPxls.empty() && prFgdPxls.empty() );
    208             showImage();
    209         }
    210         if( lblsState == IN_PROCESS )
    211         {
    212             setLblsInMask(flags, Point(x,y), false);
    213             lblsState = SET;
    214             showImage();
    215         }
    216         break;
    217     case EVENT_RBUTTONUP:
    218         if( prLblsState == IN_PROCESS )
    219         {
    220             setLblsInMask(flags, Point(x,y), true);
    221             prLblsState = SET;
    222             showImage();
    223         }
    224         break;
    225     case EVENT_MOUSEMOVE:
    226         if( rectState == IN_PROCESS )
    227         {
    228             rect = Rect( Point(rect.x, rect.y), Point(x,y) );
    229             CV_Assert( bgdPxls.empty() && fgdPxls.empty() && prBgdPxls.empty() && prFgdPxls.empty() );
    230             showImage();
    231         }
    232         else if( lblsState == IN_PROCESS )
    233         {
    234             setLblsInMask(flags, Point(x,y), false);
    235             showImage();
    236         }
    237         else if( prLblsState == IN_PROCESS )
    238         {
    239             setLblsInMask(flags, Point(x,y), true);
    240             showImage();
    241         }
    242         break;
    243     }
    244 }
    245 
    246 int GCApplication::nextIter()
    247 {
    248     if( isInitialized )
    249         grabCut( *image, mask, rect, bgdModel, fgdModel, 1 );
    250     else
    251     {
    252         if( rectState != SET )
    253             return iterCount;
    254 
    255         if( lblsState == SET || prLblsState == SET )
    256             grabCut( *image, mask, rect, bgdModel, fgdModel, 1, GC_INIT_WITH_MASK );
    257         else
    258             grabCut( *image, mask, rect, bgdModel, fgdModel, 1, GC_INIT_WITH_RECT );
    259 
    260         isInitialized = true;
    261     }
    262     iterCount++;
    263 
    264     bgdPxls.clear(); fgdPxls.clear();
    265     prBgdPxls.clear(); prFgdPxls.clear();
    266 
    267     return iterCount;
    268 }
    269 
    270 GCApplication gcapp;
    271 
    272 static void on_mouse( int event, int x, int y, int flags, void* param )
    273 {
    274     gcapp.mouseClick( event, x, y, flags, param );
    275 }
    276 
    277 int main( int argc, char** argv )
    278 {
    279     if( argc!=2 )
    280     {
    281         help();
    282         return 1;
    283     }
    284     string filename = argv[1];
    285     if( filename.empty() )
    286     {
    287         cout << "\nDurn, couldn't read in " << argv[1] << endl;
    288         return 1;
    289     }
    290     Mat image = imread( filename, 1 );
    291     if( image.empty() )
    292     {
    293         cout << "\n Durn, couldn't read image filename " << filename << endl;
    294         return 1;
    295     }
    296 
    297     help();
    298 
    299     const string winName = "image";
    300     namedWindow( winName, WINDOW_AUTOSIZE );
    301     setMouseCallback( winName, on_mouse, 0 );
    302 
    303     gcapp.setImageAndWinName( image, winName );
    304     gcapp.showImage();
    305 
    306     for(;;)
    307     {
    308         int c = waitKey(0);
    309         switch( (char) c )
    310         {
    311         case '\x1b':
    312             cout << "Exiting ..." << endl;
    313             goto exit_main;
    314         case 'r':
    315             cout << endl;
    316             gcapp.reset();
    317             gcapp.showImage();
    318             break;
    319         case 'n':
    320             int iterCount = gcapp.getIterCount();
    321             cout << "<" << iterCount << "... ";
    322             int newIterCount = gcapp.nextIter();
    323             if( newIterCount > iterCount )
    324             {
    325                 gcapp.showImage();
    326                 cout << iterCount << ">" << endl;
    327             }
    328             else
    329                 cout << "rect must be determined>" << endl;
    330             break;
    331         }
    332     }
    333 
    334 exit_main:
    335     destroyWindow( winName );
    336     return 0;
    337 }
    338