Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Third party copyrights are property of their respective owners.
     15 //
     16 // Redistribution and use in source and binary forms, with or without modification,
     17 // are permitted provided that the following conditions are met:
     18 //
     19 //   * Redistribution's of source code must retain the above copyright notice,
     20 //     this list of conditions and the following disclaimer.
     21 //
     22 //   * Redistribution's in binary form must reproduce the above copyright notice,
     23 //     this list of conditions and the following disclaimer in the documentation
     24 //     and/or other materials provided with the distribution.
     25 //
     26 //   * The name of Intel Corporation may not be used to endorse or promote products
     27 //     derived from this software without specific prior written permission.
     28 //
     29 // This software is provided by the copyright holders and contributors "as is" and
     30 // any express or implied warranties, including, but not limited to, the implied
     31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     32 // In no event shall the Intel Corporation or contributors be liable for any direct,
     33 // indirect, incidental, special, exemplary, or consequential damages
     34 // (including, but not limited to, procurement of substitute goods or services;
     35 // loss of use, data, or profits; or business interruption) however caused
     36 // and on any theory of liability, whether in contract, strict liability,
     37 // or tort (including negligence or otherwise) arising in any way out of
     38 // the use of this software, even if advised of the possibility of such damage.
     39 //
     40 //M*/
     41 
     42 /*
     43     Partially based on Yossi Rubner code:
     44     =========================================================================
     45     emd.c
     46 
     47     Last update: 3/14/98
     48 
     49     An implementation of the Earth Movers Distance.
     50     Based of the solution for the Transportation problem as described in
     51     "Introduction to Mathematical Programming" by F. S. Hillier and
     52     G. J. Lieberman, McGraw-Hill, 1990.
     53 
     54     Copyright (C) 1998 Yossi Rubner
     55     Computer Science Department, Stanford University
     56     E-Mail: rubner (at) cs.stanford.edu   URL: http://vision.stanford.edu/~rubner
     57     ==========================================================================
     58 */
     59 #include "precomp.hpp"
     60 
     61 #define MAX_ITERATIONS 500
     62 #define CV_EMD_INF   ((float)1e20)
     63 #define CV_EMD_EPS   ((float)1e-5)
     64 
     65 /* CvNode1D is used for lists, representing 1D sparse array */
     66 typedef struct CvNode1D
     67 {
     68     float val;
     69     struct CvNode1D *next;
     70 }
     71 CvNode1D;
     72 
     73 /* CvNode2D is used for lists, representing 2D sparse matrix */
     74 typedef struct CvNode2D
     75 {
     76     float val;
     77     struct CvNode2D *next[2];  /* next row & next column */
     78     int i, j;
     79 }
     80 CvNode2D;
     81 
     82 
     83 typedef struct CvEMDState
     84 {
     85     int ssize, dsize;
     86 
     87     float **cost;
     88     CvNode2D *_x;
     89     CvNode2D *end_x;
     90     CvNode2D *enter_x;
     91     char **is_x;
     92 
     93     CvNode2D **rows_x;
     94     CvNode2D **cols_x;
     95 
     96     CvNode1D *u;
     97     CvNode1D *v;
     98 
     99     int* idx1;
    100     int* idx2;
    101 
    102     /* find_loop buffers */
    103     CvNode2D **loop;
    104     char *is_used;
    105 
    106     /* russel buffers */
    107     float *s;
    108     float *d;
    109     float **delta;
    110 
    111     float weight, max_cost;
    112     char *buffer;
    113 }
    114 CvEMDState;
    115 
    116 /* static function declaration */
    117 static int icvInitEMD( const float *signature1, int size1,
    118                        const float *signature2, int size2,
    119                        int dims, CvDistanceFunction dist_func, void *user_param,
    120                        const float* cost, int cost_step,
    121                        CvEMDState * state, float *lower_bound,
    122                        cv::AutoBuffer<char>& _buffer );
    123 
    124 static int icvFindBasicVariables( float **cost, char **is_x,
    125                                   CvNode1D * u, CvNode1D * v, int ssize, int dsize );
    126 
    127 static float icvIsOptimal( float **cost, char **is_x,
    128                            CvNode1D * u, CvNode1D * v,
    129                            int ssize, int dsize, CvNode2D * enter_x );
    130 
    131 static void icvRussel( CvEMDState * state );
    132 
    133 
    134 static bool icvNewSolution( CvEMDState * state );
    135 static int icvFindLoop( CvEMDState * state );
    136 
    137 static void icvAddBasicVariable( CvEMDState * state,
    138                                  int min_i, int min_j,
    139                                  CvNode1D * prev_u_min_i,
    140                                  CvNode1D * prev_v_min_j,
    141                                  CvNode1D * u_head );
    142 
    143 static float icvDistL2( const float *x, const float *y, void *user_param );
    144 static float icvDistL1( const float *x, const float *y, void *user_param );
    145 static float icvDistC( const float *x, const float *y, void *user_param );
    146 
    147 /* The main function */
    148 CV_IMPL float cvCalcEMD2( const CvArr* signature_arr1,
    149             const CvArr* signature_arr2,
    150             int dist_type,
    151             CvDistanceFunction dist_func,
    152             const CvArr* cost_matrix,
    153             CvArr* flow_matrix,
    154             float *lower_bound,
    155             void *user_param )
    156 {
    157     cv::AutoBuffer<char> local_buf;
    158     CvEMDState state;
    159     float emd = 0;
    160 
    161     memset( &state, 0, sizeof(state));
    162 
    163     double total_cost = 0;
    164     int result = 0;
    165     float eps, min_delta;
    166     CvNode2D *xp = 0;
    167     CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
    168     CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
    169     CvMat cost_stub, *cost = &cost_stub;
    170     CvMat flow_stub, *flow = (CvMat*)flow_matrix;
    171     int dims, size1, size2;
    172 
    173     signature1 = cvGetMat( signature1, &sign_stub1 );
    174     signature2 = cvGetMat( signature2, &sign_stub2 );
    175 
    176     if( signature1->cols != signature2->cols )
    177         CV_Error( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
    178 
    179     dims = signature1->cols - 1;
    180     size1 = signature1->rows;
    181     size2 = signature2->rows;
    182 
    183     if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
    184         CV_Error( CV_StsUnmatchedFormats, "The array must have equal types" );
    185 
    186     if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
    187         CV_Error( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
    188 
    189     if( flow )
    190     {
    191         flow = cvGetMat( flow, &flow_stub );
    192 
    193         if( flow->rows != size1 || flow->cols != size2 )
    194             CV_Error( CV_StsUnmatchedSizes,
    195             "The flow matrix size does not match to the signatures' sizes" );
    196 
    197         if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
    198             CV_Error( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
    199     }
    200 
    201     cost->data.fl = 0;
    202     cost->step = 0;
    203 
    204     if( dist_type < 0 )
    205     {
    206         if( cost_matrix )
    207         {
    208             if( dist_func )
    209                 CV_Error( CV_StsBadArg,
    210                 "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
    211 
    212             if( lower_bound )
    213                 CV_Error( CV_StsBadArg,
    214                 "The lower boundary can not be calculated if the cost matrix is used" );
    215 
    216             cost = cvGetMat( cost_matrix, &cost_stub );
    217             if( cost->rows != size1 || cost->cols != size2 )
    218                 CV_Error( CV_StsUnmatchedSizes,
    219                 "The cost matrix size does not match to the signatures' sizes" );
    220 
    221             if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
    222                 CV_Error( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
    223         }
    224         else if( !dist_func )
    225             CV_Error( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
    226     }
    227     else
    228     {
    229         if( dims == 0 )
    230             CV_Error( CV_StsBadSize,
    231             "Number of dimensions can be 0 only if a user-defined metric is used" );
    232         user_param = (void *) (size_t)dims;
    233         switch (dist_type)
    234         {
    235         case CV_DIST_L1:
    236             dist_func = icvDistL1;
    237             break;
    238         case CV_DIST_L2:
    239             dist_func = icvDistL2;
    240             break;
    241         case CV_DIST_C:
    242             dist_func = icvDistC;
    243             break;
    244         default:
    245             CV_Error( CV_StsBadFlag, "Bad or unsupported metric type" );
    246         }
    247     }
    248 
    249     result = icvInitEMD( signature1->data.fl, size1,
    250                         signature2->data.fl, size2,
    251                         dims, dist_func, user_param,
    252                         cost->data.fl, cost->step,
    253                         &state, lower_bound, local_buf );
    254 
    255     if( result > 0 && lower_bound )
    256     {
    257         emd = *lower_bound;
    258         return emd;
    259     }
    260 
    261     eps = CV_EMD_EPS * state.max_cost;
    262 
    263     /* if ssize = 1 or dsize = 1 then we are done, else ... */
    264     if( state.ssize > 1 && state.dsize > 1 )
    265     {
    266         int itr;
    267 
    268         for( itr = 1; itr < MAX_ITERATIONS; itr++ )
    269         {
    270             /* find basic variables */
    271             result = icvFindBasicVariables( state.cost, state.is_x,
    272                                             state.u, state.v, state.ssize, state.dsize );
    273             if( result < 0 )
    274                 break;
    275 
    276             /* check for optimality */
    277             min_delta = icvIsOptimal( state.cost, state.is_x,
    278                                       state.u, state.v,
    279                                       state.ssize, state.dsize, state.enter_x );
    280 
    281             if( min_delta == CV_EMD_INF )
    282                 CV_Error( CV_StsNoConv, "" );
    283 
    284             /* if no negative deltamin, we found the optimal solution */
    285             if( min_delta >= -eps )
    286                 break;
    287 
    288             /* improve solution */
    289             if(!icvNewSolution( &state ))
    290                 CV_Error( CV_StsNoConv, "" );
    291         }
    292     }
    293 
    294     /* compute the total flow */
    295     for( xp = state._x; xp < state.end_x; xp++ )
    296     {
    297         float val = xp->val;
    298         int i = xp->i;
    299         int j = xp->j;
    300 
    301         if( xp == state.enter_x )
    302           continue;
    303 
    304         int ci = state.idx1[i];
    305         int cj = state.idx2[j];
    306 
    307         if( ci >= 0 && cj >= 0 )
    308         {
    309             total_cost += (double)val * state.cost[i][j];
    310             if( flow )
    311                 ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
    312         }
    313     }
    314 
    315     emd = (float) (total_cost / state.weight);
    316     return emd;
    317 }
    318 
    319 
    320 /************************************************************************************\
    321 *          initialize structure, allocate buffers and generate initial golution      *
    322 \************************************************************************************/
    323 static int icvInitEMD( const float* signature1, int size1,
    324             const float* signature2, int size2,
    325             int dims, CvDistanceFunction dist_func, void* user_param,
    326             const float* cost, int cost_step,
    327             CvEMDState* state, float* lower_bound,
    328             cv::AutoBuffer<char>& _buffer )
    329 {
    330     float s_sum = 0, d_sum = 0, diff;
    331     int i, j;
    332     int ssize = 0, dsize = 0;
    333     int equal_sums = 1;
    334     int buffer_size;
    335     float max_cost = 0;
    336     char *buffer, *buffer_end;
    337 
    338     memset( state, 0, sizeof( *state ));
    339     assert( cost_step % sizeof(float) == 0 );
    340     cost_step /= sizeof(float);
    341 
    342     /* calculate buffer size */
    343     buffer_size = (size1+1) * (size2+1) * (sizeof( float ) +    /* cost */
    344                                    sizeof( char ) +     /* is_x */
    345                                    sizeof( float )) +   /* delta matrix */
    346         (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
    347                            sizeof( CvNode2D * ) +  /* cols_x & rows_x */
    348                            sizeof( CvNode1D ) + /* u & v */
    349                            sizeof( float ) + /* s & d */
    350                            sizeof( int ) + sizeof(CvNode2D*)) +  /* idx1 & idx2 */
    351         (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
    352                  sizeof( float * )) + 256;      /*  cost, is_x and delta */
    353 
    354     if( buffer_size < (int) (dims * 2 * sizeof( float )))
    355     {
    356         buffer_size = dims * 2 * sizeof( float );
    357     }
    358 
    359     /* allocate buffers */
    360     _buffer.allocate(buffer_size);
    361 
    362     state->buffer = buffer = _buffer;
    363     buffer_end = buffer + buffer_size;
    364 
    365     state->idx1 = (int*) buffer;
    366     buffer += (size1 + 1) * sizeof( int );
    367 
    368     state->idx2 = (int*) buffer;
    369     buffer += (size2 + 1) * sizeof( int );
    370 
    371     state->s = (float *) buffer;
    372     buffer += (size1 + 1) * sizeof( float );
    373 
    374     state->d = (float *) buffer;
    375     buffer += (size2 + 1) * sizeof( float );
    376 
    377     /* sum up the supply and demand */
    378     for( i = 0; i < size1; i++ )
    379     {
    380         float weight = signature1[i * (dims + 1)];
    381 
    382         if( weight > 0 )
    383         {
    384             s_sum += weight;
    385             state->s[ssize] = weight;
    386             state->idx1[ssize++] = i;
    387 
    388         }
    389         else if( weight < 0 )
    390             CV_Error(CV_StsOutOfRange, "");
    391     }
    392 
    393     for( i = 0; i < size2; i++ )
    394     {
    395         float weight = signature2[i * (dims + 1)];
    396 
    397         if( weight > 0 )
    398         {
    399             d_sum += weight;
    400             state->d[dsize] = weight;
    401             state->idx2[dsize++] = i;
    402         }
    403         else if( weight < 0 )
    404             CV_Error(CV_StsOutOfRange, "");
    405     }
    406 
    407     if( ssize == 0 || dsize == 0 )
    408         CV_Error(CV_StsOutOfRange, "");
    409 
    410     /* if supply different than the demand, add a zero-cost dummy cluster */
    411     diff = s_sum - d_sum;
    412     if( fabs( diff ) >= CV_EMD_EPS * s_sum )
    413     {
    414         equal_sums = 0;
    415         if( diff < 0 )
    416         {
    417             state->s[ssize] = -diff;
    418             state->idx1[ssize++] = -1;
    419         }
    420         else
    421         {
    422             state->d[dsize] = diff;
    423             state->idx2[dsize++] = -1;
    424         }
    425     }
    426 
    427     state->ssize = ssize;
    428     state->dsize = dsize;
    429     state->weight = s_sum > d_sum ? s_sum : d_sum;
    430 
    431     if( lower_bound && equal_sums )     /* check lower bound */
    432     {
    433         int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
    434         float lb = 0;
    435 
    436         float* xs = (float *) buffer;
    437         float* xd = xs + dims;
    438 
    439         memset( xs, 0, dims*sizeof(xs[0]));
    440         memset( xd, 0, dims*sizeof(xd[0]));
    441 
    442         for( j = 0; j < sz1; j += dims + 1 )
    443         {
    444             float weight = signature1[j];
    445             for( i = 0; i < dims; i++ )
    446                 xs[i] += signature1[j + i + 1] * weight;
    447         }
    448 
    449         for( j = 0; j < sz2; j += dims + 1 )
    450         {
    451             float weight = signature2[j];
    452             for( i = 0; i < dims; i++ )
    453                 xd[i] += signature2[j + i + 1] * weight;
    454         }
    455 
    456         lb = dist_func( xs, xd, user_param ) / state->weight;
    457         i = *lower_bound <= lb;
    458         *lower_bound = lb;
    459         if( i )
    460             return 1;
    461     }
    462 
    463     /* assign pointers */
    464     state->is_used = (char *) buffer;
    465     /* init delta matrix */
    466     state->delta = (float **) buffer;
    467     buffer += ssize * sizeof( float * );
    468 
    469     for( i = 0; i < ssize; i++ )
    470     {
    471         state->delta[i] = (float *) buffer;
    472         buffer += dsize * sizeof( float );
    473     }
    474 
    475     state->loop = (CvNode2D **) buffer;
    476     buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
    477 
    478     state->_x = state->end_x = (CvNode2D *) buffer;
    479     buffer += (ssize + dsize) * sizeof( CvNode2D );
    480 
    481     /* init cost matrix */
    482     state->cost = (float **) buffer;
    483     buffer += ssize * sizeof( float * );
    484 
    485     /* compute the distance matrix */
    486     for( i = 0; i < ssize; i++ )
    487     {
    488         int ci = state->idx1[i];
    489 
    490         state->cost[i] = (float *) buffer;
    491         buffer += dsize * sizeof( float );
    492 
    493         if( ci >= 0 )
    494         {
    495             for( j = 0; j < dsize; j++ )
    496             {
    497                 int cj = state->idx2[j];
    498                 if( cj < 0 )
    499                     state->cost[i][j] = 0;
    500                 else
    501                 {
    502                     float val;
    503                     if( dist_func )
    504                     {
    505                         val = dist_func( signature1 + ci * (dims + 1) + 1,
    506                                          signature2 + cj * (dims + 1) + 1,
    507                                          user_param );
    508                     }
    509                     else
    510                     {
    511                         assert( cost );
    512                         val = cost[cost_step*ci + cj];
    513                     }
    514                     state->cost[i][j] = val;
    515                     if( max_cost < val )
    516                         max_cost = val;
    517                 }
    518             }
    519         }
    520         else
    521         {
    522             for( j = 0; j < dsize; j++ )
    523                 state->cost[i][j] = 0;
    524         }
    525     }
    526 
    527     state->max_cost = max_cost;
    528 
    529     memset( buffer, 0, buffer_end - buffer );
    530 
    531     state->rows_x = (CvNode2D **) buffer;
    532     buffer += ssize * sizeof( CvNode2D * );
    533 
    534     state->cols_x = (CvNode2D **) buffer;
    535     buffer += dsize * sizeof( CvNode2D * );
    536 
    537     state->u = (CvNode1D *) buffer;
    538     buffer += ssize * sizeof( CvNode1D );
    539 
    540     state->v = (CvNode1D *) buffer;
    541     buffer += dsize * sizeof( CvNode1D );
    542 
    543     /* init is_x matrix */
    544     state->is_x = (char **) buffer;
    545     buffer += ssize * sizeof( char * );
    546 
    547     for( i = 0; i < ssize; i++ )
    548     {
    549         state->is_x[i] = buffer;
    550         buffer += dsize;
    551     }
    552 
    553     assert( buffer <= buffer_end );
    554 
    555     icvRussel( state );
    556 
    557     state->enter_x = (state->end_x)++;
    558     return 0;
    559 }
    560 
    561 
    562 /****************************************************************************************\
    563 *                              icvFindBasicVariables                                   *
    564 \****************************************************************************************/
    565 static int icvFindBasicVariables( float **cost, char **is_x,
    566                        CvNode1D * u, CvNode1D * v, int ssize, int dsize )
    567 {
    568     int i, j, found;
    569     int u_cfound, v_cfound;
    570     CvNode1D u0_head, u1_head, *cur_u, *prev_u;
    571     CvNode1D v0_head, v1_head, *cur_v, *prev_v;
    572 
    573     /* initialize the rows list (u) and the columns list (v) */
    574     u0_head.next = u;
    575     for( i = 0; i < ssize; i++ )
    576     {
    577         u[i].next = u + i + 1;
    578     }
    579     u[ssize - 1].next = 0;
    580     u1_head.next = 0;
    581 
    582     v0_head.next = ssize > 1 ? v + 1 : 0;
    583     for( i = 1; i < dsize; i++ )
    584     {
    585         v[i].next = v + i + 1;
    586     }
    587     v[dsize - 1].next = 0;
    588     v1_head.next = 0;
    589 
    590     /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
    591        so set v[0]=0 */
    592     v[0].val = 0;
    593     v1_head.next = v;
    594     v1_head.next->next = 0;
    595 
    596     /* loop until all variables are found */
    597     u_cfound = v_cfound = 0;
    598     while( u_cfound < ssize || v_cfound < dsize )
    599     {
    600         found = 0;
    601         if( v_cfound < dsize )
    602         {
    603             /* loop over all marked columns */
    604             prev_v = &v1_head;
    605 
    606             for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
    607             {
    608                 float cur_v_val = cur_v->val;
    609 
    610                 j = (int)(cur_v - v);
    611                 /* find the variables in column j */
    612                 prev_u = &u0_head;
    613                 for( cur_u = u0_head.next; cur_u != 0; )
    614                 {
    615                     i = (int)(cur_u - u);
    616                     if( is_x[i][j] )
    617                     {
    618                         /* compute u[i] */
    619                         cur_u->val = cost[i][j] - cur_v_val;
    620                         /* ...and add it to the marked list */
    621                         prev_u->next = cur_u->next;
    622                         cur_u->next = u1_head.next;
    623                         u1_head.next = cur_u;
    624                         cur_u = prev_u->next;
    625                     }
    626                     else
    627                     {
    628                         prev_u = cur_u;
    629                         cur_u = cur_u->next;
    630                     }
    631                 }
    632                 prev_v->next = cur_v->next;
    633                 v_cfound++;
    634             }
    635         }
    636 
    637         if( u_cfound < ssize )
    638         {
    639             /* loop over all marked rows */
    640             prev_u = &u1_head;
    641             for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
    642             {
    643                 float cur_u_val = cur_u->val;
    644                 float *_cost;
    645                 char *_is_x;
    646 
    647                 i = (int)(cur_u - u);
    648                 _cost = cost[i];
    649                 _is_x = is_x[i];
    650                 /* find the variables in rows i */
    651                 prev_v = &v0_head;
    652                 for( cur_v = v0_head.next; cur_v != 0; )
    653                 {
    654                     j = (int)(cur_v - v);
    655                     if( _is_x[j] )
    656                     {
    657                         /* compute v[j] */
    658                         cur_v->val = _cost[j] - cur_u_val;
    659                         /* ...and add it to the marked list */
    660                         prev_v->next = cur_v->next;
    661                         cur_v->next = v1_head.next;
    662                         v1_head.next = cur_v;
    663                         cur_v = prev_v->next;
    664                     }
    665                     else
    666                     {
    667                         prev_v = cur_v;
    668                         cur_v = cur_v->next;
    669                     }
    670                 }
    671                 prev_u->next = cur_u->next;
    672                 u_cfound++;
    673             }
    674         }
    675 
    676         if( !found )
    677             return -1;
    678     }
    679 
    680     return 0;
    681 }
    682 
    683 
    684 /****************************************************************************************\
    685 *                                   icvIsOptimal                                       *
    686 \****************************************************************************************/
    687 static float
    688 icvIsOptimal( float **cost, char **is_x,
    689               CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
    690 {
    691     float delta, min_delta = CV_EMD_INF;
    692     int i, j, min_i = 0, min_j = 0;
    693 
    694     /* find the minimal cij-ui-vj over all i,j */
    695     for( i = 0; i < ssize; i++ )
    696     {
    697         float u_val = u[i].val;
    698         float *_cost = cost[i];
    699         char *_is_x = is_x[i];
    700 
    701         for( j = 0; j < dsize; j++ )
    702         {
    703             if( !_is_x[j] )
    704             {
    705                 delta = _cost[j] - u_val - v[j].val;
    706                 if( min_delta > delta )
    707                 {
    708                     min_delta = delta;
    709                     min_i = i;
    710                     min_j = j;
    711                 }
    712             }
    713         }
    714     }
    715 
    716     enter_x->i = min_i;
    717     enter_x->j = min_j;
    718 
    719     return min_delta;
    720 }
    721 
    722 /****************************************************************************************\
    723 *                                   icvNewSolution                                     *
    724 \****************************************************************************************/
    725 static bool
    726 icvNewSolution( CvEMDState * state )
    727 {
    728     int i, j;
    729     float min_val = CV_EMD_INF;
    730     int steps;
    731     CvNode2D head, *cur_x, *next_x, *leave_x = 0;
    732     CvNode2D *enter_x = state->enter_x;
    733     CvNode2D **loop = state->loop;
    734 
    735     /* enter the new basic variable */
    736     i = enter_x->i;
    737     j = enter_x->j;
    738     state->is_x[i][j] = 1;
    739     enter_x->next[0] = state->rows_x[i];
    740     enter_x->next[1] = state->cols_x[j];
    741     enter_x->val = 0;
    742     state->rows_x[i] = enter_x;
    743     state->cols_x[j] = enter_x;
    744 
    745     /* find a chain reaction */
    746     steps = icvFindLoop( state );
    747 
    748     if( steps == 0 )
    749         return false;
    750 
    751     /* find the largest value in the loop */
    752     for( i = 1; i < steps; i += 2 )
    753     {
    754         float temp = loop[i]->val;
    755 
    756         if( min_val > temp )
    757         {
    758             leave_x = loop[i];
    759             min_val = temp;
    760         }
    761     }
    762 
    763     /* update the loop */
    764     for( i = 0; i < steps; i += 2 )
    765     {
    766         float temp0 = loop[i]->val + min_val;
    767         float temp1 = loop[i + 1]->val - min_val;
    768 
    769         loop[i]->val = temp0;
    770         loop[i + 1]->val = temp1;
    771     }
    772 
    773     /* remove the leaving basic variable */
    774     i = leave_x->i;
    775     j = leave_x->j;
    776     state->is_x[i][j] = 0;
    777 
    778     head.next[0] = state->rows_x[i];
    779     cur_x = &head;
    780     while( (next_x = cur_x->next[0]) != leave_x )
    781     {
    782         cur_x = next_x;
    783         assert( cur_x );
    784     }
    785     cur_x->next[0] = next_x->next[0];
    786     state->rows_x[i] = head.next[0];
    787 
    788     head.next[1] = state->cols_x[j];
    789     cur_x = &head;
    790     while( (next_x = cur_x->next[1]) != leave_x )
    791     {
    792         cur_x = next_x;
    793         assert( cur_x );
    794     }
    795     cur_x->next[1] = next_x->next[1];
    796     state->cols_x[j] = head.next[1];
    797 
    798     /* set enter_x to be the new empty slot */
    799     state->enter_x = leave_x;
    800 
    801     return true;
    802 }
    803 
    804 
    805 
    806 /****************************************************************************************\
    807 *                                    icvFindLoop                                       *
    808 \****************************************************************************************/
    809 static int
    810 icvFindLoop( CvEMDState * state )
    811 {
    812     int i, steps = 1;
    813     CvNode2D *new_x;
    814     CvNode2D **loop = state->loop;
    815     CvNode2D *enter_x = state->enter_x, *_x = state->_x;
    816     char *is_used = state->is_used;
    817 
    818     memset( is_used, 0, state->ssize + state->dsize );
    819 
    820     new_x = loop[0] = enter_x;
    821     is_used[enter_x - _x] = 1;
    822     steps = 1;
    823 
    824     do
    825     {
    826         if( (steps & 1) == 1 )
    827         {
    828             /* find an unused x in the row */
    829             new_x = state->rows_x[new_x->i];
    830             while( new_x != 0 && is_used[new_x - _x] )
    831                 new_x = new_x->next[0];
    832         }
    833         else
    834         {
    835             /* find an unused x in the column, or the entering x */
    836             new_x = state->cols_x[new_x->j];
    837             while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
    838                 new_x = new_x->next[1];
    839             if( new_x == enter_x )
    840                 break;
    841         }
    842 
    843         if( new_x != 0 )        /* found the next x */
    844         {
    845             /* add x to the loop */
    846             loop[steps++] = new_x;
    847             is_used[new_x - _x] = 1;
    848         }
    849         else                    /* didn't find the next x */
    850         {
    851             /* backtrack */
    852             do
    853             {
    854                 i = steps & 1;
    855                 new_x = loop[steps - 1];
    856                 do
    857                 {
    858                     new_x = new_x->next[i];
    859                 }
    860                 while( new_x != 0 && is_used[new_x - _x] );
    861 
    862                 if( new_x == 0 )
    863                 {
    864                     is_used[loop[--steps] - _x] = 0;
    865                 }
    866             }
    867             while( new_x == 0 && steps > 0 );
    868 
    869             is_used[loop[steps - 1] - _x] = 0;
    870             loop[steps - 1] = new_x;
    871             is_used[new_x - _x] = 1;
    872         }
    873     }
    874     while( steps > 0 );
    875 
    876     return steps;
    877 }
    878 
    879 
    880 
    881 /****************************************************************************************\
    882 *                                        icvRussel                                     *
    883 \****************************************************************************************/
    884 static void
    885 icvRussel( CvEMDState * state )
    886 {
    887     int i, j, min_i = -1, min_j = -1;
    888     float min_delta, diff;
    889     CvNode1D u_head, *cur_u, *prev_u;
    890     CvNode1D v_head, *cur_v, *prev_v;
    891     CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
    892     CvNode1D *u = state->u, *v = state->v;
    893     int ssize = state->ssize, dsize = state->dsize;
    894     float eps = CV_EMD_EPS * state->max_cost;
    895     float **cost = state->cost;
    896     float **delta = state->delta;
    897 
    898     /* initialize the rows list (ur), and the columns list (vr) */
    899     u_head.next = u;
    900     for( i = 0; i < ssize; i++ )
    901     {
    902         u[i].next = u + i + 1;
    903     }
    904     u[ssize - 1].next = 0;
    905 
    906     v_head.next = v;
    907     for( i = 0; i < dsize; i++ )
    908     {
    909         v[i].val = -CV_EMD_INF;
    910         v[i].next = v + i + 1;
    911     }
    912     v[dsize - 1].next = 0;
    913 
    914     /* find the maximum row and column values (ur[i] and vr[j]) */
    915     for( i = 0; i < ssize; i++ )
    916     {
    917         float u_val = -CV_EMD_INF;
    918         float *cost_row = cost[i];
    919 
    920         for( j = 0; j < dsize; j++ )
    921         {
    922             float temp = cost_row[j];
    923 
    924             if( u_val < temp )
    925                 u_val = temp;
    926             if( v[j].val < temp )
    927                 v[j].val = temp;
    928         }
    929         u[i].val = u_val;
    930     }
    931 
    932     /* compute the delta matrix */
    933     for( i = 0; i < ssize; i++ )
    934     {
    935         float u_val = u[i].val;
    936         float *delta_row = delta[i];
    937         float *cost_row = cost[i];
    938 
    939         for( j = 0; j < dsize; j++ )
    940         {
    941             delta_row[j] = cost_row[j] - u_val - v[j].val;
    942         }
    943     }
    944 
    945     /* find the basic variables */
    946     do
    947     {
    948         /* find the smallest delta[i][j] */
    949         min_i = -1;
    950         min_delta = CV_EMD_INF;
    951         prev_u = &u_head;
    952         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
    953         {
    954             i = (int)(cur_u - u);
    955             float *delta_row = delta[i];
    956 
    957             prev_v = &v_head;
    958             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
    959             {
    960                 j = (int)(cur_v - v);
    961                 if( min_delta > delta_row[j] )
    962                 {
    963                     min_delta = delta_row[j];
    964                     min_i = i;
    965                     min_j = j;
    966                     prev_u_min_i = prev_u;
    967                     prev_v_min_j = prev_v;
    968                 }
    969                 prev_v = cur_v;
    970             }
    971             prev_u = cur_u;
    972         }
    973 
    974         if( min_i < 0 )
    975             break;
    976 
    977         /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
    978         remember = prev_u_min_i->next;
    979         icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
    980 
    981         /* update the necessary delta[][] */
    982         if( remember == prev_u_min_i->next )    /* line min_i was deleted */
    983         {
    984             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
    985             {
    986                 j = (int)(cur_v - v);
    987                 if( cur_v->val == cost[min_i][j] )      /* column j needs updating */
    988                 {
    989                     float max_val = -CV_EMD_INF;
    990 
    991                     /* find the new maximum value in the column */
    992                     for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
    993                     {
    994                         float temp = cost[cur_u - u][j];
    995 
    996                         if( max_val < temp )
    997                             max_val = temp;
    998                     }
    999 
   1000                     /* if needed, adjust the relevant delta[*][j] */
   1001                     diff = max_val - cur_v->val;
   1002                     cur_v->val = max_val;
   1003                     if( fabs( diff ) < eps )
   1004                     {
   1005                         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
   1006                             delta[cur_u - u][j] += diff;
   1007                     }
   1008                 }
   1009             }
   1010         }
   1011         else                    /* column min_j was deleted */
   1012         {
   1013             for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
   1014             {
   1015                 i = (int)(cur_u - u);
   1016                 if( cur_u->val == cost[i][min_j] )      /* row i needs updating */
   1017                 {
   1018                     float max_val = -CV_EMD_INF;
   1019 
   1020                     /* find the new maximum value in the row */
   1021                     for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
   1022                     {
   1023                         float temp = cost[i][cur_v - v];
   1024 
   1025                         if( max_val < temp )
   1026                             max_val = temp;
   1027                     }
   1028 
   1029                     /* if needed, adjust the relevant delta[i][*] */
   1030                     diff = max_val - cur_u->val;
   1031                     cur_u->val = max_val;
   1032 
   1033                     if( fabs( diff ) < eps )
   1034                     {
   1035                         for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
   1036                             delta[i][cur_v - v] += diff;
   1037                     }
   1038                 }
   1039             }
   1040         }
   1041     }
   1042     while( u_head.next != 0 || v_head.next != 0 );
   1043 }
   1044 
   1045 
   1046 
   1047 /****************************************************************************************\
   1048 *                                   icvAddBasicVariable                                *
   1049 \****************************************************************************************/
   1050 static void
   1051 icvAddBasicVariable( CvEMDState * state,
   1052                      int min_i, int min_j,
   1053                      CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
   1054 {
   1055     float temp;
   1056     CvNode2D *end_x = state->end_x;
   1057 
   1058     if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
   1059     {                           /* supply exhausted */
   1060         temp = state->s[min_i];
   1061         state->s[min_i] = 0;
   1062         state->d[min_j] -= temp;
   1063     }
   1064     else                        /* demand exhausted */
   1065     {
   1066         temp = state->d[min_j];
   1067         state->d[min_j] = 0;
   1068         state->s[min_i] -= temp;
   1069     }
   1070 
   1071     /* x(min_i,min_j) is a basic variable */
   1072     state->is_x[min_i][min_j] = 1;
   1073 
   1074     end_x->val = temp;
   1075     end_x->i = min_i;
   1076     end_x->j = min_j;
   1077     end_x->next[0] = state->rows_x[min_i];
   1078     end_x->next[1] = state->cols_x[min_j];
   1079     state->rows_x[min_i] = end_x;
   1080     state->cols_x[min_j] = end_x;
   1081     state->end_x = end_x + 1;
   1082 
   1083     /* delete supply row only if the empty, and if not last row */
   1084     if( state->s[min_i] == 0 && u_head->next->next != 0 )
   1085         prev_u_min_i->next = prev_u_min_i->next->next;  /* remove row from list */
   1086     else
   1087         prev_v_min_j->next = prev_v_min_j->next->next;  /* remove column from list */
   1088 }
   1089 
   1090 
   1091 /****************************************************************************************\
   1092 *                                  standard  metrics                                     *
   1093 \****************************************************************************************/
   1094 static float
   1095 icvDistL1( const float *x, const float *y, void *user_param )
   1096 {
   1097     int i, dims = (int)(size_t)user_param;
   1098     double s = 0;
   1099 
   1100     for( i = 0; i < dims; i++ )
   1101     {
   1102         double t = x[i] - y[i];
   1103 
   1104         s += fabs( t );
   1105     }
   1106     return (float)s;
   1107 }
   1108 
   1109 static float
   1110 icvDistL2( const float *x, const float *y, void *user_param )
   1111 {
   1112     int i, dims = (int)(size_t)user_param;
   1113     double s = 0;
   1114 
   1115     for( i = 0; i < dims; i++ )
   1116     {
   1117         double t = x[i] - y[i];
   1118 
   1119         s += t * t;
   1120     }
   1121     return cvSqrt( (float)s );
   1122 }
   1123 
   1124 static float
   1125 icvDistC( const float *x, const float *y, void *user_param )
   1126 {
   1127     int i, dims = (int)(size_t)user_param;
   1128     double s = 0;
   1129 
   1130     for( i = 0; i < dims; i++ )
   1131     {
   1132         double t = fabs( x[i] - y[i] );
   1133 
   1134         if( s < t )
   1135             s = t;
   1136     }
   1137     return (float)s;
   1138 }
   1139 
   1140 
   1141 float cv::EMD( InputArray _signature1, InputArray _signature2,
   1142                int distType, InputArray _cost,
   1143                float* lowerBound, OutputArray _flow )
   1144 {
   1145     Mat signature1 = _signature1.getMat(), signature2 = _signature2.getMat();
   1146     Mat cost = _cost.getMat(), flow;
   1147 
   1148     CvMat _csignature1 = signature1;
   1149     CvMat _csignature2 = signature2;
   1150     CvMat _ccost = cost, _cflow;
   1151     if( _flow.needed() )
   1152     {
   1153         _flow.create(signature1.rows, signature2.rows, CV_32F);
   1154         flow = _flow.getMat();
   1155         flow = Scalar::all(0);
   1156         _cflow = flow;
   1157     }
   1158 
   1159     return cvCalcEMD2( &_csignature1, &_csignature2, distType, 0, cost.empty() ? 0 : &_ccost,
   1160                        _flow.needed() ? &_cflow : 0, lowerBound, 0 );
   1161 }
   1162 
   1163 /* End of file. */
   1164