Home | History | Annotate | Download | only in src
      1 /*M///////////////////////////////////////////////////////////////////////////////////////
      2 //
      3 //  IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
      4 //
      5 //  By downloading, copying, installing or using the software you agree to this license.
      6 //  If you do not agree to this license, do not download, install,
      7 //  copy or use the software.
      8 //
      9 //
     10 //                        Intel License Agreement
     11 //                For Open Source Computer Vision Library
     12 //
     13 // Copyright (C) 2000, Intel Corporation, all rights reserved.
     14 // Third party copyrights are property of their respective owners.
     15 //
     16 // Redistribution and use in source and binary forms, with or without modification,
     17 // are permitted provided that the following conditions are met:
     18 //
     19 //   * Redistribution's of source code must retain the above copyright notice,
     20 //     this list of conditions and the following disclaimer.
     21 //
     22 //   * Redistribution's in binary form must reproduce the above copyright notice,
     23 //     this list of conditions and the following disclaimer in the documentation
     24 //     and/or other materials provided with the distribution.
     25 //
     26 //   * The name of Intel Corporation may not be used to endorse or promote products
     27 //     derived from this software without specific prior written permission.
     28 //
     29 // This software is provided by the copyright holders and contributors "as is" and
     30 // any express or implied warranties, including, but not limited to, the implied
     31 // warranties of merchantability and fitness for a particular purpose are disclaimed.
     32 // In no event shall the Intel Corporation or contributors be liable for any direct,
     33 // indirect, incidental, special, exemplary, or consequential damages
     34 // (including, but not limited to, procurement of substitute goods or services;
     35 // loss of use, data, or profits; or business interruption) however caused
     36 // and on any theory of liability, whether in contract, strict liability,
     37 // or tort (including negligence or otherwise) arising in any way out of
     38 // the use of this software, even if advised of the possibility of such damage.
     39 //
     40 //M*/
     41 
     42 /*
     43     Partially based on Yossi Rubner code:
     44     =========================================================================
     45     emd.c
     46 
     47     Last update: 3/14/98
     48 
     49     An implementation of the Earth Movers Distance.
     50     Based of the solution for the Transportation problem as described in
     51     "Introduction to Mathematical Programming" by F. S. Hillier and
     52     G. J. Lieberman, McGraw-Hill, 1990.
     53 
     54     Copyright (C) 1998 Yossi Rubner
     55     Computer Science Department, Stanford University
     56     E-Mail: rubner (at) cs.stanford.edu   URL: http://vision.stanford.edu/~rubner
     57     ==========================================================================
     58 */
     59 #include "_cv.h"
     60 
     61 #define MAX_ITERATIONS 500
     62 #define CV_EMD_INF   ((float)1e20)
     63 #define CV_EMD_EPS   ((float)1e-5)
     64 
     65 /* CvNode1D is used for lists, representing 1D sparse array */
     66 typedef struct CvNode1D
     67 {
     68     float val;
     69     struct CvNode1D *next;
     70 }
     71 CvNode1D;
     72 
     73 /* CvNode2D is used for lists, representing 2D sparse matrix */
     74 typedef struct CvNode2D
     75 {
     76     float val;
     77     struct CvNode2D *next[2];  /* next row & next column */
     78     int i, j;
     79 }
     80 CvNode2D;
     81 
     82 
     83 typedef struct CvEMDState
     84 {
     85     int ssize, dsize;
     86 
     87     float **cost;
     88     CvNode2D *_x;
     89     CvNode2D *end_x;
     90     CvNode2D *enter_x;
     91     char **is_x;
     92 
     93     CvNode2D **rows_x;
     94     CvNode2D **cols_x;
     95 
     96     CvNode1D *u;
     97     CvNode1D *v;
     98 
     99     int* idx1;
    100     int* idx2;
    101 
    102     /* find_loop buffers */
    103     CvNode2D **loop;
    104     char *is_used;
    105 
    106     /* russel buffers */
    107     float *s;
    108     float *d;
    109     float **delta;
    110 
    111     float weight, max_cost;
    112     char *buffer;
    113 }
    114 CvEMDState;
    115 
    116 /* static function declaration */
    117 static CvStatus icvInitEMD( const float *signature1, int size1,
    118                             const float *signature2, int size2,
    119                             int dims, CvDistanceFunction dist_func, void *user_param,
    120                             const float* cost, int cost_step,
    121                             CvEMDState * state, float *lower_bound,
    122                             char *local_buffer, int local_buffer_size );
    123 
    124 static CvStatus icvFindBasicVariables( float **cost, char **is_x,
    125                                        CvNode1D * u, CvNode1D * v, int ssize, int dsize );
    126 
    127 static float icvIsOptimal( float **cost, char **is_x,
    128                            CvNode1D * u, CvNode1D * v,
    129                            int ssize, int dsize, CvNode2D * enter_x );
    130 
    131 static void icvRussel( CvEMDState * state );
    132 
    133 
    134 static CvStatus icvNewSolution( CvEMDState * state );
    135 static int icvFindLoop( CvEMDState * state );
    136 
    137 static void icvAddBasicVariable( CvEMDState * state,
    138                                  int min_i, int min_j,
    139                                  CvNode1D * prev_u_min_i,
    140                                  CvNode1D * prev_v_min_j,
    141                                  CvNode1D * u_head );
    142 
    143 static float icvDistL2( const float *x, const float *y, void *user_param );
    144 static float icvDistL1( const float *x, const float *y, void *user_param );
    145 static float icvDistC( const float *x, const float *y, void *user_param );
    146 
    147 /* The main function */
    148 CV_IMPL float
    149 cvCalcEMD2( const CvArr* signature_arr1,
    150             const CvArr* signature_arr2,
    151             int dist_type,
    152             CvDistanceFunction dist_func,
    153             const CvArr* cost_matrix,
    154             CvArr* flow_matrix,
    155             float *lower_bound,
    156             void *user_param )
    157 {
    158     char local_buffer[16384];
    159     char *local_buffer_ptr = (char *)cvAlignPtr(local_buffer,16);
    160     CvEMDState state;
    161     float emd = 0;
    162 
    163     CV_FUNCNAME( "cvCalcEMD2" );
    164 
    165     memset( &state, 0, sizeof(state));
    166 
    167     __BEGIN__;
    168 
    169     double total_cost = 0;
    170     CvStatus result = CV_NO_ERR;
    171     float eps, min_delta;
    172     CvNode2D *xp = 0;
    173     CvMat sign_stub1, *signature1 = (CvMat*)signature_arr1;
    174     CvMat sign_stub2, *signature2 = (CvMat*)signature_arr2;
    175     CvMat cost_stub, *cost = &cost_stub;
    176     CvMat flow_stub, *flow = (CvMat*)flow_matrix;
    177     int dims, size1, size2;
    178 
    179     CV_CALL( signature1 = cvGetMat( signature1, &sign_stub1 ));
    180     CV_CALL( signature2 = cvGetMat( signature2, &sign_stub2 ));
    181 
    182     if( signature1->cols != signature2->cols )
    183         CV_ERROR( CV_StsUnmatchedSizes, "The arrays must have equal number of columns (which is number of dimensions but 1)" );
    184 
    185     dims = signature1->cols - 1;
    186     size1 = signature1->rows;
    187     size2 = signature2->rows;
    188 
    189     if( !CV_ARE_TYPES_EQ( signature1, signature2 ))
    190         CV_ERROR( CV_StsUnmatchedFormats, "The array must have equal types" );
    191 
    192     if( CV_MAT_TYPE( signature1->type ) != CV_32FC1 )
    193         CV_ERROR( CV_StsUnsupportedFormat, "The signatures must be 32fC1" );
    194 
    195     if( flow )
    196     {
    197         CV_CALL( flow = cvGetMat( flow, &flow_stub ));
    198 
    199         if( flow->rows != size1 || flow->cols != size2 )
    200             CV_ERROR( CV_StsUnmatchedSizes,
    201             "The flow matrix size does not match to the signatures' sizes" );
    202 
    203         if( CV_MAT_TYPE( flow->type ) != CV_32FC1 )
    204             CV_ERROR( CV_StsUnsupportedFormat, "The flow matrix must be 32fC1" );
    205     }
    206 
    207     cost->data.fl = 0;
    208     cost->step = 0;
    209 
    210     if( dist_type < 0 )
    211     {
    212         if( cost_matrix )
    213         {
    214             if( dist_func )
    215                 CV_ERROR( CV_StsBadArg,
    216                 "Only one of cost matrix or distance function should be non-NULL in case of user-defined distance" );
    217 
    218             if( lower_bound )
    219                 CV_ERROR( CV_StsBadArg,
    220                 "The lower boundary can not be calculated if the cost matrix is used" );
    221 
    222             CV_CALL( cost = cvGetMat( cost_matrix, &cost_stub ));
    223             if( cost->rows != size1 || cost->cols != size2 )
    224                 CV_ERROR( CV_StsUnmatchedSizes,
    225                 "The cost matrix size does not match to the signatures' sizes" );
    226 
    227             if( CV_MAT_TYPE( cost->type ) != CV_32FC1 )
    228                 CV_ERROR( CV_StsUnsupportedFormat, "The cost matrix must be 32fC1" );
    229         }
    230         else if( !dist_func )
    231             CV_ERROR( CV_StsNullPtr, "In case of user-defined distance Distance function is undefined" );
    232     }
    233     else
    234     {
    235         if( dims == 0 )
    236             CV_ERROR( CV_StsBadSize,
    237             "Number of dimensions can be 0 only if a user-defined metric is used" );
    238         user_param = (void *) (size_t)dims;
    239         switch (dist_type)
    240         {
    241         case CV_DIST_L1:
    242             dist_func = icvDistL1;
    243             break;
    244         case CV_DIST_L2:
    245             dist_func = icvDistL2;
    246             break;
    247         case CV_DIST_C:
    248             dist_func = icvDistC;
    249             break;
    250         default:
    251             CV_ERROR( CV_StsBadFlag, "Bad or unsupported metric type" );
    252         }
    253     }
    254 
    255     IPPI_CALL( result = icvInitEMD( signature1->data.fl, size1,
    256                                     signature2->data.fl, size2,
    257                                     dims, dist_func, user_param,
    258                                     cost->data.fl, cost->step,
    259                                     &state, lower_bound, local_buffer_ptr,
    260                                     sizeof( local_buffer ) - 16 ));
    261 
    262     if( result > 0 && lower_bound )
    263     {
    264         emd = *lower_bound;
    265         EXIT;
    266     }
    267 
    268     eps = CV_EMD_EPS * state.max_cost;
    269 
    270     /* if ssize = 1 or dsize = 1 then we are done, else ... */
    271     if( state.ssize > 1 && state.dsize > 1 )
    272     {
    273         int itr;
    274 
    275         for( itr = 1; itr < MAX_ITERATIONS; itr++ )
    276         {
    277             /* find basic variables */
    278             result = icvFindBasicVariables( state.cost, state.is_x,
    279                                             state.u, state.v, state.ssize, state.dsize );
    280             if( result < 0 )
    281                 break;
    282 
    283             /* check for optimality */
    284             min_delta = icvIsOptimal( state.cost, state.is_x,
    285                                       state.u, state.v,
    286                                       state.ssize, state.dsize, state.enter_x );
    287 
    288             if( min_delta == CV_EMD_INF )
    289             {
    290                 CV_ERROR( CV_StsNoConv, "" );
    291             }
    292 
    293             /* if no negative deltamin, we found the optimal solution */
    294             if( min_delta >= -eps )
    295                 break;
    296 
    297             /* improve solution */
    298             IPPI_CALL( icvNewSolution( &state ));
    299         }
    300     }
    301 
    302     /* compute the total flow */
    303     for( xp = state._x; xp < state.end_x; xp++ )
    304     {
    305         float val = xp->val;
    306         int i = xp->i;
    307         int j = xp->j;
    308         int ci = state.idx1[i];
    309         int cj = state.idx2[j];
    310 
    311         if( xp != state.enter_x && ci >= 0 && cj >= 0 )
    312         {
    313             total_cost += (double)val * state.cost[i][j];
    314             if( flow )
    315                 ((float*)(flow->data.ptr + flow->step*ci))[cj] = val;
    316         }
    317     }
    318 
    319     emd = (float) (total_cost / state.weight);
    320 
    321     __END__;
    322 
    323     if( state.buffer && state.buffer != local_buffer_ptr )
    324         cvFree( &state.buffer );
    325 
    326     return emd;
    327 }
    328 
    329 
    330 /************************************************************************************\
    331 *          initialize structure, allocate buffers and generate initial golution      *
    332 \************************************************************************************/
    333 static CvStatus
    334 icvInitEMD( const float* signature1, int size1,
    335             const float* signature2, int size2,
    336             int dims, CvDistanceFunction dist_func, void* user_param,
    337             const float* cost, int cost_step,
    338             CvEMDState* state, float* lower_bound,
    339             char* local_buffer, int local_buffer_size )
    340 {
    341     float s_sum = 0, d_sum = 0, diff;
    342     int i, j;
    343     int ssize = 0, dsize = 0;
    344     int equal_sums = 1;
    345     int buffer_size;
    346     float max_cost = 0;
    347     char *buffer, *buffer_end;
    348 
    349     memset( state, 0, sizeof( *state ));
    350     assert( cost_step % sizeof(float) == 0 );
    351     cost_step /= sizeof(float);
    352 
    353     /* calculate buffer size */
    354     buffer_size = (size1+1) * (size2+1) * (sizeof( float ) +    /* cost */
    355                                    sizeof( char ) +     /* is_x */
    356                                    sizeof( float )) +   /* delta matrix */
    357         (size1 + size2 + 2) * (sizeof( CvNode2D ) + /* _x */
    358                            sizeof( CvNode2D * ) +  /* cols_x & rows_x */
    359                            sizeof( CvNode1D ) + /* u & v */
    360                            sizeof( float ) + /* s & d */
    361                            sizeof( int ) + sizeof(CvNode2D*)) +  /* idx1 & idx2 */
    362         (size1+1) * (sizeof( float * ) + sizeof( char * ) + /* rows pointers for */
    363                  sizeof( float * )) + 256;      /*  cost, is_x and delta */
    364 
    365     if( buffer_size < (int) (dims * 2 * sizeof( float )))
    366     {
    367         buffer_size = dims * 2 * sizeof( float );
    368     }
    369 
    370     /* allocate buffers */
    371     if( local_buffer != 0 && local_buffer_size >= buffer_size )
    372     {
    373         buffer = local_buffer;
    374     }
    375     else
    376     {
    377         buffer = (char*)cvAlloc( buffer_size );
    378         if( !buffer )
    379             return CV_OUTOFMEM_ERR;
    380     }
    381 
    382     state->buffer = buffer;
    383     buffer_end = buffer + buffer_size;
    384 
    385     state->idx1 = (int*) buffer;
    386     buffer += (size1 + 1) * sizeof( int );
    387 
    388     state->idx2 = (int*) buffer;
    389     buffer += (size2 + 1) * sizeof( int );
    390 
    391     state->s = (float *) buffer;
    392     buffer += (size1 + 1) * sizeof( float );
    393 
    394     state->d = (float *) buffer;
    395     buffer += (size2 + 1) * sizeof( float );
    396 
    397     /* sum up the supply and demand */
    398     for( i = 0; i < size1; i++ )
    399     {
    400         float weight = signature1[i * (dims + 1)];
    401 
    402         if( weight > 0 )
    403         {
    404             s_sum += weight;
    405             state->s[ssize] = weight;
    406             state->idx1[ssize++] = i;
    407 
    408         }
    409         else if( weight < 0 )
    410             return CV_BADRANGE_ERR;
    411     }
    412 
    413     for( i = 0; i < size2; i++ )
    414     {
    415         float weight = signature2[i * (dims + 1)];
    416 
    417         if( weight > 0 )
    418         {
    419             d_sum += weight;
    420             state->d[dsize] = weight;
    421             state->idx2[dsize++] = i;
    422         }
    423         else if( weight < 0 )
    424             return CV_BADRANGE_ERR;
    425     }
    426 
    427     if( ssize == 0 || dsize == 0 )
    428         return CV_BADRANGE_ERR;
    429 
    430     /* if supply different than the demand, add a zero-cost dummy cluster */
    431     diff = s_sum - d_sum;
    432     if( fabs( diff ) >= CV_EMD_EPS * s_sum )
    433     {
    434         equal_sums = 0;
    435         if( diff < 0 )
    436         {
    437             state->s[ssize] = -diff;
    438             state->idx1[ssize++] = -1;
    439         }
    440         else
    441         {
    442             state->d[dsize] = diff;
    443             state->idx2[dsize++] = -1;
    444         }
    445     }
    446 
    447     state->ssize = ssize;
    448     state->dsize = dsize;
    449     state->weight = s_sum > d_sum ? s_sum : d_sum;
    450 
    451     if( lower_bound && equal_sums )     /* check lower bound */
    452     {
    453         int sz1 = size1 * (dims + 1), sz2 = size2 * (dims + 1);
    454         float lb = 0;
    455 
    456         float* xs = (float *) buffer;
    457         float* xd = xs + dims;
    458 
    459         memset( xs, 0, dims*sizeof(xs[0]));
    460         memset( xd, 0, dims*sizeof(xd[0]));
    461 
    462         for( j = 0; j < sz1; j += dims + 1 )
    463         {
    464             float weight = signature1[j];
    465             for( i = 0; i < dims; i++ )
    466                 xs[i] += signature1[j + i + 1] * weight;
    467         }
    468 
    469         for( j = 0; j < sz2; j += dims + 1 )
    470         {
    471             float weight = signature2[j];
    472             for( i = 0; i < dims; i++ )
    473                 xd[i] += signature2[j + i + 1] * weight;
    474         }
    475 
    476         lb = dist_func( xs, xd, user_param ) / state->weight;
    477         i = *lower_bound <= lb;
    478         *lower_bound = lb;
    479         if( i )
    480             return ( CvStatus ) 1;
    481     }
    482 
    483     /* assign pointers */
    484     state->is_used = (char *) buffer;
    485     /* init delta matrix */
    486     state->delta = (float **) buffer;
    487     buffer += ssize * sizeof( float * );
    488 
    489     for( i = 0; i < ssize; i++ )
    490     {
    491         state->delta[i] = (float *) buffer;
    492         buffer += dsize * sizeof( float );
    493     }
    494 
    495     state->loop = (CvNode2D **) buffer;
    496     buffer += (ssize + dsize + 1) * sizeof(CvNode2D*);
    497 
    498     state->_x = state->end_x = (CvNode2D *) buffer;
    499     buffer += (ssize + dsize) * sizeof( CvNode2D );
    500 
    501     /* init cost matrix */
    502     state->cost = (float **) buffer;
    503     buffer += ssize * sizeof( float * );
    504 
    505     /* compute the distance matrix */
    506     for( i = 0; i < ssize; i++ )
    507     {
    508         int ci = state->idx1[i];
    509 
    510         state->cost[i] = (float *) buffer;
    511         buffer += dsize * sizeof( float );
    512 
    513         if( ci >= 0 )
    514         {
    515             for( j = 0; j < dsize; j++ )
    516             {
    517                 int cj = state->idx2[j];
    518                 if( cj < 0 )
    519                     state->cost[i][j] = 0;
    520                 else
    521                 {
    522                     float val;
    523                     if( dist_func )
    524                     {
    525                         val = dist_func( signature1 + ci * (dims + 1) + 1,
    526                                          signature2 + cj * (dims + 1) + 1,
    527                                          user_param );
    528                     }
    529                     else
    530                     {
    531                         assert( cost );
    532                         val = cost[cost_step*ci + cj];
    533                     }
    534                     state->cost[i][j] = val;
    535                     if( max_cost < val )
    536                         max_cost = val;
    537                 }
    538             }
    539         }
    540         else
    541         {
    542             for( j = 0; j < dsize; j++ )
    543                 state->cost[i][j] = 0;
    544         }
    545     }
    546 
    547     state->max_cost = max_cost;
    548 
    549     memset( buffer, 0, buffer_end - buffer );
    550 
    551     state->rows_x = (CvNode2D **) buffer;
    552     buffer += ssize * sizeof( CvNode2D * );
    553 
    554     state->cols_x = (CvNode2D **) buffer;
    555     buffer += dsize * sizeof( CvNode2D * );
    556 
    557     state->u = (CvNode1D *) buffer;
    558     buffer += ssize * sizeof( CvNode1D );
    559 
    560     state->v = (CvNode1D *) buffer;
    561     buffer += dsize * sizeof( CvNode1D );
    562 
    563     /* init is_x matrix */
    564     state->is_x = (char **) buffer;
    565     buffer += ssize * sizeof( char * );
    566 
    567     for( i = 0; i < ssize; i++ )
    568     {
    569         state->is_x[i] = buffer;
    570         buffer += dsize;
    571     }
    572 
    573     assert( buffer <= buffer_end );
    574 
    575     icvRussel( state );
    576 
    577     state->enter_x = (state->end_x)++;
    578     return CV_NO_ERR;
    579 }
    580 
    581 
    582 /****************************************************************************************\
    583 *                              icvFindBasicVariables                                   *
    584 \****************************************************************************************/
    585 static CvStatus
    586 icvFindBasicVariables( float **cost, char **is_x,
    587                        CvNode1D * u, CvNode1D * v, int ssize, int dsize )
    588 {
    589     int i, j, found;
    590     int u_cfound, v_cfound;
    591     CvNode1D u0_head, u1_head, *cur_u, *prev_u;
    592     CvNode1D v0_head, v1_head, *cur_v, *prev_v;
    593 
    594     /* initialize the rows list (u) and the columns list (v) */
    595     u0_head.next = u;
    596     for( i = 0; i < ssize; i++ )
    597     {
    598         u[i].next = u + i + 1;
    599     }
    600     u[ssize - 1].next = 0;
    601     u1_head.next = 0;
    602 
    603     v0_head.next = ssize > 1 ? v + 1 : 0;
    604     for( i = 1; i < dsize; i++ )
    605     {
    606         v[i].next = v + i + 1;
    607     }
    608     v[dsize - 1].next = 0;
    609     v1_head.next = 0;
    610 
    611     /* there are ssize+dsize variables but only ssize+dsize-1 independent equations,
    612        so set v[0]=0 */
    613     v[0].val = 0;
    614     v1_head.next = v;
    615     v1_head.next->next = 0;
    616 
    617     /* loop until all variables are found */
    618     u_cfound = v_cfound = 0;
    619     while( u_cfound < ssize || v_cfound < dsize )
    620     {
    621         found = 0;
    622         if( v_cfound < dsize )
    623         {
    624             /* loop over all marked columns */
    625             prev_v = &v1_head;
    626 
    627             for( found |= (cur_v = v1_head.next) != 0; cur_v != 0; cur_v = cur_v->next )
    628             {
    629                 float cur_v_val = cur_v->val;
    630 
    631                 j = (int)(cur_v - v);
    632                 /* find the variables in column j */
    633                 prev_u = &u0_head;
    634                 for( cur_u = u0_head.next; cur_u != 0; )
    635                 {
    636                     i = (int)(cur_u - u);
    637                     if( is_x[i][j] )
    638                     {
    639                         /* compute u[i] */
    640                         cur_u->val = cost[i][j] - cur_v_val;
    641                         /* ...and add it to the marked list */
    642                         prev_u->next = cur_u->next;
    643                         cur_u->next = u1_head.next;
    644                         u1_head.next = cur_u;
    645                         cur_u = prev_u->next;
    646                     }
    647                     else
    648                     {
    649                         prev_u = cur_u;
    650                         cur_u = cur_u->next;
    651                     }
    652                 }
    653                 prev_v->next = cur_v->next;
    654                 v_cfound++;
    655             }
    656         }
    657 
    658         if( u_cfound < ssize )
    659         {
    660             /* loop over all marked rows */
    661             prev_u = &u1_head;
    662             for( found |= (cur_u = u1_head.next) != 0; cur_u != 0; cur_u = cur_u->next )
    663             {
    664                 float cur_u_val = cur_u->val;
    665                 float *_cost;
    666                 char *_is_x;
    667 
    668                 i = (int)(cur_u - u);
    669                 _cost = cost[i];
    670                 _is_x = is_x[i];
    671                 /* find the variables in rows i */
    672                 prev_v = &v0_head;
    673                 for( cur_v = v0_head.next; cur_v != 0; )
    674                 {
    675                     j = (int)(cur_v - v);
    676                     if( _is_x[j] )
    677                     {
    678                         /* compute v[j] */
    679                         cur_v->val = _cost[j] - cur_u_val;
    680                         /* ...and add it to the marked list */
    681                         prev_v->next = cur_v->next;
    682                         cur_v->next = v1_head.next;
    683                         v1_head.next = cur_v;
    684                         cur_v = prev_v->next;
    685                     }
    686                     else
    687                     {
    688                         prev_v = cur_v;
    689                         cur_v = cur_v->next;
    690                     }
    691                 }
    692                 prev_u->next = cur_u->next;
    693                 u_cfound++;
    694             }
    695         }
    696 
    697         if( !found )
    698         {
    699             return CV_NOTDEFINED_ERR;
    700         }
    701     }
    702 
    703     return CV_NO_ERR;
    704 }
    705 
    706 
    707 /****************************************************************************************\
    708 *                                   icvIsOptimal                                       *
    709 \****************************************************************************************/
    710 static float
    711 icvIsOptimal( float **cost, char **is_x,
    712               CvNode1D * u, CvNode1D * v, int ssize, int dsize, CvNode2D * enter_x )
    713 {
    714     float delta, min_delta = CV_EMD_INF;
    715     int i, j, min_i = 0, min_j = 0;
    716 
    717     /* find the minimal cij-ui-vj over all i,j */
    718     for( i = 0; i < ssize; i++ )
    719     {
    720         float u_val = u[i].val;
    721         float *_cost = cost[i];
    722         char *_is_x = is_x[i];
    723 
    724         for( j = 0; j < dsize; j++ )
    725         {
    726             if( !_is_x[j] )
    727             {
    728                 delta = _cost[j] - u_val - v[j].val;
    729                 if( min_delta > delta )
    730                 {
    731                     min_delta = delta;
    732                     min_i = i;
    733                     min_j = j;
    734                 }
    735             }
    736         }
    737     }
    738 
    739     enter_x->i = min_i;
    740     enter_x->j = min_j;
    741 
    742     return min_delta;
    743 }
    744 
    745 /****************************************************************************************\
    746 *                                   icvNewSolution                                     *
    747 \****************************************************************************************/
    748 static CvStatus
    749 icvNewSolution( CvEMDState * state )
    750 {
    751     int i, j;
    752     float min_val = CV_EMD_INF;
    753     int steps;
    754     CvNode2D head, *cur_x, *next_x, *leave_x = 0;
    755     CvNode2D *enter_x = state->enter_x;
    756     CvNode2D **loop = state->loop;
    757 
    758     /* enter the new basic variable */
    759     i = enter_x->i;
    760     j = enter_x->j;
    761     state->is_x[i][j] = 1;
    762     enter_x->next[0] = state->rows_x[i];
    763     enter_x->next[1] = state->cols_x[j];
    764     enter_x->val = 0;
    765     state->rows_x[i] = enter_x;
    766     state->cols_x[j] = enter_x;
    767 
    768     /* find a chain reaction */
    769     steps = icvFindLoop( state );
    770 
    771     if( steps == 0 )
    772         return CV_NOTDEFINED_ERR;
    773 
    774     /* find the largest value in the loop */
    775     for( i = 1; i < steps; i += 2 )
    776     {
    777         float temp = loop[i]->val;
    778 
    779         if( min_val > temp )
    780         {
    781             leave_x = loop[i];
    782             min_val = temp;
    783         }
    784     }
    785 
    786     /* update the loop */
    787     for( i = 0; i < steps; i += 2 )
    788     {
    789         float temp0 = loop[i]->val + min_val;
    790         float temp1 = loop[i + 1]->val - min_val;
    791 
    792         loop[i]->val = temp0;
    793         loop[i + 1]->val = temp1;
    794     }
    795 
    796     /* remove the leaving basic variable */
    797     i = leave_x->i;
    798     j = leave_x->j;
    799     state->is_x[i][j] = 0;
    800 
    801     head.next[0] = state->rows_x[i];
    802     cur_x = &head;
    803     while( (next_x = cur_x->next[0]) != leave_x )
    804     {
    805         cur_x = next_x;
    806         assert( cur_x );
    807     }
    808     cur_x->next[0] = next_x->next[0];
    809     state->rows_x[i] = head.next[0];
    810 
    811     head.next[1] = state->cols_x[j];
    812     cur_x = &head;
    813     while( (next_x = cur_x->next[1]) != leave_x )
    814     {
    815         cur_x = next_x;
    816         assert( cur_x );
    817     }
    818     cur_x->next[1] = next_x->next[1];
    819     state->cols_x[j] = head.next[1];
    820 
    821     /* set enter_x to be the new empty slot */
    822     state->enter_x = leave_x;
    823 
    824     return CV_NO_ERR;
    825 }
    826 
    827 
    828 
    829 /****************************************************************************************\
    830 *                                    icvFindLoop                                       *
    831 \****************************************************************************************/
    832 static int
    833 icvFindLoop( CvEMDState * state )
    834 {
    835     int i, steps = 1;
    836     CvNode2D *new_x;
    837     CvNode2D **loop = state->loop;
    838     CvNode2D *enter_x = state->enter_x, *_x = state->_x;
    839     char *is_used = state->is_used;
    840 
    841     memset( is_used, 0, state->ssize + state->dsize );
    842 
    843     new_x = loop[0] = enter_x;
    844     is_used[enter_x - _x] = 1;
    845     steps = 1;
    846 
    847     do
    848     {
    849         if( (steps & 1) == 1 )
    850         {
    851             /* find an unused x in the row */
    852             new_x = state->rows_x[new_x->i];
    853             while( new_x != 0 && is_used[new_x - _x] )
    854                 new_x = new_x->next[0];
    855         }
    856         else
    857         {
    858             /* find an unused x in the column, or the entering x */
    859             new_x = state->cols_x[new_x->j];
    860             while( new_x != 0 && is_used[new_x - _x] && new_x != enter_x )
    861                 new_x = new_x->next[1];
    862             if( new_x == enter_x )
    863                 break;
    864         }
    865 
    866         if( new_x != 0 )        /* found the next x */
    867         {
    868             /* add x to the loop */
    869             loop[steps++] = new_x;
    870             is_used[new_x - _x] = 1;
    871         }
    872         else                    /* didn't find the next x */
    873         {
    874             /* backtrack */
    875             do
    876             {
    877                 i = steps & 1;
    878                 new_x = loop[steps - 1];
    879                 do
    880                 {
    881                     new_x = new_x->next[i];
    882                 }
    883                 while( new_x != 0 && is_used[new_x - _x] );
    884 
    885                 if( new_x == 0 )
    886                 {
    887                     is_used[loop[--steps] - _x] = 0;
    888                 }
    889             }
    890             while( new_x == 0 && steps > 0 );
    891 
    892             is_used[loop[steps - 1] - _x] = 0;
    893             loop[steps - 1] = new_x;
    894             is_used[new_x - _x] = 1;
    895         }
    896     }
    897     while( steps > 0 );
    898 
    899     return steps;
    900 }
    901 
    902 
    903 
    904 /****************************************************************************************\
    905 *                                        icvRussel                                     *
    906 \****************************************************************************************/
    907 static void
    908 icvRussel( CvEMDState * state )
    909 {
    910     int i, j, min_i = -1, min_j = -1;
    911     float min_delta, diff;
    912     CvNode1D u_head, *cur_u, *prev_u;
    913     CvNode1D v_head, *cur_v, *prev_v;
    914     CvNode1D *prev_u_min_i = 0, *prev_v_min_j = 0, *remember;
    915     CvNode1D *u = state->u, *v = state->v;
    916     int ssize = state->ssize, dsize = state->dsize;
    917     float eps = CV_EMD_EPS * state->max_cost;
    918     float **cost = state->cost;
    919     float **delta = state->delta;
    920 
    921     /* initialize the rows list (ur), and the columns list (vr) */
    922     u_head.next = u;
    923     for( i = 0; i < ssize; i++ )
    924     {
    925         u[i].next = u + i + 1;
    926     }
    927     u[ssize - 1].next = 0;
    928 
    929     v_head.next = v;
    930     for( i = 0; i < dsize; i++ )
    931     {
    932         v[i].val = -CV_EMD_INF;
    933         v[i].next = v + i + 1;
    934     }
    935     v[dsize - 1].next = 0;
    936 
    937     /* find the maximum row and column values (ur[i] and vr[j]) */
    938     for( i = 0; i < ssize; i++ )
    939     {
    940         float u_val = -CV_EMD_INF;
    941         float *cost_row = cost[i];
    942 
    943         for( j = 0; j < dsize; j++ )
    944         {
    945             float temp = cost_row[j];
    946 
    947             if( u_val < temp )
    948                 u_val = temp;
    949             if( v[j].val < temp )
    950                 v[j].val = temp;
    951         }
    952         u[i].val = u_val;
    953     }
    954 
    955     /* compute the delta matrix */
    956     for( i = 0; i < ssize; i++ )
    957     {
    958         float u_val = u[i].val;
    959         float *delta_row = delta[i];
    960         float *cost_row = cost[i];
    961 
    962         for( j = 0; j < dsize; j++ )
    963         {
    964             delta_row[j] = cost_row[j] - u_val - v[j].val;
    965         }
    966     }
    967 
    968     /* find the basic variables */
    969     do
    970     {
    971         /* find the smallest delta[i][j] */
    972         min_i = -1;
    973         min_delta = CV_EMD_INF;
    974         prev_u = &u_head;
    975         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
    976         {
    977             i = (int)(cur_u - u);
    978             float *delta_row = delta[i];
    979 
    980             prev_v = &v_head;
    981             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
    982             {
    983                 j = (int)(cur_v - v);
    984                 if( min_delta > delta_row[j] )
    985                 {
    986                     min_delta = delta_row[j];
    987                     min_i = i;
    988                     min_j = j;
    989                     prev_u_min_i = prev_u;
    990                     prev_v_min_j = prev_v;
    991                 }
    992                 prev_v = cur_v;
    993             }
    994             prev_u = cur_u;
    995         }
    996 
    997         if( min_i < 0 )
    998             break;
    999 
   1000         /* add x[min_i][min_j] to the basis, and adjust supplies and cost */
   1001         remember = prev_u_min_i->next;
   1002         icvAddBasicVariable( state, min_i, min_j, prev_u_min_i, prev_v_min_j, &u_head );
   1003 
   1004         /* update the necessary delta[][] */
   1005         if( remember == prev_u_min_i->next )    /* line min_i was deleted */
   1006         {
   1007             for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
   1008             {
   1009                 j = (int)(cur_v - v);
   1010                 if( cur_v->val == cost[min_i][j] )      /* column j needs updating */
   1011                 {
   1012                     float max_val = -CV_EMD_INF;
   1013 
   1014                     /* find the new maximum value in the column */
   1015                     for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
   1016                     {
   1017                         float temp = cost[cur_u - u][j];
   1018 
   1019                         if( max_val < temp )
   1020                             max_val = temp;
   1021                     }
   1022 
   1023                     /* if needed, adjust the relevant delta[*][j] */
   1024                     diff = max_val - cur_v->val;
   1025                     cur_v->val = max_val;
   1026                     if( fabs( diff ) < eps )
   1027                     {
   1028                         for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
   1029                             delta[cur_u - u][j] += diff;
   1030                     }
   1031                 }
   1032             }
   1033         }
   1034         else                    /* column min_j was deleted */
   1035         {
   1036             for( cur_u = u_head.next; cur_u != 0; cur_u = cur_u->next )
   1037             {
   1038                 i = (int)(cur_u - u);
   1039                 if( cur_u->val == cost[i][min_j] )      /* row i needs updating */
   1040                 {
   1041                     float max_val = -CV_EMD_INF;
   1042 
   1043                     /* find the new maximum value in the row */
   1044                     for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
   1045                     {
   1046                         float temp = cost[i][cur_v - v];
   1047 
   1048                         if( max_val < temp )
   1049                             max_val = temp;
   1050                     }
   1051 
   1052                     /* if needed, adjust the relevant delta[i][*] */
   1053                     diff = max_val - cur_u->val;
   1054                     cur_u->val = max_val;
   1055 
   1056                     if( fabs( diff ) < eps )
   1057                     {
   1058                         for( cur_v = v_head.next; cur_v != 0; cur_v = cur_v->next )
   1059                             delta[i][cur_v - v] += diff;
   1060                     }
   1061                 }
   1062             }
   1063         }
   1064     }
   1065     while( u_head.next != 0 || v_head.next != 0 );
   1066 }
   1067 
   1068 
   1069 
   1070 /****************************************************************************************\
   1071 *                                   icvAddBasicVariable                                *
   1072 \****************************************************************************************/
   1073 static void
   1074 icvAddBasicVariable( CvEMDState * state,
   1075                      int min_i, int min_j,
   1076                      CvNode1D * prev_u_min_i, CvNode1D * prev_v_min_j, CvNode1D * u_head )
   1077 {
   1078     float temp;
   1079     CvNode2D *end_x = state->end_x;
   1080 
   1081     if( state->s[min_i] < state->d[min_j] + state->weight * CV_EMD_EPS )
   1082     {                           /* supply exhausted */
   1083         temp = state->s[min_i];
   1084         state->s[min_i] = 0;
   1085         state->d[min_j] -= temp;
   1086     }
   1087     else                        /* demand exhausted */
   1088     {
   1089         temp = state->d[min_j];
   1090         state->d[min_j] = 0;
   1091         state->s[min_i] -= temp;
   1092     }
   1093 
   1094     /* x(min_i,min_j) is a basic variable */
   1095     state->is_x[min_i][min_j] = 1;
   1096 
   1097     end_x->val = temp;
   1098     end_x->i = min_i;
   1099     end_x->j = min_j;
   1100     end_x->next[0] = state->rows_x[min_i];
   1101     end_x->next[1] = state->cols_x[min_j];
   1102     state->rows_x[min_i] = end_x;
   1103     state->cols_x[min_j] = end_x;
   1104     state->end_x = end_x + 1;
   1105 
   1106     /* delete supply row only if the empty, and if not last row */
   1107     if( state->s[min_i] == 0 && u_head->next->next != 0 )
   1108         prev_u_min_i->next = prev_u_min_i->next->next;  /* remove row from list */
   1109     else
   1110         prev_v_min_j->next = prev_v_min_j->next->next;  /* remove column from list */
   1111 }
   1112 
   1113 
   1114 /****************************************************************************************\
   1115 *                                  standard  metrics                                     *
   1116 \****************************************************************************************/
   1117 static float
   1118 icvDistL1( const float *x, const float *y, void *user_param )
   1119 {
   1120     int i, dims = (int)(size_t)user_param;
   1121     double s = 0;
   1122 
   1123     for( i = 0; i < dims; i++ )
   1124     {
   1125         double t = x[i] - y[i];
   1126 
   1127         s += fabs( t );
   1128     }
   1129     return (float)s;
   1130 }
   1131 
   1132 static float
   1133 icvDistL2( const float *x, const float *y, void *user_param )
   1134 {
   1135     int i, dims = (int)(size_t)user_param;
   1136     double s = 0;
   1137 
   1138     for( i = 0; i < dims; i++ )
   1139     {
   1140         double t = x[i] - y[i];
   1141 
   1142         s += t * t;
   1143     }
   1144     return cvSqrt( (float)s );
   1145 }
   1146 
   1147 static float
   1148 icvDistC( const float *x, const float *y, void *user_param )
   1149 {
   1150     int i, dims = (int)(size_t)user_param;
   1151     double s = 0;
   1152 
   1153     for( i = 0; i < dims; i++ )
   1154     {
   1155         double t = fabs( x[i] - y[i] );
   1156 
   1157         if( s < t )
   1158             s = t;
   1159     }
   1160     return (float)s;
   1161 }
   1162 
   1163 /* End of file. */
   1164 
   1165